diff --git a/funasr/bin/ss_infer.py b/funasr/bin/ss_infer.py new file mode 100644 index 000000000..483967b37 --- /dev/null +++ b/funasr/bin/ss_infer.py @@ -0,0 +1,127 @@ +#!/usr/bin/env python3 +# -*- encoding: utf-8 -*- +# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. +# MIT License (https://opensource.org/licenses/MIT) + + +import logging +from pathlib import Path +from typing import List +from typing import Union + +import numpy as np +import torch + +from funasr.build_utils.build_model_from_file import build_model_from_file +from funasr.torch_utils.device_funcs import to_device + + +class SpeechSeparator: + """SpeechSeparator class + + Examples: + >>> import soundfile + >>> speech_separator = MossFormer("ss_config.yml", "ss.pt") + >>> audio, rate = soundfile.read("speech.wav") + >>> separated_wavs = speech_separator(audio) + + """ + + def __init__( + self, + ss_infer_config: Union[Path, str] = None, + ss_model_file: Union[Path, str] = None, + device: str = "cpu", + batch_size: int = 1, + dtype: str = "float32", + **kwargs, + ): + + # 1. Build ss model + ss_model, ss_infer_args = build_model_from_file( + ss_infer_config, ss_model_file, None, device, task_name="ss" + ) + + logging.info("ss_model: {}".format(ss_model)) + logging.info("ss_infer_args: {}".format(ss_infer_args)) + + ss_model.to(dtype=getattr(torch, dtype)).eval() + + self.ss_model = ss_model + self.ss_infer_args = ss_infer_args + self.device = device + self.dtype = dtype + self.batch_size = batch_size + + def decode(self, model, args, inputs, nsamples): + decode_do_segment = False + with torch.no_grad(): + out = [] + window = args.sample_rate * args.decode_window # decoding window length + stride = int(window*0.75) # decoding stride if segmentation is used + b, t = inputs.shape + if t > window * args.one_time_decode_length: + decode_do_segment = True # set segment decoding to true for very long sequence + + if t < window: + inputs = np.concatenate([inputs, np.zeros((inputs.shape[0], window-t))], 1) + elif t < window + stride: + padding = window + stride - t + inputs = np.concatenate([inputs, np.zeros((inputs.shape[0], padding))], 1) + else: + if (t - window) % stride != 0: + padding = t - (t-window)//stride * stride + inputs = np.concatenate([inputs, np.zeros((inputs.shape[0], padding))], 1) + inputs = torch.from_numpy(np.float32(inputs)) + inputs = to_device(inputs, device=self.device) + b, t = inputs.shape + if decode_do_segment: + outputs = np.zeros((args.num_spks, t)) + give_up_length = (window - stride)//2 + current_idx = 0 + while current_idx + window <= t: + tmp_input = inputs[:, current_idx:current_idx+window] + tmp_out_list = model(tmp_input,) + for spk in range(args.num_spks): + tmp_out_list[spk] = tmp_out_list[spk][0, :].cpu().numpy() + if current_idx == 0: + outputs[spk, current_idx:current_idx+window-give_up_length] = \ + tmp_out_list[spk][:-give_up_length] + else: + outputs[spk, current_idx+give_up_length:current_idx+window-give_up_length] = \ + tmp_out_list[spk][give_up_length:-give_up_length] + current_idx += stride + for spk in range(args.num_spks): + out.append(outputs[spk, :]) + else: + out_list = model(inputs) + for spk in range(args.num_spks): + out.append(out_list[spk][0, :].cpu().numpy()) + + max_abs = 0 + for spk in range(args.num_spks): + if max_abs < max(abs(out[spk])): + max_abs = max(abs(out[spk])) + for spk in range(args.num_spks): + out[spk] = out[spk][:nsamples] + out[spk] = out[spk]/max_abs + + return out + + @torch.no_grad() + def __call__( + self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None, + ) -> List[torch.Tensor]: + """Inference + + Args: + speech: Input speech data + Returns: + speech list: list of speech data + + """ + + out = self.decode(self.ss_model, self.ss_infer_args, speech, speech_lengths) + + return out + diff --git a/funasr/bin/ss_inference_launch.py b/funasr/bin/ss_inference_launch.py new file mode 100644 index 000000000..bab68ad98 --- /dev/null +++ b/funasr/bin/ss_inference_launch.py @@ -0,0 +1,253 @@ +#!/usr/bin/env python3 +# -*- encoding: utf-8 -*- +# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. +# MIT License (https://opensource.org/licenses/MIT) + + +import argparse +import logging +import os +import sys +from typing import Optional +from typing import Union + +import numpy as np +import torch +import soundfile as sf +from funasr.build_utils.build_streaming_iterator import build_streaming_iterator +from funasr.torch_utils.set_all_random_seed import set_all_random_seed +from funasr.utils import config_argparse +from funasr.utils.cli_utils import get_commandline_args +from funasr.utils.types import str2triple_str +from funasr.bin.ss_infer import SpeechSeparator + + +def inference_ss( + batch_size: int, + ngpu: int, + log_level: Union[int, str], + ss_infer_config: Optional[str], + ss_model_file: Optional[str], + output_dir: Optional[str] = None, + dtype: str = "float32", + seed: int = 0, + num_workers: int = 1, + num_spks: int = 2, + sample_rate: int = 8000, + param_dict: dict = None, + **kwargs, +): + ncpu = kwargs.get("ncpu", 1) + torch.set_num_threads(ncpu) + if batch_size > 1: + raise NotImplementedError("batch decoding is not implemented") + logging.basicConfig( + level=log_level, + format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", + ) + + if ngpu >= 1 and torch.cuda.is_available(): + device = "cuda" + else: + device = "cpu" + batch_size = 1 + + # 1. Set random-seed + set_all_random_seed(seed) + + # 2. Build speech separator + speech_separator_kwargs = dict( + ss_infer_config=ss_infer_config, + ss_model_file=ss_model_file, + device=device, + dtype=dtype, + ) + logging.info("speech_separator_kwargs: {}".format(speech_separator_kwargs)) + speech_separator = SpeechSeparator(**speech_separator_kwargs) + + def _forward( + data_path_and_name_and_type, + raw_inputs: Union[np.ndarray, torch.Tensor] = None, + output_dir_v2: Optional[str] = None, + fs: dict = None, + param_dict: dict = None + ): + # 3. Build data-iterator + if data_path_and_name_and_type is None and raw_inputs is not None: + if isinstance(raw_inputs, torch.Tensor): + raw_inputs = raw_inputs.numpy() + data_path_and_name_and_type = [raw_inputs, "speech", "waveform"] + loader = build_streaming_iterator( + task_name="ss", + preprocess_args=None, + data_path_and_name_and_type=data_path_and_name_and_type, + dtype=dtype, + fs=fs, + batch_size=batch_size, + num_workers=num_workers, + ) + + # 4 .Start for-loop + output_path = output_dir_v2 if output_dir_v2 is not None else output_dir + if not os.path.exists(output_path): + cmd = 'mkdir -p ' + output_path + os.system(cmd) + + for keys, batch in loader: + assert isinstance(batch, dict), type(batch) + assert all(isinstance(s, str) for s in keys), keys + _bs = len(next(iter(batch.values()))) + assert len(keys) == _bs, f"{len(keys)} != {_bs}" + + # do speech separation + logging.info('decoding: {}'.format(keys[0])) + ss_results = speech_separator(**batch) + + for spk in range(num_spks): + sf.write(os.path.join(output_path, keys[0].replace('.wav', '_s'+str(spk+1)+'.wav')), ss_results[spk], sample_rate) + torch.cuda.empty_cache() + return ss_results + + return _forward + + +def inference_launch(mode, **kwargs): + if mode == "mossformer": + return inference_ss(**kwargs) + else: + logging.info("Unknown decoding mode: {}".format(mode)) + return None + + +def get_parser(): + parser = config_argparse.ArgumentParser( + description="Speech Separator Decoding", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + + # Note(kamo): Use '_' instead of '-' as separator. + # '-' is confusing if written in yaml. + parser.add_argument( + "--log_level", + type=lambda x: x.upper(), + default="INFO", + choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"), + help="The verbose level of logging", + ) + + parser.add_argument("--output_dir", type=str, required=True) + parser.add_argument( + "--ngpu", + type=int, + default=1, + help="The number of gpus. 0 indicates CPU mode", + ) + parser.add_argument( + "--njob", + type=int, + default=1, + help="The number of jobs for each gpu", + ) + parser.add_argument( + "--gpuid_list", + type=str, + default="2", + help="The visible gpus", + ) + parser.add_argument("--seed", type=int, default=0, help="Random seed") + parser.add_argument( + "--dtype", + default="float32", + choices=["float16", "float32", "float64"], + help="Data type", + ) + parser.add_argument( + "--num_workers", + type=int, + default=1, + help="The number of workers used for DataLoader", + ) + + group = parser.add_argument_group("Input data related") + group.add_argument( + "--data_path_and_name_and_type", + type=str2triple_str, + required=True, + action="append", + ) + + group = parser.add_argument_group("The model configuration related") + group.add_argument( + "--ss_infer_config", + type=str, + help="SS infer configuration", + ) + group.add_argument( + "--ss_model_file", + type=str, + help="SS model parameter file", + ) + group.add_argument( + "--ss_train_config", + type=str, + help="SS training configuration", + ) + + group = parser.add_argument_group("The inference configuration related") + group.add_argument( + "--batch_size", + type=int, + default=1, + help="The batch size for inference", + ) + + parser.add_argument( + '--num-spks', dest='num_spks', type=int, default=2) + + parser.add_argument( + '--one-time-decode-length', dest='one_time_decode_length', type=int, + default=60, help='the max length (second) for one-time decoding') + + parser.add_argument( + '--decode-window', dest='decode_window', type=int, + default=1, help='segmental decoding window length (second)') + + parser.add_argument( + '--sample-rate', dest='sample_rate', type=int, default='8000') + return parser + + +def main(cmd=None): + print(get_commandline_args(), file=sys.stderr) + parser = get_parser() + parser.add_argument( + "--mode", + type=str, + default="mossformer", + help="The decoding mode", + ) + args = parser.parse_args(cmd) + kwargs = vars(args) + kwargs.pop("config", None) + + # set logging messages + logging.basicConfig( + level=args.log_level, + format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", + ) + logging.info("Decoding args: {}".format(kwargs)) + + # gpu setting + if args.ngpu > 0: + jobid = int(args.output_dir.split(".")[-1]) + gpuid = args.gpuid_list.split(",")[(jobid - 1) // args.njob] + os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" + os.environ["CUDA_VISIBLE_DEVICES"] = gpuid + + inference_pipeline = inference_launch(**kwargs) + return inference_pipeline(kwargs["data_path_and_name_and_type"]) + + +if __name__ == "__main__": + main() + diff --git a/funasr/build_utils/build_model.py b/funasr/build_utils/build_model.py index be8f91061..66fdfd07a 100644 --- a/funasr/build_utils/build_model.py +++ b/funasr/build_utils/build_model.py @@ -5,6 +5,7 @@ from funasr.build_utils.build_pretrain_model import build_pretrain_model from funasr.build_utils.build_punc_model import build_punc_model from funasr.build_utils.build_sv_model import build_sv_model from funasr.build_utils.build_vad_model import build_vad_model +from funasr.build_utils.build_ss_model import build_ss_model def build_model(args): @@ -22,6 +23,8 @@ def build_model(args): model = build_diar_model(args) elif args.task_name == "sv": model = build_sv_model(args) + elif args.task_name == "ss": + model = build_ss_model(args) else: raise NotImplementedError("Not supported task: {}".format(args.task_name)) diff --git a/funasr/build_utils/build_model_from_file.py b/funasr/build_utils/build_model_from_file.py index 26542cd60..6130e7114 100644 --- a/funasr/build_utils/build_model_from_file.py +++ b/funasr/build_utils/build_model_from_file.py @@ -11,6 +11,18 @@ from funasr.build_utils.build_model import build_model from funasr.models.base_model import FunASRModel +def load_checkpoint(checkpoint_path, use_cuda=1): + if use_cuda: + checkpoint = torch.load(checkpoint_path) + else: + checkpoint = torch.load( + checkpoint_path, map_location=lambda storage, loc: storage) + return checkpoint + +def reload_ss_for_eval(model, checkpoint_path, use_cuda=False): + checkpoint = load_checkpoint(checkpoint_path, use_cuda) + model.load_state_dict(checkpoint['model'], strict=False) + def build_model_from_file( config_file: Union[Path, str] = None, model_file: Union[Path, str] = None, @@ -70,6 +82,9 @@ def build_model_from_file( model.load_state_dict(model_dict) else: model_dict = torch.load(model_file, map_location=device) + if task_name == 'ss': + reload_ss_for_eval(model, model_file, use_cuda=True) + logging.info("model is loaded from path: {}".format(model_file)) if task_name == "diar" and mode == "sond": model_dict = fileter_model_dict(model_dict, model.state_dict()) if task_name == "vad": diff --git a/funasr/build_utils/build_ss_model.py b/funasr/build_utils/build_ss_model.py new file mode 100644 index 000000000..a6b520917 --- /dev/null +++ b/funasr/build_utils/build_ss_model.py @@ -0,0 +1,15 @@ +from funasr.models.e2e_ss import MossFormer + +def build_ss_model(args): + model = MossFormer( + in_channels=args.encoder_embedding_dim, + out_channels=args.mossformer_sequence_dim, + num_blocks=args.num_mossformer_layer, + kernel_size=args.encoder_kernel_size, + norm=args.norm, + num_spks=args.num_spks, + skip_around_intra=args.skip_around_intra, + use_global_pos_enc=args.use_global_pos_enc, + max_length=args.max_length) + + return model diff --git a/funasr/models/decoder/mossformer_decoder.py b/funasr/models/decoder/mossformer_decoder.py new file mode 100644 index 000000000..e0189f717 --- /dev/null +++ b/funasr/models/decoder/mossformer_decoder.py @@ -0,0 +1,53 @@ +import torch +import torch.nn as nn + + +class MossFormerDecoder(nn.ConvTranspose1d): + """A decoder layer that consists of ConvTranspose1d. + + Arguments + --------- + kernel_size : int + Length of filters. + in_channels : int + Number of input channels. + out_channels : int + Number of output channels. + + + Example + --------- + >>> x = torch.randn(2, 100, 1000) + >>> decoder = Decoder(kernel_size=4, in_channels=100, out_channels=1) + >>> h = decoder(x) + >>> h.shape + torch.Size([2, 1003]) + """ + + def __init__(self, *args, **kwargs): + super(MossFormerDecoder, self).__init__(*args, **kwargs) + + def forward(self, x): + """Return the decoded output. + + Arguments + --------- + x : torch.Tensor + Input tensor with dimensionality [B, N, L]. + where, B = Batchsize, + N = number of filters + L = time points + """ + + if x.dim() not in [2, 3]: + raise RuntimeError( + "{} accept 3/4D tensor as input".format(self.__name__) + ) + x = super().forward(x if x.dim() == 3 else torch.unsqueeze(x, 1)) + + if torch.squeeze(x).dim() == 1: + x = torch.squeeze(x, dim=1) + else: + x = torch.squeeze(x) + return x + diff --git a/funasr/models/e2e_ss.py b/funasr/models/e2e_ss.py new file mode 100644 index 000000000..1a46b3ffd --- /dev/null +++ b/funasr/models/e2e_ss.py @@ -0,0 +1,95 @@ +import math +import torch +import torch.nn as nn +import torch.nn.functional as F +import copy +from funasr.models.base_model import FunASRModel +from funasr.models.encoder.mossformer_encoder import MossFormerEncoder, MossFormer_MaskNet +from funasr.models.decoder.mossformer_decoder import MossFormerDecoder + + +class MossFormer(FunASRModel): + """The MossFormer model for separating input mixed speech into different speaker's speech. + + Arguments + --------- + in_channels : int + Number of channels at the output of the encoder. + out_channels : int + Number of channels that would be inputted to the intra and inter blocks. + num_blocks : int + Number of layers of Dual Computation Block. + norm : str + Normalization type. + num_spks : int + Number of sources (speakers). + skip_around_intra : bool + Skip connection around intra. + use_global_pos_enc : bool + Global positional encodings. + max_length : int + Maximum sequence length. + kernel_size: int + Encoder and decoder kernel size + """ + + def __init__( + self, + in_channels=512, + out_channels=512, + num_blocks=24, + kernel_size=16, + norm="ln", + num_spks=2, + skip_around_intra=True, + use_global_pos_enc=True, + max_length=20000, + ): + super(MossFormer, self).__init__() + self.num_spks = num_spks + # Encoding + self.enc = MossFormerEncoder(kernel_size=kernel_size, out_channels=in_channels, in_channels=1) + + ##Compute Mask + self.mask_net = MossFormer_MaskNet( + in_channels=in_channels, + out_channels=out_channels, + num_blocks=num_blocks, + norm=norm, + num_spks=num_spks, + skip_around_intra=skip_around_intra, + use_global_pos_enc=use_global_pos_enc, + max_length=max_length, + ) + self.dec = MossFormerDecoder( + in_channels=out_channels, + out_channels=1, + kernel_size=kernel_size, + stride = kernel_size//2, + bias=False + ) + def forward(self, input): + x = self.enc(input) + mask = self.mask_net(x) + x = torch.stack([x] * self.num_spks) + sep_x = x * mask + + # Decoding + est_source = torch.cat( + [ + self.dec(sep_x[i]).unsqueeze(-1) + for i in range(self.num_spks) + ], + dim=-1, + ) + T_origin = input.size(1) + T_est = est_source.size(1) + if T_origin > T_est: + est_source = F.pad(est_source, (0, 0, 0, T_origin - T_est)) + else: + est_source = est_source[:, :T_origin, :] + + out = [] + for spk in range(self.num_spks): + out.append(est_source[:,:,spk]) + return out diff --git a/funasr/models/encoder/mossformer_encoder.py b/funasr/models/encoder/mossformer_encoder.py new file mode 100644 index 000000000..54d80ca7a --- /dev/null +++ b/funasr/models/encoder/mossformer_encoder.py @@ -0,0 +1,417 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from rotary_embedding_torch import RotaryEmbedding +from funasr.modules.layer_norm import GlobalLayerNorm, CumulativeLayerNorm, ScaleNorm +from funasr.modules.embedding import ScaledSinuEmbedding +from funasr.modules.mossformer import FLASH_ShareA_FFConvM + + +def select_norm(norm, dim, shape): + """Just a wrapper to select the normalization type. + """ + + if norm == "gln": + return GlobalLayerNorm(dim, shape, elementwise_affine=True) + if norm == "cln": + return CumulativeLayerNorm(dim, elementwise_affine=True) + if norm == "ln": + return nn.GroupNorm(1, dim, eps=1e-8) + else: + return nn.BatchNorm1d(dim) + + +class MossformerBlock(nn.Module): + def __init__( + self, + *, + dim, + depth, + group_size = 256, + query_key_dim = 128, + expansion_factor = 4., + causal = False, + attn_dropout = 0.1, + norm_type = 'scalenorm', + shift_tokens = True + ): + super().__init__() + assert norm_type in ('scalenorm', 'layernorm'), 'norm_type must be one of scalenorm or layernorm' + + if norm_type == 'scalenorm': + norm_klass = ScaleNorm + elif norm_type == 'layernorm': + norm_klass = nn.LayerNorm + + self.group_size = group_size + + rotary_pos_emb = RotaryEmbedding(dim = min(32, query_key_dim)) + # max rotary embedding dimensions of 32, partial Rotary embeddings, from Wang et al - GPT-J + self.layers = nn.ModuleList([FLASH_ShareA_FFConvM(dim = dim, group_size = group_size, query_key_dim = query_key_dim, expansion_factor = expansion_factor, causal = causal, dropout = attn_dropout, rotary_pos_emb = rotary_pos_emb, norm_klass = norm_klass, shift_tokens = shift_tokens) for _ in range(depth)]) + + def forward( + self, + x, + *, + mask = None + ): + ii = 0 + for flash in self.layers: + x = flash(x, mask = mask) + ii = ii + 1 + return x + + +class MossFormer_MaskNet(nn.Module): + """The MossFormer module for computing output masks. + + Arguments + --------- + in_channels : int + Number of channels at the output of the encoder. + out_channels : int + Number of channels that would be inputted to the intra and inter blocks. + num_blocks : int + Number of layers of Dual Computation Block. + norm : str + Normalization type. + num_spks : int + Number of sources (speakers). + skip_around_intra : bool + Skip connection around intra. + use_global_pos_enc : bool + Global positional encodings. + max_length : int + Maximum sequence length. + + Example + --------- + >>> mossformer_block = MossFormerM(1, 64, 8) + >>> mossformer_masknet = MossFormer_MaskNet(64, 64, intra_block, num_spks=2) + >>> x = torch.randn(10, 64, 2000) + >>> x = mossformer_masknet(x) + >>> x.shape + torch.Size([2, 10, 64, 2000]) + """ + + def __init__( + self, + in_channels, + out_channels, + num_blocks=24, + norm="ln", + num_spks=2, + skip_around_intra=True, + use_global_pos_enc=True, + max_length=20000, + ): + super(MossFormer_MaskNet, self).__init__() + self.num_spks = num_spks + self.num_blocks = num_blocks + self.norm = select_norm(norm, in_channels, 3) + self.conv1d_encoder = nn.Conv1d(in_channels, out_channels, 1, bias=False) + self.use_global_pos_enc = use_global_pos_enc + + if self.use_global_pos_enc: + self.pos_enc = ScaledSinuEmbedding(out_channels) + + self.mdl = Computation_Block( + num_blocks, + out_channels, + norm, + skip_around_intra=skip_around_intra, + ) + + self.conv1d_out = nn.Conv1d( + out_channels, out_channels * num_spks, kernel_size=1 + ) + self.conv1_decoder = nn.Conv1d(out_channels, in_channels, 1, bias=False) + self.prelu = nn.PReLU() + self.activation = nn.ReLU() + # gated output layer + self.output = nn.Sequential( + nn.Conv1d(out_channels, out_channels, 1), nn.Tanh() + ) + self.output_gate = nn.Sequential( + nn.Conv1d(out_channels, out_channels, 1), nn.Sigmoid() + ) + + def forward(self, x): + """Returns the output tensor. + + Arguments + --------- + x : torch.Tensor + Input tensor of dimension [B, N, S]. + + Returns + ------- + out : torch.Tensor + Output tensor of dimension [spks, B, N, S] + where, spks = Number of speakers + B = Batchsize, + N = number of filters + S = the number of time frames + """ + + # before each line we indicate the shape after executing the line + + # [B, N, L] + x = self.norm(x) + + # [B, N, L] + x = self.conv1d_encoder(x) + if self.use_global_pos_enc: + #x = self.pos_enc(x.transpose(1, -1)).transpose(1, -1) + x * ( + # x.size(1) ** 0.5) + base = x + x = x.transpose(1, -1) + emb = self.pos_enc(x) + emb = emb.transpose(0, -1) + #print('base: {}, emb: {}'.format(base.shape, emb.shape)) + x = base + emb + + + # [B, N, S] + #for i in range(self.num_modules): + # x = self.dual_mdl[i](x) + x = self.mdl(x) + x = self.prelu(x) + + # [B, N*spks, S] + x = self.conv1d_out(x) + B, _, S = x.shape + + # [B*spks, N, S] + x = x.view(B * self.num_spks, -1, S) + + # [B*spks, N, S] + x = self.output(x) * self.output_gate(x) + + # [B*spks, N, S] + x = self.conv1_decoder(x) + + # [B, spks, N, S] + _, N, L = x.shape + x = x.view(B, self.num_spks, N, L) + x = self.activation(x) + + # [spks, B, N, S] + x = x.transpose(0, 1) + + return x + + +class MossFormerEncoder(nn.Module): + """Convolutional Encoder Layer. + + Arguments + --------- + kernel_size : int + Length of filters. + in_channels : int + Number of input channels. + out_channels : int + Number of output channels. + + Example + ------- + >>> x = torch.randn(2, 1000) + >>> encoder = Encoder(kernel_size=4, out_channels=64) + >>> h = encoder(x) + >>> h.shape + torch.Size([2, 64, 499]) + """ + + def __init__(self, kernel_size=2, out_channels=64, in_channels=1): + super(MossFormerEncoder, self).__init__() + self.conv1d = nn.Conv1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=kernel_size // 2, + groups=1, + bias=False, + ) + self.in_channels = in_channels + + def forward(self, x): + """Return the encoded output. + + Arguments + --------- + x : torch.Tensor + Input tensor with dimensionality [B, L]. + Return + ------ + x : torch.Tensor + Encoded tensor with dimensionality [B, N, T_out]. + + where B = Batchsize + L = Number of timepoints + N = Number of filters + T_out = Number of timepoints at the output of the encoder + """ + # B x L -> B x 1 x L + if self.in_channels == 1: + x = torch.unsqueeze(x, dim=1) + # B x 1 x L -> B x N x T_out + x = self.conv1d(x) + x = F.relu(x) + + return x + +class MossFormerM(nn.Module): + """This class implements the transformer encoder. + + Arguments + --------- + num_blocks : int + Number of mossformer blocks to include. + d_model : int + The dimension of the input embedding. + attn_dropout : float + Dropout for the self-attention (Optional). + group_size: int + the chunk size + query_key_dim: int + the attention vector dimension + expansion_factor: int + the expansion factor for the linear projection in conv module + causal: bool + true for causal / false for non causal + + Example + ------- + >>> import torch + >>> x = torch.rand((8, 60, 512)) + >>> net = TransformerEncoder_MossFormerM(num_blocks=8, d_model=512) + >>> output, _ = net(x) + >>> output.shape + torch.Size([8, 60, 512]) + """ + def __init__( + self, + num_blocks, + d_model=None, + causal=False, + group_size = 256, + query_key_dim = 128, + expansion_factor = 4., + attn_dropout = 0.1 + ): + super().__init__() + + self.mossformerM = MossformerBlock( + dim=d_model, + depth=num_blocks, + group_size=group_size, + query_key_dim=query_key_dim, + expansion_factor=expansion_factor, + causal=causal, + attn_dropout=attn_dropout + ) + self.norm = nn.LayerNorm(d_model, eps=1e-6) + + def forward( + self, + src, + ): + """ + Arguments + ---------- + src : torch.Tensor + Tensor shape [B, L, N], + where, B = Batchsize, + L = time points + N = number of filters + The sequence to the encoder layer (required). + src_mask : tensor + The mask for the src sequence (optional). + src_key_padding_mask : tensor + The mask for the src keys per batch (optional). + """ + output = self.mossformerM(src) + output = self.norm(output) + + return output + + +class Computation_Block(nn.Module): + """Computation block for dual-path processing. + + Arguments + --------- + out_channels : int + Dimensionality of inter/intra model. + norm : str + Normalization type. + skip_around_intra : bool + Skip connection around the intra layer. + + Example + --------- + >>> comp_block = Computation_Block(64) + >>> x = torch.randn(10, 64, 100) + >>> x = comp_block(x) + >>> x.shape + torch.Size([10, 64, 100]) + """ + + def __init__( + self, + num_blocks, + out_channels, + norm="ln", + skip_around_intra=True, + ): + super(Computation_Block, self).__init__() + + ##MossFormer2M: MossFormer with recurrence + #self.intra_mdl = MossFormer2M(num_blocks=num_blocks, d_model=out_channels) + ##MossFormerM: the orignal MossFormer + self.intra_mdl = MossFormerM(num_blocks=num_blocks, d_model=out_channels) + self.skip_around_intra = skip_around_intra + + # Norm + self.norm = norm + if norm is not None: + self.intra_norm = select_norm(norm, out_channels, 3) + + def forward(self, x): + """Returns the output tensor. + + Arguments + --------- + x : torch.Tensor + Input tensor of dimension [B, N, S]. + + + Return + --------- + out: torch.Tensor + Output tensor of dimension [B, N, S]. + where, B = Batchsize, + N = number of filters + S = sequence time index + """ + B, N, S = x.shape + # intra RNN + # [B, S, N] + intra = x.permute(0, 2, 1).contiguous() #.view(B, S, N) + + intra = self.intra_mdl(intra) + + # [B, N, S] + intra = intra.permute(0, 2, 1).contiguous() + if self.norm is not None: + intra = self.intra_norm(intra) + + # [B, N, S] + if self.skip_around_intra: + intra = intra + x + + out = intra + return out + diff --git a/funasr/modules/embedding.py b/funasr/modules/embedding.py index 374eba420..1995bbe79 100644 --- a/funasr/modules/embedding.py +++ b/funasr/modules/embedding.py @@ -9,6 +9,7 @@ import math import torch import torch.nn.functional as F +from torch import einsum def _pre_hook( state_dict, @@ -510,3 +511,19 @@ class StreamingRelPositionalEncoding(torch.nn.Module): pos_enc = self.dropout(pos_enc) return pos_enc + + +class ScaledSinuEmbedding(torch.nn.Module): + def __init__(self, dim): + super().__init__() + self.scale = torch.nn.Parameter(torch.ones(1,)) + inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim)) + self.register_buffer('inv_freq', inv_freq) + + def forward(self, x): + n, device = x.shape[1], x.device + t = torch.arange(n, device = device).type_as(self.inv_freq) + sinu = einsum('i , j -> i j', t, self.inv_freq) + emb = torch.cat((sinu.sin(), sinu.cos()), dim = -1) + return emb * self.scale + diff --git a/funasr/modules/layer_norm.py b/funasr/modules/layer_norm.py index 6e934e644..86832306c 100644 --- a/funasr/modules/layer_norm.py +++ b/funasr/modules/layer_norm.py @@ -7,6 +7,7 @@ """Layer normalization module.""" import torch +import torch.nn as nn class LayerNorm(torch.nn.LayerNorm): @@ -40,3 +41,137 @@ class LayerNorm(torch.nn.LayerNorm): .forward(x.transpose(self.dim, -1)) .transpose(self.dim, -1) ) + + +class GlobalLayerNorm(nn.Module): + """Calculate Global Layer Normalization. + + Arguments + --------- + dim : (int or list or torch.Size) + Input shape from an expected input of size. + eps : float + A value added to the denominator for numerical stability. + elementwise_affine : bool + A boolean value that when set to True, + this module has learnable per-element affine parameters + initialized to ones (for weights) and zeros (for biases). + + Example + ------- + >>> x = torch.randn(5, 10, 20) + >>> GLN = GlobalLayerNorm(10, 3) + >>> x_norm = GLN(x) + """ + + def __init__(self, dim, shape, eps=1e-8, elementwise_affine=True): + super(GlobalLayerNorm, self).__init__() + self.dim = dim + self.eps = eps + self.elementwise_affine = elementwise_affine + + if self.elementwise_affine: + if shape == 3: + self.weight = nn.Parameter(torch.ones(self.dim, 1)) + self.bias = nn.Parameter(torch.zeros(self.dim, 1)) + if shape == 4: + self.weight = nn.Parameter(torch.ones(self.dim, 1, 1)) + self.bias = nn.Parameter(torch.zeros(self.dim, 1, 1)) + else: + self.register_parameter("weight", None) + self.register_parameter("bias", None) + + def forward(self, x): + """Returns the normalized tensor. + + Arguments + --------- + x : torch.Tensor + Tensor of size [N, C, K, S] or [N, C, L]. + """ + # x = N x C x K x S or N x C x L + # N x 1 x 1 + # cln: mean,var N x 1 x K x S + # gln: mean,var N x 1 x 1 + if x.dim() == 3: + mean = torch.mean(x, (1, 2), keepdim=True) + var = torch.mean((x - mean) ** 2, (1, 2), keepdim=True) + if self.elementwise_affine: + x = ( + self.weight * (x - mean) / torch.sqrt(var + self.eps) + + self.bias + ) + else: + x = (x - mean) / torch.sqrt(var + self.eps) + + if x.dim() == 4: + mean = torch.mean(x, (1, 2, 3), keepdim=True) + var = torch.mean((x - mean) ** 2, (1, 2, 3), keepdim=True) + if self.elementwise_affine: + x = ( + self.weight * (x - mean) / torch.sqrt(var + self.eps) + + self.bias + ) + else: + x = (x - mean) / torch.sqrt(var + self.eps) + return x + + +class CumulativeLayerNorm(nn.LayerNorm): + """Calculate Cumulative Layer Normalization. + + Arguments + --------- + dim : int + Dimension that you want to normalize. + elementwise_affine : True + Learnable per-element affine parameters. + + Example + ------- + >>> x = torch.randn(5, 10, 20) + >>> CLN = CumulativeLayerNorm(10) + >>> x_norm = CLN(x) + """ + + def __init__(self, dim, elementwise_affine=True): + super(CumulativeLayerNorm, self).__init__( + dim, elementwise_affine=elementwise_affine, eps=1e-8 + ) + + def forward(self, x): + """Returns the normalized tensor. + + Arguments + --------- + x : torch.Tensor + Tensor size [N, C, K, S] or [N, C, L] + """ + # x: N x C x K x S or N x C x L + # N x K x S x C + if x.dim() == 4: + x = x.permute(0, 2, 3, 1).contiguous() + # N x K x S x C == only channel norm + x = super().forward(x) + # N x C x K x S + x = x.permute(0, 3, 1, 2).contiguous() + if x.dim() == 3: + x = torch.transpose(x, 1, 2) + # N x L x C == only channel norm + x = super().forward(x) + # N x C x L + x = torch.transpose(x, 1, 2) + return x + + +class ScaleNorm(nn.Module): + def __init__(self, dim, eps = 1e-5): + super().__init__() + self.scale = dim ** -0.5 + self.eps = eps + self.g = nn.Parameter(torch.ones(1)) + + def forward(self, x): + norm = torch.norm(x, dim = -1, keepdim = True) * self.scale + return x / norm.clamp(min = self.eps) * self.g + diff --git a/funasr/modules/mossformer.py b/funasr/modules/mossformer.py new file mode 100644 index 000000000..f1e8e2861 --- /dev/null +++ b/funasr/modules/mossformer.py @@ -0,0 +1,307 @@ +import torch +import torch.nn.functional as F +from torch import nn, einsum +from einops import rearrange + + +def identity(t, *args, **kwargs): + return t + +def append_dims(x, num_dims): + if num_dims <= 0: + return x + return x.view(*x.shape, *((1,) * num_dims)) + +def exists(val): + return val is not None + +def default(val, d): + return val if exists(val) else d + +def padding_to_multiple_of(n, mult): + remainder = n % mult + if remainder == 0: + return 0 + return mult - remainder + + +class Transpose(nn.Module): + """ Wrapper class of torch.transpose() for Sequential module. """ + def __init__(self, shape: tuple): + super(Transpose, self).__init__() + self.shape = shape + + def forward(self, x): + return x.transpose(*self.shape) + + +class DepthwiseConv1d(nn.Module): + """ + When groups == in_channels and out_channels == K * in_channels, where K is a positive integer, + this operation is termed in literature as depthwise convolution. + Args: + in_channels (int): Number of channels in the input + out_channels (int): Number of channels produced by the convolution + kernel_size (int or tuple): Size of the convolving kernel + stride (int, optional): Stride of the convolution. Default: 1 + padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0 + bias (bool, optional): If True, adds a learnable bias to the output. Default: True + Inputs: inputs + - **inputs** (batch, in_channels, time): Tensor containing input vector + Returns: outputs + - **outputs** (batch, out_channels, time): Tensor produces by depthwise 1-D convolution. + """ + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int = 1, + padding: int = 0, + bias: bool = False, + ) -> None: + super(DepthwiseConv1d, self).__init__() + assert out_channels % in_channels == 0, "out_channels should be constant multiple of in_channels" + self.conv = nn.Conv1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + groups=in_channels, + stride=stride, + padding=padding, + bias=bias, + ) + + def forward(self, inputs): + return self.conv(inputs) + + +class ConvModule(nn.Module): + """ + Conformer convolution module starts with a pointwise convolution and a gated linear unit (GLU). + This is followed by a single 1-D depthwise convolution layer. Batchnorm is deployed just after the convolution + to aid training deep models. + Args: + in_channels (int): Number of channels in the input + kernel_size (int or tuple, optional): Size of the convolving kernel Default: 31 + dropout_p (float, optional): probability of dropout + Inputs: inputs + inputs (batch, time, dim): Tensor contains input sequences + Outputs: outputs + outputs (batch, time, dim): Tensor produces by conformer convolution module. + """ + def __init__( + self, + in_channels: int, + kernel_size: int = 17, + expansion_factor: int = 2, + dropout_p: float = 0.1, + ) -> None: + super(ConvModule, self).__init__() + assert (kernel_size - 1) % 2 == 0, "kernel_size should be a odd number for 'SAME' padding" + assert expansion_factor == 2, "Currently, Only Supports expansion_factor 2" + + self.sequential = nn.Sequential( + Transpose(shape=(1, 2)), + DepthwiseConv1d(in_channels, in_channels, kernel_size, stride=1, padding=(kernel_size - 1) // 2), + ) + + def forward(self, inputs): + return inputs + self.sequential(inputs).transpose(1, 2) + + +class OffsetScale(nn.Module): + def __init__(self, dim, heads = 1): + super().__init__() + self.gamma = nn.Parameter(torch.ones(heads, dim)) + self.beta = nn.Parameter(torch.zeros(heads, dim)) + nn.init.normal_(self.gamma, std = 0.02) + + def forward(self, x): + out = einsum('... d, h d -> ... h d', x, self.gamma) + self.beta + return out.unbind(dim = -2) + + +class FFConvM(nn.Module): + def __init__( + self, + dim_in, + dim_out, + norm_klass = nn.LayerNorm, + dropout = 0.1 + ): + super().__init__() + self.mdl = nn.Sequential( + norm_klass(dim_in), + nn.Linear(dim_in, dim_out), + nn.SiLU(), + ConvModule(dim_out), + nn.Dropout(dropout) + ) + def forward( + self, + x, + ): + output = self.mdl(x) + return output + + +class FLASH_ShareA_FFConvM(nn.Module): + def __init__( + self, + *, + dim, + group_size = 256, + query_key_dim = 128, + expansion_factor = 1., + causal = False, + dropout = 0.1, + rotary_pos_emb = None, + norm_klass = nn.LayerNorm, + shift_tokens = True + ): + super().__init__() + hidden_dim = int(dim * expansion_factor) + self.group_size = group_size + self.causal = causal + self.shift_tokens = shift_tokens + + # positional embeddings + self.rotary_pos_emb = rotary_pos_emb + # norm + self.dropout = nn.Dropout(dropout) + # projections + + self.to_hidden = FFConvM( + dim_in = dim, + dim_out = hidden_dim, + norm_klass = norm_klass, + dropout = dropout, + ) + self.to_qk = FFConvM( + dim_in = dim, + dim_out = query_key_dim, + norm_klass = norm_klass, + dropout = dropout, + ) + + self.qk_offset_scale = OffsetScale(query_key_dim, heads = 4) + + self.to_out = FFConvM( + dim_in = dim*2, + dim_out = dim, + norm_klass = norm_klass, + dropout = dropout, + ) + + self.gateActivate=nn.Sigmoid() + + def forward( + self, + x, + *, + mask = None + ): + + """ + b - batch + n - sequence length (within groups) + g - group dimension + d - feature dimension (keys) + e - feature dimension (values) + i - sequence dimension (source) + j - sequence dimension (target) + """ + + normed_x = x + + # do token shift - a great, costless trick from an independent AI researcher in Shenzhen + residual = x + + if self.shift_tokens: + x_shift, x_pass = normed_x.chunk(2, dim = -1) + x_shift = F.pad(x_shift, (0, 0, 1, -1), value = 0.) + normed_x = torch.cat((x_shift, x_pass), dim = -1) + + # initial projections + + v, u = self.to_hidden(normed_x).chunk(2, dim = -1) + qk = self.to_qk(normed_x) + + # offset and scale + quad_q, lin_q, quad_k, lin_k = self.qk_offset_scale(qk) + att_v, att_u = self.cal_attention(x, quad_q, lin_q, quad_k, lin_k, v, u) + out = (att_u*v ) * self.gateActivate(att_v*u) + x = x + self.to_out(out) + return x + + def cal_attention(self, x, quad_q, lin_q, quad_k, lin_k, v, u, mask = None): + b, n, device, g = x.shape[0], x.shape[-2], x.device, self.group_size + + if exists(mask): + lin_mask = rearrange(mask, '... -> ... 1') + lin_k = lin_k.masked_fill(~lin_mask, 0.) + + # rotate queries and keys + + if exists(self.rotary_pos_emb): + quad_q, lin_q, quad_k, lin_k = map(self.rotary_pos_emb.rotate_queries_or_keys, (quad_q, lin_q, quad_k, lin_k)) + + # padding for groups + + padding = padding_to_multiple_of(n, g) + + if padding > 0: + quad_q, quad_k, lin_q, lin_k, v, u = map(lambda t: F.pad(t, (0, 0, 0, padding), value = 0.), (quad_q, quad_k, lin_q, lin_k, v, u)) + + mask = default(mask, torch.ones((b, n), device = device, dtype = torch.bool)) + mask = F.pad(mask, (0, padding), value = False) + + # group along sequence + + quad_q, quad_k, lin_q, lin_k, v, u = map(lambda t: rearrange(t, 'b (g n) d -> b g n d', n = self.group_size), (quad_q, quad_k, lin_q, lin_k, v, u)) + + if exists(mask): + mask = rearrange(mask, 'b (g j) -> b g 1 j', j = g) + + # calculate quadratic attention output + + sim = einsum('... i d, ... j d -> ... i j', quad_q, quad_k) / g + + attn = F.relu(sim) ** 2 + attn = self.dropout(attn) + + if exists(mask): + attn = attn.masked_fill(~mask, 0.) + + if self.causal: + causal_mask = torch.ones((g, g), dtype = torch.bool, device = device).triu(1) + attn = attn.masked_fill(causal_mask, 0.) + + quad_out_v = einsum('... i j, ... j d -> ... i d', attn, v) + quad_out_u = einsum('... i j, ... j d -> ... i d', attn, u) + + # calculate linear attention output + + if self.causal: + lin_kv = einsum('b g n d, b g n e -> b g d e', lin_k, v) / g + # exclusive cumulative sum along group dimension + lin_kv = lin_kv.cumsum(dim = 1) + lin_kv = F.pad(lin_kv, (0, 0, 0, 0, 1, -1), value = 0.) + lin_out_v = einsum('b g d e, b g n d -> b g n e', lin_kv, lin_q) + + lin_ku = einsum('b g n d, b g n e -> b g d e', lin_k, u) / g + # exclusive cumulative sum along group dimension + lin_ku = lin_ku.cumsum(dim = 1) + lin_ku = F.pad(lin_ku, (0, 0, 0, 0, 1, -1), value = 0.) + lin_out_u = einsum('b g d e, b g n d -> b g n e', lin_ku, lin_q) + else: + lin_kv = einsum('b g n d, b g n e -> b d e', lin_k, v) / n + lin_out_v = einsum('b g n d, b d e -> b g n e', lin_q, lin_kv) + + lin_ku = einsum('b g n d, b g n e -> b d e', lin_k, u) / n + lin_out_u = einsum('b g n d, b d e -> b g n e', lin_q, lin_ku) + + # fold back groups into full sequence, and excise out padding + return map(lambda t: rearrange(t, 'b g n d -> b (g n) d')[:, :n], (quad_out_v+lin_out_v, quad_out_u+lin_out_u)) +