add mossformer code

This commit is contained in:
hnluo 2023-08-10 12:38:55 +08:00
parent 30f0c7ff29
commit bce7248763
11 changed files with 1437 additions and 0 deletions

127
funasr/bin/ss_infer.py Normal file
View File

@ -0,0 +1,127 @@
#!/usr/bin/env python3
# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
# MIT License (https://opensource.org/licenses/MIT)
import logging
from pathlib import Path
from typing import List
from typing import Union
import numpy as np
import torch
from funasr.build_utils.build_model_from_file import build_model_from_file
from funasr.torch_utils.device_funcs import to_device
class SpeechSeparator:
"""SpeechSeparator class
Examples:
>>> import soundfile
>>> speech_separator = MossFormer("ss_config.yml", "ss.pt")
>>> audio, rate = soundfile.read("speech.wav")
>>> separated_wavs = speech_separator(audio)
"""
def __init__(
self,
ss_infer_config: Union[Path, str] = None,
ss_model_file: Union[Path, str] = None,
device: str = "cpu",
batch_size: int = 1,
dtype: str = "float32",
**kwargs,
):
# 1. Build ss model
ss_model, ss_infer_args = build_model_from_file(
ss_infer_config, ss_model_file, None, device, task_name="ss"
)
logging.info("ss_model: {}".format(ss_model))
logging.info("ss_infer_args: {}".format(ss_infer_args))
ss_model.to(dtype=getattr(torch, dtype)).eval()
self.ss_model = ss_model
self.ss_infer_args = ss_infer_args
self.device = device
self.dtype = dtype
self.batch_size = batch_size
def decode(self, model, args, inputs, nsamples):
decode_do_segment = False
with torch.no_grad():
out = []
window = args.sample_rate * args.decode_window # decoding window length
stride = int(window*0.75) # decoding stride if segmentation is used
b, t = inputs.shape
if t > window * args.one_time_decode_length:
decode_do_segment = True # set segment decoding to true for very long sequence
if t < window:
inputs = np.concatenate([inputs, np.zeros((inputs.shape[0], window-t))], 1)
elif t < window + stride:
padding = window + stride - t
inputs = np.concatenate([inputs, np.zeros((inputs.shape[0], padding))], 1)
else:
if (t - window) % stride != 0:
padding = t - (t-window)//stride * stride
inputs = np.concatenate([inputs, np.zeros((inputs.shape[0], padding))], 1)
inputs = torch.from_numpy(np.float32(inputs))
inputs = to_device(inputs, device=self.device)
b, t = inputs.shape
if decode_do_segment:
outputs = np.zeros((args.num_spks, t))
give_up_length = (window - stride)//2
current_idx = 0
while current_idx + window <= t:
tmp_input = inputs[:, current_idx:current_idx+window]
tmp_out_list = model(tmp_input,)
for spk in range(args.num_spks):
tmp_out_list[spk] = tmp_out_list[spk][0, :].cpu().numpy()
if current_idx == 0:
outputs[spk, current_idx:current_idx+window-give_up_length] = \
tmp_out_list[spk][:-give_up_length]
else:
outputs[spk, current_idx+give_up_length:current_idx+window-give_up_length] = \
tmp_out_list[spk][give_up_length:-give_up_length]
current_idx += stride
for spk in range(args.num_spks):
out.append(outputs[spk, :])
else:
out_list = model(inputs)
for spk in range(args.num_spks):
out.append(out_list[spk][0, :].cpu().numpy())
max_abs = 0
for spk in range(args.num_spks):
if max_abs < max(abs(out[spk])):
max_abs = max(abs(out[spk]))
for spk in range(args.num_spks):
out[spk] = out[spk][:nsamples]
out[spk] = out[spk]/max_abs
return out
@torch.no_grad()
def __call__(
self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None,
) -> List[torch.Tensor]:
"""Inference
Args:
speech: Input speech data
Returns:
speech list: list of speech data
"""
out = self.decode(self.ss_model, self.ss_infer_args, speech, speech_lengths)
return out

View File

