mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
Merge branch 'dev_gzf_deepspeed' of http://gitlab.alibaba-inc.com/zhifu.gzf/FunASR into dev_gzf_deepspeed
This commit is contained in:
commit
8551c7f419
@ -224,18 +224,21 @@ class AutoModel:
|
||||
# init_param
|
||||
init_param = kwargs.get("init_param", None)
|
||||
if init_param is not None:
|
||||
if os.path.exists(init_param):
|
||||
logging.info(f"Loading pretrained params from {init_param}")
|
||||
load_pretrained_model(
|
||||
model=model,
|
||||
path=init_param,
|
||||
ignore_init_mismatch=kwargs.get("ignore_init_mismatch", True),
|
||||
oss_bucket=kwargs.get("oss_bucket", None),
|
||||
scope_map=kwargs.get("scope_map", []),
|
||||
excludes=kwargs.get("excludes", None),
|
||||
)
|
||||
else:
|
||||
print(f"error, init_param does not exist!: {init_param}")
|
||||
if isinstance(init_param, str):
|
||||
init_param = init_param.split(",")
|
||||
for i, init_param_i in enumerate(init_param):
|
||||
if os.path.exists(init_param_i):
|
||||
logging.info(f"Loading pretrained params from ckpt-{i}: {init_param_i}")
|
||||
load_pretrained_model(
|
||||
model=model,
|
||||
path=init_param_i,
|
||||
ignore_init_mismatch=kwargs.get("ignore_init_mismatch", True),
|
||||
oss_bucket=kwargs.get("oss_bucket", None),
|
||||
scope_map=kwargs.get("scope_map", []),
|
||||
excludes=kwargs.get("excludes", None),
|
||||
)
|
||||
else:
|
||||
print(f"error, init_param from ckpt-{i} does not exist!: {init_param_i}")
|
||||
|
||||
# fp16
|
||||
if kwargs.get("fp16", False):
|
||||
|
||||
@ -610,6 +610,8 @@ class OpenAIDatasetMultiTurnCodec(torch.utils.data.Dataset):
|
||||
fake_token_len_i = 0
|
||||
fbank_beg_i = -1
|
||||
fbank_lens_i = []
|
||||
speech = []
|
||||
speech_lengths = []
|
||||
for k, sub_str in enumerate(splits):
|
||||
if not sub_str.startswith("<|startofspeech|>"):
|
||||
sub_token = self.tokenizer.encode(sub_str)
|
||||
@ -688,9 +690,11 @@ class OpenAIDatasetMultiTurnCodec(torch.utils.data.Dataset):
|
||||
|
||||
input_ids += source_ids + target_ids
|
||||
labels += source_mask + target_ids
|
||||
fbank.append(speech[0, :, :])
|
||||
|
||||
fbank_mask += fbank_mask_i
|
||||
fbank_lens.append(speech_lengths)
|
||||
if len(speech) > 0:
|
||||
fbank.append(speech[0, :, :])
|
||||
fbank_lens.append(speech_lengths)
|
||||
|
||||
if badcase_flag:
|
||||
continue
|
||||
@ -706,8 +710,6 @@ class OpenAIDatasetMultiTurnCodec(torch.utils.data.Dataset):
|
||||
fake_token_len = torch.tensor(fake_token_len, dtype=torch.int32)
|
||||
|
||||
output = {
|
||||
"speech": fbank,
|
||||
"speech_lengths": fbank_lens,
|
||||
"fbank_mask": fbank_mask,
|
||||
"fbank_beg": fbank_beg,
|
||||
"fake_token_len": fake_token_len,
|
||||
@ -719,6 +721,10 @@ class OpenAIDatasetMultiTurnCodec(torch.utils.data.Dataset):
|
||||
codec_len = torch.tensor(codec_len, dtype=torch.int32)
|
||||
output["codec"] = codec
|
||||
output["codec_len"] = codec_len
|
||||
if len(fbank) > 0:
|
||||
output["speech"] = fbank
|
||||
output["speech_lengths"] = fbank_lens
|
||||
|
||||
break
|
||||
|
||||
return output
|
||||
|
||||
@ -55,7 +55,7 @@ class OpenAIIndexDSJsonl(torch.utils.data.Dataset): # torch.utils.data.Dataset
|
||||
text_length = data_dict.get("text_length", 0)
|
||||
if speech_length > self.max_source_length:
|
||||
logging.info(
|
||||
"speech_length: {speech_length} > {self.max_source_length}, drop it"
|
||||
f"speech_length: {speech_length} > {self.max_source_length}, drop it"
|
||||
)
|
||||
continue
|
||||
if text_length > self.max_target_length:
|
||||
|
||||
@ -59,10 +59,19 @@ def download_from_ms(**kwargs):
|
||||
elif os.path.exists(os.path.join(model_or_path, "config.yaml")):
|
||||
config = OmegaConf.load(os.path.join(model_or_path, "config.yaml"))
|
||||
kwargs = OmegaConf.merge(config, kwargs)
|
||||
init_param = os.path.join(model_or_path, "model.pt")
|
||||
if "init_param" not in kwargs or not os.path.exists(kwargs["init_param"]):
|
||||
kwargs["init_param"] = init_param
|
||||
assert os.path.exists(kwargs["init_param"]), "init_param does not exist"
|
||||
|
||||
init_param = kwargs.get("init_param", "")
|
||||
if not os.path.exists(init_param):
|
||||
init_param_new = init_param
|
||||
if isinstance(init_param, str):
|
||||
init_param = init_param.split(",")
|
||||
for init_param_i in init_param:
|
||||
if not os.path.exists(init_param_i):
|
||||
print(f"init_param: {init_param_i}, does not exist")
|
||||
init_param_i = os.path.join(model_or_path, "model.pt")
|
||||
init_param_new = f"{init_param_new},{init_param_i}"
|
||||
kwargs["init_param"] = init_param_new
|
||||
# assert os.path.exists(kwargs["init_param"]), "init_param does not exist"
|
||||
if os.path.exists(os.path.join(model_or_path, "tokens.txt")):
|
||||
kwargs["tokenizer_conf"]["token_list"] = os.path.join(model_or_path, "tokens.txt")
|
||||
if os.path.exists(os.path.join(model_or_path, "tokens.json")):
|
||||
|
||||
82
funasr/models/llm_asr/label_smoothing_loss.py
Normal file
82
funasr/models/llm_asr/label_smoothing_loss.py
Normal file
@ -0,0 +1,82 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2019 Shigeki Karita
|
||||
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||
|
||||
"""Label smoothing module."""
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from funasr.models.transformer.utils.nets_utils import make_pad_mask
|
||||
|
||||
|
||||
class LabelSmoothingLoss(nn.Module):
|
||||
"""Label-smoothing loss.
|
||||
|
||||
:param int size: the number of class
|
||||
:param int padding_idx: ignored class id
|
||||
:param float smoothing: smoothing rate (0.0 means the conventional CE)
|
||||
:param bool normalize_length: normalize loss by sequence length if True
|
||||
:param torch.nn.Module criterion: loss function to be smoothed
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
size,
|
||||
padding_idx,
|
||||
smoothing,
|
||||
normalize_length=False,
|
||||
criterion=nn.KLDivLoss(reduction="none"),
|
||||
reduction=True,
|
||||
):
|
||||
"""Construct an LabelSmoothingLoss object."""
|
||||
super(LabelSmoothingLoss, self).__init__()
|
||||
self.criterion = criterion
|
||||
self.padding_idx = padding_idx
|
||||
self.confidence = 1.0 - smoothing
|
||||
self.smoothing = smoothing
|
||||
self.size = size
|
||||
self.true_dist = None
|
||||
self.normalize_length = normalize_length
|
||||
self.reduction = reduction
|
||||
|
||||
def forward(self, x, target):
|
||||
"""Compute loss between x and target.
|
||||
|
||||
:param torch.Tensor x: prediction (batch, seqlen, class)
|
||||
:param torch.Tensor target:
|
||||
target signal masked with self.padding_id (batch, seqlen)
|
||||
:return: scalar float value
|
||||
:rtype torch.Tensor
|
||||
"""
|
||||
assert x.size(2) == self.size
|
||||
batch_size = x.size(0)
|
||||
x = x.reshape(-1, self.size)
|
||||
target = target.reshape(-1)
|
||||
with torch.no_grad():
|
||||
true_dist = x.clone()
|
||||
true_dist.fill_(self.smoothing / (self.size - 1))
|
||||
ignore = target == self.padding_idx # (B,)
|
||||
total = len(target) - ignore.sum().item()
|
||||
target = target.masked_fill(ignore, 0) # avoid -1 index
|
||||
true_dist.scatter_(1, target.unsqueeze(1), self.confidence)
|
||||
kl = self.criterion(torch.log_softmax(x, dim=1), true_dist)
|
||||
if not self.reduction:
|
||||
return kl
|
||||
else:
|
||||
denom = total if self.normalize_length else batch_size
|
||||
return kl.masked_fill(ignore.unsqueeze(1), 0).sum() / denom
|
||||
|
||||
|
||||
class SequenceBinaryCrossEntropy(nn.Module):
|
||||
def __init__(self, normalize_length=False, criterion=nn.BCEWithLogitsLoss(reduction="none")):
|
||||
super().__init__()
|
||||
self.normalize_length = normalize_length
|
||||
self.criterion = criterion
|
||||
|
||||
def forward(self, pred, label, lengths):
|
||||
pad_mask = make_pad_mask(lengths, maxlen=pred.shape[1]).to(pred.device)
|
||||
loss = self.criterion(pred, label)
|
||||
denom = (~pad_mask).sum() if self.normalize_length else pred.shape[0]
|
||||
return loss.masked_fill(pad_mask.unsqueeze(-1), 0).sum() / denom
|
||||
File diff suppressed because it is too large
Load Diff
751
funasr/models/llm_asr/transformer_encoder.py
Normal file
751
funasr/models/llm_asr/transformer_encoder.py
Normal file
@ -0,0 +1,751 @@
|
||||
# Copyright 2019 Shigeki Karita
|
||||
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||
|
||||
"""Transformer encoder definition."""
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
import torch
|
||||
from torch import nn
|
||||
import logging
|
||||
from funasr.models.transformer.attention import (
|
||||
MultiHeadedAttention,
|
||||
RelPositionMultiHeadedAttention, # noqa: H301
|
||||
LegacyRelPositionMultiHeadedAttention, # noqa: H301
|
||||
)
|
||||
from funasr.models.transformer.embedding import (
|
||||
PositionalEncoding, # noqa: H301
|
||||
ScaledPositionalEncoding, # noqa: H301
|
||||
RelPositionalEncoding, # noqa: H301
|
||||
LegacyRelPositionalEncoding, # noqa: H301
|
||||
)
|
||||
from funasr.models.transformer.layer_norm import LayerNorm
|
||||
from funasr.models.transformer.utils.multi_layer_conv import Conv1dLinear
|
||||
from funasr.models.transformer.utils.multi_layer_conv import MultiLayeredConv1d
|
||||
from funasr.models.transformer.utils.nets_utils import make_pad_mask
|
||||
from funasr.models.transformer.positionwise_feed_forward import (
|
||||
PositionwiseFeedForward, # noqa: H301
|
||||
)
|
||||
from funasr.models.transformer.utils.repeat import repeat
|
||||
from funasr.models.transformer.utils.nets_utils import rename_state_dict
|
||||
from funasr.models.transformer.utils.dynamic_conv import DynamicConvolution
|
||||
from funasr.models.transformer.utils.dynamic_conv2d import DynamicConvolution2D
|
||||
from funasr.models.transformer.utils.lightconv import LightweightConvolution
|
||||
from funasr.models.transformer.utils.lightconv2d import LightweightConvolution2D
|
||||
from funasr.models.transformer.utils.subsampling import Conv2dSubsampling
|
||||
from funasr.models.transformer.utils.subsampling import Conv2dSubsampling2
|
||||
from funasr.models.transformer.utils.subsampling import Conv2dSubsampling6
|
||||
from funasr.models.transformer.utils.subsampling import Conv2dSubsampling8
|
||||
from funasr.models.transformer.utils.subsampling import TooShortUttError
|
||||
from funasr.models.transformer.utils.subsampling import check_short_utt
|
||||
|
||||
|
||||
class EncoderLayer(nn.Module):
|
||||
"""Encoder layer module.
|
||||
|
||||
Args:
|
||||
size (int): Input dimension.
|
||||
self_attn (torch.nn.Module): Self-attention module instance.
|
||||
`MultiHeadedAttention` or `RelPositionMultiHeadedAttention` instance
|
||||
can be used as the argument.
|
||||
feed_forward (torch.nn.Module): Feed-forward module instance.
|
||||
`PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance
|
||||
can be used as the argument.
|
||||
dropout_rate (float): Dropout rate.
|
||||
normalize_before (bool): Whether to use layer_norm before the first block.
|
||||
concat_after (bool): Whether to concat attention layer's input and output.
|
||||
if True, additional linear will be applied.
|
||||
i.e. x -> x + linear(concat(x, att(x)))
|
||||
if False, no additional linear will be applied. i.e. x -> x + att(x)
|
||||
stochastic_depth_rate (float): Proability to skip this layer.
|
||||
During training, the layer may skip residual computation and return input
|
||||
as-is with given probability.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
size,
|
||||
self_attn,
|
||||
feed_forward,
|
||||
dropout_rate,
|
||||
normalize_before=True,
|
||||
concat_after=False,
|
||||
stochastic_depth_rate=0.0,
|
||||
):
|
||||
"""Construct an EncoderLayer object."""
|
||||
super(EncoderLayer, self).__init__()
|
||||
self.self_attn = self_attn
|
||||
self.feed_forward = feed_forward
|
||||
self.norm1 = LayerNorm(size)
|
||||
self.norm2 = LayerNorm(size)
|
||||
self.dropout = nn.Dropout(dropout_rate)
|
||||
self.size = size
|
||||
self.normalize_before = normalize_before
|
||||
self.concat_after = concat_after
|
||||
if self.concat_after:
|
||||
self.concat_linear = nn.Linear(size + size, size)
|
||||
self.stochastic_depth_rate = stochastic_depth_rate
|
||||
|
||||
def forward(self, x, mask, cache=None):
|
||||
"""Compute encoded features.
|
||||
|
||||
Args:
|
||||
x_input (torch.Tensor): Input tensor (#batch, time, size).
|
||||
mask (torch.Tensor): Mask tensor for the input (#batch, time).
|
||||
cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Output tensor (#batch, time, size).
|
||||
torch.Tensor: Mask tensor (#batch, time).
|
||||
|
||||
"""
|
||||
if isinstance(x, tuple):
|
||||
x, pos_emb = x[0], x[1]
|
||||
else:
|
||||
x, pos_emb = x, None
|
||||
|
||||
skip_layer = False
|
||||
# with stochastic depth, residual connection `x + f(x)` becomes
|
||||
# `x <- x + 1 / (1 - p) * f(x)` at training time.
|
||||
stoch_layer_coeff = 1.0
|
||||
if self.training and self.stochastic_depth_rate > 0:
|
||||
skip_layer = torch.rand(1).item() < self.stochastic_depth_rate
|
||||
stoch_layer_coeff = 1.0 / (1 - self.stochastic_depth_rate)
|
||||
|
||||
if skip_layer:
|
||||
if cache is not None:
|
||||
x = torch.cat([cache, x], dim=1)
|
||||
if pos_emb is not None:
|
||||
return (x, pos_emb), mask
|
||||
return x, mask
|
||||
|
||||
residual = x
|
||||
if self.normalize_before:
|
||||
x = self.norm1(x)
|
||||
|
||||
if cache is None:
|
||||
x_q = x
|
||||
else:
|
||||
assert cache.shape == (x.shape[0], x.shape[1] - 1, self.size)
|
||||
x_q = x[:, -1:, :]
|
||||
residual = residual[:, -1:, :]
|
||||
mask = None if mask is None else mask[:, -1:, :]
|
||||
|
||||
if pos_emb is not None:
|
||||
x_att = self.self_attn(x_q, x, x, pos_emb, mask)
|
||||
else:
|
||||
x_att = self.self_attn(x_q, x, x, mask)
|
||||
|
||||
if self.concat_after:
|
||||
x_concat = torch.cat((x, x_att), dim=-1)
|
||||
x = residual + stoch_layer_coeff * self.concat_linear(x_concat)
|
||||
else:
|
||||
x = residual + stoch_layer_coeff * self.dropout(x_att)
|
||||
if not self.normalize_before:
|
||||
x = self.norm1(x)
|
||||
|
||||
residual = x
|
||||
if self.normalize_before:
|
||||
x = self.norm2(x)
|
||||
x = residual + stoch_layer_coeff * self.dropout(self.feed_forward(x))
|
||||
if not self.normalize_before:
|
||||
x = self.norm2(x)
|
||||
|
||||
if cache is not None:
|
||||
x = torch.cat([cache, x], dim=1)
|
||||
|
||||
if pos_emb is not None:
|
||||
return (x, pos_emb), mask
|
||||
|
||||
return x, mask
|
||||
|
||||
|
||||
class TransformerEncoder(nn.Module):
|
||||
"""Transformer encoder module.
|
||||
|
||||
Args:
|
||||
input_size: input dim
|
||||
output_size: dimension of attention
|
||||
attention_heads: the number of heads of multi head attention
|
||||
linear_units: the number of units of position-wise feed forward
|
||||
num_blocks: the number of decoder blocks
|
||||
dropout_rate: dropout rate
|
||||
attention_dropout_rate: dropout rate in attention
|
||||
positional_dropout_rate: dropout rate after adding positional encoding
|
||||
input_layer: input layer type
|
||||
pos_enc_class: PositionalEncoding or ScaledPositionalEncoding
|
||||
normalize_before: whether to use layer_norm before the first block
|
||||
concat_after: whether to concat attention layer's input and output
|
||||
if True, additional linear will be applied.
|
||||
i.e. x -> x + linear(concat(x, att(x)))
|
||||
if False, no additional linear will be applied.
|
||||
i.e. x -> x + att(x)
|
||||
positionwise_layer_type: linear of conv1d
|
||||
positionwise_conv_kernel_size: kernel size of positionwise conv1d layer
|
||||
padding_idx: padding_idx for input_layer=embed
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size: int,
|
||||
output_size: int = 256,
|
||||
attention_heads: int = 4,
|
||||
linear_units: int = 2048,
|
||||
num_blocks: int = 6,
|
||||
dropout_rate: float = 0.1,
|
||||
positional_dropout_rate: float = 0.1,
|
||||
attention_dropout_rate: float = 0.0,
|
||||
input_layer: Optional[str] = "conv2d",
|
||||
pos_enc_class=PositionalEncoding,
|
||||
normalize_before: bool = True,
|
||||
concat_after: bool = False,
|
||||
positionwise_layer_type: str = "linear",
|
||||
positionwise_conv_kernel_size: int = 1,
|
||||
padding_idx: int = -1,
|
||||
interctc_layer_idx: List[int] = [],
|
||||
interctc_use_conditioning: bool = False,
|
||||
causal_mode: str = "None",
|
||||
):
|
||||
super().__init__()
|
||||
self._output_size = output_size
|
||||
self.causal_mode = causal_mode
|
||||
|
||||
if input_layer == "linear":
|
||||
self.embed = torch.nn.Sequential(
|
||||
torch.nn.Linear(input_size, output_size),
|
||||
torch.nn.LayerNorm(output_size),
|
||||
torch.nn.Dropout(dropout_rate),
|
||||
torch.nn.ReLU(),
|
||||
pos_enc_class(output_size, positional_dropout_rate),
|
||||
)
|
||||
elif input_layer == "conv2d":
|
||||
self.embed = Conv2dSubsampling(input_size, output_size, dropout_rate)
|
||||
elif input_layer == "conv2d2":
|
||||
self.embed = Conv2dSubsampling2(input_size, output_size, dropout_rate)
|
||||
elif input_layer == "conv2d6":
|
||||
self.embed = Conv2dSubsampling6(input_size, output_size, dropout_rate)
|
||||
elif input_layer == "conv2d8":
|
||||
self.embed = Conv2dSubsampling8(input_size, output_size, dropout_rate)
|
||||
elif input_layer == "embed":
|
||||
self.embed = torch.nn.Sequential(
|
||||
torch.nn.Embedding(input_size, output_size, padding_idx=padding_idx),
|
||||
pos_enc_class(output_size, positional_dropout_rate),
|
||||
)
|
||||
elif input_layer is None:
|
||||
if input_size == output_size:
|
||||
self.embed = None
|
||||
else:
|
||||
self.embed = torch.nn.Linear(input_size, output_size)
|
||||
else:
|
||||
raise ValueError("unknown input_layer: " + input_layer)
|
||||
self.normalize_before = normalize_before
|
||||
if positionwise_layer_type == "linear":
|
||||
positionwise_layer = PositionwiseFeedForward
|
||||
positionwise_layer_args = (
|
||||
output_size,
|
||||
linear_units,
|
||||
dropout_rate,
|
||||
)
|
||||
elif positionwise_layer_type == "conv1d":
|
||||
positionwise_layer = MultiLayeredConv1d
|
||||
positionwise_layer_args = (
|
||||
output_size,
|
||||
linear_units,
|
||||
positionwise_conv_kernel_size,
|
||||
dropout_rate,
|
||||
)
|
||||
elif positionwise_layer_type == "conv1d-linear":
|
||||
positionwise_layer = Conv1dLinear
|
||||
positionwise_layer_args = (
|
||||
output_size,
|
||||
linear_units,
|
||||
positionwise_conv_kernel_size,
|
||||
dropout_rate,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError("Support only linear or conv1d.")
|
||||
self.encoders = repeat(
|
||||
num_blocks,
|
||||
lambda lnum: EncoderLayer(
|
||||
output_size,
|
||||
MultiHeadedAttention(
|
||||
attention_heads, output_size, attention_dropout_rate
|
||||
),
|
||||
positionwise_layer(*positionwise_layer_args),
|
||||
dropout_rate,
|
||||
normalize_before,
|
||||
concat_after,
|
||||
),
|
||||
)
|
||||
if self.normalize_before:
|
||||
self.after_norm = LayerNorm(output_size)
|
||||
|
||||
self.interctc_layer_idx = interctc_layer_idx
|
||||
if len(interctc_layer_idx) > 0:
|
||||
assert 0 < min(interctc_layer_idx) and max(interctc_layer_idx) < num_blocks
|
||||
self.interctc_use_conditioning = interctc_use_conditioning
|
||||
self.conditioning_layer = None
|
||||
|
||||
def output_size(self) -> int:
|
||||
return self._output_size
|
||||
|
||||
def forward(
|
||||
self,
|
||||
xs_pad: torch.Tensor,
|
||||
ilens: torch.Tensor,
|
||||
prev_states: torch.Tensor = None,
|
||||
ctc = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
||||
"""Embed positions in tensor.
|
||||
|
||||
Args:
|
||||
xs_pad: input tensor (B, L, D)
|
||||
ilens: input length (B)
|
||||
prev_states: Not to be used now.
|
||||
Returns:
|
||||
position embedded tensor and mask
|
||||
"""
|
||||
masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
|
||||
if self.causal_mode == "None":
|
||||
pass
|
||||
elif self.causal_mode == "causal":
|
||||
tt = xs_pad.shape[1]
|
||||
pos_idx = torch.arange(tt)
|
||||
causal_mask = torch.less_equal(pos_idx.unsqueeze(0), pos_idx.unsqueeze(1))
|
||||
causal_mask = causal_mask.unsqueeze(0).to(xs_pad.device)
|
||||
masks = masks * causal_mask
|
||||
|
||||
if self.embed is None:
|
||||
xs_pad = xs_pad
|
||||
elif (
|
||||
isinstance(self.embed, Conv2dSubsampling)
|
||||
or isinstance(self.embed, Conv2dSubsampling2)
|
||||
or isinstance(self.embed, Conv2dSubsampling6)
|
||||
or isinstance(self.embed, Conv2dSubsampling8)
|
||||
):
|
||||
short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
|
||||
if short_status:
|
||||
raise TooShortUttError(
|
||||
f"has {xs_pad.size(1)} frames and is too short for subsampling "
|
||||
+ f"(it needs more than {limit_size} frames), return empty results",
|
||||
xs_pad.size(1),
|
||||
limit_size,
|
||||
)
|
||||
xs_pad, masks = self.embed(xs_pad, masks)
|
||||
else:
|
||||
xs_pad = self.embed(xs_pad)
|
||||
|
||||
intermediate_outs = []
|
||||
if len(self.interctc_layer_idx) == 0:
|
||||
xs_pad, masks = self.encoders(xs_pad, masks)
|
||||
else:
|
||||
for layer_idx, encoder_layer in enumerate(self.encoders):
|
||||
xs_pad, masks = encoder_layer(xs_pad, masks)
|
||||
|
||||
if layer_idx + 1 in self.interctc_layer_idx:
|
||||
encoder_out = xs_pad
|
||||
|
||||
# intermediate outputs are also normalized
|
||||
if self.normalize_before:
|
||||
encoder_out = self.after_norm(encoder_out)
|
||||
|
||||
intermediate_outs.append((layer_idx + 1, encoder_out))
|
||||
|
||||
if self.interctc_use_conditioning:
|
||||
ctc_out = ctc.softmax(encoder_out)
|
||||
xs_pad = xs_pad + self.conditioning_layer(ctc_out)
|
||||
|
||||
if self.normalize_before:
|
||||
xs_pad = self.after_norm(xs_pad)
|
||||
|
||||
olens = masks.squeeze(1).sum(1)
|
||||
if len(intermediate_outs) > 0:
|
||||
return (xs_pad, intermediate_outs), olens, None
|
||||
return xs_pad, olens, None
|
||||
|
||||
|
||||
def _pre_hook(
|
||||
state_dict,
|
||||
prefix,
|
||||
local_metadata,
|
||||
strict,
|
||||
missing_keys,
|
||||
unexpected_keys,
|
||||
error_msgs,
|
||||
):
|
||||
# https://github.com/espnet/espnet/commit/21d70286c354c66c0350e65dc098d2ee236faccc#diff-bffb1396f038b317b2b64dd96e6d3563
|
||||
rename_state_dict(prefix + "input_layer.", prefix + "embed.", state_dict)
|
||||
# https://github.com/espnet/espnet/commit/3d422f6de8d4f03673b89e1caef698745ec749ea#diff-bffb1396f038b317b2b64dd96e6d3563
|
||||
rename_state_dict(prefix + "norm.", prefix + "after_norm.", state_dict)
|
||||
|
||||
|
||||
class TransformerEncoder_s0(nn.Module):
|
||||
"""Transformer encoder module.
|
||||
|
||||
Args:
|
||||
idim (int): Input dimension.
|
||||
attention_dim (int): Dimension of attention.
|
||||
attention_heads (int): The number of heads of multi head attention.
|
||||
conv_wshare (int): The number of kernel of convolution. Only used in
|
||||
selfattention_layer_type == "lightconv*" or "dynamiconv*".
|
||||
conv_kernel_length (Union[int, str]): Kernel size str of convolution
|
||||
(e.g. 71_71_71_71_71_71). Only used in selfattention_layer_type
|
||||
== "lightconv*" or "dynamiconv*".
|
||||
conv_usebias (bool): Whether to use bias in convolution. Only used in
|
||||
selfattention_layer_type == "lightconv*" or "dynamiconv*".
|
||||
linear_units (int): The number of units of position-wise feed forward.
|
||||
num_blocks (int): The number of decoder blocks.
|
||||
dropout_rate (float): Dropout rate.
|
||||
positional_dropout_rate (float): Dropout rate after adding positional encoding.
|
||||
attention_dropout_rate (float): Dropout rate in attention.
|
||||
input_layer (Union[str, torch.nn.Module]): Input layer type.
|
||||
pos_enc_class (torch.nn.Module): Positional encoding module class.
|
||||
`PositionalEncoding `or `ScaledPositionalEncoding`
|
||||
normalize_before (bool): Whether to use layer_norm before the first block.
|
||||
concat_after (bool): Whether to concat attention layer's input and output.
|
||||
if True, additional linear will be applied.
|
||||
i.e. x -> x + linear(concat(x, att(x)))
|
||||
if False, no additional linear will be applied. i.e. x -> x + att(x)
|
||||
positionwise_layer_type (str): "linear", "conv1d", or "conv1d-linear".
|
||||
positionwise_conv_kernel_size (int): Kernel size of positionwise conv1d layer.
|
||||
selfattention_layer_type (str): Encoder attention layer type.
|
||||
padding_idx (int): Padding idx for input_layer=embed.
|
||||
stochastic_depth_rate (float): Maximum probability to skip the encoder layer.
|
||||
intermediate_layers (Union[List[int], None]): indices of intermediate CTC layer.
|
||||
indices start from 1.
|
||||
if not None, intermediate outputs are returned (which changes return type
|
||||
signature.)
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
idim,
|
||||
attention_dim=256,
|
||||
attention_heads=4,
|
||||
conv_wshare=4,
|
||||
conv_kernel_length="11",
|
||||
conv_usebias=False,
|
||||
linear_units=2048,
|
||||
num_blocks=6,
|
||||
dropout_rate=0.1,
|
||||
positional_dropout_rate=0.1,
|
||||
attention_dropout_rate=0.0,
|
||||
input_layer="conv2d",
|
||||
pos_enc_class=PositionalEncoding,
|
||||
normalize_before=True,
|
||||
concat_after=False,
|
||||
positionwise_layer_type="linear",
|
||||
positionwise_conv_kernel_size=1,
|
||||
selfattention_layer_type="selfattn",
|
||||
padding_idx=-1,
|
||||
stochastic_depth_rate=0.0,
|
||||
intermediate_layers=None,
|
||||
ctc_softmax=None,
|
||||
conditioning_layer_dim=None,
|
||||
zero_triu: bool = False,
|
||||
):
|
||||
"""Construct an Encoder object."""
|
||||
super(TransformerEncoder_s0, self).__init__()
|
||||
self._register_load_state_dict_pre_hook(_pre_hook)
|
||||
|
||||
self.conv_subsampling_factor = 1
|
||||
if input_layer == "linear":
|
||||
self.embed = torch.nn.Sequential(
|
||||
torch.nn.Linear(idim, attention_dim),
|
||||
torch.nn.LayerNorm(attention_dim),
|
||||
torch.nn.Dropout(dropout_rate),
|
||||
torch.nn.ReLU(),
|
||||
pos_enc_class(attention_dim, positional_dropout_rate),
|
||||
)
|
||||
elif input_layer == "conv2d":
|
||||
self.embed = Conv2dSubsampling(idim, attention_dim, dropout_rate)
|
||||
self.conv_subsampling_factor = 4
|
||||
elif input_layer == "conv2d-scaled-pos-enc":
|
||||
self.embed = Conv2dSubsampling(
|
||||
idim,
|
||||
attention_dim,
|
||||
dropout_rate,
|
||||
pos_enc_class(attention_dim, positional_dropout_rate),
|
||||
)
|
||||
self.conv_subsampling_factor = 4
|
||||
elif input_layer == "conv2d6":
|
||||
self.embed = Conv2dSubsampling6(idim, attention_dim, dropout_rate)
|
||||
self.conv_subsampling_factor = 6
|
||||
elif input_layer == "conv2d8":
|
||||
self.embed = Conv2dSubsampling8(idim, attention_dim, dropout_rate)
|
||||
self.conv_subsampling_factor = 8
|
||||
elif input_layer == "embed":
|
||||
self.embed = torch.nn.Sequential(
|
||||
torch.nn.Embedding(idim, attention_dim, padding_idx=padding_idx),
|
||||
pos_enc_class(attention_dim, positional_dropout_rate),
|
||||
)
|
||||
elif isinstance(input_layer, torch.nn.Module):
|
||||
self.embed = torch.nn.Sequential(
|
||||
input_layer,
|
||||
pos_enc_class(attention_dim, positional_dropout_rate),
|
||||
)
|
||||
elif input_layer is None:
|
||||
self.embed = torch.nn.Sequential(
|
||||
pos_enc_class(attention_dim, positional_dropout_rate)
|
||||
)
|
||||
elif input_layer == "none":
|
||||
self.embed = torch.nn.Identity()
|
||||
else:
|
||||
raise ValueError("unknown input_layer: " + input_layer)
|
||||
self.normalize_before = normalize_before
|
||||
positionwise_layer, positionwise_layer_args = self.get_positionwise_layer(
|
||||
positionwise_layer_type,
|
||||
attention_dim,
|
||||
linear_units,
|
||||
dropout_rate,
|
||||
positionwise_conv_kernel_size,
|
||||
)
|
||||
# if selfattention_layer_type in [
|
||||
# "selfattn",
|
||||
# "rel_selfattn",
|
||||
# "legacy_rel_selfattn",
|
||||
# ]:
|
||||
# logging.info("encoder self-attention layer type = self-attention")
|
||||
# encoder_selfattn_layer = MultiHeadedAttention
|
||||
# encoder_selfattn_layer_args = [
|
||||
# (
|
||||
# attention_heads,
|
||||
# attention_dim,
|
||||
# attention_dropout_rate,
|
||||
# )
|
||||
# ] * num_blocks
|
||||
if selfattention_layer_type == "selfattn":
|
||||
logging.info("encoder self-attention layer type = self-attention")
|
||||
encoder_selfattn_layer = MultiHeadedAttention
|
||||
encoder_selfattn_layer_args = [(
|
||||
attention_heads,
|
||||
attention_dim,
|
||||
attention_dropout_rate,
|
||||
)] * num_blocks
|
||||
elif selfattention_layer_type == "legacy_rel_selfattn":
|
||||
logging.info("encoder self-attention layer type = legacy relative self-attention")
|
||||
assert pos_enc_class == LegacyRelPositionalEncoding
|
||||
encoder_selfattn_layer = LegacyRelPositionMultiHeadedAttention
|
||||
encoder_selfattn_layer_args = [(
|
||||
attention_heads,
|
||||
attention_dim,
|
||||
attention_dropout_rate,
|
||||
)] * num_blocks
|
||||
logging.warning(
|
||||
"Using legacy_rel_selfattn and it will be deprecated in the future."
|
||||
)
|
||||
elif selfattention_layer_type == "rel_selfattn":
|
||||
logging.info("encoder self-attention layer type = relative self-attention")
|
||||
assert pos_enc_class == RelPositionalEncoding
|
||||
encoder_selfattn_layer = RelPositionMultiHeadedAttention
|
||||
encoder_selfattn_layer_args = [(
|
||||
attention_heads,
|
||||
attention_dim,
|
||||
attention_dropout_rate,
|
||||
zero_triu,
|
||||
)] * num_blocks
|
||||
elif selfattention_layer_type == "lightconv":
|
||||
logging.info("encoder self-attention layer type = lightweight convolution")
|
||||
encoder_selfattn_layer = LightweightConvolution
|
||||
encoder_selfattn_layer_args = [
|
||||
(
|
||||
conv_wshare,
|
||||
attention_dim,
|
||||
attention_dropout_rate,
|
||||
int(conv_kernel_length.split("_")[lnum]),
|
||||
False,
|
||||
conv_usebias,
|
||||
)
|
||||
for lnum in range(num_blocks)
|
||||
]
|
||||
elif selfattention_layer_type == "lightconv2d":
|
||||
logging.info(
|
||||
"encoder self-attention layer "
|
||||
"type = lightweight convolution 2-dimensional"
|
||||
)
|
||||
encoder_selfattn_layer = LightweightConvolution2D
|
||||
encoder_selfattn_layer_args = [
|
||||
(
|
||||
conv_wshare,
|
||||
attention_dim,
|
||||
attention_dropout_rate,
|
||||
int(conv_kernel_length.split("_")[lnum]),
|
||||
False,
|
||||
conv_usebias,
|
||||
)
|
||||
for lnum in range(num_blocks)
|
||||
]
|
||||
elif selfattention_layer_type == "dynamicconv":
|
||||
logging.info("encoder self-attention layer type = dynamic convolution")
|
||||
encoder_selfattn_layer = DynamicConvolution
|
||||
encoder_selfattn_layer_args = [
|
||||
(
|
||||
conv_wshare,
|
||||
attention_dim,
|
||||
attention_dropout_rate,
|
||||
int(conv_kernel_length.split("_")[lnum]),
|
||||
False,
|
||||
conv_usebias,
|
||||
)
|
||||
for lnum in range(num_blocks)
|
||||
]
|
||||
elif selfattention_layer_type == "dynamicconv2d":
|
||||
logging.info(
|
||||
"encoder self-attention layer type = dynamic convolution 2-dimensional"
|
||||
)
|
||||
encoder_selfattn_layer = DynamicConvolution2D
|
||||
encoder_selfattn_layer_args = [
|
||||
(
|
||||
conv_wshare,
|
||||
attention_dim,
|
||||
attention_dropout_rate,
|
||||
int(conv_kernel_length.split("_")[lnum]),
|
||||
False,
|
||||
conv_usebias,
|
||||
)
|
||||
for lnum in range(num_blocks)
|
||||
]
|
||||
else:
|
||||
raise NotImplementedError(selfattention_layer_type)
|
||||
|
||||
self.encoders = repeat(
|
||||
num_blocks,
|
||||
lambda lnum: EncoderLayer(
|
||||
attention_dim,
|
||||
encoder_selfattn_layer(*encoder_selfattn_layer_args[lnum]),
|
||||
positionwise_layer(*positionwise_layer_args),
|
||||
dropout_rate,
|
||||
normalize_before,
|
||||
concat_after,
|
||||
stochastic_depth_rate * float(1 + lnum) / num_blocks,
|
||||
),
|
||||
)
|
||||
if self.normalize_before:
|
||||
self.after_norm = LayerNorm(attention_dim)
|
||||
|
||||
self.intermediate_layers = intermediate_layers
|
||||
self.use_conditioning = True if ctc_softmax is not None else False
|
||||
if self.use_conditioning:
|
||||
self.ctc_softmax = ctc_softmax
|
||||
self.conditioning_layer = torch.nn.Linear(
|
||||
conditioning_layer_dim, attention_dim
|
||||
)
|
||||
|
||||
def get_positionwise_layer(
|
||||
self,
|
||||
positionwise_layer_type="linear",
|
||||
attention_dim=256,
|
||||
linear_units=2048,
|
||||
dropout_rate=0.1,
|
||||
positionwise_conv_kernel_size=1,
|
||||
):
|
||||
"""Define positionwise layer."""
|
||||
if positionwise_layer_type == "linear":
|
||||
positionwise_layer = PositionwiseFeedForward
|
||||
positionwise_layer_args = (attention_dim, linear_units, dropout_rate)
|
||||
elif positionwise_layer_type == "conv1d":
|
||||
positionwise_layer = MultiLayeredConv1d
|
||||
positionwise_layer_args = (
|
||||
attention_dim,
|
||||
linear_units,
|
||||
positionwise_conv_kernel_size,
|
||||
dropout_rate,
|
||||
)
|
||||
elif positionwise_layer_type == "conv1d-linear":
|
||||
positionwise_layer = Conv1dLinear
|
||||
positionwise_layer_args = (
|
||||
attention_dim,
|
||||
linear_units,
|
||||
positionwise_conv_kernel_size,
|
||||
dropout_rate,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError("Support only linear or conv1d.")
|
||||
return positionwise_layer, positionwise_layer_args
|
||||
|
||||
def forward(self, xs, masks):
|
||||
"""Encode input sequence.
|
||||
|
||||
Args:
|
||||
xs (torch.Tensor): Input tensor (#batch, time, idim).
|
||||
masks (torch.Tensor): Mask tensor (#batch, time).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Output tensor (#batch, time, attention_dim).
|
||||
torch.Tensor: Mask tensor (#batch, time).
|
||||
|
||||
"""
|
||||
if isinstance(
|
||||
self.embed,
|
||||
(Conv2dSubsampling, Conv2dSubsampling6, Conv2dSubsampling8),
|
||||
):
|
||||
xs, masks = self.embed(xs, masks)
|
||||
else:
|
||||
xs = self.embed(xs)
|
||||
|
||||
if self.intermediate_layers is None:
|
||||
xs, masks = self.encoders(xs, masks)
|
||||
else:
|
||||
intermediate_outputs = []
|
||||
for layer_idx, encoder_layer in enumerate(self.encoders):
|
||||
xs, masks = encoder_layer(xs, masks)
|
||||
|
||||
if (
|
||||
self.intermediate_layers is not None
|
||||
and layer_idx + 1 in self.intermediate_layers
|
||||
):
|
||||
if isinstance(xs, tuple):
|
||||
encoder_output = xs[0]
|
||||
else:
|
||||
encoder_output = xs
|
||||
# intermediate branches also require normalization.
|
||||
if self.normalize_before:
|
||||
encoder_output = self.after_norm(encoder_output)
|
||||
intermediate_outputs.append(encoder_output)
|
||||
|
||||
if self.use_conditioning:
|
||||
intermediate_result = self.ctc_softmax(encoder_output)
|
||||
xs = xs + self.conditioning_layer(intermediate_result)
|
||||
|
||||
if isinstance(xs, tuple):
|
||||
xs = xs[0]
|
||||
if self.normalize_before:
|
||||
xs = self.after_norm(xs)
|
||||
|
||||
if self.intermediate_layers is not None:
|
||||
return xs, masks, intermediate_outputs
|
||||
return xs, masks
|
||||
|
||||
def forward_one_step(self, xs, masks, cache=None):
|
||||
"""Encode input frame.
|
||||
|
||||
Args:
|
||||
xs (torch.Tensor): Input tensor.
|
||||
masks (torch.Tensor): Mask tensor.
|
||||
cache (List[torch.Tensor]): List of cache tensors.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Output tensor.
|
||||
torch.Tensor: Mask tensor.
|
||||
List[torch.Tensor]: List of new cache tensors.
|
||||
|
||||
"""
|
||||
if isinstance(self.embed, (Conv2dSubsampling, Conv2dSubsampling6, Conv2dSubsampling8)):
|
||||
xs, masks = self.embed(xs, masks)
|
||||
else:
|
||||
xs = self.embed(xs)
|
||||
if cache is None:
|
||||
cache = [None for _ in range(len(self.encoders))]
|
||||
new_cache = []
|
||||
for c, e in zip(cache, self.encoders):
|
||||
xs, masks = e(xs, masks, cache=c)
|
||||
if isinstance(xs, tuple):
|
||||
new_cache.append(xs[0])
|
||||
else:
|
||||
new_cache.append(xs)
|
||||
if isinstance(xs, tuple):
|
||||
xs = xs[0]
|
||||
if self.normalize_before:
|
||||
xs = self.after_norm(xs)
|
||||
return xs, masks, new_cache
|
||||
345
funasr/models/llm_asr/transformer_lm.py
Normal file
345
funasr/models/llm_asr/transformer_lm.py
Normal file
@ -0,0 +1,345 @@
|
||||
from typing import Any
|
||||
from typing import List
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from funasr.models.transformer.embedding import PositionalEncoding, ScaledPositionalEncoding, RelPositionalEncoding, LegacyRelPositionalEncoding
|
||||
from funasr.models.llm_asr.transformer_encoder import TransformerEncoder_s0 as Encoder
|
||||
from funasr.models.transformer.utils.mask import subsequent_mask
|
||||
from funasr.models.transformer.utils.nets_utils import make_pad_mask
|
||||
import logging
|
||||
from distutils.version import LooseVersion
|
||||
from contextlib import contextmanager
|
||||
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
|
||||
from torch.cuda.amp import autocast
|
||||
else:
|
||||
# Nothing to do if torch<1.6.0
|
||||
@contextmanager
|
||||
def autocast(enabled=True):
|
||||
yield
|
||||
|
||||
|
||||
class TransformerEmbedLM(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size: int,
|
||||
pos_enc: str = None,
|
||||
embed_unit: int = 128,
|
||||
att_unit: int = 256,
|
||||
head: int = 2,
|
||||
unit: int = 1024,
|
||||
layer: int = 4,
|
||||
dropout_rate: float = 0.5,
|
||||
attention_dropout_rate: float = 0.0,
|
||||
pe_type: str = "split",
|
||||
bidirectional_inputs: bool = False,
|
||||
text_vocab_size: int = 4000,
|
||||
input_aug_conf: dict = None,
|
||||
output_aug_conf: dict = None,
|
||||
codec_groups: int = 4,
|
||||
selfattention_layer_type: str = "selfattn",
|
||||
input_normalize: bool = False,
|
||||
use_decoder: bool = True,
|
||||
encoder_type: str = "transformer",
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
if pos_enc == "sinusoidal":
|
||||
pos_enc_class = PositionalEncoding
|
||||
elif pos_enc == "abs_pos":
|
||||
pos_enc_class = PositionalEncoding
|
||||
elif pos_enc == "scaled_abs_pos":
|
||||
pos_enc_class = ScaledPositionalEncoding
|
||||
elif pos_enc == "rel_pos":
|
||||
assert selfattention_layer_type == "rel_selfattn"
|
||||
pos_enc_class = RelPositionalEncoding
|
||||
elif pos_enc == "legacy_rel_pos":
|
||||
assert selfattention_layer_type == "legacy_rel_selfattn"
|
||||
pos_enc_class = LegacyRelPositionalEncoding
|
||||
logging.warning(
|
||||
"Using legacy_rel_pos and it will be deprecated in the future."
|
||||
)
|
||||
elif pos_enc is None:
|
||||
|
||||
def pos_enc_class(*args, **kwargs):
|
||||
return nn.Sequential() # indentity
|
||||
|
||||
else:
|
||||
raise ValueError(f"unknown pos-enc option: {pos_enc}")
|
||||
|
||||
self.embed_unit = embed_unit
|
||||
self.pe_type = pe_type
|
||||
self.encoder_type = encoder_type
|
||||
if encoder_type == "llama":
|
||||
raise NotImplementedError("llama encoder has not been implemented")
|
||||
# from cosyvoice.nets.encoder.llama_encoder import LlamaEncoder
|
||||
# # set causal to false, using mask to control causal mode.
|
||||
# self.encoder = LlamaEncoder(
|
||||
# input_size=embed_unit,
|
||||
# output_size=att_unit,
|
||||
# attention_heads=head,
|
||||
# num_blocks=layer,
|
||||
# dropout_rate=dropout_rate,
|
||||
# attention_dropout_rate=attention_dropout_rate,
|
||||
# causal=False,
|
||||
# linear_units=unit,
|
||||
# )
|
||||
else:
|
||||
self.encoder = Encoder(
|
||||
idim=embed_unit,
|
||||
attention_dim=att_unit,
|
||||
attention_heads=head,
|
||||
linear_units=unit,
|
||||
num_blocks=layer,
|
||||
dropout_rate=dropout_rate,
|
||||
positional_dropout_rate=dropout_rate,
|
||||
attention_dropout_rate=attention_dropout_rate,
|
||||
input_layer="none" if pe_type == "split" else "linear",
|
||||
pos_enc_class=pos_enc_class,
|
||||
selfattention_layer_type=selfattention_layer_type,
|
||||
)
|
||||
if use_decoder:
|
||||
self.decoder = nn.Linear(att_unit, vocab_size)
|
||||
else:
|
||||
self.decoder = None
|
||||
self.attn_unit = att_unit
|
||||
self.pos_enc_func = None
|
||||
if pe_type == "split":
|
||||
assert pos_enc == "sinusoidal" or pos_enc == "abs_pos" or pos_enc == "scaled_abs_pos", \
|
||||
"Different positional embedding for inputs and outputs " \
|
||||
"only supports sinusoidal, abs_pos and scaled_abs_pos."
|
||||
self.pos_enc_func = pos_enc_class(embed_unit, 0.1)
|
||||
self.input_layer = torch.nn.Linear(embed_unit, att_unit)
|
||||
self.bidirectional_inputs = bidirectional_inputs
|
||||
self.text_vocab_size = text_vocab_size
|
||||
self.codec_groups = codec_groups
|
||||
self.input_aug = None
|
||||
if input_aug_conf is not None:
|
||||
from funasr.models.specaug.specaug import SpecAug
|
||||
self.input_aug = SpecAug(**input_aug_conf)
|
||||
|
||||
self.output_aug = None
|
||||
if output_aug_conf is not None:
|
||||
from funasr.models.specaug.specaug import SpecAug
|
||||
self.output_aug = SpecAug(**output_aug_conf)
|
||||
|
||||
self.normalize = None
|
||||
if input_normalize:
|
||||
from funasr.models.normalize.utterance_mvn import UtteranceMVN
|
||||
self.normalize = UtteranceMVN()
|
||||
|
||||
self.first_pack_mask_conf: dict = kwargs.get("first_pack_mask_conf", None)
|
||||
|
||||
def output_size(self):
|
||||
return self.attn_unit
|
||||
|
||||
def _target_mask(self, lengths):
|
||||
ys_mask = ~make_pad_mask(lengths)
|
||||
m = subsequent_mask(ys_mask.size(-1), device=ys_mask.device).unsqueeze(0)
|
||||
return ys_mask.unsqueeze(-2) & m
|
||||
|
||||
def clac_first_package_mask(self, mask, input_lengths, cond_lengths):
|
||||
device = mask.device
|
||||
mask_type = self.first_pack_mask_conf.get("mask_type", "first_pack")
|
||||
fp_token_len = self.first_pack_mask_conf["fp_token_len"]
|
||||
fp_text_len = self.first_pack_mask_conf["fp_text_len"]
|
||||
# NOTE: fp_text_len excluding sos, xvec, only including text
|
||||
# NOTE: cond_lengths including sos, xvec and text
|
||||
if mask_type == "streaming":
|
||||
for i, (seq_len, cond_len) in enumerate(zip(input_lengths, cond_lengths)):
|
||||
# 1 for task_id
|
||||
token_len = seq_len - cond_len - 1
|
||||
if token_len > 0:
|
||||
target_text_len = torch.ceil(torch.arange(1, token_len+1, device=device) / fp_token_len) * fp_text_len
|
||||
# 2 for sos and xvec, M -> M x 1
|
||||
target_text_len = torch.minimum(target_text_len + 2, cond_len).unsqueeze(1)
|
||||
# 1 x N
|
||||
pos_range = torch.arange(0, cond_len, device=device).unsqueeze(0)
|
||||
# M x N
|
||||
text_mask = pos_range < target_text_len
|
||||
# 1 for <task_id>
|
||||
mask[i, cond_len+1:seq_len, :cond_len] = mask[i, cond_len+1:seq_len, :cond_len] * text_mask
|
||||
else:
|
||||
for i, (seq_len, cond_len) in enumerate(zip(input_lengths, cond_lengths)):
|
||||
mask_token_end = min(cond_len+1+fp_token_len, seq_len)
|
||||
mask[i, cond_len+1:mask_token_end, fp_text_len+2:cond_len] = 0
|
||||
|
||||
return mask
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input: torch.Tensor,
|
||||
input_lengths: torch.Tensor,
|
||||
cond_lengths: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Compute LM loss value from buffer sequences.
|
||||
|
||||
Args:
|
||||
input (torch.Tensor): Input ids. (batch, len, dim)
|
||||
input_lengths (torch.Tensor): length of input. (batch,)
|
||||
cond_lengths (torch.Tensor): length of conditions (including sos, excluding taskid). (batch,)
|
||||
|
||||
"""
|
||||
mask = self._target_mask(input_lengths).to(input.device)
|
||||
if self.first_pack_mask_conf is not None:
|
||||
mask = self.clac_first_package_mask(mask, input_lengths, cond_lengths)
|
||||
if self.bidirectional_inputs:
|
||||
for i, length in enumerate(cond_lengths):
|
||||
mask[i, :length, :length] = True
|
||||
pos_emb = None
|
||||
if self.pe_type == "split":
|
||||
pos_emb = torch.zeros((input.shape[0], input.shape[1]*2-1, self.attn_unit)).to(input)
|
||||
kk = self.codec_groups
|
||||
# with torch.no_grad():
|
||||
with autocast(False):
|
||||
for i, length in enumerate(cond_lengths):
|
||||
# perform specaug for each frame including multi-group.
|
||||
raw_feat = input[i:i + 1, 1:length].clone()
|
||||
bb, tt, dd = raw_feat.shape
|
||||
raw_feat = raw_feat.reshape(bb, tt // kk, kk, dd).reshape(bb, tt // kk, kk * dd)
|
||||
|
||||
if self.input_aug is not None and self.training:
|
||||
raw_feat = self.input_aug(raw_feat, (cond_lengths[i:i+1] - 1) // kk)[0]
|
||||
|
||||
if self.normalize is not None:
|
||||
raw_feat = self.normalize(raw_feat, None)[0]
|
||||
|
||||
input[i:i + 1, 1:length] = raw_feat.reshape(bb, tt//kk, kk, dd).reshape(bb, tt, dd)
|
||||
|
||||
if self.output_aug is not None and self.training:
|
||||
raw_feat = input[i:i + 1, length+1:].clone()
|
||||
aug_feat = self.output_aug(raw_feat, input_lengths[i:i+1] - length - 2)[0]
|
||||
input[i:i + 1, length + 1:] = aug_feat
|
||||
|
||||
# add positional encoding
|
||||
if self.pe_type == "split" and self.pos_enc_func is not None:
|
||||
posed_input = self.pos_enc_func(input[i:i + 1, :length].clone())
|
||||
if isinstance(posed_input, tuple):
|
||||
pos_emb[i:i+1, :length*2-1] = posed_input[1]
|
||||
posed_input = posed_input[0]
|
||||
input[i:i + 1, :length] = posed_input
|
||||
|
||||
posed_output = self.pos_enc_func(input[i:i + 1, length + 1:].clone())
|
||||
if isinstance(posed_output, tuple):
|
||||
pos_emb[i:i+1, length*2: length*2+posed_output[1].shape[1]] = posed_output[1]
|
||||
posed_output = posed_output[0]
|
||||
input[i:i + 1, length + 1:] = posed_output
|
||||
|
||||
if self.pe_type == "split":
|
||||
input = self.input_layer(input)
|
||||
if isinstance(self.pos_enc_func, (RelPositionalEncoding, LegacyRelPositionalEncoding)):
|
||||
input = (input, pos_emb)
|
||||
# logging.info(f"shapes {input.shape} {mask.shape} {input_lengths}")
|
||||
h, _ = self.encoder(input, mask)
|
||||
if self.decoder is None:
|
||||
return h, h
|
||||
|
||||
y = self.decoder(h)
|
||||
return y, h
|
||||
|
||||
def init_state(self, x: torch.Tensor):
|
||||
return None
|
||||
|
||||
def score(
|
||||
self, y: torch.Tensor, state: Any, x: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, Any]:
|
||||
"""Score new token.
|
||||
|
||||
Args:
|
||||
y (torch.Tensor): 2D torch.float prefix embeddings.
|
||||
state: Scorer state for prefix tokens
|
||||
x (torch.Tensor): encoder feature that generates ys.
|
||||
|
||||
Returns:
|
||||
tuple[torch.Tensor, Any]: Tuple of
|
||||
torch.float32 scores for next token (vocab_size)
|
||||
and next state for ys
|
||||
|
||||
"""
|
||||
# this implementation is much faster than the blow!!
|
||||
mask = torch.tril(torch.ones((1, y.shape[0], y.shape[0]), device=y.device)).to(torch.bool)
|
||||
y_emb = y.unsqueeze(0).to(x.device)
|
||||
# lengths = y_emb.new_full([1], dtype=torch.long, fill_value=y_emb.size(1))
|
||||
# mask = self._target_mask(lengths).to(y_emb.device)
|
||||
# x includes <sos>, feat, <task_id>
|
||||
input_length = x.shape[0] - 1
|
||||
if self.bidirectional_inputs:
|
||||
mask[:1, :input_length, :input_length] = True
|
||||
# if self.first_pack_mask_conf is not None:
|
||||
# mask = self.clac_first_package_mask(
|
||||
# mask,
|
||||
# torch.tensor([y.shape[0]], device=y.device),
|
||||
# torch.tensor([input_length], device=y.device),
|
||||
# )
|
||||
if self.pe_type == "split" and self.pos_enc_func is not None:
|
||||
pos_emb = torch.zeros((y_emb.shape[0], y_emb.shape[1], self.attn_unit)).to(y_emb)
|
||||
|
||||
posed_input = self.pos_enc_func(y_emb[:1, :input_length])
|
||||
if isinstance(posed_input, tuple):
|
||||
pos_emb[:1, :input_length] = posed_input[1]
|
||||
posed_input = posed_input[0]
|
||||
y_emb[:1, :input_length] = posed_input
|
||||
|
||||
posed_output = self.pos_enc_func(y_emb[:1, input_length + 1:])
|
||||
if isinstance(posed_output, tuple):
|
||||
pos_emb[:1, input_length + 1:] = posed_output[1]
|
||||
posed_output = posed_output[0]
|
||||
y_emb[:1, input_length + 1:] = posed_output
|
||||
|
||||
if self.pe_type == "split":
|
||||
y_emb = self.input_layer(y_emb)
|
||||
if isinstance(self.pos_enc_func, (RelPositionalEncoding, LegacyRelPositionalEncoding)):
|
||||
y_emb = (y_emb, pos_emb)
|
||||
lm_hidden_states, _, cache = self.encoder.forward_one_step(
|
||||
y_emb, mask, cache=state
|
||||
)
|
||||
if self.decoder is None:
|
||||
return lm_hidden_states[:, -1], cache
|
||||
|
||||
h = self.decoder(lm_hidden_states[:, -1])[:, :self.text_vocab_size]
|
||||
|
||||
logp = h.log_softmax(dim=-1).squeeze(0)
|
||||
# return logp, cache
|
||||
return logp, (cache, lm_hidden_states[:, -1])
|
||||
|
||||
def batch_score(
|
||||
self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, List[Any]]:
|
||||
"""Score new token batch.
|
||||
|
||||
Args:
|
||||
ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen).
|
||||
states (List[Any]): Scorer states for prefix tokens.
|
||||
xs (torch.Tensor):
|
||||
The encoder feature that generates ys (n_batch, xlen, n_feat).
|
||||
|
||||
Returns:
|
||||
tuple[torch.Tensor, List[Any]]: Tuple of
|
||||
batchfied scores for next token with shape of `(n_batch, vocab_size)`
|
||||
and next state list for ys.
|
||||
|
||||
"""
|
||||
# merge states
|
||||
n_batch = len(ys)
|
||||
n_layers = len(self.encoder.encoders)
|
||||
if states[0] is None:
|
||||
batch_state = None
|
||||
else:
|
||||
# transpose state of [batch, layer] into [layer, batch]
|
||||
batch_state = [
|
||||
torch.stack([states[b][i] for b in range(n_batch)])
|
||||
for i in range(n_layers)
|
||||
]
|
||||
|
||||
# batch decoding
|
||||
h, _, states = self.encoder.forward_one_step(
|
||||
self.embed(ys), self._target_mask(ys), cache=batch_state
|
||||
)
|
||||
h = self.decoder(h[:, -1])
|
||||
logp = h.log_softmax(dim=-1)
|
||||
|
||||
# transpose state of [layer, batch] into [batch, layer]
|
||||
state_list = [[states[i][b] for i in range(n_layers)] for b in range(n_batch)]
|
||||
return logp, state_list
|
||||
@ -11,7 +11,7 @@ import pdb
|
||||
|
||||
|
||||
def load_pretrained_model(
|
||||
path: str,
|
||||
path,
|
||||
model: torch.nn.Module,
|
||||
ignore_init_mismatch: bool = True,
|
||||
map_location: str = "cpu",
|
||||
@ -100,3 +100,30 @@ def load_pretrained_model(
|
||||
|
||||
flag = obj.load_state_dict(dst_state, strict=True)
|
||||
logging.info(f"Loading ckpt: {path}, status: {flag}")
|
||||
|
||||
|
||||
# def load_pretrained_model(
|
||||
# path,
|
||||
# model: torch.nn.Module,
|
||||
# ignore_init_mismatch: bool = True,
|
||||
# map_location: str = "cpu",
|
||||
# oss_bucket=None,
|
||||
# scope_map=[],
|
||||
# excludes=None,
|
||||
# **kwargs,
|
||||
# ):
|
||||
# if isinstance(path, str):
|
||||
# path = path.split(",")
|
||||
#
|
||||
# for i, path_i in enumerate(path):
|
||||
# logging.info(f"Loading ckpt-{i}: {path_i}")
|
||||
# _load_pretrained_model(
|
||||
# path_i,
|
||||
# model=model,
|
||||
# ignore_init_mismatch=ignore_init_mismatch,
|
||||
# map_location=map_location,
|
||||
# oss_bucket=oss_bucket,
|
||||
# scope_map=scope_map,
|
||||
# excludes=excludes,
|
||||
# **kwargs,
|
||||
# )
|
||||
|
||||
Loading…
Reference in New Issue
Block a user