mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
add mossformer code
This commit is contained in:
parent
30f0c7ff29
commit
bce7248763
127
funasr/bin/ss_infer.py
Normal file
127
funasr/bin/ss_infer.py
Normal file
@ -0,0 +1,127 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- encoding: utf-8 -*-
|
||||
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
|
||||
# MIT License (https://opensource.org/licenses/MIT)
|
||||
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
from typing import Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from funasr.build_utils.build_model_from_file import build_model_from_file
|
||||
from funasr.torch_utils.device_funcs import to_device
|
||||
|
||||
|
||||
class SpeechSeparator:
|
||||
"""SpeechSeparator class
|
||||
|
||||
Examples:
|
||||
>>> import soundfile
|
||||
>>> speech_separator = MossFormer("ss_config.yml", "ss.pt")
|
||||
>>> audio, rate = soundfile.read("speech.wav")
|
||||
>>> separated_wavs = speech_separator(audio)
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ss_infer_config: Union[Path, str] = None,
|
||||
ss_model_file: Union[Path, str] = None,
|
||||
device: str = "cpu",
|
||||
batch_size: int = 1,
|
||||
dtype: str = "float32",
|
||||
**kwargs,
|
||||
):
|
||||
|
||||
# 1. Build ss model
|
||||
ss_model, ss_infer_args = build_model_from_file(
|
||||
ss_infer_config, ss_model_file, None, device, task_name="ss"
|
||||
)
|
||||
|
||||
logging.info("ss_model: {}".format(ss_model))
|
||||
logging.info("ss_infer_args: {}".format(ss_infer_args))
|
||||
|
||||
ss_model.to(dtype=getattr(torch, dtype)).eval()
|
||||
|
||||
self.ss_model = ss_model
|
||||
self.ss_infer_args = ss_infer_args
|
||||
self.device = device
|
||||
self.dtype = dtype
|
||||
self.batch_size = batch_size
|
||||
|
||||
def decode(self, model, args, inputs, nsamples):
|
||||
decode_do_segment = False
|
||||
with torch.no_grad():
|
||||
out = []
|
||||
window = args.sample_rate * args.decode_window # decoding window length
|
||||
stride = int(window*0.75) # decoding stride if segmentation is used
|
||||
b, t = inputs.shape
|
||||
if t > window * args.one_time_decode_length:
|
||||
decode_do_segment = True # set segment decoding to true for very long sequence
|
||||
|
||||
if t < window:
|
||||
inputs = np.concatenate([inputs, np.zeros((inputs.shape[0], window-t))], 1)
|
||||
elif t < window + stride:
|
||||
padding = window + stride - t
|
||||
inputs = np.concatenate([inputs, np.zeros((inputs.shape[0], padding))], 1)
|
||||
else:
|
||||
if (t - window) % stride != 0:
|
||||
padding = t - (t-window)//stride * stride
|
||||
inputs = np.concatenate([inputs, np.zeros((inputs.shape[0], padding))], 1)
|
||||
inputs = torch.from_numpy(np.float32(inputs))
|
||||
inputs = to_device(inputs, device=self.device)
|
||||
b, t = inputs.shape
|
||||
if decode_do_segment:
|
||||
outputs = np.zeros((args.num_spks, t))
|
||||
give_up_length = (window - stride)//2
|
||||
current_idx = 0
|
||||
while current_idx + window <= t:
|
||||
tmp_input = inputs[:, current_idx:current_idx+window]
|
||||
tmp_out_list = model(tmp_input,)
|
||||
for spk in range(args.num_spks):
|
||||
tmp_out_list[spk] = tmp_out_list[spk][0, :].cpu().numpy()
|
||||
if current_idx == 0:
|
||||
outputs[spk, current_idx:current_idx+window-give_up_length] = \
|
||||
tmp_out_list[spk][:-give_up_length]
|
||||
else:
|
||||
outputs[spk, current_idx+give_up_length:current_idx+window-give_up_length] = \
|
||||
tmp_out_list[spk][give_up_length:-give_up_length]
|
||||
current_idx += stride
|
||||
for spk in range(args.num_spks):
|
||||
out.append(outputs[spk, :])
|
||||
else:
|
||||
out_list = model(inputs)
|
||||
for spk in range(args.num_spks):
|
||||
out.append(out_list[spk][0, :].cpu().numpy())
|
||||
|
||||
max_abs = 0
|
||||
for spk in range(args.num_spks):
|
||||
if max_abs < max(abs(out[spk])):
|
||||
max_abs = max(abs(out[spk]))
|
||||
for spk in range(args.num_spks):
|
||||
out[spk] = out[spk][:nsamples]
|
||||
out[spk] = out[spk]/max_abs
|
||||
|
||||
return out
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None,
|
||||
) -> List[torch.Tensor]:
|
||||
"""Inference
|
||||
|
||||
Args:
|
||||
speech: Input speech data
|
||||
Returns:
|
||||
speech list: list of speech data
|
||||
|
||||
"""
|
||||
|
||||
out = self.decode(self.ss_model, self.ss_infer_args, speech, speech_lengths)
|
||||
|
||||
return out
|
||||
|
||||
253
funasr/bin/ss_inference_launch.py
Normal file
253
funasr/bin/ss_inference_launch.py
Normal file
@ -0,0 +1,253 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- encoding: utf-8 -*-
|
||||
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
|
||||
# MIT License (https://opensource.org/licenses/MIT)
|
||||
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from typing import Optional
|
||||
from typing import Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import soundfile as sf
|
||||
from funasr.build_utils.build_streaming_iterator import build_streaming_iterator
|
||||
from funasr.torch_utils.set_all_random_seed import set_all_random_seed
|
||||
from funasr.utils import config_argparse
|
||||
from funasr.utils.cli_utils import get_commandline_args
|
||||
from funasr.utils.types import str2triple_str
|
||||
from funasr.bin.ss_infer import SpeechSeparator
|
||||
|
||||
|
||||
def inference_ss(
|
||||
batch_size: int,
|
||||
ngpu: int,
|
||||
log_level: Union[int, str],
|
||||
ss_infer_config: Optional[str],
|
||||
ss_model_file: Optional[str],
|
||||
output_dir: Optional[str] = None,
|
||||
dtype: str = "float32",
|
||||
seed: int = 0,
|
||||
num_workers: int = 1,
|
||||
num_spks: int = 2,
|
||||
sample_rate: int = 8000,
|
||||
param_dict: dict = None,
|
||||
**kwargs,
|
||||
):
|
||||
ncpu = kwargs.get("ncpu", 1)
|
||||
torch.set_num_threads(ncpu)
|
||||
if batch_size > 1:
|
||||
raise NotImplementedError("batch decoding is not implemented")
|
||||
logging.basicConfig(
|
||||
level=log_level,
|
||||
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
|
||||
)
|
||||
|
||||
if ngpu >= 1 and torch.cuda.is_available():
|
||||
device = "cuda"
|
||||
else:
|
||||
device = "cpu"
|
||||
batch_size = 1
|
||||
|
||||
# 1. Set random-seed
|
||||
set_all_random_seed(seed)
|
||||
|
||||
# 2. Build speech separator
|
||||
speech_separator_kwargs = dict(
|
||||
ss_infer_config=ss_infer_config,
|
||||
ss_model_file=ss_model_file,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
logging.info("speech_separator_kwargs: {}".format(speech_separator_kwargs))
|
||||
speech_separator = SpeechSeparator(**speech_separator_kwargs)
|
||||
|
||||
def _forward(
|
||||
data_path_and_name_and_type,
|
||||
raw_inputs: Union[np.ndarray, torch.Tensor] = None,
|
||||
output_dir_v2: Optional[str] = None,
|
||||
fs: dict = None,
|
||||
param_dict: dict = None
|
||||
):
|
||||
# 3. Build data-iterator
|
||||
if data_path_and_name_and_type is None and raw_inputs is not None:
|
||||
if isinstance(raw_inputs, torch.Tensor):
|
||||
raw_inputs = raw_inputs.numpy()
|
||||
data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
|
||||
loader = build_streaming_iterator(
|
||||
task_name="ss",
|
||||
preprocess_args=None,
|
||||
data_path_and_name_and_type=data_path_and_name_and_type,
|
||||
dtype=dtype,
|
||||
fs=fs,
|
||||
batch_size=batch_size,
|
||||
num_workers=num_workers,
|
||||
)
|
||||
|
||||
# 4 .Start for-loop
|
||||
output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
|
||||
if not os.path.exists(output_path):
|
||||
cmd = 'mkdir -p ' + output_path
|
||||
os.system(cmd)
|
||||
|
||||
for keys, batch in loader:
|
||||
assert isinstance(batch, dict), type(batch)
|
||||
assert all(isinstance(s, str) for s in keys), keys
|
||||
_bs = len(next(iter(batch.values())))
|
||||
assert len(keys) == _bs, f"{len(keys)} != {_bs}"
|
||||
|
||||
# do speech separation
|
||||
logging.info('decoding: {}'.format(keys[0]))
|
||||
ss_results = speech_separator(**batch)
|
||||
|
||||
for spk in range(num_spks):
|
||||
sf.write(os.path.join(output_path, keys[0].replace('.wav', '_s'+str(spk+1)+'.wav')), ss_results[spk], sample_rate)
|
||||
torch.cuda.empty_cache()
|
||||
return ss_results
|
||||
|
||||
return _forward
|
||||
|
||||
|
||||
def inference_launch(mode, **kwargs):
|
||||
if mode == "mossformer":
|
||||
return inference_ss(**kwargs)
|
||||
else:
|
||||
logging.info("Unknown decoding mode: {}".format(mode))
|
||||
return None
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = config_argparse.ArgumentParser(
|
||||
description="Speech Separator Decoding",
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
||||
)
|
||||
|
||||
# Note(kamo): Use '_' instead of '-' as separator.
|
||||
# '-' is confusing if written in yaml.
|
||||
parser.add_argument(
|
||||
"--log_level",
|
||||
type=lambda x: x.upper(),
|
||||
default="INFO",
|
||||
choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
|
||||
help="The verbose level of logging",
|
||||
)
|
||||
|
||||
parser.add_argument("--output_dir", type=str, required=True)
|
||||
parser.add_argument(
|
||||
"--ngpu",
|
||||
type=int,
|
||||
default=1,
|
||||
help="The number of gpus. 0 indicates CPU mode",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--njob",
|
||||
type=int,
|
||||
default=1,
|
||||
help="The number of jobs for each gpu",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gpuid_list",
|
||||
type=str,
|
||||
default="2",
|
||||
help="The visible gpus",
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=0, help="Random seed")
|
||||
parser.add_argument(
|
||||
"--dtype",
|
||||
default="float32",
|
||||
choices=["float16", "float32", "float64"],
|
||||
help="Data type",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_workers",
|
||||
type=int,
|
||||
default=1,
|
||||
help="The number of workers used for DataLoader",
|
||||
)
|
||||
|
||||
group = parser.add_argument_group("Input data related")
|
||||
group.add_argument(
|
||||
"--data_path_and_name_and_type",
|
||||
type=str2triple_str,
|
||||
required=True,
|
||||
action="append",
|
||||
)
|
||||
|
||||
group = parser.add_argument_group("The model configuration related")
|
||||
group.add_argument(
|
||||
"--ss_infer_config",
|
||||
type=str,
|
||||
help="SS infer configuration",
|
||||
)
|
||||
group.add_argument(
|
||||
"--ss_model_file",
|
||||
type=str,
|
||||
help="SS model parameter file",
|
||||
)
|
||||
group.add_argument(
|
||||
"--ss_train_config",
|
||||
type=str,
|
||||
help="SS training configuration",
|
||||
)
|
||||
|
||||
group = parser.add_argument_group("The inference configuration related")
|
||||
group.add_argument(
|
||||
"--batch_size",
|
||||
type=int,
|
||||
default=1,
|
||||
help="The batch size for inference",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--num-spks', dest='num_spks', type=int, default=2)
|
||||
|
||||
parser.add_argument(
|
||||
'--one-time-decode-length', dest='one_time_decode_length', type=int,
|
||||
default=60, help='the max length (second) for one-time decoding')
|
||||
|
||||
parser.add_argument(
|
||||
'--decode-window', dest='decode_window', type=int,
|
||||
default=1, help='segmental decoding window length (second)')
|
||||
|
||||
parser.add_argument(
|
||||
'--sample-rate', dest='sample_rate', type=int, default='8000')
|
||||
return parser
|
||||
|
||||
|
||||
def main(cmd=None):
|
||||
print(get_commandline_args(), file=sys.stderr)
|
||||
parser = get_parser()
|
||||
parser.add_argument(
|
||||
"--mode",
|
||||
type=str,
|
||||
default="mossformer",
|
||||
help="The decoding mode",
|
||||
)
|
||||
args = parser.parse_args(cmd)
|
||||
kwargs = vars(args)
|
||||
kwargs.pop("config", None)
|
||||
|
||||
# set logging messages
|
||||
logging.basicConfig(
|
||||
level=args.log_level,
|
||||
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
|
||||
)
|
||||
logging.info("Decoding args: {}".format(kwargs))
|
||||
|
||||
# gpu setting
|
||||
if args.ngpu > 0:
|
||||
jobid = int(args.output_dir.split(".")[-1])
|
||||
gpuid = args.gpuid_list.split(",")[(jobid - 1) // args.njob]
|
||||
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = gpuid
|
||||
|
||||
inference_pipeline = inference_launch(**kwargs)
|
||||
return inference_pipeline(kwargs["data_path_and_name_and_type"])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@ -5,6 +5,7 @@ from funasr.build_utils.build_pretrain_model import build_pretrain_model
|
||||
from funasr.build_utils.build_punc_model import build_punc_model
|
||||
from funasr.build_utils.build_sv_model import build_sv_model
|
||||
from funasr.build_utils.build_vad_model import build_vad_model
|
||||
from funasr.build_utils.build_ss_model import build_ss_model
|
||||
|
||||
|
||||
def build_model(args):
|
||||
@ -22,6 +23,8 @@ def build_model(args):
|
||||
model = build_diar_model(args)
|
||||
elif args.task_name == "sv":
|
||||
model = build_sv_model(args)
|
||||
elif args.task_name == "ss":
|
||||
model = build_ss_model(args)
|
||||
else:
|
||||
raise NotImplementedError("Not supported task: {}".format(args.task_name))
|
||||
|
||||
|
||||
@ -11,6 +11,18 @@ from funasr.build_utils.build_model import build_model
|
||||
from funasr.models.base_model import FunASRModel
|
||||
|
||||
|
||||
def load_checkpoint(checkpoint_path, use_cuda=1):
|
||||
if use_cuda:
|
||||
checkpoint = torch.load(checkpoint_path)
|
||||
else:
|
||||
checkpoint = torch.load(
|
||||
checkpoint_path, map_location=lambda storage, loc: storage)
|
||||
return checkpoint
|
||||
|
||||
def reload_ss_for_eval(model, checkpoint_path, use_cuda=False):
|
||||
checkpoint = load_checkpoint(checkpoint_path, use_cuda)
|
||||
model.load_state_dict(checkpoint['model'], strict=False)
|
||||
|
||||
def build_model_from_file(
|
||||
config_file: Union[Path, str] = None,
|
||||
model_file: Union[Path, str] = None,
|
||||
@ -70,6 +82,9 @@ def build_model_from_file(
|
||||
model.load_state_dict(model_dict)
|
||||
else:
|
||||
model_dict = torch.load(model_file, map_location=device)
|
||||
if task_name == 'ss':
|
||||
reload_ss_for_eval(model, model_file, use_cuda=True)
|
||||
logging.info("model is loaded from path: {}".format(model_file))
|
||||
if task_name == "diar" and mode == "sond":
|
||||
model_dict = fileter_model_dict(model_dict, model.state_dict())
|
||||
if task_name == "vad":
|
||||
|
||||
15
funasr/build_utils/build_ss_model.py
Normal file
15
funasr/build_utils/build_ss_model.py
Normal file
@ -0,0 +1,15 @@
|
||||
from funasr.models.e2e_ss import MossFormer
|
||||
|
||||
def build_ss_model(args):
|
||||
model = MossFormer(
|
||||
in_channels=args.encoder_embedding_dim,
|
||||
out_channels=args.mossformer_sequence_dim,
|
||||
num_blocks=args.num_mossformer_layer,
|
||||
kernel_size=args.encoder_kernel_size,
|
||||
norm=args.norm,
|
||||
num_spks=args.num_spks,
|
||||
skip_around_intra=args.skip_around_intra,
|
||||
use_global_pos_enc=args.use_global_pos_enc,
|
||||
max_length=args.max_length)
|
||||
|
||||
return model
|
||||
53
funasr/models/decoder/mossformer_decoder.py
Normal file
53
funasr/models/decoder/mossformer_decoder.py
Normal file
@ -0,0 +1,53 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class MossFormerDecoder(nn.ConvTranspose1d):
|
||||
"""A decoder layer that consists of ConvTranspose1d.
|
||||
|
||||
Arguments
|
||||
---------
|
||||
kernel_size : int
|
||||
Length of filters.
|
||||
in_channels : int
|
||||
Number of input channels.
|
||||
out_channels : int
|
||||
Number of output channels.
|
||||
|
||||
|
||||
Example
|
||||
---------
|
||||
>>> x = torch.randn(2, 100, 1000)
|
||||
>>> decoder = Decoder(kernel_size=4, in_channels=100, out_channels=1)
|
||||
>>> h = decoder(x)
|
||||
>>> h.shape
|
||||
torch.Size([2, 1003])
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(MossFormerDecoder, self).__init__(*args, **kwargs)
|
||||
|
||||
def forward(self, x):
|
||||
"""Return the decoded output.
|
||||
|
||||
Arguments
|
||||
---------
|
||||
x : torch.Tensor
|
||||
Input tensor with dimensionality [B, N, L].
|
||||
where, B = Batchsize,
|
||||
N = number of filters
|
||||
L = time points
|
||||
"""
|
||||
|
||||
if x.dim() not in [2, 3]:
|
||||
raise RuntimeError(
|
||||
"{} accept 3/4D tensor as input".format(self.__name__)
|
||||
)
|
||||
x = super().forward(x if x.dim() == 3 else torch.unsqueeze(x, 1))
|
||||
|
||||
if torch.squeeze(x).dim() == 1:
|
||||
x = torch.squeeze(x, dim=1)
|
||||
else:
|
||||
x = torch.squeeze(x)
|
||||
return x
|
||||
|
||||
95
funasr/models/e2e_ss.py
Normal file
95
funasr/models/e2e_ss.py
Normal file
@ -0,0 +1,95 @@
|
||||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import copy
|
||||
from funasr.models.base_model import FunASRModel
|
||||
from funasr.models.encoder.mossformer_encoder import MossFormerEncoder, MossFormer_MaskNet
|
||||
from funasr.models.decoder.mossformer_decoder import MossFormerDecoder
|
||||
|
||||
|
||||
class MossFormer(FunASRModel):
|
||||
"""The MossFormer model for separating input mixed speech into different speaker's speech.
|
||||
|
||||
Arguments
|
||||
---------
|
||||
in_channels : int
|
||||
Number of channels at the output of the encoder.
|
||||
out_channels : int
|
||||
Number of channels that would be inputted to the intra and inter blocks.
|
||||
num_blocks : int
|
||||
Number of layers of Dual Computation Block.
|
||||
norm : str
|
||||
Normalization type.
|
||||
num_spks : int
|
||||
Number of sources (speakers).
|
||||
skip_around_intra : bool
|
||||
Skip connection around intra.
|
||||
use_global_pos_enc : bool
|
||||
Global positional encodings.
|
||||
max_length : int
|
||||
Maximum sequence length.
|
||||
kernel_size: int
|
||||
Encoder and decoder kernel size
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels=512,
|
||||
out_channels=512,
|
||||
num_blocks=24,
|
||||
kernel_size=16,
|
||||
norm="ln",
|
||||
num_spks=2,
|
||||
skip_around_intra=True,
|
||||
use_global_pos_enc=True,
|
||||
max_length=20000,
|
||||
):
|
||||
super(MossFormer, self).__init__()
|
||||
self.num_spks = num_spks
|
||||
# Encoding
|
||||
self.enc = MossFormerEncoder(kernel_size=kernel_size, out_channels=in_channels, in_channels=1)
|
||||
|
||||
##Compute Mask
|
||||
self.mask_net = MossFormer_MaskNet(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
num_blocks=num_blocks,
|
||||
norm=norm,
|
||||
num_spks=num_spks,
|
||||
skip_around_intra=skip_around_intra,
|
||||
use_global_pos_enc=use_global_pos_enc,
|
||||
max_length=max_length,
|
||||
)
|
||||
self.dec = MossFormerDecoder(
|
||||
in_channels=out_channels,
|
||||
out_channels=1,
|
||||
kernel_size=kernel_size,
|
||||
stride = kernel_size//2,
|
||||
bias=False
|
||||
)
|
||||
def forward(self, input):
|
||||
x = self.enc(input)
|
||||
mask = self.mask_net(x)
|
||||
x = torch.stack([x] * self.num_spks)
|
||||
sep_x = x * mask
|
||||
|
||||
# Decoding
|
||||
est_source = torch.cat(
|
||||
[
|
||||
self.dec(sep_x[i]).unsqueeze(-1)
|
||||
for i in range(self.num_spks)
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
T_origin = input.size(1)
|
||||
T_est = est_source.size(1)
|
||||
if T_origin > T_est:
|
||||
est_source = F.pad(est_source, (0, 0, 0, T_origin - T_est))
|
||||
else:
|
||||
est_source = est_source[:, :T_origin, :]
|
||||
|
||||
out = []
|
||||
for spk in range(self.num_spks):
|
||||
out.append(est_source[:,:,spk])
|
||||
return out
|
||||
417
funasr/models/encoder/mossformer_encoder.py
Normal file
417
funasr/models/encoder/mossformer_encoder.py
Normal file
@ -0,0 +1,417 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from rotary_embedding_torch import RotaryEmbedding
|
||||
from funasr.modules.layer_norm import GlobalLayerNorm, CumulativeLayerNorm, ScaleNorm
|
||||
from funasr.modules.embedding import ScaledSinuEmbedding
|
||||
from funasr.modules.mossformer import FLASH_ShareA_FFConvM
|
||||
|
||||
|
||||
def select_norm(norm, dim, shape):
|
||||
"""Just a wrapper to select the normalization type.
|
||||
"""
|
||||
|
||||
if norm == "gln":
|
||||
return GlobalLayerNorm(dim, shape, elementwise_affine=True)
|
||||
if norm == "cln":
|
||||
return CumulativeLayerNorm(dim, elementwise_affine=True)
|
||||
if norm == "ln":
|
||||
return nn.GroupNorm(1, dim, eps=1e-8)
|
||||
else:
|
||||
return nn.BatchNorm1d(dim)
|
||||
|
||||
|
||||
class MossformerBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
dim,
|
||||
depth,
|
||||
group_size = 256,
|
||||
query_key_dim = 128,
|
||||
expansion_factor = 4.,
|
||||
causal = False,
|
||||
attn_dropout = 0.1,
|
||||
norm_type = 'scalenorm',
|
||||
shift_tokens = True
|
||||
):
|
||||
super().__init__()
|
||||
assert norm_type in ('scalenorm', 'layernorm'), 'norm_type must be one of scalenorm or layernorm'
|
||||
|
||||
if norm_type == 'scalenorm':
|
||||
norm_klass = ScaleNorm
|
||||
elif norm_type == 'layernorm':
|
||||
norm_klass = nn.LayerNorm
|
||||
|
||||
self.group_size = group_size
|
||||
|
||||
rotary_pos_emb = RotaryEmbedding(dim = min(32, query_key_dim))
|
||||
# max rotary embedding dimensions of 32, partial Rotary embeddings, from Wang et al - GPT-J
|
||||
self.layers = nn.ModuleList([FLASH_ShareA_FFConvM(dim = dim, group_size = group_size, query_key_dim = query_key_dim, expansion_factor = expansion_factor, causal = causal, dropout = attn_dropout, rotary_pos_emb = rotary_pos_emb, norm_klass = norm_klass, shift_tokens = shift_tokens) for _ in range(depth)])
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
*,
|
||||
mask = None
|
||||
):
|
||||
ii = 0
|
||||
for flash in self.layers:
|
||||
x = flash(x, mask = mask)
|
||||
ii = ii + 1
|
||||
return x
|
||||
|
||||
|
||||
class MossFormer_MaskNet(nn.Module):
|
||||
"""The MossFormer module for computing output masks.
|
||||
|
||||
Arguments
|
||||
---------
|
||||
in_channels : int
|
||||
Number of channels at the output of the encoder.
|
||||
out_channels : int
|
||||
Number of channels that would be inputted to the intra and inter blocks.
|
||||
num_blocks : int
|
||||
Number of layers of Dual Computation Block.
|
||||
norm : str
|
||||
Normalization type.
|
||||
num_spks : int
|
||||
Number of sources (speakers).
|
||||
skip_around_intra : bool
|
||||
Skip connection around intra.
|
||||
use_global_pos_enc : bool
|
||||
Global positional encodings.
|
||||
max_length : int
|
||||
Maximum sequence length.
|
||||
|
||||
Example
|
||||
---------
|
||||
>>> mossformer_block = MossFormerM(1, 64, 8)
|
||||
>>> mossformer_masknet = MossFormer_MaskNet(64, 64, intra_block, num_spks=2)
|
||||
>>> x = torch.randn(10, 64, 2000)
|
||||
>>> x = mossformer_masknet(x)
|
||||
>>> x.shape
|
||||
torch.Size([2, 10, 64, 2000])
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
num_blocks=24,
|
||||
norm="ln",
|
||||
num_spks=2,
|
||||
skip_around_intra=True,
|
||||
use_global_pos_enc=True,
|
||||
max_length=20000,
|
||||
):
|
||||
super(MossFormer_MaskNet, self).__init__()
|
||||
self.num_spks = num_spks
|
||||
self.num_blocks = num_blocks
|
||||
self.norm = select_norm(norm, in_channels, 3)
|
||||
self.conv1d_encoder = nn.Conv1d(in_channels, out_channels, 1, bias=False)
|
||||
self.use_global_pos_enc = use_global_pos_enc
|
||||
|
||||
if self.use_global_pos_enc:
|
||||
self.pos_enc = ScaledSinuEmbedding(out_channels)
|
||||
|
||||
self.mdl = Computation_Block(
|
||||
num_blocks,
|
||||
out_channels,
|
||||
norm,
|
||||
skip_around_intra=skip_around_intra,
|
||||
)
|
||||
|
||||
self.conv1d_out = nn.Conv1d(
|
||||
out_channels, out_channels * num_spks, kernel_size=1
|
||||
)
|
||||
self.conv1_decoder = nn.Conv1d(out_channels, in_channels, 1, bias=False)
|
||||
self.prelu = nn.PReLU()
|
||||
self.activation = nn.ReLU()
|
||||
# gated output layer
|
||||
self.output = nn.Sequential(
|
||||
nn.Conv1d(out_channels, out_channels, 1), nn.Tanh()
|
||||
)
|
||||
self.output_gate = nn.Sequential(
|
||||
nn.Conv1d(out_channels, out_channels, 1), nn.Sigmoid()
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
"""Returns the output tensor.
|
||||
|
||||
Arguments
|
||||
---------
|
||||
x : torch.Tensor
|
||||
Input tensor of dimension [B, N, S].
|
||||
|
||||
Returns
|
||||
-------
|
||||
out : torch.Tensor
|
||||
Output tensor of dimension [spks, B, N, S]
|
||||
where, spks = Number of speakers
|
||||
B = Batchsize,
|
||||
N = number of filters
|
||||
S = the number of time frames
|
||||
"""
|
||||
|
||||
# before each line we indicate the shape after executing the line
|
||||
|
||||
# [B, N, L]
|
||||
x = self.norm(x)
|
||||
|
||||
# [B, N, L]
|
||||
x = self.conv1d_encoder(x)
|
||||
if self.use_global_pos_enc:
|
||||
#x = self.pos_enc(x.transpose(1, -1)).transpose(1, -1) + x * (
|
||||
# x.size(1) ** 0.5)
|
||||
base = x
|
||||
x = x.transpose(1, -1)
|
||||
emb = self.pos_enc(x)
|
||||
emb = emb.transpose(0, -1)
|
||||
#print('base: {}, emb: {}'.format(base.shape, emb.shape))
|
||||
x = base + emb
|
||||
|
||||
|
||||
# [B, N, S]
|
||||
#for i in range(self.num_modules):
|
||||
# x = self.dual_mdl[i](x)
|
||||
x = self.mdl(x)
|
||||
x = self.prelu(x)
|
||||
|
||||
# [B, N*spks, S]
|
||||
x = self.conv1d_out(x)
|
||||
B, _, S = x.shape
|
||||
|
||||
# [B*spks, N, S]
|
||||
x = x.view(B * self.num_spks, -1, S)
|
||||
|
||||
# [B*spks, N, S]
|
||||
x = self.output(x) * self.output_gate(x)
|
||||
|
||||
# [B*spks, N, S]
|
||||
x = self.conv1_decoder(x)
|
||||
|
||||
# [B, spks, N, S]
|
||||
_, N, L = x.shape
|
||||
x = x.view(B, self.num_spks, N, L)
|
||||
x = self.activation(x)
|
||||
|
||||
# [spks, B, N, S]
|
||||
x = x.transpose(0, 1)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class MossFormerEncoder(nn.Module):
|
||||
"""Convolutional Encoder Layer.
|
||||
|
||||
Arguments
|
||||
---------
|
||||
kernel_size : int
|
||||
Length of filters.
|
||||
in_channels : int
|
||||
Number of input channels.
|
||||
out_channels : int
|
||||
Number of output channels.
|
||||
|
||||
Example
|
||||
-------
|
||||
>>> x = torch.randn(2, 1000)
|
||||
>>> encoder = Encoder(kernel_size=4, out_channels=64)
|
||||
>>> h = encoder(x)
|
||||
>>> h.shape
|
||||
torch.Size([2, 64, 499])
|
||||
"""
|
||||
|
||||
def __init__(self, kernel_size=2, out_channels=64, in_channels=1):
|
||||
super(MossFormerEncoder, self).__init__()
|
||||
self.conv1d = nn.Conv1d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=kernel_size // 2,
|
||||
groups=1,
|
||||
bias=False,
|
||||
)
|
||||
self.in_channels = in_channels
|
||||
|
||||
def forward(self, x):
|
||||
"""Return the encoded output.
|
||||
|
||||
Arguments
|
||||
---------
|
||||
x : torch.Tensor
|
||||
Input tensor with dimensionality [B, L].
|
||||
Return
|
||||
------
|
||||
x : torch.Tensor
|
||||
Encoded tensor with dimensionality [B, N, T_out].
|
||||
|
||||
where B = Batchsize
|
||||
L = Number of timepoints
|
||||
N = Number of filters
|
||||
T_out = Number of timepoints at the output of the encoder
|
||||
"""
|
||||
# B x L -> B x 1 x L
|
||||
if self.in_channels == 1:
|
||||
x = torch.unsqueeze(x, dim=1)
|
||||
# B x 1 x L -> B x N x T_out
|
||||
x = self.conv1d(x)
|
||||
x = F.relu(x)
|
||||
|
||||
return x
|
||||
|
||||
class MossFormerM(nn.Module):
|
||||
"""This class implements the transformer encoder.
|
||||
|
||||
Arguments
|
||||
---------
|
||||
num_blocks : int
|
||||
Number of mossformer blocks to include.
|
||||
d_model : int
|
||||
The dimension of the input embedding.
|
||||
attn_dropout : float
|
||||
Dropout for the self-attention (Optional).
|
||||
group_size: int
|
||||
the chunk size
|
||||
query_key_dim: int
|
||||
the attention vector dimension
|
||||
expansion_factor: int
|
||||
the expansion factor for the linear projection in conv module
|
||||
causal: bool
|
||||
true for causal / false for non causal
|
||||
|
||||
Example
|
||||
-------
|
||||
>>> import torch
|
||||
>>> x = torch.rand((8, 60, 512))
|
||||
>>> net = TransformerEncoder_MossFormerM(num_blocks=8, d_model=512)
|
||||
>>> output, _ = net(x)
|
||||
>>> output.shape
|
||||
torch.Size([8, 60, 512])
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
num_blocks,
|
||||
d_model=None,
|
||||
causal=False,
|
||||
group_size = 256,
|
||||
query_key_dim = 128,
|
||||
expansion_factor = 4.,
|
||||
attn_dropout = 0.1
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.mossformerM = MossformerBlock(
|
||||
dim=d_model,
|
||||
depth=num_blocks,
|
||||
group_size=group_size,
|
||||
query_key_dim=query_key_dim,
|
||||
expansion_factor=expansion_factor,
|
||||
causal=causal,
|
||||
attn_dropout=attn_dropout
|
||||
)
|
||||
self.norm = nn.LayerNorm(d_model, eps=1e-6)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
src,
|
||||
):
|
||||
"""
|
||||
Arguments
|
||||
----------
|
||||
src : torch.Tensor
|
||||
Tensor shape [B, L, N],
|
||||
where, B = Batchsize,
|
||||
L = time points
|
||||
N = number of filters
|
||||
The sequence to the encoder layer (required).
|
||||
src_mask : tensor
|
||||
The mask for the src sequence (optional).
|
||||
src_key_padding_mask : tensor
|
||||
The mask for the src keys per batch (optional).
|
||||
"""
|
||||
output = self.mossformerM(src)
|
||||
output = self.norm(output)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class Computation_Block(nn.Module):
|
||||
"""Computation block for dual-path processing.
|
||||
|
||||
Arguments
|
||||
---------
|
||||
out_channels : int
|
||||
Dimensionality of inter/intra model.
|
||||
norm : str
|
||||
Normalization type.
|
||||
skip_around_intra : bool
|
||||
Skip connection around the intra layer.
|
||||
|
||||
Example
|
||||
---------
|
||||
>>> comp_block = Computation_Block(64)
|
||||
>>> x = torch.randn(10, 64, 100)
|
||||
>>> x = comp_block(x)
|
||||
>>> x.shape
|
||||
torch.Size([10, 64, 100])
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_blocks,
|
||||
out_channels,
|
||||
norm="ln",
|
||||
skip_around_intra=True,
|
||||
):
|
||||
super(Computation_Block, self).__init__()
|
||||
|
||||
##MossFormer2M: MossFormer with recurrence
|
||||
#self.intra_mdl = MossFormer2M(num_blocks=num_blocks, d_model=out_channels)
|
||||
##MossFormerM: the orignal MossFormer
|
||||
self.intra_mdl = MossFormerM(num_blocks=num_blocks, d_model=out_channels)
|
||||
self.skip_around_intra = skip_around_intra
|
||||
|
||||
# Norm
|
||||
self.norm = norm
|
||||
if norm is not None:
|
||||
self.intra_norm = select_norm(norm, out_channels, 3)
|
||||
|
||||
def forward(self, x):
|
||||
"""Returns the output tensor.
|
||||
|
||||
Arguments
|
||||
---------
|
||||
x : torch.Tensor
|
||||
Input tensor of dimension [B, N, S].
|
||||
|
||||
|
||||
Return
|
||||
---------
|
||||
out: torch.Tensor
|
||||
Output tensor of dimension [B, N, S].
|
||||
where, B = Batchsize,
|
||||
N = number of filters
|
||||
S = sequence time index
|
||||
"""
|
||||
B, N, S = x.shape
|
||||
# intra RNN
|
||||
# [B, S, N]
|
||||
intra = x.permute(0, 2, 1).contiguous() #.view(B, S, N)
|
||||
|
||||
intra = self.intra_mdl(intra)
|
||||
|
||||
# [B, N, S]
|
||||
intra = intra.permute(0, 2, 1).contiguous()
|
||||
if self.norm is not None:
|
||||
intra = self.intra_norm(intra)
|
||||
|
||||
# [B, N, S]
|
||||
if self.skip_around_intra:
|
||||
intra = intra + x
|
||||
|
||||
out = intra
|
||||
return out
|
||||
|
||||
@ -9,6 +9,7 @@
|
||||
import math
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import einsum
|
||||
|
||||
def _pre_hook(
|
||||
state_dict,
|
||||
@ -510,3 +511,19 @@ class StreamingRelPositionalEncoding(torch.nn.Module):
|
||||
pos_enc = self.dropout(pos_enc)
|
||||
|
||||
return pos_enc
|
||||
|
||||
|
||||
class ScaledSinuEmbedding(torch.nn.Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.scale = torch.nn.Parameter(torch.ones(1,))
|
||||
inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim))
|
||||
self.register_buffer('inv_freq', inv_freq)
|
||||
|
||||
def forward(self, x):
|
||||
n, device = x.shape[1], x.device
|
||||
t = torch.arange(n, device = device).type_as(self.inv_freq)
|
||||
sinu = einsum('i , j -> i j', t, self.inv_freq)
|
||||
emb = torch.cat((sinu.sin(), sinu.cos()), dim = -1)
|
||||
return emb * self.scale
|
||||
|
||||
|
||||
@ -7,6 +7,7 @@
|
||||
"""Layer normalization module."""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class LayerNorm(torch.nn.LayerNorm):
|
||||
@ -40,3 +41,137 @@ class LayerNorm(torch.nn.LayerNorm):
|
||||
.forward(x.transpose(self.dim, -1))
|
||||
.transpose(self.dim, -1)
|
||||
)
|
||||
|
||||
|
||||
class GlobalLayerNorm(nn.Module):
|
||||
"""Calculate Global Layer Normalization.
|
||||
|
||||
Arguments
|
||||
---------
|
||||
dim : (int or list or torch.Size)
|
||||
Input shape from an expected input of size.
|
||||
eps : float
|
||||
A value added to the denominator for numerical stability.
|
||||
elementwise_affine : bool
|
||||
A boolean value that when set to True,
|
||||
this module has learnable per-element affine parameters
|
||||
initialized to ones (for weights) and zeros (for biases).
|
||||
|
||||
Example
|
||||
-------
|
||||
>>> x = torch.randn(5, 10, 20)
|
||||
>>> GLN = GlobalLayerNorm(10, 3)
|
||||
>>> x_norm = GLN(x)
|
||||
"""
|
||||
|
||||
def __init__(self, dim, shape, eps=1e-8, elementwise_affine=True):
|
||||
super(GlobalLayerNorm, self).__init__()
|
||||
self.dim = dim
|
||||
self.eps = eps
|
||||
self.elementwise_affine = elementwise_affine
|
||||
|
||||
if self.elementwise_affine:
|
||||
if shape == 3:
|
||||
self.weight = nn.Parameter(torch.ones(self.dim, 1))
|
||||
self.bias = nn.Parameter(torch.zeros(self.dim, 1))
|
||||
if shape == 4:
|
||||
self.weight = nn.Parameter(torch.ones(self.dim, 1, 1))
|
||||
self.bias = nn.Parameter(torch.zeros(self.dim, 1, 1))
|
||||
else:
|
||||
self.register_parameter("weight", None)
|
||||
self.register_parameter("bias", None)
|
||||
|
||||
def forward(self, x):
|
||||
"""Returns the normalized tensor.
|
||||
|
||||
Arguments
|
||||
---------
|
||||
x : torch.Tensor
|
||||
Tensor of size [N, C, K, S] or [N, C, L].
|
||||
"""
|
||||
# x = N x C x K x S or N x C x L
|
||||
# N x 1 x 1
|
||||
# cln: mean,var N x 1 x K x S
|
||||
# gln: mean,var N x 1 x 1
|
||||
if x.dim() == 3:
|
||||
mean = torch.mean(x, (1, 2), keepdim=True)
|
||||
var = torch.mean((x - mean) ** 2, (1, 2), keepdim=True)
|
||||
if self.elementwise_affine:
|
||||
x = (
|
||||
self.weight * (x - mean) / torch.sqrt(var + self.eps)
|
||||
+ self.bias
|
||||
)
|
||||
else:
|
||||
x = (x - mean) / torch.sqrt(var + self.eps)
|
||||
|
||||
if x.dim() == 4:
|
||||
mean = torch.mean(x, (1, 2, 3), keepdim=True)
|
||||
var = torch.mean((x - mean) ** 2, (1, 2, 3), keepdim=True)
|
||||
if self.elementwise_affine:
|
||||
x = (
|
||||
self.weight * (x - mean) / torch.sqrt(var + self.eps)
|
||||
+ self.bias
|
||||
)
|
||||
else:
|
||||
x = (x - mean) / torch.sqrt(var + self.eps)
|
||||
return x
|
||||
|
||||
|
||||
class CumulativeLayerNorm(nn.LayerNorm):
|
||||
"""Calculate Cumulative Layer Normalization.
|
||||
|
||||
Arguments
|
||||
---------
|
||||
dim : int
|
||||
Dimension that you want to normalize.
|
||||
elementwise_affine : True
|
||||
Learnable per-element affine parameters.
|
||||
|
||||
Example
|
||||
-------
|
||||
>>> x = torch.randn(5, 10, 20)
|
||||
>>> CLN = CumulativeLayerNorm(10)
|
||||
>>> x_norm = CLN(x)
|
||||
"""
|
||||
|
||||
def __init__(self, dim, elementwise_affine=True):
|
||||
super(CumulativeLayerNorm, self).__init__(
|
||||
dim, elementwise_affine=elementwise_affine, eps=1e-8
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
"""Returns the normalized tensor.
|
||||
|
||||
Arguments
|
||||
---------
|
||||
x : torch.Tensor
|
||||
Tensor size [N, C, K, S] or [N, C, L]
|
||||
"""
|
||||
# x: N x C x K x S or N x C x L
|
||||
# N x K x S x C
|
||||
if x.dim() == 4:
|
||||
x = x.permute(0, 2, 3, 1).contiguous()
|
||||
# N x K x S x C == only channel norm
|
||||
x = super().forward(x)
|
||||
# N x C x K x S
|
||||
x = x.permute(0, 3, 1, 2).contiguous()
|
||||
if x.dim() == 3:
|
||||
x = torch.transpose(x, 1, 2)
|
||||
# N x L x C == only channel norm
|
||||
x = super().forward(x)
|
||||
# N x C x L
|
||||
x = torch.transpose(x, 1, 2)
|
||||
return x
|
||||
|
||||
|
||||
class ScaleNorm(nn.Module):
|
||||
def __init__(self, dim, eps = 1e-5):
|
||||
super().__init__()
|
||||
self.scale = dim ** -0.5
|
||||
self.eps = eps
|
||||
self.g = nn.Parameter(torch.ones(1))
|
||||
|
||||
def forward(self, x):
|
||||
norm = torch.norm(x, dim = -1, keepdim = True) * self.scale
|
||||
return x / norm.clamp(min = self.eps) * self.g
|
||||
|
||||
|
||||
307
funasr/modules/mossformer.py
Normal file
307
funasr/modules/mossformer.py
Normal file
@ -0,0 +1,307 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn, einsum
|
||||
from einops import rearrange
|
||||
|
||||
|
||||
def identity(t, *args, **kwargs):
|
||||
return t
|
||||
|
||||
def append_dims(x, num_dims):
|
||||
if num_dims <= 0:
|
||||
return x
|
||||
return x.view(*x.shape, *((1,) * num_dims))
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
def default(val, d):
|
||||
return val if exists(val) else d
|
||||
|
||||
def padding_to_multiple_of(n, mult):
|
||||
remainder = n % mult
|
||||
if remainder == 0:
|
||||
return 0
|
||||
return mult - remainder
|
||||
|
||||
|
||||
class Transpose(nn.Module):
|
||||
""" Wrapper class of torch.transpose() for Sequential module. """
|
||||
def __init__(self, shape: tuple):
|
||||
super(Transpose, self).__init__()
|
||||
self.shape = shape
|
||||
|
||||
def forward(self, x):
|
||||
return x.transpose(*self.shape)
|
||||
|
||||
|
||||
class DepthwiseConv1d(nn.Module):
|
||||
"""
|
||||
When groups == in_channels and out_channels == K * in_channels, where K is a positive integer,
|
||||
this operation is termed in literature as depthwise convolution.
|
||||
Args:
|
||||
in_channels (int): Number of channels in the input
|
||||
out_channels (int): Number of channels produced by the convolution
|
||||
kernel_size (int or tuple): Size of the convolving kernel
|
||||
stride (int, optional): Stride of the convolution. Default: 1
|
||||
padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0
|
||||
bias (bool, optional): If True, adds a learnable bias to the output. Default: True
|
||||
Inputs: inputs
|
||||
- **inputs** (batch, in_channels, time): Tensor containing input vector
|
||||
Returns: outputs
|
||||
- **outputs** (batch, out_channels, time): Tensor produces by depthwise 1-D convolution.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: int,
|
||||
stride: int = 1,
|
||||
padding: int = 0,
|
||||
bias: bool = False,
|
||||
) -> None:
|
||||
super(DepthwiseConv1d, self).__init__()
|
||||
assert out_channels % in_channels == 0, "out_channels should be constant multiple of in_channels"
|
||||
self.conv = nn.Conv1d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=kernel_size,
|
||||
groups=in_channels,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
def forward(self, inputs):
|
||||
return self.conv(inputs)
|
||||
|
||||
|
||||
class ConvModule(nn.Module):
|
||||
"""
|
||||
Conformer convolution module starts with a pointwise convolution and a gated linear unit (GLU).
|
||||
This is followed by a single 1-D depthwise convolution layer. Batchnorm is deployed just after the convolution
|
||||
to aid training deep models.
|
||||
Args:
|
||||
in_channels (int): Number of channels in the input
|
||||
kernel_size (int or tuple, optional): Size of the convolving kernel Default: 31
|
||||
dropout_p (float, optional): probability of dropout
|
||||
Inputs: inputs
|
||||
inputs (batch, time, dim): Tensor contains input sequences
|
||||
Outputs: outputs
|
||||
outputs (batch, time, dim): Tensor produces by conformer convolution module.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
kernel_size: int = 17,
|
||||
expansion_factor: int = 2,
|
||||
dropout_p: float = 0.1,
|
||||
) -> None:
|
||||
super(ConvModule, self).__init__()
|
||||
assert (kernel_size - 1) % 2 == 0, "kernel_size should be a odd number for 'SAME' padding"
|
||||
assert expansion_factor == 2, "Currently, Only Supports expansion_factor 2"
|
||||
|
||||
self.sequential = nn.Sequential(
|
||||
Transpose(shape=(1, 2)),
|
||||
DepthwiseConv1d(in_channels, in_channels, kernel_size, stride=1, padding=(kernel_size - 1) // 2),
|
||||
)
|
||||
|
||||
def forward(self, inputs):
|
||||
return inputs + self.sequential(inputs).transpose(1, 2)
|
||||
|
||||
|
||||
class OffsetScale(nn.Module):
|
||||
def __init__(self, dim, heads = 1):
|
||||
super().__init__()
|
||||
self.gamma = nn.Parameter(torch.ones(heads, dim))
|
||||
self.beta = nn.Parameter(torch.zeros(heads, dim))
|
||||
nn.init.normal_(self.gamma, std = 0.02)
|
||||
|
||||
def forward(self, x):
|
||||
out = einsum('... d, h d -> ... h d', x, self.gamma) + self.beta
|
||||
return out.unbind(dim = -2)
|
||||
|
||||
|
||||
class FFConvM(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim_in,
|
||||
dim_out,
|
||||
norm_klass = nn.LayerNorm,
|
||||
dropout = 0.1
|
||||
):
|
||||
super().__init__()
|
||||
self.mdl = nn.Sequential(
|
||||
norm_klass(dim_in),
|
||||
nn.Linear(dim_in, dim_out),
|
||||
nn.SiLU(),
|
||||
ConvModule(dim_out),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
):
|
||||
output = self.mdl(x)
|
||||
return output
|
||||
|
||||
|
||||
class FLASH_ShareA_FFConvM(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
dim,
|
||||
group_size = 256,
|
||||
query_key_dim = 128,
|
||||
expansion_factor = 1.,
|
||||
causal = False,
|
||||
dropout = 0.1,
|
||||
rotary_pos_emb = None,
|
||||
norm_klass = nn.LayerNorm,
|
||||
shift_tokens = True
|
||||
):
|
||||
super().__init__()
|
||||
hidden_dim = int(dim * expansion_factor)
|
||||
self.group_size = group_size
|
||||
self.causal = causal
|
||||
self.shift_tokens = shift_tokens
|
||||
|
||||
# positional embeddings
|
||||
self.rotary_pos_emb = rotary_pos_emb
|
||||
# norm
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
# projections
|
||||
|
||||
self.to_hidden = FFConvM(
|
||||
dim_in = dim,
|
||||
dim_out = hidden_dim,
|
||||
norm_klass = norm_klass,
|
||||
dropout = dropout,
|
||||
)
|
||||
self.to_qk = FFConvM(
|
||||
dim_in = dim,
|
||||
dim_out = query_key_dim,
|
||||
norm_klass = norm_klass,
|
||||
dropout = dropout,
|
||||
)
|
||||
|
||||
self.qk_offset_scale = OffsetScale(query_key_dim, heads = 4)
|
||||
|
||||
self.to_out = FFConvM(
|
||||
dim_in = dim*2,
|
||||
dim_out = dim,
|
||||
norm_klass = norm_klass,
|
||||
dropout = dropout,
|
||||
)
|
||||
|
||||
self.gateActivate=nn.Sigmoid()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
*,
|
||||
mask = None
|
||||
):
|
||||
|
||||
"""
|
||||
b - batch
|
||||
n - sequence length (within groups)
|
||||
g - group dimension
|
||||
d - feature dimension (keys)
|
||||
e - feature dimension (values)
|
||||
i - sequence dimension (source)
|
||||
j - sequence dimension (target)
|
||||
"""
|
||||
|
||||
normed_x = x
|
||||
|
||||
# do token shift - a great, costless trick from an independent AI researcher in Shenzhen
|
||||
residual = x
|
||||
|
||||
if self.shift_tokens:
|
||||
x_shift, x_pass = normed_x.chunk(2, dim = -1)
|
||||
x_shift = F.pad(x_shift, (0, 0, 1, -1), value = 0.)
|
||||
normed_x = torch.cat((x_shift, x_pass), dim = -1)
|
||||
|
||||
# initial projections
|
||||
|
||||
v, u = self.to_hidden(normed_x).chunk(2, dim = -1)
|
||||
qk = self.to_qk(normed_x)
|
||||
|
||||
# offset and scale
|
||||
quad_q, lin_q, quad_k, lin_k = self.qk_offset_scale(qk)
|
||||
att_v, att_u = self.cal_attention(x, quad_q, lin_q, quad_k, lin_k, v, u)
|
||||
out = (att_u*v ) * self.gateActivate(att_v*u)
|
||||
x = x + self.to_out(out)
|
||||
return x
|
||||
|
||||
def cal_attention(self, x, quad_q, lin_q, quad_k, lin_k, v, u, mask = None):
|
||||
b, n, device, g = x.shape[0], x.shape[-2], x.device, self.group_size
|
||||
|
||||
if exists(mask):
|
||||
lin_mask = rearrange(mask, '... -> ... 1')
|
||||
lin_k = lin_k.masked_fill(~lin_mask, 0.)
|
||||
|
||||
# rotate queries and keys
|
||||
|
||||
if exists(self.rotary_pos_emb):
|
||||
quad_q, lin_q, quad_k, lin_k = map(self.rotary_pos_emb.rotate_queries_or_keys, (quad_q, lin_q, quad_k, lin_k))
|
||||
|
||||
# padding for groups
|
||||
|
||||
padding = padding_to_multiple_of(n, g)
|
||||
|
||||
if padding > 0:
|
||||
quad_q, quad_k, lin_q, lin_k, v, u = map(lambda t: F.pad(t, (0, 0, 0, padding), value = 0.), (quad_q, quad_k, lin_q, lin_k, v, u))
|
||||
|
||||
mask = default(mask, torch.ones((b, n), device = device, dtype = torch.bool))
|
||||
mask = F.pad(mask, (0, padding), value = False)
|
||||
|
||||
# group along sequence
|
||||
|
||||
quad_q, quad_k, lin_q, lin_k, v, u = map(lambda t: rearrange(t, 'b (g n) d -> b g n d', n = self.group_size), (quad_q, quad_k, lin_q, lin_k, v, u))
|
||||
|
||||
if exists(mask):
|
||||
mask = rearrange(mask, 'b (g j) -> b g 1 j', j = g)
|
||||
|
||||
# calculate quadratic attention output
|
||||
|
||||
sim = einsum('... i d, ... j d -> ... i j', quad_q, quad_k) / g
|
||||
|
||||
attn = F.relu(sim) ** 2
|
||||
attn = self.dropout(attn)
|
||||
|
||||
if exists(mask):
|
||||
attn = attn.masked_fill(~mask, 0.)
|
||||
|
||||
if self.causal:
|
||||
causal_mask = torch.ones((g, g), dtype = torch.bool, device = device).triu(1)
|
||||
attn = attn.masked_fill(causal_mask, 0.)
|
||||
|
||||
quad_out_v = einsum('... i j, ... j d -> ... i d', attn, v)
|
||||
quad_out_u = einsum('... i j, ... j d -> ... i d', attn, u)
|
||||
|
||||
# calculate linear attention output
|
||||
|
||||
if self.causal:
|
||||
lin_kv = einsum('b g n d, b g n e -> b g d e', lin_k, v) / g
|
||||
# exclusive cumulative sum along group dimension
|
||||
lin_kv = lin_kv.cumsum(dim = 1)
|
||||
lin_kv = F.pad(lin_kv, (0, 0, 0, 0, 1, -1), value = 0.)
|
||||
lin_out_v = einsum('b g d e, b g n d -> b g n e', lin_kv, lin_q)
|
||||
|
||||
lin_ku = einsum('b g n d, b g n e -> b g d e', lin_k, u) / g
|
||||
# exclusive cumulative sum along group dimension
|
||||
lin_ku = lin_ku.cumsum(dim = 1)
|
||||
lin_ku = F.pad(lin_ku, (0, 0, 0, 0, 1, -1), value = 0.)
|
||||
lin_out_u = einsum('b g d e, b g n d -> b g n e', lin_ku, lin_q)
|
||||
else:
|
||||
lin_kv = einsum('b g n d, b g n e -> b d e', lin_k, v) / n
|
||||
lin_out_v = einsum('b g n d, b d e -> b g n e', lin_q, lin_kv)
|
||||
|
||||
lin_ku = einsum('b g n d, b g n e -> b d e', lin_k, u) / n
|
||||
lin_out_u = einsum('b g n d, b d e -> b g n e', lin_q, lin_ku)
|
||||
|
||||
# fold back groups into full sequence, and excise out padding
|
||||
return map(lambda t: rearrange(t, 'b g n d -> b (g n) d')[:, :n], (quad_out_v+lin_out_v, quad_out_u+lin_out_u))
|
||||
|
||||
Loading…
Reference in New Issue
Block a user