@ -0,0 +1,253 @@
#!/usr/bin/env python3
# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
# MIT License (https://opensource.org/licenses/MIT)
import argparse
import logging
import os
import sys
from typing import Optional
from typing import Union
import numpy as np
import torch
import soundfile as sf
from funasr.build_utils.build_streaming_iterator import build_streaming_iterator
from funasr.torch_utils.set_all_random_seed import set_all_random_seed
from funasr.utils import config_argparse
from funasr.utils.cli_utils import get_commandline_args
from funasr.utils.types import str2triple_str
from funasr.bin.ss_infer import SpeechSeparator
def inference_ss(
batch_size: int,
ngpu: int,
log_level: Union[int, str],
ss_infer_config: Optional[str],
ss_model_file: Optional[str],
output_dir: Optional[str] = None,
dtype: str = "float32",
seed: int = 0,
num_workers: int = 1,
num_spks: int = 2,
sample_rate: int = 8000,
param_dict: dict = None,
**kwargs,
):
ncpu = kwargs.get("ncpu", 1)
torch.set_num_threads(ncpu)
if batch_size > 1:
raise NotImplementedError("batch decoding is not implemented")
logging.basicConfig(
level=log_level,
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
)
if ngpu >= 1 and torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"
batch_size = 1
# 1. Set random-seed
set_all_random_seed(seed)
# 2. Build speech separator
speech_separator_kwargs = dict(
ss_infer_config=ss_infer_config,
ss_model_file=ss_model_file,
device=device,
dtype=dtype,
)
logging.info("speech_separator_kwargs: {}".format(speech_separator_kwargs))
speech_separator = SpeechSeparator(**speech_separator_kwargs)
def _forward(
data_path_and_name_and_type,
raw_inputs: Union[np.ndarray, torch.Tensor] = None,
output_dir_v2: Optional[str] = None,
fs: dict = None,
param_dict: dict = None
):
# 3. Build data-iterator
if data_path_and_name_and_type is None and raw_inputs is not None:
if isinstance(raw_inputs, torch.Tensor):
raw_inputs = raw_inputs.numpy()
data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
loader = build_streaming_iterator(
task_name="ss",
preprocess_args=None,
data_path_and_name_and_type=data_path_and_name_and_type,
dtype=dtype,
fs=fs,
batch_size=batch_size,
num_workers=num_workers,
)
# 4 .Start for-loop
output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
if not os.path.exists(output_path):
cmd = 'mkdir -p ' + output_path
os.system(cmd)
for keys, batch in loader:
assert isinstance(batch, dict), type(batch)
assert all(isinstance(s, str) for s in keys), keys
_bs = len(next(iter(batch.values())))
assert len(keys) == _bs, f"{len(keys)} != {_bs}"
# do speech separation
logging.info('decoding: {}'.format(keys[0]))
ss_results = speech_separator(**batch)
for spk in range(num_spks):
sf.write(os.path.join(output_path, keys[0].replace('.wav', '_s'+str(spk+1)+'.wav')), ss_results[spk], sample_rate)
torch.cuda.empty_cache()
return ss_results
return _forward
def inference_launch(mode, **kwargs):
if mode == "mossformer":
return inference_ss(**kwargs)
else:
logging.info("Unknown decoding mode: {}".format(mode))
return None
def get_parser():
parser = config_argparse.ArgumentParser(
description="Speech Separator Decoding",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
# Note(kamo): Use '_' instead of '-' as separator.
# '-' is confusing if written in yaml.
parser.add_argument(
"--log_level",
type=lambda x: x.upper(),
default="INFO",
choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
help="The verbose level of logging",
)
parser.add_argument("--output_dir", type=str, required=True)
parser.add_argument(
"--ngpu",
type=int,
default=1,
help="The number of gpus. 0 indicates CPU mode",
)
parser.add_argument(
"--njob",
type=int,
default=1,
help="The number of jobs for each gpu",
)
parser.add_argument(
"--gpuid_list",
type=str,
default="2",
help="The visible gpus",
)
parser.add_argument("--seed", type=int, default=0, help="Random seed")
parser.add_argument(
"--dtype",
default="float32",
choices=["float16", "float32", "float64"],
help="Data type",
)
parser.add_argument(
"--num_workers",
type=int,
default=1,
help="The number of workers used for DataLoader",
)
group = parser.add_argument_group("Input data related")
group.add_argument(
"--data_path_and_name_and_type",
type=str2triple_str,
required=True,
action="append",
)
group = parser.add_argument_group("The model configuration related")
group.add_argument(
"--ss_infer_config",
type=str,
help="SS infer configuration",
)
group.add_argument(
"--ss_model_file",
type=str,
help="SS model parameter file",
)
group.add_argument(
"--ss_train_config",
type=str,
help="SS training configuration",
)
group = parser.add_argument_group("The inference configuration related")
group.add_argument(
"--batch_size",
type=int,
default=1,
help="The batch size for inference",
)
parser.add_argument(
'--num-spks', dest='num_spks', type=int, default=2)
parser.add_argument(
'--one-time-decode-length', dest='one_time_decode_length', type=int,
default=60, help='the max length (second) for one-time decoding')
parser.add_argument(
'--decode-window', dest='decode_window', type=int,
default=1, help='segmental decoding window length (second)')
parser.add_argument(
'--sample-rate', dest='sample_rate', type=int, default='8000')
return parser
def main(cmd=None):
print(get_commandline_args(), file=sys.stderr)
parser = get_parser()
parser.add_argument(
"--mode",
type=str,
default="mossformer",
help="The decoding mode",
)
args = parser.parse_args(cmd)
kwargs = vars(args)
kwargs.pop("config", None)
# set logging messages
logging.basicConfig(
level=args.log_level,
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
)
logging.info("Decoding args: {}".format(kwargs))
# gpu setting
if args.ngpu > 0:
jobid = int(args.output_dir.split(".")[-1])
gpuid = args.gpuid_list.split(",")[(jobid - 1) // args.njob]
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = gpuid
inference_pipeline = inference_launch(**kwargs)
return inference_pipeline(kwargs["data_path_and_name_and_type"])
if __name__ == "__main__":
main()

View File

@ -5,6 +5,7 @@ from funasr.build_utils.build_pretrain_model import build_pretrain_model
from funasr.build_utils.build_punc_model import build_punc_model
from funasr.build_utils.build_sv_model import build_sv_model
from funasr.build_utils.build_vad_model import build_vad_model
from funasr.build_utils.build_ss_model import build_ss_model
def build_model(args):
@ -22,6 +23,8 @@ def build_model(args):
model = build_diar_model(args)
elif args.task_name == "sv":
model = build_sv_model(args)
elif args.task_name == "ss":
model = build_ss_model(args)
else:
raise NotImplementedError("Not supported task: {}".format(args.task_name))

View File

@ -11,6 +11,18 @@ from funasr.build_utils.build_model import build_model
from funasr.models.base_model import FunASRModel
def load_checkpoint(checkpoint_path, use_cuda=1):
if use_cuda:
checkpoint = torch.load(checkpoint_path)
else:
checkpoint = torch.load(
checkpoint_path, map_location=lambda storage, loc: storage)
return checkpoint
def reload_ss_for_eval(model, checkpoint_path, use_cuda=False):
checkpoint = load_checkpoint(checkpoint_path, use_cuda)
model.load_state_dict(checkpoint['model'], strict=False)
def build_model_from_file(
config_file: Union[Path, str] = None,
model_file: Union[Path, str] = None,
@ -70,6 +82,9 @@ def build_model_from_file(
model.load_state_dict(model_dict)
else:
model_dict = torch.load(model_file, map_location=device)
if task_name == 'ss':
reload_ss_for_eval(model, model_file, use_cuda=True)
logging.info("model is loaded from path: {}".format(model_file))
if task_name == "diar" and mode == "sond":
model_dict = fileter_model_dict(model_dict, model.state_dict())
if task_name == "vad":

View File

@ -0,0 +1,15 @@
from funasr.models.e2e_ss import MossFormer
def build_ss_model(args):
model = MossFormer(
in_channels=args.encoder_embedding_dim,
out_channels=args.mossformer_sequence_dim,
num_blocks=args.num_mossformer_layer,
kernel_size=args.encoder_kernel_size,
norm=args.norm,
num_spks=args.num_spks,
skip_around_intra=args.skip_around_intra,
use_global_pos_enc=args.use_global_pos_enc,
max_length=args.max_length)
return model

View File

@ -0,0 +1,53 @@
import torch
import torch.nn as nn
class MossFormerDecoder(nn.ConvTranspose1d):
"""A decoder layer that consists of ConvTranspose1d.
Arguments
---------
kernel_size : int
Length of filters.
in_channels : int
Number of input channels.
out_channels : int
Number of output channels.
Example
---------
>>> x = torch.randn(2, 100, 1000)
>>> decoder = Decoder(kernel_size=4, in_channels=100, out_channels=1)
>>> h = decoder(x)
>>> h.shape
torch.Size([2, 1003])
"""
def __init__(self, *args, **kwargs):
super(MossFormerDecoder, self).__init__(*args, **kwargs)
def forward(self, x):
"""Return the decoded output.
Arguments
---------
x : torch.Tensor
Input tensor with dimensionality [B, N, L].
where, B = Batchsize,
N = number of filters
L = time points
"""
if x.dim() not in [2, 3]:
raise RuntimeError(
"{} accept 3/4D tensor as input".format(self.__name__)
)
x = super().forward(x if x.dim() == 3 else torch.unsqueeze(x, 1))
if torch.squeeze(x).dim() == 1:
x = torch.squeeze(x, dim=1)
else:
x = torch.squeeze(x)
return x

95
funasr/models/e2e_ss.py Normal file
View File

@ -0,0 +1,95 @@
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import copy
from funasr.models.base_model import FunASRModel
from funasr.models.encoder.mossformer_encoder import MossFormerEncoder, MossFormer_MaskNet
from funasr.models.decoder.mossformer_decoder import MossFormerDecoder
class MossFormer(FunASRModel):
"""The MossFormer model for separating input mixed speech into different speaker's speech.
Arguments
---------
in_channels : int
Number of channels at the output of the encoder.
out_channels : int
Number of channels that would be inputted to the intra and inter blocks.
num_blocks : int
Number of layers of Dual Computation Block.
norm : str
Normalization type.
num_spks : int
Number of sources (speakers).
skip_around_intra : bool
Skip connection around intra.
use_global_pos_enc : bool
Global positional encodings.
max_length : int
Maximum sequence length.
kernel_size: int
Encoder and decoder kernel size
"""
def __init__(
self,
in_channels=512,
out_channels=512,
num_blocks=24,
kernel_size=16,
norm="ln",
num_spks=2,
skip_around_intra=True,
use_global_pos_enc=True,
max_length=20000,
):
super(MossFormer, self).__init__()
self.num_spks = num_spks
# Encoding
self.enc = MossFormerEncoder(kernel_size=kernel_size, out_channels=in_channels, in_channels=1)
##Compute Mask
self.mask_net = MossFormer_MaskNet(
in_channels=in_channels,
out_channels=out_channels,
num_blocks=num_blocks,
norm=norm,
num_spks=num_spks,
skip_around_intra=skip_around_intra,
use_global_pos_enc=use_global_pos_enc,
max_length=max_length,
)
self.dec = MossFormerDecoder(
in_channels=out_channels,
out_channels=1,
kernel_size=kernel_size,
stride = kernel_size//2,
bias=False
)
def forward(self, input):
x = self.enc(input)
mask = self.mask_net(x)
x = torch.stack([x] * self.num_spks)
sep_x = x * mask
# Decoding
est_source = torch.cat(
[
self.dec(sep_x[i]).unsqueeze(-1)
for i in range(self.num_spks)
],
dim=-1,
)
T_origin = input.size(1)
T_est = est_source.size(1)
if T_origin > T_est:
est_source = F.pad(est_source, (0, 0, 0, T_origin - T_est))
else:
est_source = est_source[:, :T_origin, :]
out = []
for spk in range(self.num_spks):
out.append(est_source[:,:,spk])
return out

View File

@ -0,0 +1,417 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from rotary_embedding_torch import RotaryEmbedding
from funasr.modules.layer_norm import GlobalLayerNorm, CumulativeLayerNorm, ScaleNorm
from funasr.modules.embedding import ScaledSinuEmbedding
from funasr.modules.mossformer import FLASH_ShareA_FFConvM
def select_norm(norm, dim, shape):
"""Just a wrapper to select the normalization type.
"""
if norm == "gln":
return GlobalLayerNorm(dim, shape, elementwise_affine=True)
if norm == "cln":
return CumulativeLayerNorm(dim, elementwise_affine=True)
if norm == "ln":
return nn.GroupNorm(1, dim, eps=1e-8)
else:
return nn.BatchNorm1d(dim)
class MossformerBlock(nn.Module):
def __init__(
self,
*,
dim,
depth,
group_size = 256,
query_key_dim = 128,
expansion_factor = 4.,
causal = False,
attn_dropout = 0.1,
norm_type = 'scalenorm',
shift_tokens = True
):
super().__init__()
assert norm_type in ('scalenorm', 'layernorm'), 'norm_type must be one of scalenorm or layernorm'
if norm_type == 'scalenorm':
norm_klass = ScaleNorm
elif norm_type == 'layernorm':
norm_klass = nn.LayerNorm
self.group_size = group_size
rotary_pos_emb = RotaryEmbedding(dim = min(32, query_key_dim))
# max rotary embedding dimensions of 32, partial Rotary embeddings, from Wang et al - GPT-J
self.layers = nn.ModuleList([FLASH_ShareA_FFConvM(dim = dim, group_size = group_size, query_key_dim = query_key_dim, expansion_factor = expansion_factor, causal = causal, dropout = attn_dropout, rotary_pos_emb = rotary_pos_emb, norm_klass = norm_klass, shift_tokens = shift_tokens) for _ in range(depth)])
def forward(
self,
x,
*,
mask = None
):
ii = 0
for flash in self.layers:
x = flash(x, mask = mask)
ii = ii + 1
return x
class MossFormer_MaskNet(nn.Module):
"""The MossFormer module for computing output masks.
Arguments
---------
in_channels : int
Number of channels at the output of the encoder.
out_channels : int
Number of channels that would be inputted to the intra and inter blocks.
num_blocks : int
Number of layers of Dual Computation Block.
norm : str
Normalization type.
num_spks : int
Number of sources (speakers).
skip_around_intra : bool
Skip connection around intra.
use_global_pos_enc : bool
Global positional encodings.
max_length : int
Maximum sequence length.
Example
---------
>>> mossformer_block = MossFormerM(1, 64, 8)
>>> mossformer_masknet = MossFormer_MaskNet(64, 64, intra_block, num_spks=2)
>>> x = torch.randn(10, 64, 2000)
>>> x = mossformer_masknet(x)
>>> x.shape
torch.Size([2, 10, 64, 2000])
"""
def __init__(
self,
in_channels,
out_channels,
num_blocks=24,
norm="ln",
num_spks=2,
skip_around_intra=True,
use_global_pos_enc=True,
max_length=20000,
):
super(MossFormer_MaskNet, self).__init__()
self.num_spks = num_spks
self.num_blocks = num_blocks
self.norm = select_norm(norm, in_channels, 3)
self.conv1d_encoder = nn.Conv1d(in_channels, out_channels, 1, bias=False)
self.use_global_pos_enc = use_global_pos_enc
if self.use_global_pos_enc:
self.pos_enc = ScaledSinuEmbedding(out_channels)
self.mdl = Computation_Block(
num_blocks,
out_channels,
norm,
skip_around_intra=skip_around_intra,
)
self.conv1d_out = nn.Conv1d(
out_channels, out_channels * num_spks, kernel_size=1
)
self.conv1_decoder = nn.Conv1d(out_channels, in_channels, 1, bias=False)
self.prelu = nn.PReLU()
self.activation = nn.ReLU()
# gated output layer
self.output = nn.Sequential(
nn.Conv1d(out_channels, out_channels, 1), nn.Tanh()
)
self.output_gate = nn.Sequential(
nn.Conv1d(out_channels, out_channels, 1), nn.Sigmoid()
)
def forward(self, x):
"""Returns the output tensor.
Arguments
---------
x : torch.Tensor
Input tensor of dimension [B, N, S].
Returns
-------
out : torch.Tensor
Output tensor of dimension [spks, B, N, S]
where, spks = Number of speakers
B = Batchsize,
N = number of filters
S = the number of time frames
"""
# before each line we indicate the shape after executing the line
# [B, N, L]
x = self.norm(x)
# [B, N, L]
x = self.conv1d_encoder(x)
if self.use_global_pos_enc:
#x = self.pos_enc(x.transpose(1, -1)).transpose(1, -1) + x * (
# x.size(1) ** 0.5)
base = x
x = x.transpose(1, -1)
emb = self.pos_enc(x)
emb = emb.transpose(0, -1)
#print('base: {}, emb: {}'.format(base.shape, emb.shape))
x = base + emb
# [B, N, S]
#for i in range(self.num_modules):
# x = self.dual_mdl[i](x)
x = self.mdl(x)
x = self.prelu(x)
# [B, N*spks, S]
x = self.conv1d_out(x)
B, _, S = x.shape
# [B*spks, N, S]
x = x.view(B * self.num_spks, -1, S)
# [B*spks, N, S]
x = self.output(x) * self.output_gate(x)
# [B*spks, N, S]
x = self.conv1_decoder(x)
# [B, spks, N, S]
_, N, L = x.shape
x = x.view(B, self.num_spks, N, L)
x = self.activation(x)
# [spks, B, N, S]
x = x.transpose(0, 1)
return x
class MossFormerEncoder(nn.Module):
"""Convolutional Encoder Layer.
Arguments
---------
kernel_size : int
Length of filters.
in_channels : int
Number of input channels.
out_channels : int
Number of output channels.
Example
-------
>>> x = torch.randn(2, 1000)
>>> encoder = Encoder(kernel_size=4, out_channels=64)
>>> h = encoder(x)
>>> h.shape
torch.Size([2, 64, 499])
"""
def __init__(self, kernel_size=2, out_channels=64, in_channels=1):
super(MossFormerEncoder, self).__init__()
self.conv1d = nn.Conv1d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=kernel_size // 2,
groups=1,
bias=False,
)
self.in_channels = in_channels
def forward(self, x):
"""Return the encoded output.
Arguments
---------
x : torch.Tensor
Input tensor with dimensionality [B, L].
Return
------
x : torch.Tensor
Encoded tensor with dimensionality [B, N, T_out].
where B = Batchsize
L = Number of timepoints
N = Number of filters
T_out = Number of timepoints at the output of the encoder
"""
# B x L -> B x 1 x L
if self.in_channels == 1:
x = torch.unsqueeze(x, dim=1)
# B x 1 x L -> B x N x T_out
x = self.conv1d(x)
x = F.relu(x)
return x
class MossFormerM(nn.Module):
"""This class implements the transformer encoder.
Arguments
---------
num_blocks : int
Number of mossformer blocks to include.
d_model : int
The dimension of the input embedding.
attn_dropout : float
Dropout for the self-attention (Optional).
group_size: int
the chunk size
query_key_dim: int
the attention vector dimension
expansion_factor: int
the expansion factor for the linear projection in conv module
causal: bool
true for causal / false for non causal
Example
-------
>>> import torch
>>> x = torch.rand((8, 60, 512))
>>> net = TransformerEncoder_MossFormerM(num_blocks=8, d_model=512)
>>> output, _ = net(x)
>>> output.shape
torch.Size([8, 60, 512])
"""
def __init__(
self,
num_blocks,
d_model=None,
causal=False,
group_size = 256,
query_key_dim = 128,
expansion_factor = 4.,
attn_dropout = 0.1
):
super().__init__()
self.mossformerM = MossformerBlock(
dim=d_model,
depth=num_blocks,
group_size=group_size,
query_key_dim=query_key_dim,
expansion_factor=expansion_factor,
causal=causal,
attn_dropout=attn_dropout
)
self.norm = nn.LayerNorm(d_model, eps=1e-6)
def forward(
self,
src,
):
"""
Arguments
----------
src : torch.Tensor
Tensor shape [B, L, N],
where, B = Batchsize,
L = time points
N = number of filters
The sequence to the encoder layer (required).
src_mask : tensor
The mask for the src sequence (optional).
src_key_padding_mask : tensor
The mask for the src keys per batch (optional).
"""
output = self.mossformerM(src)
output = self.norm(output)
return output
class Computation_Block(nn.Module):
"""Computation block for dual-path processing.
Arguments
---------
out_channels : int
Dimensionality of inter/intra model.
norm : str
Normalization type.
skip_around_intra : bool
Skip connection around the intra layer.
Example
---------
>>> comp_block = Computation_Block(64)
>>> x = torch.randn(10, 64, 100)
>>> x = comp_block(x)
>>> x.shape
torch.Size([10, 64, 100])
"""
def __init__(
self,
num_blocks,
out_channels,
norm="ln",
skip_around_intra=True,
):
super(Computation_Block, self).__init__()
##MossFormer2M: MossFormer with recurrence
#self.intra_mdl = MossFormer2M(num_blocks=num_blocks, d_model=out_channels)
##MossFormerM: the orignal MossFormer
self.intra_mdl = MossFormerM(num_blocks=num_blocks, d_model=out_channels)
self.skip_around_intra = skip_around_intra
# Norm
self.norm = norm
if norm is not None:
self.intra_norm = select_norm(norm, out_channels, 3)
def forward(self, x):
"""Returns the output tensor.
Arguments
---------
x : torch.Tensor
Input tensor of dimension [B, N, S].
Return
---------
out: torch.Tensor
Output tensor of dimension [B, N, S].
where, B = Batchsize,
N = number of filters
S = sequence time index
"""
B, N, S = x.shape
# intra RNN
# [B, S, N]
intra = x.permute(0, 2, 1).contiguous() #.view(B, S, N)
intra = self.intra_mdl(intra)
# [B, N, S]
intra = intra.permute(0, 2, 1).contiguous()
if self.norm is not None:
intra = self.intra_norm(intra)
# [B, N, S]
if self.skip_around_intra:
intra = intra + x
out = intra
return out

View File

@ -9,6 +9,7 @@
import math
import torch
import torch.nn.functional as F
from torch import einsum
def _pre_hook(
state_dict,
@ -510,3 +511,19 @@ class StreamingRelPositionalEncoding(torch.nn.Module):
pos_enc = self.dropout(pos_enc)
return pos_enc
class ScaledSinuEmbedding(torch.nn.Module):
def __init__(self, dim):
super().__init__()
self.scale = torch.nn.Parameter(torch.ones(1,))
inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer('inv_freq', inv_freq)
def forward(self, x):
n, device = x.shape[1], x.device
t = torch.arange(n, device = device).type_as(self.inv_freq)
sinu = einsum('i , j -> i j', t, self.inv_freq)
emb = torch.cat((sinu.sin(), sinu.cos()), dim = -1)
return emb * self.scale

View File

@ -7,6 +7,7 @@
"""Layer normalization module."""
import torch
import torch.nn as nn
class LayerNorm(torch.nn.LayerNorm):
@ -40,3 +41,137 @@ class LayerNorm(torch.nn.LayerNorm):
.forward(x.transpose(self.dim, -1))
.transpose(self.dim, -1)
)
class GlobalLayerNorm(nn.Module):
"""Calculate Global Layer Normalization.
Arguments
---------
dim : (int or list or torch.Size)
Input shape from an expected input of size.
eps : float
A value added to the denominator for numerical stability.
elementwise_affine : bool
A boolean value that when set to True,
this module has learnable per-element affine parameters
initialized to ones (for weights) and zeros (for biases).
Example
-------
>>> x = torch.randn(5, 10, 20)
>>> GLN = GlobalLayerNorm(10, 3)
>>> x_norm = GLN(x)
"""
def __init__(self, dim, shape, eps=1e-8, elementwise_affine=True):
super(GlobalLayerNorm, self).__init__()
self.dim = dim
self.eps = eps
self.elementwise_affine = elementwise_affine
if self.elementwise_affine:
if shape == 3:
self.weight = nn.Parameter(torch.ones(self.dim, 1))
self.bias = nn.Parameter(torch.zeros(self.dim, 1))
if shape == 4:
self.weight = nn.Parameter(torch.ones(self.dim, 1, 1))
self.bias = nn.Parameter(torch.zeros(self.dim, 1, 1))
else:
self.register_parameter("weight", None)
self.register_parameter("bias", None)
def forward(self, x):
"""Returns the normalized tensor.
Arguments
---------
x : torch.Tensor
Tensor of size [N, C, K, S] or [N, C, L].
"""
# x = N x C x K x S or N x C x L
# N x 1 x 1
# cln: mean,var N x 1 x K x S
# gln: mean,var N x 1 x 1
if x.dim() == 3:
mean = torch.mean(x, (1, 2), keepdim=True)
var = torch.mean((x - mean) ** 2, (1, 2), keepdim=True)
if self.elementwise_affine:
x = (
self.weight * (x - mean) / torch.sqrt(var + self.eps)
+ self.bias
)
else:
x = (x - mean) / torch.sqrt(var + self.eps)
if x.dim() == 4:
mean = torch.mean(x, (1, 2, 3), keepdim=True)
var = torch.mean((x - mean) ** 2, (1, 2, 3), keepdim=True)
if self.elementwise_affine:
x = (
self.weight * (x - mean) / torch.sqrt(var + self.eps)
+ self.bias
)
else:
x = (x - mean) / torch.sqrt(var + self.eps)
return x
class CumulativeLayerNorm(nn.LayerNorm):
"""Calculate Cumulative Layer Normalization.
Arguments
---------
dim : int
Dimension that you want to normalize.
elementwise_affine : True
Learnable per-element affine parameters.
Example
-------
>>> x = torch.randn(5, 10, 20)
>>> CLN = CumulativeLayerNorm(10)
>>> x_norm = CLN(x)
"""
def __init__(self, dim, elementwise_affine=True):
super(CumulativeLayerNorm, self).__init__(
dim, elementwise_affine=elementwise_affine, eps=1e-8
)
def forward(self, x):
"""Returns the normalized tensor.
Arguments
---------
x : torch.Tensor
Tensor size [N, C, K, S] or [N, C, L]
"""
# x: N x C x K x S or N x C x L
# N x K x S x C
if x.dim() == 4:
x = x.permute(0, 2, 3, 1).contiguous()
# N x K x S x C == only channel norm
x = super().forward(x)
# N x C x K x S
x = x.permute(0, 3, 1, 2).contiguous()
if x.dim() == 3:
x = torch.transpose(x, 1, 2)
# N x L x C == only channel norm
x = super().forward(x)
# N x C x L
x = torch.transpose(x, 1, 2)
return x
class ScaleNorm(nn.Module):
def __init__(self, dim, eps = 1e-5):
super().__init__()
self.scale = dim ** -0.5
self.eps = eps
self.g = nn.Parameter(torch.ones(1))
def forward(self, x):
norm = torch.norm(x, dim = -1, keepdim = True) * self.scale
return x / norm.clamp(min = self.eps) * self.g

View File

@ -0,0 +1,307 @@
import torch
import torch.nn.functional as F
from torch import nn, einsum
from einops import rearrange
def identity(t, *args, **kwargs):
return t
def append_dims(x, num_dims):
if num_dims <= 0:
return x
return x.view(*x.shape, *((1,) * num_dims))
def exists(val):
return val is not None
def default(val, d):
return val if exists(val) else d
def padding_to_multiple_of(n, mult):
remainder = n % mult
if remainder == 0:
return 0
return mult - remainder
class Transpose(nn.Module):
""" Wrapper class of torch.transpose() for Sequential module. """
def __init__(self, shape: tuple):
super(Transpose, self).__init__()
self.shape = shape
def forward(self, x):
return x.transpose(*self.shape)
class DepthwiseConv1d(nn.Module):
"""
When groups == in_channels and out_channels == K * in_channels, where K is a positive integer,
this operation is termed in literature as depthwise convolution.
Args:
in_channels (int): Number of channels in the input
out_channels (int): Number of channels produced by the convolution
kernel_size (int or tuple): Size of the convolving kernel
stride (int, optional): Stride of the convolution. Default: 1
padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0
bias (bool, optional): If True, adds a learnable bias to the output. Default: True
Inputs: inputs
- **inputs** (batch, in_channels, time): Tensor containing input vector
Returns: outputs
- **outputs** (batch, out_channels, time): Tensor produces by depthwise 1-D convolution.
"""
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int,
stride: int = 1,
padding: int = 0,
bias: bool = False,
) -> None:
super(DepthwiseConv1d, self).__init__()
assert out_channels % in_channels == 0, "out_channels should be constant multiple of in_channels"
self.conv = nn.Conv1d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
groups=in_channels,
stride=stride,
padding=padding,
bias=bias,
)
def forward(self, inputs):
return self.conv(inputs)
class ConvModule(nn.Module):
"""
Conformer convolution module starts with a pointwise convolution and a gated linear unit (GLU).
This is followed by a single 1-D depthwise convolution layer. Batchnorm is deployed just after the convolution
to aid training deep models.
Args:
in_channels (int): Number of channels in the input
kernel_size (int or tuple, optional): Size of the convolving kernel Default: 31
dropout_p (float, optional): probability of dropout
Inputs: inputs
inputs (batch, time, dim): Tensor contains input sequences
Outputs: outputs
outputs (batch, time, dim): Tensor produces by conformer convolution module.
"""
def __init__(
self,
in_channels: int,
kernel_size: int = 17,
expansion_factor: int = 2,
dropout_p: float = 0.1,
) -> None:
super(ConvModule, self).__init__()
assert (kernel_size - 1) % 2 == 0, "kernel_size should be a odd number for 'SAME' padding"
assert expansion_factor == 2, "Currently, Only Supports expansion_factor 2"
self.sequential = nn.Sequential(
Transpose(shape=(1, 2)),
DepthwiseConv1d(in_channels, in_channels, kernel_size, stride=1, padding=(kernel_size - 1) // 2),
)
def forward(self, inputs):
return inputs + self.sequential(inputs).transpose(1, 2)
class OffsetScale(nn.Module):
def __init__(self, dim, heads = 1):
super().__init__()
self.gamma = nn.Parameter(torch.ones(heads, dim))
self.beta = nn.Parameter(torch.zeros(heads, dim))
nn.init.normal_(self.gamma, std = 0.02)
def forward(self, x):
out = einsum('... d, h d -> ... h d', x, self.gamma) + self.beta
return out.unbind(dim = -2)
class FFConvM(nn.Module):
def __init__(
self,
dim_in,
dim_out,
norm_klass = nn.LayerNorm,
dropout = 0.1
):
super().__init__()
self.mdl = nn.Sequential(
norm_klass(dim_in),
nn.Linear(dim_in, dim_out),
nn.SiLU(),
ConvModule(dim_out),
nn.Dropout(dropout)
)
def forward(
self,
x,
):
output = self.mdl(x)
return output
class FLASH_ShareA_FFConvM(nn.Module):
def __init__(
self,
*,
dim,
group_size = 256,
query_key_dim = 128,
expansion_factor = 1.,
causal = False,
dropout = 0.1,
rotary_pos_emb = None,
norm_klass = nn.LayerNorm,
shift_tokens = True
):
super().__init__()
hidden_dim = int(dim * expansion_factor)
self.group_size = group_size
self.causal = causal
self.shift_tokens = shift_tokens
# positional embeddings
self.rotary_pos_emb = rotary_pos_emb
# norm
self.dropout = nn.Dropout(dropout)
# projections
self.to_hidden = FFConvM(
dim_in = dim,
dim_out = hidden_dim,
norm_klass = norm_klass,
dropout = dropout,
)
self.to_qk = FFConvM(
dim_in = dim,
dim_out = query_key_dim,
norm_klass = norm_klass,
dropout = dropout,
)
self.qk_offset_scale = OffsetScale(query_key_dim, heads = 4)
self.to_out = FFConvM(
dim_in = dim*2,
dim_out = dim,
norm_klass = norm_klass,
dropout = dropout,
)
self.gateActivate=nn.Sigmoid()
def forward(
self,
x,
*,
mask = None
):
"""
b - batch
n - sequence length (within groups)
g - group dimension
d - feature dimension (keys)
e - feature dimension (values)
i - sequence dimension (source)
j - sequence dimension (target)
"""
normed_x = x
# do token shift - a great, costless trick from an independent AI researcher in Shenzhen
residual = x
if self.shift_tokens:
x_shift, x_pass = normed_x.chunk(2, dim = -1)
x_shift = F.pad(x_shift, (0, 0, 1, -1), value = 0.)
normed_x = torch.cat((x_shift, x_pass), dim = -1)
# initial projections
v, u = self.to_hidden(normed_x).chunk(2, dim = -1)
qk = self.to_qk(normed_x)
# offset and scale
quad_q, lin_q, quad_k, lin_k = self.qk_offset_scale(qk)
att_v, att_u = self.cal_attention(x, quad_q, lin_q, quad_k, lin_k, v, u)
out = (att_u*v ) * self.gateActivate(att_v*u)
x = x + self.to_out(out)
return x
def cal_attention(self, x, quad_q, lin_q, quad_k, lin_k, v, u, mask = None):
b, n, device, g = x.shape[0], x.shape[-2], x.device, self.group_size
if exists(mask):
lin_mask = rearrange(mask, '... -> ... 1')
lin_k = lin_k.masked_fill(~lin_mask, 0.)
# rotate queries and keys
if exists(self.rotary_pos_emb):
quad_q, lin_q, quad_k, lin_k = map(self.rotary_pos_emb.rotate_queries_or_keys, (quad_q, lin_q, quad_k, lin_k))
# padding for groups
padding = padding_to_multiple_of(n, g)
if padding > 0:
quad_q, quad_k, lin_q, lin_k, v, u = map(lambda t: F.pad(t, (0, 0, 0, padding), value = 0.), (quad_q, quad_k, lin_q, lin_k, v, u))
mask = default(mask, torch.ones((b, n), device = device, dtype = torch.bool))
mask = F.pad(mask, (0, padding), value = False)
# group along sequence
quad_q, quad_k, lin_q, lin_k, v, u = map(lambda t: rearrange(t, 'b (g n) d -> b g n d', n = self.group_size), (quad_q, quad_k, lin_q, lin_k, v, u))
if exists(mask):
mask = rearrange(mask, 'b (g j) -> b g 1 j', j = g)
# calculate quadratic attention output
sim = einsum('... i d, ... j d -> ... i j', quad_q, quad_k) / g
attn = F.relu(sim) ** 2
attn = self.dropout(attn)
if exists(mask):
attn = attn.masked_fill(~mask, 0.)
if self.causal:
causal_mask = torch.ones((g, g), dtype = torch.bool, device = device).triu(1)
attn = attn.masked_fill(causal_mask, 0.)
quad_out_v = einsum('... i j, ... j d -> ... i d', attn, v)
quad_out_u = einsum('... i j, ... j d -> ... i d', attn, u)
# calculate linear attention output
if self.causal:
lin_kv = einsum('b g n d, b g n e -> b g d e', lin_k, v) / g
# exclusive cumulative sum along group dimension
lin_kv = lin_kv.cumsum(dim = 1)
lin_kv = F.pad(lin_kv, (0, 0, 0, 0, 1, -1), value = 0.)
lin_out_v = einsum('b g d e, b g n d -> b g n e', lin_kv, lin_q)
lin_ku = einsum('b g n d, b g n e -> b g d e', lin_k, u) / g
# exclusive cumulative sum along group dimension
lin_ku = lin_ku.cumsum(dim = 1)
lin_ku = F.pad(lin_ku, (0, 0, 0, 0, 1, -1), value = 0.)
lin_out_u = einsum('b g d e, b g n d -> b g n e', lin_ku, lin_q)
else:
lin_kv = einsum('b g n d, b g n e -> b d e', lin_k, v) / n
lin_out_v = einsum('b g n d, b d e -> b g n e', lin_q, lin_kv)
lin_ku = einsum('b g n d, b g n e -> b d e', lin_k, u) / n
lin_out_u = einsum('b g n d, b d e -> b g n e', lin_q, lin_ku)
# fold back groups into full sequence, and excise out padding
return map(lambda t: rearrange(t, 'b g n d -> b (g n) d')[:, :n], (quad_out_v+lin_out_v, quad_out_u+lin_out_u))