Merge branch 'dev_gzf_deepspeed' of http://gitlab.alibaba-inc.com/zhifu.gzf/FunASR into dev_gzf_deepspeed

This commit is contained in:
dcaaaa 2024-07-10 17:43:15 +08:00
commit 8551c7f419
9 changed files with 2748 additions and 620 deletions

View File

@ -224,18 +224,21 @@ class AutoModel:
# init_param
init_param = kwargs.get("init_param", None)
if init_param is not None:
if os.path.exists(init_param):
logging.info(f"Loading pretrained params from {init_param}")
load_pretrained_model(
model=model,
path=init_param,
ignore_init_mismatch=kwargs.get("ignore_init_mismatch", True),
oss_bucket=kwargs.get("oss_bucket", None),
scope_map=kwargs.get("scope_map", []),
excludes=kwargs.get("excludes", None),
)
else:
print(f"error, init_param does not exist!: {init_param}")
if isinstance(init_param, str):
init_param = init_param.split(",")
for i, init_param_i in enumerate(init_param):
if os.path.exists(init_param_i):
logging.info(f"Loading pretrained params from ckpt-{i}: {init_param_i}")
load_pretrained_model(
model=model,
path=init_param_i,
ignore_init_mismatch=kwargs.get("ignore_init_mismatch", True),
oss_bucket=kwargs.get("oss_bucket", None),
scope_map=kwargs.get("scope_map", []),
excludes=kwargs.get("excludes", None),
)
else:
print(f"error, init_param from ckpt-{i} does not exist!: {init_param_i}")
# fp16
if kwargs.get("fp16", False):

View File

@ -610,6 +610,8 @@ class OpenAIDatasetMultiTurnCodec(torch.utils.data.Dataset):
fake_token_len_i = 0
fbank_beg_i = -1
fbank_lens_i = []
speech = []
speech_lengths = []
for k, sub_str in enumerate(splits):
if not sub_str.startswith("<|startofspeech|>"):
sub_token = self.tokenizer.encode(sub_str)
@ -688,9 +690,11 @@ class OpenAIDatasetMultiTurnCodec(torch.utils.data.Dataset):
input_ids += source_ids + target_ids
labels += source_mask + target_ids
fbank.append(speech[0, :, :])
fbank_mask += fbank_mask_i
fbank_lens.append(speech_lengths)
if len(speech) > 0:
fbank.append(speech[0, :, :])
fbank_lens.append(speech_lengths)
if badcase_flag:
continue
@ -706,8 +710,6 @@ class OpenAIDatasetMultiTurnCodec(torch.utils.data.Dataset):
fake_token_len = torch.tensor(fake_token_len, dtype=torch.int32)
output = {
"speech": fbank,
"speech_lengths": fbank_lens,
"fbank_mask": fbank_mask,
"fbank_beg": fbank_beg,
"fake_token_len": fake_token_len,
@ -719,6 +721,10 @@ class OpenAIDatasetMultiTurnCodec(torch.utils.data.Dataset):
codec_len = torch.tensor(codec_len, dtype=torch.int32)
output["codec"] = codec
output["codec_len"] = codec_len
if len(fbank) > 0:
output["speech"] = fbank
output["speech_lengths"] = fbank_lens
break
return output

View File

@ -55,7 +55,7 @@ class OpenAIIndexDSJsonl(torch.utils.data.Dataset): # torch.utils.data.Dataset
text_length = data_dict.get("text_length", 0)
if speech_length > self.max_source_length:
logging.info(
"speech_length: {speech_length} > {self.max_source_length}, drop it"
f"speech_length: {speech_length} > {self.max_source_length}, drop it"
)
continue
if text_length > self.max_target_length:

View File

@ -59,10 +59,19 @@ def download_from_ms(**kwargs):
elif os.path.exists(os.path.join(model_or_path, "config.yaml")):
config = OmegaConf.load(os.path.join(model_or_path, "config.yaml"))
kwargs = OmegaConf.merge(config, kwargs)
init_param = os.path.join(model_or_path, "model.pt")
if "init_param" not in kwargs or not os.path.exists(kwargs["init_param"]):
kwargs["init_param"] = init_param
assert os.path.exists(kwargs["init_param"]), "init_param does not exist"
init_param = kwargs.get("init_param", "")
if not os.path.exists(init_param):
init_param_new = init_param
if isinstance(init_param, str):
init_param = init_param.split(",")
for init_param_i in init_param:
if not os.path.exists(init_param_i):
print(f"init_param: {init_param_i}, does not exist")
init_param_i = os.path.join(model_or_path, "model.pt")
init_param_new = f"{init_param_new},{init_param_i}"
kwargs["init_param"] = init_param_new
# assert os.path.exists(kwargs["init_param"]), "init_param does not exist"
if os.path.exists(os.path.join(model_or_path, "tokens.txt")):
kwargs["tokenizer_conf"]["token_list"] = os.path.join(model_or_path, "tokens.txt")
if os.path.exists(os.path.join(model_or_path, "tokens.json")):

View File

@ -0,0 +1,82 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Copyright 2019 Shigeki Karita
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""Label smoothing module."""
import torch
from torch import nn
from funasr.models.transformer.utils.nets_utils import make_pad_mask
class LabelSmoothingLoss(nn.Module):
"""Label-smoothing loss.
:param int size: the number of class
:param int padding_idx: ignored class id
:param float smoothing: smoothing rate (0.0 means the conventional CE)
:param bool normalize_length: normalize loss by sequence length if True
:param torch.nn.Module criterion: loss function to be smoothed
"""
def __init__(
self,
size,
padding_idx,
smoothing,
normalize_length=False,
criterion=nn.KLDivLoss(reduction="none"),
reduction=True,
):
"""Construct an LabelSmoothingLoss object."""
super(LabelSmoothingLoss, self).__init__()
self.criterion = criterion
self.padding_idx = padding_idx
self.confidence = 1.0 - smoothing
self.smoothing = smoothing
self.size = size
self.true_dist = None
self.normalize_length = normalize_length
self.reduction = reduction
def forward(self, x, target):
"""Compute loss between x and target.
:param torch.Tensor x: prediction (batch, seqlen, class)
:param torch.Tensor target:
target signal masked with self.padding_id (batch, seqlen)
:return: scalar float value
:rtype torch.Tensor
"""
assert x.size(2) == self.size
batch_size = x.size(0)
x = x.reshape(-1, self.size)
target = target.reshape(-1)
with torch.no_grad():
true_dist = x.clone()
true_dist.fill_(self.smoothing / (self.size - 1))
ignore = target == self.padding_idx # (B,)
total = len(target) - ignore.sum().item()
target = target.masked_fill(ignore, 0) # avoid -1 index
true_dist.scatter_(1, target.unsqueeze(1), self.confidence)
kl = self.criterion(torch.log_softmax(x, dim=1), true_dist)
if not self.reduction:
return kl
else:
denom = total if self.normalize_length else batch_size
return kl.masked_fill(ignore.unsqueeze(1), 0).sum() / denom
class SequenceBinaryCrossEntropy(nn.Module):
def __init__(self, normalize_length=False, criterion=nn.BCEWithLogitsLoss(reduction="none")):
super().__init__()
self.normalize_length = normalize_length
self.criterion = criterion
def forward(self, pred, label, lengths):
pad_mask = make_pad_mask(lengths, maxlen=pred.shape[1]).to(pred.device)
loss = self.criterion(pred, label)
denom = (~pad_mask).sum() if self.normalize_length else pred.shape[0]
return loss.masked_fill(pad_mask.unsqueeze(-1), 0).sum() / denom

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,751 @@
# Copyright 2019 Shigeki Karita
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""Transformer encoder definition."""
from typing import List
from typing import Optional
from typing import Tuple
import torch
from torch import nn
import logging
from funasr.models.transformer.attention import (
MultiHeadedAttention,
RelPositionMultiHeadedAttention, # noqa: H301
LegacyRelPositionMultiHeadedAttention, # noqa: H301
)
from funasr.models.transformer.embedding import (
PositionalEncoding, # noqa: H301
ScaledPositionalEncoding, # noqa: H301
RelPositionalEncoding, # noqa: H301
LegacyRelPositionalEncoding, # noqa: H301
)
from funasr.models.transformer.layer_norm import LayerNorm
from funasr.models.transformer.utils.multi_layer_conv import Conv1dLinear
from funasr.models.transformer.utils.multi_layer_conv import MultiLayeredConv1d
from funasr.models.transformer.utils.nets_utils import make_pad_mask
from funasr.models.transformer.positionwise_feed_forward import (
PositionwiseFeedForward, # noqa: H301
)
from funasr.models.transformer.utils.repeat import repeat
from funasr.models.transformer.utils.nets_utils import rename_state_dict
from funasr.models.transformer.utils.dynamic_conv import DynamicConvolution
from funasr.models.transformer.utils.dynamic_conv2d import DynamicConvolution2D
from funasr.models.transformer.utils.lightconv import LightweightConvolution
from funasr.models.transformer.utils.lightconv2d import LightweightConvolution2D
from funasr.models.transformer.utils.subsampling import Conv2dSubsampling
from funasr.models.transformer.utils.subsampling import Conv2dSubsampling2
from funasr.models.transformer.utils.subsampling import Conv2dSubsampling6
from funasr.models.transformer.utils.subsampling import Conv2dSubsampling8
from funasr.models.transformer.utils.subsampling import TooShortUttError
from funasr.models.transformer.utils.subsampling import check_short_utt
class EncoderLayer(nn.Module):
"""Encoder layer module.
Args:
size (int): Input dimension.
self_attn (torch.nn.Module): Self-attention module instance.
`MultiHeadedAttention` or `RelPositionMultiHeadedAttention` instance
can be used as the argument.
feed_forward (torch.nn.Module): Feed-forward module instance.
`PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance
can be used as the argument.
dropout_rate (float): Dropout rate.
normalize_before (bool): Whether to use layer_norm before the first block.
concat_after (bool): Whether to concat attention layer's input and output.
if True, additional linear will be applied.
i.e. x -> x + linear(concat(x, att(x)))
if False, no additional linear will be applied. i.e. x -> x + att(x)
stochastic_depth_rate (float): Proability to skip this layer.
During training, the layer may skip residual computation and return input
as-is with given probability.
"""
def __init__(
self,
size,
self_attn,
feed_forward,
dropout_rate,
normalize_before=True,
concat_after=False,
stochastic_depth_rate=0.0,
):
"""Construct an EncoderLayer object."""
super(EncoderLayer, self).__init__()
self.self_attn = self_attn
self.feed_forward = feed_forward
self.norm1 = LayerNorm(size)
self.norm2 = LayerNorm(size)
self.dropout = nn.Dropout(dropout_rate)
self.size = size
self.normalize_before = normalize_before
self.concat_after = concat_after
if self.concat_after:
self.concat_linear = nn.Linear(size + size, size)
self.stochastic_depth_rate = stochastic_depth_rate
def forward(self, x, mask, cache=None):
"""Compute encoded features.
Args:
x_input (torch.Tensor): Input tensor (#batch, time, size).
mask (torch.Tensor): Mask tensor for the input (#batch, time).
cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size).
Returns:
torch.Tensor: Output tensor (#batch, time, size).
torch.Tensor: Mask tensor (#batch, time).
"""
if isinstance(x, tuple):
x, pos_emb = x[0], x[1]
else:
x, pos_emb = x, None
skip_layer = False
# with stochastic depth, residual connection `x + f(x)` becomes
# `x <- x + 1 / (1 - p) * f(x)` at training time.
stoch_layer_coeff = 1.0
if self.training and self.stochastic_depth_rate > 0:
skip_layer = torch.rand(1).item() < self.stochastic_depth_rate
stoch_layer_coeff = 1.0 / (1 - self.stochastic_depth_rate)
if skip_layer:
if cache is not None:
x = torch.cat([cache, x], dim=1)
if pos_emb is not None:
return (x, pos_emb), mask
return x, mask
residual = x
if self.normalize_before:
x = self.norm1(x)
if cache is None:
x_q = x
else:
assert cache.shape == (x.shape[0], x.shape[1] - 1, self.size)
x_q = x[:, -1:, :]
residual = residual[:, -1:, :]
mask = None if mask is None else mask[:, -1:, :]
if pos_emb is not None:
x_att = self.self_attn(x_q, x, x, pos_emb, mask)
else:
x_att = self.self_attn(x_q, x, x, mask)
if self.concat_after:
x_concat = torch.cat((x, x_att), dim=-1)
x = residual + stoch_layer_coeff * self.concat_linear(x_concat)
else:
x = residual + stoch_layer_coeff * self.dropout(x_att)
if not self.normalize_before:
x = self.norm1(x)
residual = x
if self.normalize_before:
x = self.norm2(x)
x = residual + stoch_layer_coeff * self.dropout(self.feed_forward(x))
if not self.normalize_before:
x = self.norm2(x)
if cache is not None:
x = torch.cat([cache, x], dim=1)
if pos_emb is not None:
return (x, pos_emb), mask
return x, mask
class TransformerEncoder(nn.Module):
"""Transformer encoder module.
Args:
input_size: input dim
output_size: dimension of attention
attention_heads: the number of heads of multi head attention
linear_units: the number of units of position-wise feed forward
num_blocks: the number of decoder blocks
dropout_rate: dropout rate
attention_dropout_rate: dropout rate in attention
positional_dropout_rate: dropout rate after adding positional encoding
input_layer: input layer type
pos_enc_class: PositionalEncoding or ScaledPositionalEncoding
normalize_before: whether to use layer_norm before the first block
concat_after: whether to concat attention layer's input and output
if True, additional linear will be applied.
i.e. x -> x + linear(concat(x, att(x)))
if False, no additional linear will be applied.
i.e. x -> x + att(x)
positionwise_layer_type: linear of conv1d
positionwise_conv_kernel_size: kernel size of positionwise conv1d layer
padding_idx: padding_idx for input_layer=embed
"""
def __init__(
self,
input_size: int,
output_size: int = 256,
attention_heads: int = 4,
linear_units: int = 2048,
num_blocks: int = 6,
dropout_rate: float = 0.1,
positional_dropout_rate: float = 0.1,
attention_dropout_rate: float = 0.0,
input_layer: Optional[str] = "conv2d",
pos_enc_class=PositionalEncoding,
normalize_before: bool = True,
concat_after: bool = False,
positionwise_layer_type: str = "linear",
positionwise_conv_kernel_size: int = 1,
padding_idx: int = -1,
interctc_layer_idx: List[int] = [],
interctc_use_conditioning: bool = False,
causal_mode: str = "None",
):
super().__init__()
self._output_size = output_size
self.causal_mode = causal_mode
if input_layer == "linear":
self.embed = torch.nn.Sequential(
torch.nn.Linear(input_size, output_size),
torch.nn.LayerNorm(output_size),
torch.nn.Dropout(dropout_rate),
torch.nn.ReLU(),
pos_enc_class(output_size, positional_dropout_rate),
)
elif input_layer == "conv2d":
self.embed = Conv2dSubsampling(input_size, output_size, dropout_rate)
elif input_layer == "conv2d2":
self.embed = Conv2dSubsampling2(input_size, output_size, dropout_rate)
elif input_layer == "conv2d6":
self.embed = Conv2dSubsampling6(input_size, output_size, dropout_rate)
elif input_layer == "conv2d8":
self.embed = Conv2dSubsampling8(input_size, output_size, dropout_rate)
elif input_layer == "embed":
self.embed = torch.nn.Sequential(
torch.nn.Embedding(input_size, output_size, padding_idx=padding_idx),
pos_enc_class(output_size, positional_dropout_rate),
)
elif input_layer is None:
if input_size == output_size:
self.embed = None
else:
self.embed = torch.nn.Linear(input_size, output_size)
else:
raise ValueError("unknown input_layer: " + input_layer)
self.normalize_before = normalize_before
if positionwise_layer_type == "linear":
positionwise_layer = PositionwiseFeedForward
positionwise_layer_args = (
output_size,
linear_units,
dropout_rate,
)
elif positionwise_layer_type == "conv1d":
positionwise_layer = MultiLayeredConv1d
positionwise_layer_args = (
output_size,
linear_units,
positionwise_conv_kernel_size,
dropout_rate,
)
elif positionwise_layer_type == "conv1d-linear":
positionwise_layer = Conv1dLinear
positionwise_layer_args = (
output_size,
linear_units,
positionwise_conv_kernel_size,
dropout_rate,
)
else:
raise NotImplementedError("Support only linear or conv1d.")
self.encoders = repeat(
num_blocks,
lambda lnum: EncoderLayer(
output_size,
MultiHeadedAttention(
attention_heads, output_size, attention_dropout_rate
),
positionwise_layer(*positionwise_layer_args),
dropout_rate,
normalize_before,
concat_after,
),
)
if self.normalize_before:
self.after_norm = LayerNorm(output_size)
self.interctc_layer_idx = interctc_layer_idx
if len(interctc_layer_idx) > 0:
assert 0 < min(interctc_layer_idx) and max(interctc_layer_idx) < num_blocks
self.interctc_use_conditioning = interctc_use_conditioning
self.conditioning_layer = None
def output_size(self) -> int:
return self._output_size
def forward(
self,
xs_pad: torch.Tensor,
ilens: torch.Tensor,
prev_states: torch.Tensor = None,
ctc = None,
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
"""Embed positions in tensor.
Args:
xs_pad: input tensor (B, L, D)
ilens: input length (B)
prev_states: Not to be used now.
Returns:
position embedded tensor and mask
"""
masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
if self.causal_mode == "None":
pass
elif self.causal_mode == "causal":
tt = xs_pad.shape[1]
pos_idx = torch.arange(tt)
causal_mask = torch.less_equal(pos_idx.unsqueeze(0), pos_idx.unsqueeze(1))
causal_mask = causal_mask.unsqueeze(0).to(xs_pad.device)
masks = masks * causal_mask
if self.embed is None:
xs_pad = xs_pad
elif (
isinstance(self.embed, Conv2dSubsampling)
or isinstance(self.embed, Conv2dSubsampling2)
or isinstance(self.embed, Conv2dSubsampling6)
or isinstance(self.embed, Conv2dSubsampling8)
):
short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
if short_status:
raise TooShortUttError(
f"has {xs_pad.size(1)} frames and is too short for subsampling "
+ f"(it needs more than {limit_size} frames), return empty results",
xs_pad.size(1),
limit_size,
)
xs_pad, masks = self.embed(xs_pad, masks)
else:
xs_pad = self.embed(xs_pad)
intermediate_outs = []
if len(self.interctc_layer_idx) == 0:
xs_pad, masks = self.encoders(xs_pad, masks)
else:
for layer_idx, encoder_layer in enumerate(self.encoders):
xs_pad, masks = encoder_layer(xs_pad, masks)
if layer_idx + 1 in self.interctc_layer_idx:
encoder_out = xs_pad
# intermediate outputs are also normalized
if self.normalize_before:
encoder_out = self.after_norm(encoder_out)
intermediate_outs.append((layer_idx + 1, encoder_out))
if self.interctc_use_conditioning:
ctc_out = ctc.softmax(encoder_out)
xs_pad = xs_pad + self.conditioning_layer(ctc_out)
if self.normalize_before:
xs_pad = self.after_norm(xs_pad)
olens = masks.squeeze(1).sum(1)
if len(intermediate_outs) > 0:
return (xs_pad, intermediate_outs), olens, None
return xs_pad, olens, None
def _pre_hook(
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
):
# https://github.com/espnet/espnet/commit/21d70286c354c66c0350e65dc098d2ee236faccc#diff-bffb1396f038b317b2b64dd96e6d3563
rename_state_dict(prefix + "input_layer.", prefix + "embed.", state_dict)
# https://github.com/espnet/espnet/commit/3d422f6de8d4f03673b89e1caef698745ec749ea#diff-bffb1396f038b317b2b64dd96e6d3563
rename_state_dict(prefix + "norm.", prefix + "after_norm.", state_dict)
class TransformerEncoder_s0(nn.Module):
"""Transformer encoder module.
Args:
idim (int): Input dimension.
attention_dim (int): Dimension of attention.
attention_heads (int): The number of heads of multi head attention.
conv_wshare (int): The number of kernel of convolution. Only used in
selfattention_layer_type == "lightconv*" or "dynamiconv*".
conv_kernel_length (Union[int, str]): Kernel size str of convolution
(e.g. 71_71_71_71_71_71). Only used in selfattention_layer_type
== "lightconv*" or "dynamiconv*".
conv_usebias (bool): Whether to use bias in convolution. Only used in
selfattention_layer_type == "lightconv*" or "dynamiconv*".
linear_units (int): The number of units of position-wise feed forward.
num_blocks (int): The number of decoder blocks.
dropout_rate (float): Dropout rate.
positional_dropout_rate (float): Dropout rate after adding positional encoding.
attention_dropout_rate (float): Dropout rate in attention.
input_layer (Union[str, torch.nn.Module]): Input layer type.
pos_enc_class (torch.nn.Module): Positional encoding module class.
`PositionalEncoding `or `ScaledPositionalEncoding`
normalize_before (bool): Whether to use layer_norm before the first block.
concat_after (bool): Whether to concat attention layer's input and output.
if True, additional linear will be applied.
i.e. x -> x + linear(concat(x, att(x)))
if False, no additional linear will be applied. i.e. x -> x + att(x)
positionwise_layer_type (str): "linear", "conv1d", or "conv1d-linear".
positionwise_conv_kernel_size (int): Kernel size of positionwise conv1d layer.
selfattention_layer_type (str): Encoder attention layer type.
padding_idx (int): Padding idx for input_layer=embed.
stochastic_depth_rate (float): Maximum probability to skip the encoder layer.
intermediate_layers (Union[List[int], None]): indices of intermediate CTC layer.
indices start from 1.
if not None, intermediate outputs are returned (which changes return type
signature.)
"""
def __init__(
self,
idim,
attention_dim=256,
attention_heads=4,
conv_wshare=4,
conv_kernel_length="11",
conv_usebias=False,
linear_units=2048,
num_blocks=6,
dropout_rate=0.1,
positional_dropout_rate=0.1,
attention_dropout_rate=0.0,
input_layer="conv2d",
pos_enc_class=PositionalEncoding,
normalize_before=True,
concat_after=False,
positionwise_layer_type="linear",
positionwise_conv_kernel_size=1,
selfattention_layer_type="selfattn",
padding_idx=-1,
stochastic_depth_rate=0.0,
intermediate_layers=None,
ctc_softmax=None,
conditioning_layer_dim=None,
zero_triu: bool = False,
):
"""Construct an Encoder object."""
super(TransformerEncoder_s0, self).__init__()
self._register_load_state_dict_pre_hook(_pre_hook)
self.conv_subsampling_factor = 1
if input_layer == "linear":
self.embed = torch.nn.Sequential(
torch.nn.Linear(idim, attention_dim),
torch.nn.LayerNorm(attention_dim),
torch.nn.Dropout(dropout_rate),
torch.nn.ReLU(),
pos_enc_class(attention_dim, positional_dropout_rate),
)
elif input_layer == "conv2d":
self.embed = Conv2dSubsampling(idim, attention_dim, dropout_rate)
self.conv_subsampling_factor = 4
elif input_layer == "conv2d-scaled-pos-enc":
self.embed = Conv2dSubsampling(
idim,
attention_dim,
dropout_rate,
pos_enc_class(attention_dim, positional_dropout_rate),
)
self.conv_subsampling_factor = 4
elif input_layer == "conv2d6":
self.embed = Conv2dSubsampling6(idim, attention_dim, dropout_rate)
self.conv_subsampling_factor = 6
elif input_layer == "conv2d8":
self.embed = Conv2dSubsampling8(idim, attention_dim, dropout_rate)
self.conv_subsampling_factor = 8
elif input_layer == "embed":
self.embed = torch.nn.Sequential(
torch.nn.Embedding(idim, attention_dim, padding_idx=padding_idx),
pos_enc_class(attention_dim, positional_dropout_rate),
)
elif isinstance(input_layer, torch.nn.Module):
self.embed = torch.nn.Sequential(
input_layer,
pos_enc_class(attention_dim, positional_dropout_rate),
)
elif input_layer is None:
self.embed = torch.nn.Sequential(
pos_enc_class(attention_dim, positional_dropout_rate)
)
elif input_layer == "none":
self.embed = torch.nn.Identity()
else:
raise ValueError("unknown input_layer: " + input_layer)
self.normalize_before = normalize_before
positionwise_layer, positionwise_layer_args = self.get_positionwise_layer(
positionwise_layer_type,
attention_dim,
linear_units,
dropout_rate,
positionwise_conv_kernel_size,
)
# if selfattention_layer_type in [
# "selfattn",
# "rel_selfattn",
# "legacy_rel_selfattn",
# ]:
# logging.info("encoder self-attention layer type = self-attention")
# encoder_selfattn_layer = MultiHeadedAttention
# encoder_selfattn_layer_args = [
# (
# attention_heads,
# attention_dim,
# attention_dropout_rate,
# )
# ] * num_blocks
if selfattention_layer_type == "selfattn":
logging.info("encoder self-attention layer type = self-attention")
encoder_selfattn_layer = MultiHeadedAttention
encoder_selfattn_layer_args = [(
attention_heads,
attention_dim,
attention_dropout_rate,
)] * num_blocks
elif selfattention_layer_type == "legacy_rel_selfattn":
logging.info("encoder self-attention layer type = legacy relative self-attention")
assert pos_enc_class == LegacyRelPositionalEncoding
encoder_selfattn_layer = LegacyRelPositionMultiHeadedAttention
encoder_selfattn_layer_args = [(
attention_heads,
attention_dim,
attention_dropout_rate,
)] * num_blocks
logging.warning(
"Using legacy_rel_selfattn and it will be deprecated in the future."
)
elif selfattention_layer_type == "rel_selfattn":
logging.info("encoder self-attention layer type = relative self-attention")
assert pos_enc_class == RelPositionalEncoding
encoder_selfattn_layer = RelPositionMultiHeadedAttention
encoder_selfattn_layer_args = [(
attention_heads,
attention_dim,
attention_dropout_rate,
zero_triu,
)] * num_blocks
elif selfattention_layer_type == "lightconv":
logging.info("encoder self-attention layer type = lightweight convolution")
encoder_selfattn_layer = LightweightConvolution
encoder_selfattn_layer_args = [
(
conv_wshare,
attention_dim,
attention_dropout_rate,
int(conv_kernel_length.split("_")[lnum]),
False,
conv_usebias,
)
for lnum in range(num_blocks)
]
elif selfattention_layer_type == "lightconv2d":
logging.info(
"encoder self-attention layer "
"type = lightweight convolution 2-dimensional"
)
encoder_selfattn_layer = LightweightConvolution2D
encoder_selfattn_layer_args = [
(
conv_wshare,
attention_dim,
attention_dropout_rate,
int(conv_kernel_length.split("_")[lnum]),
False,
conv_usebias,
)
for lnum in range(num_blocks)
]
elif selfattention_layer_type == "dynamicconv":
logging.info("encoder self-attention layer type = dynamic convolution")
encoder_selfattn_layer = DynamicConvolution
encoder_selfattn_layer_args = [
(
conv_wshare,
attention_dim,
attention_dropout_rate,
int(conv_kernel_length.split("_")[lnum]),
False,
conv_usebias,
)
for lnum in range(num_blocks)
]
elif selfattention_layer_type == "dynamicconv2d":
logging.info(
"encoder self-attention layer type = dynamic convolution 2-dimensional"
)
encoder_selfattn_layer = DynamicConvolution2D
encoder_selfattn_layer_args = [
(
conv_wshare,
attention_dim,
attention_dropout_rate,
int(conv_kernel_length.split("_")[lnum]),
False,
conv_usebias,
)
for lnum in range(num_blocks)
]
else:
raise NotImplementedError(selfattention_layer_type)
self.encoders = repeat(
num_blocks,
lambda lnum: EncoderLayer(
attention_dim,
encoder_selfattn_layer(*encoder_selfattn_layer_args[lnum]),
positionwise_layer(*positionwise_layer_args),
dropout_rate,
normalize_before,
concat_after,
stochastic_depth_rate * float(1 + lnum) / num_blocks,
),
)
if self.normalize_before:
self.after_norm = LayerNorm(attention_dim)
self.intermediate_layers = intermediate_layers
self.use_conditioning = True if ctc_softmax is not None else False
if self.use_conditioning:
self.ctc_softmax = ctc_softmax
self.conditioning_layer = torch.nn.Linear(
conditioning_layer_dim, attention_dim
)
def get_positionwise_layer(
self,
positionwise_layer_type="linear",
attention_dim=256,
linear_units=2048,
dropout_rate=0.1,
positionwise_conv_kernel_size=1,
):
"""Define positionwise layer."""
if positionwise_layer_type == "linear":
positionwise_layer = PositionwiseFeedForward
positionwise_layer_args = (attention_dim, linear_units, dropout_rate)
elif positionwise_layer_type == "conv1d":
positionwise_layer = MultiLayeredConv1d
positionwise_layer_args = (
attention_dim,
linear_units,
positionwise_conv_kernel_size,
dropout_rate,
)
elif positionwise_layer_type == "conv1d-linear":
positionwise_layer = Conv1dLinear
positionwise_layer_args = (
attention_dim,
linear_units,
positionwise_conv_kernel_size,
dropout_rate,
)
else:
raise NotImplementedError("Support only linear or conv1d.")
return positionwise_layer, positionwise_layer_args
def forward(self, xs, masks):
"""Encode input sequence.
Args:
xs (torch.Tensor): Input tensor (#batch, time, idim).
masks (torch.Tensor): Mask tensor (#batch, time).
Returns:
torch.Tensor: Output tensor (#batch, time, attention_dim).
torch.Tensor: Mask tensor (#batch, time).
"""
if isinstance(
self.embed,
(Conv2dSubsampling, Conv2dSubsampling6, Conv2dSubsampling8),
):
xs, masks = self.embed(xs, masks)
else:
xs = self.embed(xs)
if self.intermediate_layers is None:
xs, masks = self.encoders(xs, masks)
else:
intermediate_outputs = []
for layer_idx, encoder_layer in enumerate(self.encoders):
xs, masks = encoder_layer(xs, masks)
if (
self.intermediate_layers is not None
and layer_idx + 1 in self.intermediate_layers
):
if isinstance(xs, tuple):
encoder_output = xs[0]
else:
encoder_output = xs
# intermediate branches also require normalization.
if self.normalize_before:
encoder_output = self.after_norm(encoder_output)
intermediate_outputs.append(encoder_output)
if self.use_conditioning:
intermediate_result = self.ctc_softmax(encoder_output)
xs = xs + self.conditioning_layer(intermediate_result)
if isinstance(xs, tuple):
xs = xs[0]
if self.normalize_before:
xs = self.after_norm(xs)
if self.intermediate_layers is not None:
return xs, masks, intermediate_outputs
return xs, masks
def forward_one_step(self, xs, masks, cache=None):
"""Encode input frame.
Args:
xs (torch.Tensor): Input tensor.
masks (torch.Tensor): Mask tensor.
cache (List[torch.Tensor]): List of cache tensors.
Returns:
torch.Tensor: Output tensor.
torch.Tensor: Mask tensor.
List[torch.Tensor]: List of new cache tensors.
"""
if isinstance(self.embed, (Conv2dSubsampling, Conv2dSubsampling6, Conv2dSubsampling8)):
xs, masks = self.embed(xs, masks)
else:
xs = self.embed(xs)
if cache is None:
cache = [None for _ in range(len(self.encoders))]
new_cache = []
for c, e in zip(cache, self.encoders):
xs, masks = e(xs, masks, cache=c)
if isinstance(xs, tuple):
new_cache.append(xs[0])
else:
new_cache.append(xs)
if isinstance(xs, tuple):
xs = xs[0]
if self.normalize_before:
xs = self.after_norm(xs)
return xs, masks, new_cache

View File

@ -0,0 +1,345 @@
from typing import Any
from typing import List
from typing import Tuple
import torch
import torch.nn as nn
from funasr.models.transformer.embedding import PositionalEncoding, ScaledPositionalEncoding, RelPositionalEncoding, LegacyRelPositionalEncoding
from funasr.models.llm_asr.transformer_encoder import TransformerEncoder_s0 as Encoder
from funasr.models.transformer.utils.mask import subsequent_mask
from funasr.models.transformer.utils.nets_utils import make_pad_mask
import logging
from distutils.version import LooseVersion
from contextlib import contextmanager
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
from torch.cuda.amp import autocast
else:
# Nothing to do if torch<1.6.0
@contextmanager
def autocast(enabled=True):
yield
class TransformerEmbedLM(nn.Module):
def __init__(
self,
vocab_size: int,
pos_enc: str = None,
embed_unit: int = 128,
att_unit: int = 256,
head: int = 2,
unit: int = 1024,
layer: int = 4,
dropout_rate: float = 0.5,
attention_dropout_rate: float = 0.0,
pe_type: str = "split",
bidirectional_inputs: bool = False,
text_vocab_size: int = 4000,
input_aug_conf: dict = None,
output_aug_conf: dict = None,
codec_groups: int = 4,
selfattention_layer_type: str = "selfattn",
input_normalize: bool = False,
use_decoder: bool = True,
encoder_type: str = "transformer",
**kwargs
):
super().__init__()
if pos_enc == "sinusoidal":
pos_enc_class = PositionalEncoding
elif pos_enc == "abs_pos":
pos_enc_class = PositionalEncoding
elif pos_enc == "scaled_abs_pos":
pos_enc_class = ScaledPositionalEncoding
elif pos_enc == "rel_pos":
assert selfattention_layer_type == "rel_selfattn"
pos_enc_class = RelPositionalEncoding
elif pos_enc == "legacy_rel_pos":
assert selfattention_layer_type == "legacy_rel_selfattn"
pos_enc_class = LegacyRelPositionalEncoding
logging.warning(
"Using legacy_rel_pos and it will be deprecated in the future."
)
elif pos_enc is None:
def pos_enc_class(*args, **kwargs):
return nn.Sequential() # indentity
else:
raise ValueError(f"unknown pos-enc option: {pos_enc}")
self.embed_unit = embed_unit
self.pe_type = pe_type
self.encoder_type = encoder_type
if encoder_type == "llama":
raise NotImplementedError("llama encoder has not been implemented")
# from cosyvoice.nets.encoder.llama_encoder import LlamaEncoder
# # set causal to false, using mask to control causal mode.
# self.encoder = LlamaEncoder(
# input_size=embed_unit,
# output_size=att_unit,
# attention_heads=head,
# num_blocks=layer,
# dropout_rate=dropout_rate,
# attention_dropout_rate=attention_dropout_rate,
# causal=False,
# linear_units=unit,
# )
else:
self.encoder = Encoder(
idim=embed_unit,
attention_dim=att_unit,
attention_heads=head,
linear_units=unit,
num_blocks=layer,
dropout_rate=dropout_rate,
positional_dropout_rate=dropout_rate,
attention_dropout_rate=attention_dropout_rate,
input_layer="none" if pe_type == "split" else "linear",
pos_enc_class=pos_enc_class,
selfattention_layer_type=selfattention_layer_type,
)
if use_decoder:
self.decoder = nn.Linear(att_unit, vocab_size)
else:
self.decoder = None
self.attn_unit = att_unit
self.pos_enc_func = None
if pe_type == "split":
assert pos_enc == "sinusoidal" or pos_enc == "abs_pos" or pos_enc == "scaled_abs_pos", \
"Different positional embedding for inputs and outputs " \
"only supports sinusoidal, abs_pos and scaled_abs_pos."
self.pos_enc_func = pos_enc_class(embed_unit, 0.1)
self.input_layer = torch.nn.Linear(embed_unit, att_unit)
self.bidirectional_inputs = bidirectional_inputs
self.text_vocab_size = text_vocab_size
self.codec_groups = codec_groups
self.input_aug = None
if input_aug_conf is not None:
from funasr.models.specaug.specaug import SpecAug
self.input_aug = SpecAug(**input_aug_conf)
self.output_aug = None
if output_aug_conf is not None:
from funasr.models.specaug.specaug import SpecAug
self.output_aug = SpecAug(**output_aug_conf)
self.normalize = None
if input_normalize:
from funasr.models.normalize.utterance_mvn import UtteranceMVN
self.normalize = UtteranceMVN()
self.first_pack_mask_conf: dict = kwargs.get("first_pack_mask_conf", None)
def output_size(self):
return self.attn_unit
def _target_mask(self, lengths):
ys_mask = ~make_pad_mask(lengths)
m = subsequent_mask(ys_mask.size(-1), device=ys_mask.device).unsqueeze(0)
return ys_mask.unsqueeze(-2) & m
def clac_first_package_mask(self, mask, input_lengths, cond_lengths):
device = mask.device
mask_type = self.first_pack_mask_conf.get("mask_type", "first_pack")
fp_token_len = self.first_pack_mask_conf["fp_token_len"]
fp_text_len = self.first_pack_mask_conf["fp_text_len"]
# NOTE: fp_text_len excluding sos, xvec, only including text
# NOTE: cond_lengths including sos, xvec and text
if mask_type == "streaming":
for i, (seq_len, cond_len) in enumerate(zip(input_lengths, cond_lengths)):
# 1 for task_id
token_len = seq_len - cond_len - 1
if token_len > 0:
target_text_len = torch.ceil(torch.arange(1, token_len+1, device=device) / fp_token_len) * fp_text_len
# 2 for sos and xvec, M -> M x 1
target_text_len = torch.minimum(target_text_len + 2, cond_len).unsqueeze(1)
# 1 x N
pos_range = torch.arange(0, cond_len, device=device).unsqueeze(0)
# M x N
text_mask = pos_range < target_text_len
# 1 for <task_id>
mask[i, cond_len+1:seq_len, :cond_len] = mask[i, cond_len+1:seq_len, :cond_len] * text_mask
else:
for i, (seq_len, cond_len) in enumerate(zip(input_lengths, cond_lengths)):
mask_token_end = min(cond_len+1+fp_token_len, seq_len)
mask[i, cond_len+1:mask_token_end, fp_text_len+2:cond_len] = 0
return mask
def forward(
self,
input: torch.Tensor,
input_lengths: torch.Tensor,
cond_lengths: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Compute LM loss value from buffer sequences.
Args:
input (torch.Tensor): Input ids. (batch, len, dim)
input_lengths (torch.Tensor): length of input. (batch,)
cond_lengths (torch.Tensor): length of conditions (including sos, excluding taskid). (batch,)
"""
mask = self._target_mask(input_lengths).to(input.device)
if self.first_pack_mask_conf is not None:
mask = self.clac_first_package_mask(mask, input_lengths, cond_lengths)
if self.bidirectional_inputs:
for i, length in enumerate(cond_lengths):
mask[i, :length, :length] = True
pos_emb = None
if self.pe_type == "split":
pos_emb = torch.zeros((input.shape[0], input.shape[1]*2-1, self.attn_unit)).to(input)
kk = self.codec_groups
# with torch.no_grad():
with autocast(False):
for i, length in enumerate(cond_lengths):
# perform specaug for each frame including multi-group.
raw_feat = input[i:i + 1, 1:length].clone()
bb, tt, dd = raw_feat.shape
raw_feat = raw_feat.reshape(bb, tt // kk, kk, dd).reshape(bb, tt // kk, kk * dd)
if self.input_aug is not None and self.training:
raw_feat = self.input_aug(raw_feat, (cond_lengths[i:i+1] - 1) // kk)[0]
if self.normalize is not None:
raw_feat = self.normalize(raw_feat, None)[0]
input[i:i + 1, 1:length] = raw_feat.reshape(bb, tt//kk, kk, dd).reshape(bb, tt, dd)
if self.output_aug is not None and self.training:
raw_feat = input[i:i + 1, length+1:].clone()
aug_feat = self.output_aug(raw_feat, input_lengths[i:i+1] - length - 2)[0]
input[i:i + 1, length + 1:] = aug_feat
# add positional encoding
if self.pe_type == "split" and self.pos_enc_func is not None:
posed_input = self.pos_enc_func(input[i:i + 1, :length].clone())
if isinstance(posed_input, tuple):
pos_emb[i:i+1, :length*2-1] = posed_input[1]
posed_input = posed_input[0]
input[i:i + 1, :length] = posed_input
posed_output = self.pos_enc_func(input[i:i + 1, length + 1:].clone())
if isinstance(posed_output, tuple):
pos_emb[i:i+1, length*2: length*2+posed_output[1].shape[1]] = posed_output[1]
posed_output = posed_output[0]
input[i:i + 1, length + 1:] = posed_output
if self.pe_type == "split":
input = self.input_layer(input)
if isinstance(self.pos_enc_func, (RelPositionalEncoding, LegacyRelPositionalEncoding)):
input = (input, pos_emb)
# logging.info(f"shapes {input.shape} {mask.shape} {input_lengths}")
h, _ = self.encoder(input, mask)
if self.decoder is None:
return h, h
y = self.decoder(h)
return y, h
def init_state(self, x: torch.Tensor):
return None
def score(
self, y: torch.Tensor, state: Any, x: torch.Tensor
) -> Tuple[torch.Tensor, Any]:
"""Score new token.
Args:
y (torch.Tensor): 2D torch.float prefix embeddings.
state: Scorer state for prefix tokens
x (torch.Tensor): encoder feature that generates ys.
Returns:
tuple[torch.Tensor, Any]: Tuple of
torch.float32 scores for next token (vocab_size)
and next state for ys
"""
# this implementation is much faster than the blow!!
mask = torch.tril(torch.ones((1, y.shape[0], y.shape[0]), device=y.device)).to(torch.bool)
y_emb = y.unsqueeze(0).to(x.device)
# lengths = y_emb.new_full([1], dtype=torch.long, fill_value=y_emb.size(1))
# mask = self._target_mask(lengths).to(y_emb.device)
# x includes <sos>, feat, <task_id>
input_length = x.shape[0] - 1
if self.bidirectional_inputs:
mask[:1, :input_length, :input_length] = True
# if self.first_pack_mask_conf is not None:
# mask = self.clac_first_package_mask(
# mask,
# torch.tensor([y.shape[0]], device=y.device),
# torch.tensor([input_length], device=y.device),
# )
if self.pe_type == "split" and self.pos_enc_func is not None:
pos_emb = torch.zeros((y_emb.shape[0], y_emb.shape[1], self.attn_unit)).to(y_emb)
posed_input = self.pos_enc_func(y_emb[:1, :input_length])
if isinstance(posed_input, tuple):
pos_emb[:1, :input_length] = posed_input[1]
posed_input = posed_input[0]
y_emb[:1, :input_length] = posed_input
posed_output = self.pos_enc_func(y_emb[:1, input_length + 1:])
if isinstance(posed_output, tuple):
pos_emb[:1, input_length + 1:] = posed_output[1]
posed_output = posed_output[0]
y_emb[:1, input_length + 1:] = posed_output
if self.pe_type == "split":
y_emb = self.input_layer(y_emb)
if isinstance(self.pos_enc_func, (RelPositionalEncoding, LegacyRelPositionalEncoding)):
y_emb = (y_emb, pos_emb)
lm_hidden_states, _, cache = self.encoder.forward_one_step(
y_emb, mask, cache=state
)
if self.decoder is None:
return lm_hidden_states[:, -1], cache
h = self.decoder(lm_hidden_states[:, -1])[:, :self.text_vocab_size]
logp = h.log_softmax(dim=-1).squeeze(0)
# return logp, cache
return logp, (cache, lm_hidden_states[:, -1])
def batch_score(
self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor
) -> Tuple[torch.Tensor, List[Any]]:
"""Score new token batch.
Args:
ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen).
states (List[Any]): Scorer states for prefix tokens.
xs (torch.Tensor):
The encoder feature that generates ys (n_batch, xlen, n_feat).
Returns:
tuple[torch.Tensor, List[Any]]: Tuple of
batchfied scores for next token with shape of `(n_batch, vocab_size)`
and next state list for ys.
"""
# merge states
n_batch = len(ys)
n_layers = len(self.encoder.encoders)
if states[0] is None:
batch_state = None
else:
# transpose state of [batch, layer] into [layer, batch]
batch_state = [
torch.stack([states[b][i] for b in range(n_batch)])
for i in range(n_layers)
]
# batch decoding
h, _, states = self.encoder.forward_one_step(
self.embed(ys), self._target_mask(ys), cache=batch_state
)
h = self.decoder(h[:, -1])
logp = h.log_softmax(dim=-1)
# transpose state of [layer, batch] into [batch, layer]
state_list = [[states[i][b] for i in range(n_layers)] for b in range(n_batch)]
return logp, state_list

View File

@ -11,7 +11,7 @@ import pdb
def load_pretrained_model(
path: str,
path,
model: torch.nn.Module,
ignore_init_mismatch: bool = True,
map_location: str = "cpu",
@ -100,3 +100,30 @@ def load_pretrained_model(
flag = obj.load_state_dict(dst_state, strict=True)
logging.info(f"Loading ckpt: {path}, status: {flag}")
# def load_pretrained_model(
# path,
# model: torch.nn.Module,
# ignore_init_mismatch: bool = True,
# map_location: str = "cpu",
# oss_bucket=None,
# scope_map=[],
# excludes=None,
# **kwargs,
# ):
# if isinstance(path, str):
# path = path.split(",")
#
# for i, path_i in enumerate(path):
# logging.info(f"Loading ckpt-{i}: {path_i}")
# _load_pretrained_model(
# path_i,
# model=model,
# ignore_init_mismatch=ignore_init_mismatch,
# map_location=map_location,
# oss_bucket=oss_bucket,
# scope_map=scope_map,
# excludes=excludes,
# **kwargs,
# )