diff --git a/funasr/bin/punctuation_infer.py b/funasr/bin/punctuation_infer.py index a801ee8c6..dd28ef8da 100644 --- a/funasr/bin/punctuation_infer.py +++ b/funasr/bin/punctuation_infer.py @@ -23,7 +23,7 @@ from funasr.torch_utils.set_all_random_seed import set_all_random_seed from funasr.utils import config_argparse from funasr.utils.types import str2triple_str from funasr.utils.types import str_or_none -from funasr.punctuation.text_preprocessor import split_to_mini_sentence +from funasr.datasets.preprocessor import split_to_mini_sentence class Text2Punc: diff --git a/funasr/bin/punctuation_infer_vadrealtime.py b/funasr/bin/punctuation_infer_vadrealtime.py index ce1cee8b0..81f9d7ae8 100644 --- a/funasr/bin/punctuation_infer_vadrealtime.py +++ b/funasr/bin/punctuation_infer_vadrealtime.py @@ -23,7 +23,7 @@ from funasr.torch_utils.set_all_random_seed import set_all_random_seed from funasr.utils import config_argparse from funasr.utils.types import str2triple_str from funasr.utils.types import str_or_none -from funasr.punctuation.text_preprocessor import split_to_mini_sentence +from funasr.datasets.preprocessor import split_to_mini_sentence class Text2Punc: diff --git a/funasr/datasets/preprocessor.py b/funasr/datasets/preprocessor.py index 98cca1dcd..afeff4ee6 100644 --- a/funasr/datasets/preprocessor.py +++ b/funasr/datasets/preprocessor.py @@ -800,3 +800,17 @@ class PuncTrainTokenizerCommonPreprocessor(CommonPreprocessor): data[self.vad_name] = np.array([vad], dtype=np.int64) text_ints = self.token_id_converter[i].tokens2ids(tokens) data[text_name] = np.array(text_ints, dtype=np.int64) + + +def split_to_mini_sentence(words: list, word_limit: int = 20): + assert word_limit > 1 + if len(words) <= word_limit: + return [words] + sentences = [] + length = len(words) + sentence_len = length // word_limit + for i in range(sentence_len): + sentences.append(words[i * word_limit:(i + 1) * word_limit]) + if length % word_limit > 0: + sentences.append(words[sentence_len * word_limit:]) + return sentences \ No newline at end of file diff --git a/funasr/export/models/__init__.py b/funasr/export/models/__init__.py index 62ee72354..4ac0456b9 100644 --- a/funasr/export/models/__init__.py +++ b/funasr/export/models/__init__.py @@ -3,10 +3,10 @@ from funasr.export.models.e2e_asr_paraformer import Paraformer as Paraformer_exp from funasr.export.models.e2e_asr_paraformer import BiCifParaformer as BiCifParaformer_export from funasr.models.e2e_vad import E2EVadModel from funasr.export.models.e2e_vad import E2EVadModel as E2EVadModel_export -from funasr.punctuation.target_delay_transformer import TargetDelayTransformer +from funasr.models.target_delay_transformer import TargetDelayTransformer from funasr.export.models.target_delay_transformer import TargetDelayTransformer as TargetDelayTransformer_export -from funasr.punctuation.espnet_model import ESPnetPunctuationModel -from funasr.punctuation.vad_realtime_transformer import VadRealtimeTransformer +from funasr.train.abs_model import PunctuationModel +from funasr.models.vad_realtime_transformer import VadRealtimeTransformer from funasr.export.models.vad_realtime_transformer import VadRealtimeTransformer as VadRealtimeTransformer_export def get_model(model, export_config=None): @@ -16,7 +16,7 @@ def get_model(model, export_config=None): return Paraformer_export(model, **export_config) elif isinstance(model, E2EVadModel): return E2EVadModel_export(model, **export_config) - elif isinstance(model, ESPnetPunctuationModel): + elif isinstance(model, PunctuationModel): if isinstance(model.punc_model, TargetDelayTransformer): return TargetDelayTransformer_export(model.punc_model, **export_config) elif isinstance(model.punc_model, VadRealtimeTransformer): diff --git a/funasr/export/models/target_delay_transformer.py b/funasr/export/models/target_delay_transformer.py index fd90835c9..bfe3ec423 100644 --- a/funasr/export/models/target_delay_transformer.py +++ b/funasr/export/models/target_delay_transformer.py @@ -1,18 +1,8 @@ -from typing import Any -from typing import List from typing import Tuple import torch import torch.nn as nn -from funasr.export.utils.torch_function import MakePadMask -from funasr.export.utils.torch_function import sequence_mask -#from funasr.models.encoder.sanm_encoder import SANMEncoder as Encoder -from funasr.punctuation.sanm_encoder import SANMEncoder -from funasr.export.models.encoder.sanm_encoder import SANMEncoder as SANMEncoder_export -from funasr.punctuation.abs_model import AbsPunctuation - - class TargetDelayTransformer(nn.Module): def __init__( @@ -32,85 +22,10 @@ class TargetDelayTransformer(nn.Module): self.feats_dim = self.embed.embedding_dim self.num_embeddings = self.embed.num_embeddings self.model_name = model_name - from typing import Any - from typing import List - from typing import Tuple - import torch - import torch.nn as nn - - from funasr.export.utils.torch_function import MakePadMask - from funasr.export.utils.torch_function import sequence_mask # from funasr.models.encoder.sanm_encoder import SANMEncoder as Encoder - from funasr.punctuation.sanm_encoder import SANMEncoder + from funasr.models.encoder.sanm_encoder import SANMEncoder from funasr.export.models.encoder.sanm_encoder import SANMEncoder as SANMEncoder_export - from funasr.punctuation.abs_model import AbsPunctuation - - # class TargetDelayTransformer(nn.Module): - # - # def __init__( - # self, - # model, - # max_seq_len=512, - # model_name='punc_model', - # **kwargs, - # ): - # super().__init__() - # onnx = False - # if "onnx" in kwargs: - # onnx = kwargs["onnx"] - # self.embed = model.embed - # self.decoder = model.decoder - # self.model = model - # self.feats_dim = self.embed.embedding_dim - # self.num_embeddings = self.embed.num_embeddings - # self.model_name = model_name - # - # if isinstance(model.encoder, SANMEncoder): - # self.encoder = SANMEncoder_export(model.encoder, onnx=onnx) - # else: - # assert False, "Only support samn encode." - # - # def forward(self, input: torch.Tensor, text_lengths: torch.Tensor) -> Tuple[torch.Tensor, None]: - # """Compute loss value from buffer sequences. - # - # Args: - # input (torch.Tensor): Input ids. (batch, len) - # hidden (torch.Tensor): Target ids. (batch, len) - # - # """ - # x = self.embed(input) - # # mask = self._target_mask(input) - # h, _ = self.encoder(x, text_lengths) - # y = self.decoder(h) - # return y - # - # def get_dummy_inputs(self): - # length = 120 - # text_indexes = torch.randint(0, self.embed.num_embeddings, (2, length)) - # text_lengths = torch.tensor([length - 20, length], dtype=torch.int32) - # return (text_indexes, text_lengths) - # - # def get_input_names(self): - # return ['input', 'text_lengths'] - # - # def get_output_names(self): - # return ['logits'] - # - # def get_dynamic_axes(self): - # return { - # 'input': { - # 0: 'batch_size', - # 1: 'feats_length' - # }, - # 'text_lengths': { - # 0: 'batch_size', - # }, - # 'logits': { - # 0: 'batch_size', - # 1: 'logits_length' - # }, - # } if isinstance(model.encoder, SANMEncoder): self.encoder = SANMEncoder_export(model.encoder, onnx=onnx) diff --git a/funasr/export/models/vad_realtime_transformer.py b/funasr/export/models/vad_realtime_transformer.py index 093e71de1..693b9c844 100644 --- a/funasr/export/models/vad_realtime_transformer.py +++ b/funasr/export/models/vad_realtime_transformer.py @@ -1,14 +1,9 @@ -from typing import Any -from typing import List from typing import Tuple import torch import torch.nn as nn -from funasr.modules.embedding import SinusoidalPositionEncoder -from funasr.punctuation.sanm_encoder import SANMVadEncoder as Encoder -from funasr.punctuation.abs_model import AbsPunctuation -from funasr.punctuation.sanm_encoder import SANMVadEncoder +from funasr.models.encoder.sanm_encoder import SANMVadEncoder from funasr.export.models.encoder.sanm_encoder import SANMVadEncoder as SANMVadEncoder_export class VadRealtimeTransformer(nn.Module): diff --git a/funasr/lm/espnet_model.py b/funasr/lm/espnet_model.py index db11b6741..a9b8130c6 100644 --- a/funasr/lm/espnet_model.py +++ b/funasr/lm/espnet_model.py @@ -12,7 +12,7 @@ from funasr.torch_utils.device_funcs import force_gatherable from funasr.train.abs_espnet_model import AbsESPnetModel -class ESPnetLanguageModel(AbsESPnetModel): +class LanguageModel(AbsESPnetModel): def __init__(self, lm: AbsLM, vocab_size: int, ignore_id: int = 0): assert check_argument_types() super().__init__() diff --git a/funasr/models/encoder/sanm_encoder.py b/funasr/models/encoder/sanm_encoder.py index 57890efe6..2a3a35353 100644 --- a/funasr/models/encoder/sanm_encoder.py +++ b/funasr/models/encoder/sanm_encoder.py @@ -10,7 +10,7 @@ from funasr.modules.streaming_utils.chunk_utilis import overlap_chunk from typeguard import check_argument_types import numpy as np from funasr.modules.nets_utils import make_pad_mask -from funasr.modules.attention import MultiHeadedAttention, MultiHeadedAttentionSANM +from funasr.modules.attention import MultiHeadedAttention, MultiHeadedAttentionSANM, MultiHeadedAttentionSANMwithMask from funasr.modules.embedding import SinusoidalPositionEncoder from funasr.modules.layer_norm import LayerNorm from funasr.modules.multi_layer_conv import Conv1dLinear @@ -27,7 +27,7 @@ from funasr.modules.subsampling import TooShortUttError from funasr.modules.subsampling import check_short_utt from funasr.models.ctc import CTC from funasr.models.encoder.abs_encoder import AbsEncoder - +from funasr.modules.mask import subsequent_mask, vad_mask class EncoderLayerSANM(nn.Module): def __init__( @@ -958,3 +958,231 @@ class SANMEncoderChunkOpt(AbsEncoder): var_dict_tf[name_tf].shape)) return var_dict_torch_update + + +class SANMVadEncoder(AbsEncoder): + """ + author: Speech Lab, Alibaba Group, China + + """ + + def __init__( + self, + input_size: int, + output_size: int = 256, + attention_heads: int = 4, + linear_units: int = 2048, + num_blocks: int = 6, + dropout_rate: float = 0.1, + positional_dropout_rate: float = 0.1, + attention_dropout_rate: float = 0.0, + input_layer: Optional[str] = "conv2d", + pos_enc_class=SinusoidalPositionEncoder, + normalize_before: bool = True, + concat_after: bool = False, + positionwise_layer_type: str = "linear", + positionwise_conv_kernel_size: int = 1, + padding_idx: int = -1, + interctc_layer_idx: List[int] = [], + interctc_use_conditioning: bool = False, + kernel_size : int = 11, + sanm_shfit : int = 0, + selfattention_layer_type: str = "sanm", + ): + assert check_argument_types() + super().__init__() + self._output_size = output_size + + if input_layer == "linear": + self.embed = torch.nn.Sequential( + torch.nn.Linear(input_size, output_size), + torch.nn.LayerNorm(output_size), + torch.nn.Dropout(dropout_rate), + torch.nn.ReLU(), + pos_enc_class(output_size, positional_dropout_rate), + ) + elif input_layer == "conv2d": + self.embed = Conv2dSubsampling(input_size, output_size, dropout_rate) + elif input_layer == "conv2d2": + self.embed = Conv2dSubsampling2(input_size, output_size, dropout_rate) + elif input_layer == "conv2d6": + self.embed = Conv2dSubsampling6(input_size, output_size, dropout_rate) + elif input_layer == "conv2d8": + self.embed = Conv2dSubsampling8(input_size, output_size, dropout_rate) + elif input_layer == "embed": + self.embed = torch.nn.Sequential( + torch.nn.Embedding(input_size, output_size, padding_idx=padding_idx), + SinusoidalPositionEncoder(), + ) + elif input_layer is None: + if input_size == output_size: + self.embed = None + else: + self.embed = torch.nn.Linear(input_size, output_size) + elif input_layer == "pe": + self.embed = SinusoidalPositionEncoder() + else: + raise ValueError("unknown input_layer: " + input_layer) + self.normalize_before = normalize_before + if positionwise_layer_type == "linear": + positionwise_layer = PositionwiseFeedForward + positionwise_layer_args = ( + output_size, + linear_units, + dropout_rate, + ) + elif positionwise_layer_type == "conv1d": + positionwise_layer = MultiLayeredConv1d + positionwise_layer_args = ( + output_size, + linear_units, + positionwise_conv_kernel_size, + dropout_rate, + ) + elif positionwise_layer_type == "conv1d-linear": + positionwise_layer = Conv1dLinear + positionwise_layer_args = ( + output_size, + linear_units, + positionwise_conv_kernel_size, + dropout_rate, + ) + else: + raise NotImplementedError("Support only linear or conv1d.") + + if selfattention_layer_type == "selfattn": + encoder_selfattn_layer = MultiHeadedAttention + encoder_selfattn_layer_args = ( + attention_heads, + output_size, + attention_dropout_rate, + ) + + elif selfattention_layer_type == "sanm": + self.encoder_selfattn_layer = MultiHeadedAttentionSANMwithMask + encoder_selfattn_layer_args0 = ( + attention_heads, + input_size, + output_size, + attention_dropout_rate, + kernel_size, + sanm_shfit, + ) + + encoder_selfattn_layer_args = ( + attention_heads, + output_size, + output_size, + attention_dropout_rate, + kernel_size, + sanm_shfit, + ) + + self.encoders0 = repeat( + 1, + lambda lnum: EncoderLayerSANM( + input_size, + output_size, + self.encoder_selfattn_layer(*encoder_selfattn_layer_args0), + positionwise_layer(*positionwise_layer_args), + dropout_rate, + normalize_before, + concat_after, + ), + ) + + self.encoders = repeat( + num_blocks-1, + lambda lnum: EncoderLayerSANM( + output_size, + output_size, + self.encoder_selfattn_layer(*encoder_selfattn_layer_args), + positionwise_layer(*positionwise_layer_args), + dropout_rate, + normalize_before, + concat_after, + ), + ) + if self.normalize_before: + self.after_norm = LayerNorm(output_size) + + self.interctc_layer_idx = interctc_layer_idx + if len(interctc_layer_idx) > 0: + assert 0 < min(interctc_layer_idx) and max(interctc_layer_idx) < num_blocks + self.interctc_use_conditioning = interctc_use_conditioning + self.conditioning_layer = None + self.dropout = nn.Dropout(dropout_rate) + + def output_size(self) -> int: + return self._output_size + + def forward( + self, + xs_pad: torch.Tensor, + ilens: torch.Tensor, + vad_indexes: torch.Tensor, + prev_states: torch.Tensor = None, + ctc: CTC = None, + ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + """Embed positions in tensor. + + Args: + xs_pad: input tensor (B, L, D) + ilens: input length (B) + prev_states: Not to be used now. + Returns: + position embedded tensor and mask + """ + masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device) + sub_masks = subsequent_mask(masks.size(-1), device=xs_pad.device).unsqueeze(0) + no_future_masks = masks & sub_masks + xs_pad *= self.output_size()**0.5 + if self.embed is None: + xs_pad = xs_pad + elif (isinstance(self.embed, Conv2dSubsampling) or isinstance(self.embed, Conv2dSubsampling2) + or isinstance(self.embed, Conv2dSubsampling6) or isinstance(self.embed, Conv2dSubsampling8)): + short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1)) + if short_status: + raise TooShortUttError( + f"has {xs_pad.size(1)} frames and is too short for subsampling " + + f"(it needs more than {limit_size} frames), return empty results", + xs_pad.size(1), + limit_size, + ) + xs_pad, masks = self.embed(xs_pad, masks) + else: + xs_pad = self.embed(xs_pad) + + # xs_pad = self.dropout(xs_pad) + mask_tup0 = [masks, no_future_masks] + encoder_outs = self.encoders0(xs_pad, mask_tup0) + xs_pad, _ = encoder_outs[0], encoder_outs[1] + intermediate_outs = [] + + + for layer_idx, encoder_layer in enumerate(self.encoders): + if layer_idx + 1 == len(self.encoders): + # This is last layer. + coner_mask = torch.ones(masks.size(0), + masks.size(-1), + masks.size(-1), + device=xs_pad.device, + dtype=torch.bool) + for word_index, length in enumerate(ilens): + coner_mask[word_index, :, :] = vad_mask(masks.size(-1), + vad_indexes[word_index], + device=xs_pad.device) + layer_mask = masks & coner_mask + else: + layer_mask = no_future_masks + mask_tup1 = [masks, layer_mask] + encoder_outs = encoder_layer(xs_pad, mask_tup1) + xs_pad, layer_mask = encoder_outs[0], encoder_outs[1] + + if self.normalize_before: + xs_pad = self.after_norm(xs_pad) + + olens = masks.squeeze(1).sum(1) + if len(intermediate_outs) > 0: + return (xs_pad, intermediate_outs), olens, None + return xs_pad, olens, None diff --git a/funasr/punctuation/target_delay_transformer.py b/funasr/models/target_delay_transformer.py similarity index 97% rename from funasr/punctuation/target_delay_transformer.py rename to funasr/models/target_delay_transformer.py index 219af263f..a71952b15 100644 --- a/funasr/punctuation/target_delay_transformer.py +++ b/funasr/models/target_delay_transformer.py @@ -5,12 +5,11 @@ from typing import Tuple import torch import torch.nn as nn -from funasr.modules.embedding import PositionalEncoding from funasr.modules.embedding import SinusoidalPositionEncoder #from funasr.models.encoder.transformer_encoder import TransformerEncoder as Encoder from funasr.punctuation.sanm_encoder import SANMEncoder as Encoder #from funasr.modules.mask import subsequent_n_mask -from funasr.punctuation.abs_model import AbsPunctuation +from funasr.train.abs_model import AbsPunctuation class TargetDelayTransformer(AbsPunctuation): diff --git a/funasr/punctuation/vad_realtime_transformer.py b/funasr/models/vad_realtime_transformer.py similarity index 98% rename from funasr/punctuation/vad_realtime_transformer.py rename to funasr/models/vad_realtime_transformer.py index 35224f9bd..2945572f5 100644 --- a/funasr/punctuation/vad_realtime_transformer.py +++ b/funasr/models/vad_realtime_transformer.py @@ -7,7 +7,7 @@ import torch.nn as nn from funasr.modules.embedding import SinusoidalPositionEncoder from funasr.punctuation.sanm_encoder import SANMVadEncoder as Encoder -from funasr.punctuation.abs_model import AbsPunctuation +from funasr.train.abs_model import AbsPunctuation class VadRealtimeTransformer(AbsPunctuation): diff --git a/funasr/punctuation/abs_model.py b/funasr/punctuation/abs_model.py deleted file mode 100644 index 404d5e893..000000000 --- a/funasr/punctuation/abs_model.py +++ /dev/null @@ -1,31 +0,0 @@ -from abc import ABC -from abc import abstractmethod -from typing import Tuple - -import torch - -from funasr.modules.scorers.scorer_interface import BatchScorerInterface - - -class AbsPunctuation(torch.nn.Module, BatchScorerInterface, ABC): - """The abstract class - - To share the loss calculation way among different models, - We uses delegate pattern here: - The instance of this class should be passed to "LanguageModel" - - >>> from funasr.punctuation.abs_model import AbsPunctuation - >>> punc = AbsPunctuation() - >>> model = ESPnetPunctuationModel(punc=punc) - - This "model" is one of mediator objects for "Task" class. - - """ - - @abstractmethod - def forward(self, input: torch.Tensor, hidden: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - raise NotImplementedError - - @abstractmethod - def with_vad(self) -> bool: - raise NotImplementedError diff --git a/funasr/punctuation/sanm_encoder.py b/funasr/punctuation/sanm_encoder.py deleted file mode 100644 index 896209323..000000000 --- a/funasr/punctuation/sanm_encoder.py +++ /dev/null @@ -1,590 +0,0 @@ -from typing import List -from typing import Optional -from typing import Sequence -from typing import Tuple -from typing import Union -import logging -import torch -import torch.nn as nn -from funasr.modules.streaming_utils.chunk_utilis import overlap_chunk -from typeguard import check_argument_types -import numpy as np -from funasr.modules.nets_utils import make_pad_mask -from funasr.modules.attention import MultiHeadedAttention, MultiHeadedAttentionSANM, MultiHeadedAttentionSANMwithMask -from funasr.modules.embedding import SinusoidalPositionEncoder -from funasr.modules.layer_norm import LayerNorm -from funasr.modules.multi_layer_conv import Conv1dLinear -from funasr.modules.multi_layer_conv import MultiLayeredConv1d -from funasr.modules.positionwise_feed_forward import ( - PositionwiseFeedForward, # noqa: H301 -) -from funasr.modules.repeat import repeat -from funasr.modules.subsampling import Conv2dSubsampling -from funasr.modules.subsampling import Conv2dSubsampling2 -from funasr.modules.subsampling import Conv2dSubsampling6 -from funasr.modules.subsampling import Conv2dSubsampling8 -from funasr.modules.subsampling import TooShortUttError -from funasr.modules.subsampling import check_short_utt -from funasr.models.ctc import CTC -from funasr.models.encoder.abs_encoder import AbsEncoder - -from funasr.modules.nets_utils import make_pad_mask -from funasr.modules.mask import subsequent_mask, vad_mask - -class EncoderLayerSANM(nn.Module): - def __init__( - self, - in_size, - size, - self_attn, - feed_forward, - dropout_rate, - normalize_before=True, - concat_after=False, - stochastic_depth_rate=0.0, - ): - """Construct an EncoderLayer object.""" - super(EncoderLayerSANM, self).__init__() - self.self_attn = self_attn - self.feed_forward = feed_forward - self.norm1 = LayerNorm(in_size) - self.norm2 = LayerNorm(size) - self.dropout = nn.Dropout(dropout_rate) - self.in_size = in_size - self.size = size - self.normalize_before = normalize_before - self.concat_after = concat_after - if self.concat_after: - self.concat_linear = nn.Linear(size + size, size) - self.stochastic_depth_rate = stochastic_depth_rate - self.dropout_rate = dropout_rate - - def forward(self, x, mask, cache=None, mask_shfit_chunk=None, mask_att_chunk_encoder=None): - """Compute encoded features. - - Args: - x_input (torch.Tensor): Input tensor (#batch, time, size). - mask (torch.Tensor): Mask tensor for the input (#batch, time). - cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size). - - Returns: - torch.Tensor: Output tensor (#batch, time, size). - torch.Tensor: Mask tensor (#batch, time). - - """ - skip_layer = False - # with stochastic depth, residual connection `x + f(x)` becomes - # `x <- x + 1 / (1 - p) * f(x)` at training time. - stoch_layer_coeff = 1.0 - if self.training and self.stochastic_depth_rate > 0: - skip_layer = torch.rand(1).item() < self.stochastic_depth_rate - stoch_layer_coeff = 1.0 / (1 - self.stochastic_depth_rate) - - if skip_layer: - if cache is not None: - x = torch.cat([cache, x], dim=1) - return x, mask - - residual = x - if self.normalize_before: - x = self.norm1(x) - - if self.concat_after: - x_concat = torch.cat((x, self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder)), dim=-1) - if self.in_size == self.size: - x = residual + stoch_layer_coeff * self.concat_linear(x_concat) - else: - x = stoch_layer_coeff * self.concat_linear(x_concat) - else: - if self.in_size == self.size: - x = residual + stoch_layer_coeff * self.dropout( - self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder) - ) - else: - x = stoch_layer_coeff * self.dropout( - self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder) - ) - if not self.normalize_before: - x = self.norm1(x) - - residual = x - if self.normalize_before: - x = self.norm2(x) - x = residual + stoch_layer_coeff * self.dropout(self.feed_forward(x)) - if not self.normalize_before: - x = self.norm2(x) - - - return x, mask, cache, mask_shfit_chunk, mask_att_chunk_encoder - -class SANMEncoder(AbsEncoder): - """ - author: Speech Lab, Alibaba Group, China - - """ - - def __init__( - self, - input_size: int, - output_size: int = 256, - attention_heads: int = 4, - linear_units: int = 2048, - num_blocks: int = 6, - dropout_rate: float = 0.1, - positional_dropout_rate: float = 0.1, - attention_dropout_rate: float = 0.0, - input_layer: Optional[str] = "conv2d", - pos_enc_class=SinusoidalPositionEncoder, - normalize_before: bool = True, - concat_after: bool = False, - positionwise_layer_type: str = "linear", - positionwise_conv_kernel_size: int = 1, - padding_idx: int = -1, - interctc_layer_idx: List[int] = [], - interctc_use_conditioning: bool = False, - kernel_size : int = 11, - sanm_shfit : int = 0, - selfattention_layer_type: str = "sanm", - ): - assert check_argument_types() - super().__init__() - self._output_size = output_size - - if input_layer == "linear": - self.embed = torch.nn.Sequential( - torch.nn.Linear(input_size, output_size), - torch.nn.LayerNorm(output_size), - torch.nn.Dropout(dropout_rate), - torch.nn.ReLU(), - pos_enc_class(output_size, positional_dropout_rate), - ) - elif input_layer == "conv2d": - self.embed = Conv2dSubsampling(input_size, output_size, dropout_rate) - elif input_layer == "conv2d2": - self.embed = Conv2dSubsampling2(input_size, output_size, dropout_rate) - elif input_layer == "conv2d6": - self.embed = Conv2dSubsampling6(input_size, output_size, dropout_rate) - elif input_layer == "conv2d8": - self.embed = Conv2dSubsampling8(input_size, output_size, dropout_rate) - elif input_layer == "embed": - self.embed = torch.nn.Sequential( - torch.nn.Embedding(input_size, output_size, padding_idx=padding_idx), - SinusoidalPositionEncoder(), - ) - elif input_layer is None: - if input_size == output_size: - self.embed = None - else: - self.embed = torch.nn.Linear(input_size, output_size) - elif input_layer == "pe": - self.embed = SinusoidalPositionEncoder() - else: - raise ValueError("unknown input_layer: " + input_layer) - self.normalize_before = normalize_before - if positionwise_layer_type == "linear": - positionwise_layer = PositionwiseFeedForward - positionwise_layer_args = ( - output_size, - linear_units, - dropout_rate, - ) - elif positionwise_layer_type == "conv1d": - positionwise_layer = MultiLayeredConv1d - positionwise_layer_args = ( - output_size, - linear_units, - positionwise_conv_kernel_size, - dropout_rate, - ) - elif positionwise_layer_type == "conv1d-linear": - positionwise_layer = Conv1dLinear - positionwise_layer_args = ( - output_size, - linear_units, - positionwise_conv_kernel_size, - dropout_rate, - ) - else: - raise NotImplementedError("Support only linear or conv1d.") - - if selfattention_layer_type == "selfattn": - encoder_selfattn_layer = MultiHeadedAttention - encoder_selfattn_layer_args = ( - attention_heads, - output_size, - attention_dropout_rate, - ) - - elif selfattention_layer_type == "sanm": - self.encoder_selfattn_layer = MultiHeadedAttentionSANM - encoder_selfattn_layer_args0 = ( - attention_heads, - input_size, - output_size, - attention_dropout_rate, - kernel_size, - sanm_shfit, - ) - - encoder_selfattn_layer_args = ( - attention_heads, - output_size, - output_size, - attention_dropout_rate, - kernel_size, - sanm_shfit, - ) - - self.encoders0 = repeat( - 1, - lambda lnum: EncoderLayerSANM( - input_size, - output_size, - self.encoder_selfattn_layer(*encoder_selfattn_layer_args0), - positionwise_layer(*positionwise_layer_args), - dropout_rate, - normalize_before, - concat_after, - ), - ) - - self.encoders = repeat( - num_blocks-1, - lambda lnum: EncoderLayerSANM( - output_size, - output_size, - self.encoder_selfattn_layer(*encoder_selfattn_layer_args), - positionwise_layer(*positionwise_layer_args), - dropout_rate, - normalize_before, - concat_after, - ), - ) - if self.normalize_before: - self.after_norm = LayerNorm(output_size) - - self.interctc_layer_idx = interctc_layer_idx - if len(interctc_layer_idx) > 0: - assert 0 < min(interctc_layer_idx) and max(interctc_layer_idx) < num_blocks - self.interctc_use_conditioning = interctc_use_conditioning - self.conditioning_layer = None - self.dropout = nn.Dropout(dropout_rate) - - def output_size(self) -> int: - return self._output_size - - def forward( - self, - xs_pad: torch.Tensor, - ilens: torch.Tensor, - prev_states: torch.Tensor = None, - ctc: CTC = None, - ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: - """Embed positions in tensor. - - Args: - xs_pad: input tensor (B, L, D) - ilens: input length (B) - prev_states: Not to be used now. - Returns: - position embedded tensor and mask - """ - masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device) - xs_pad *= self.output_size()**0.5 - if self.embed is None: - xs_pad = xs_pad - elif ( - isinstance(self.embed, Conv2dSubsampling) - or isinstance(self.embed, Conv2dSubsampling2) - or isinstance(self.embed, Conv2dSubsampling6) - or isinstance(self.embed, Conv2dSubsampling8) - ): - short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1)) - if short_status: - raise TooShortUttError( - f"has {xs_pad.size(1)} frames and is too short for subsampling " - + f"(it needs more than {limit_size} frames), return empty results", - xs_pad.size(1), - limit_size, - ) - xs_pad, masks = self.embed(xs_pad, masks) - else: - xs_pad = self.embed(xs_pad) - - # xs_pad = self.dropout(xs_pad) - encoder_outs = self.encoders0(xs_pad, masks) - xs_pad, masks = encoder_outs[0], encoder_outs[1] - intermediate_outs = [] - if len(self.interctc_layer_idx) == 0: - encoder_outs = self.encoders(xs_pad, masks) - xs_pad, masks = encoder_outs[0], encoder_outs[1] - else: - for layer_idx, encoder_layer in enumerate(self.encoders): - encoder_outs = encoder_layer(xs_pad, masks) - xs_pad, masks = encoder_outs[0], encoder_outs[1] - - if layer_idx + 1 in self.interctc_layer_idx: - encoder_out = xs_pad - - # intermediate outputs are also normalized - if self.normalize_before: - encoder_out = self.after_norm(encoder_out) - - intermediate_outs.append((layer_idx + 1, encoder_out)) - - if self.interctc_use_conditioning: - ctc_out = ctc.softmax(encoder_out) - xs_pad = xs_pad + self.conditioning_layer(ctc_out) - - if self.normalize_before: - xs_pad = self.after_norm(xs_pad) - - olens = masks.squeeze(1).sum(1) - if len(intermediate_outs) > 0: - return (xs_pad, intermediate_outs), olens, None - return xs_pad, olens, None - -class SANMVadEncoder(AbsEncoder): - """ - author: Speech Lab, Alibaba Group, China - - """ - - def __init__( - self, - input_size: int, - output_size: int = 256, - attention_heads: int = 4, - linear_units: int = 2048, - num_blocks: int = 6, - dropout_rate: float = 0.1, - positional_dropout_rate: float = 0.1, - attention_dropout_rate: float = 0.0, - input_layer: Optional[str] = "conv2d", - pos_enc_class=SinusoidalPositionEncoder, - normalize_before: bool = True, - concat_after: bool = False, - positionwise_layer_type: str = "linear", - positionwise_conv_kernel_size: int = 1, - padding_idx: int = -1, - interctc_layer_idx: List[int] = [], - interctc_use_conditioning: bool = False, - kernel_size : int = 11, - sanm_shfit : int = 0, - selfattention_layer_type: str = "sanm", - ): - assert check_argument_types() - super().__init__() - self._output_size = output_size - - if input_layer == "linear": - self.embed = torch.nn.Sequential( - torch.nn.Linear(input_size, output_size), - torch.nn.LayerNorm(output_size), - torch.nn.Dropout(dropout_rate), - torch.nn.ReLU(), - pos_enc_class(output_size, positional_dropout_rate), - ) - elif input_layer == "conv2d": - self.embed = Conv2dSubsampling(input_size, output_size, dropout_rate) - elif input_layer == "conv2d2": - self.embed = Conv2dSubsampling2(input_size, output_size, dropout_rate) - elif input_layer == "conv2d6": - self.embed = Conv2dSubsampling6(input_size, output_size, dropout_rate) - elif input_layer == "conv2d8": - self.embed = Conv2dSubsampling8(input_size, output_size, dropout_rate) - elif input_layer == "embed": - self.embed = torch.nn.Sequential( - torch.nn.Embedding(input_size, output_size, padding_idx=padding_idx), - SinusoidalPositionEncoder(), - ) - elif input_layer is None: - if input_size == output_size: - self.embed = None - else: - self.embed = torch.nn.Linear(input_size, output_size) - elif input_layer == "pe": - self.embed = SinusoidalPositionEncoder() - else: - raise ValueError("unknown input_layer: " + input_layer) - self.normalize_before = normalize_before - if positionwise_layer_type == "linear": - positionwise_layer = PositionwiseFeedForward - positionwise_layer_args = ( - output_size, - linear_units, - dropout_rate, - ) - elif positionwise_layer_type == "conv1d": - positionwise_layer = MultiLayeredConv1d - positionwise_layer_args = ( - output_size, - linear_units, - positionwise_conv_kernel_size, - dropout_rate, - ) - elif positionwise_layer_type == "conv1d-linear": - positionwise_layer = Conv1dLinear - positionwise_layer_args = ( - output_size, - linear_units, - positionwise_conv_kernel_size, - dropout_rate, - ) - else: - raise NotImplementedError("Support only linear or conv1d.") - - if selfattention_layer_type == "selfattn": - encoder_selfattn_layer = MultiHeadedAttention - encoder_selfattn_layer_args = ( - attention_heads, - output_size, - attention_dropout_rate, - ) - - elif selfattention_layer_type == "sanm": - self.encoder_selfattn_layer = MultiHeadedAttentionSANMwithMask - encoder_selfattn_layer_args0 = ( - attention_heads, - input_size, - output_size, - attention_dropout_rate, - kernel_size, - sanm_shfit, - ) - - encoder_selfattn_layer_args = ( - attention_heads, - output_size, - output_size, - attention_dropout_rate, - kernel_size, - sanm_shfit, - ) - - self.encoders0 = repeat( - 1, - lambda lnum: EncoderLayerSANM( - input_size, - output_size, - self.encoder_selfattn_layer(*encoder_selfattn_layer_args0), - positionwise_layer(*positionwise_layer_args), - dropout_rate, - normalize_before, - concat_after, - ), - ) - - self.encoders = repeat( - num_blocks-1, - lambda lnum: EncoderLayerSANM( - output_size, - output_size, - self.encoder_selfattn_layer(*encoder_selfattn_layer_args), - positionwise_layer(*positionwise_layer_args), - dropout_rate, - normalize_before, - concat_after, - ), - ) - if self.normalize_before: - self.after_norm = LayerNorm(output_size) - - self.interctc_layer_idx = interctc_layer_idx - if len(interctc_layer_idx) > 0: - assert 0 < min(interctc_layer_idx) and max(interctc_layer_idx) < num_blocks - self.interctc_use_conditioning = interctc_use_conditioning - self.conditioning_layer = None - self.dropout = nn.Dropout(dropout_rate) - - def output_size(self) -> int: - return self._output_size - - def forward( - self, - xs_pad: torch.Tensor, - ilens: torch.Tensor, - vad_indexes: torch.Tensor, - prev_states: torch.Tensor = None, - ctc: CTC = None, - ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: - """Embed positions in tensor. - - Args: - xs_pad: input tensor (B, L, D) - ilens: input length (B) - prev_states: Not to be used now. - Returns: - position embedded tensor and mask - """ - masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device) - sub_masks = subsequent_mask(masks.size(-1), device=xs_pad.device).unsqueeze(0) - no_future_masks = masks & sub_masks - xs_pad *= self.output_size()**0.5 - if self.embed is None: - xs_pad = xs_pad - elif (isinstance(self.embed, Conv2dSubsampling) or isinstance(self.embed, Conv2dSubsampling2) - or isinstance(self.embed, Conv2dSubsampling6) or isinstance(self.embed, Conv2dSubsampling8)): - short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1)) - if short_status: - raise TooShortUttError( - f"has {xs_pad.size(1)} frames and is too short for subsampling " + - f"(it needs more than {limit_size} frames), return empty results", - xs_pad.size(1), - limit_size, - ) - xs_pad, masks = self.embed(xs_pad, masks) - else: - xs_pad = self.embed(xs_pad) - - # xs_pad = self.dropout(xs_pad) - mask_tup0 = [masks, no_future_masks] - encoder_outs = self.encoders0(xs_pad, mask_tup0) - xs_pad, _ = encoder_outs[0], encoder_outs[1] - intermediate_outs = [] - #if len(self.interctc_layer_idx) == 0: - if False: - # Here, we should not use the repeat operation to do it for all layers. - encoder_outs = self.encoders(xs_pad, masks) - xs_pad, masks = encoder_outs[0], encoder_outs[1] - else: - for layer_idx, encoder_layer in enumerate(self.encoders): - if layer_idx + 1 == len(self.encoders): - # This is last layer. - coner_mask = torch.ones(masks.size(0), - masks.size(-1), - masks.size(-1), - device=xs_pad.device, - dtype=torch.bool) - for word_index, length in enumerate(ilens): - coner_mask[word_index, :, :] = vad_mask(masks.size(-1), - vad_indexes[word_index], - device=xs_pad.device) - layer_mask = masks & coner_mask - else: - layer_mask = no_future_masks - mask_tup1 = [masks, layer_mask] - encoder_outs = encoder_layer(xs_pad, mask_tup1) - xs_pad, layer_mask = encoder_outs[0], encoder_outs[1] - - if layer_idx + 1 in self.interctc_layer_idx: - encoder_out = xs_pad - - # intermediate outputs are also normalized - if self.normalize_before: - encoder_out = self.after_norm(encoder_out) - - intermediate_outs.append((layer_idx + 1, encoder_out)) - - if self.interctc_use_conditioning: - ctc_out = ctc.softmax(encoder_out) - xs_pad = xs_pad + self.conditioning_layer(ctc_out) - - if self.normalize_before: - xs_pad = self.after_norm(xs_pad) - - olens = masks.squeeze(1).sum(1) - if len(intermediate_outs) > 0: - return (xs_pad, intermediate_outs), olens, None - return xs_pad, olens, None - diff --git a/funasr/punctuation/text_preprocessor.py b/funasr/punctuation/text_preprocessor.py index c9c4bac57..8b1378917 100644 --- a/funasr/punctuation/text_preprocessor.py +++ b/funasr/punctuation/text_preprocessor.py @@ -1,12 +1 @@ -def split_to_mini_sentence(words: list, word_limit: int = 20): - assert word_limit > 1 - if len(words) <= word_limit: - return [words] - sentences = [] - length = len(words) - sentence_len = length // word_limit - for i in range(sentence_len): - sentences.append(words[i * word_limit:(i + 1) * word_limit]) - if length % word_limit > 0: - sentences.append(words[sentence_len * word_limit:]) - return sentences + diff --git a/funasr/tasks/lm.py b/funasr/tasks/lm.py index 608c1d3eb..dc8fd3e25 100644 --- a/funasr/tasks/lm.py +++ b/funasr/tasks/lm.py @@ -15,7 +15,7 @@ from typeguard import check_return_type from funasr.datasets.collate_fn import CommonCollateFn from funasr.datasets.preprocessor import CommonPreprocessor from funasr.lm.abs_model import AbsLM -from funasr.lm.espnet_model import ESPnetLanguageModel +from funasr.lm.espnet_model import LanguageModel from funasr.lm.seq_rnn_lm import SequentialRNNLM from funasr.lm.transformer_lm import TransformerLM from funasr.tasks.abs_task import AbsTask @@ -83,7 +83,7 @@ class LMTask(AbsTask): group.add_argument( "--model_conf", action=NestedDictAction, - default=get_default_kwargs(ESPnetLanguageModel), + default=get_default_kwargs(LanguageModel), help="The keyword arguments for model class.", ) @@ -178,7 +178,7 @@ class LMTask(AbsTask): return retval @classmethod - def build_model(cls, args: argparse.Namespace) -> ESPnetLanguageModel: + def build_model(cls, args: argparse.Namespace) -> LanguageModel: assert check_argument_types() if isinstance(args.token_list, str): with open(args.token_list, encoding="utf-8") as f: @@ -201,7 +201,7 @@ class LMTask(AbsTask): # 2. Build ESPnetModel # Assume the last-id is sos_and_eos - model = ESPnetLanguageModel(lm=lm, vocab_size=vocab_size, **args.model_conf) + model = LanguageModel(lm=lm, vocab_size=vocab_size, **args.model_conf) # 3. Initialize if args.init is not None: diff --git a/funasr/tasks/punctuation.py b/funasr/tasks/punctuation.py index ea1e10284..0170f28a8 100644 --- a/funasr/tasks/punctuation.py +++ b/funasr/tasks/punctuation.py @@ -14,10 +14,10 @@ from typeguard import check_return_type from funasr.datasets.collate_fn import CommonCollateFn from funasr.datasets.preprocessor import PuncTrainTokenizerCommonPreprocessor -from funasr.punctuation.abs_model import AbsPunctuation -from funasr.punctuation.espnet_model import ESPnetPunctuationModel -from funasr.punctuation.target_delay_transformer import TargetDelayTransformer -from funasr.punctuation.vad_realtime_transformer import VadRealtimeTransformer +from funasr.train.abs_model import AbsPunctuation +from funasr.train.abs_model import PunctuationModel +from funasr.models.target_delay_transformer import TargetDelayTransformer +from funasr.models.vad_realtime_transformer import VadRealtimeTransformer from funasr.tasks.abs_task import AbsTask from funasr.text.phoneme_tokenizer import g2p_choices from funasr.torch_utils.initialize import initialize @@ -79,7 +79,7 @@ class PunctuationTask(AbsTask): group.add_argument( "--model_conf", action=NestedDictAction, - default=get_default_kwargs(ESPnetPunctuationModel), + default=get_default_kwargs(PunctuationModel), help="The keyword arguments for model class.", ) @@ -183,7 +183,7 @@ class PunctuationTask(AbsTask): return retval @classmethod - def build_model(cls, args: argparse.Namespace) -> ESPnetPunctuationModel: + def build_model(cls, args: argparse.Namespace) -> PunctuationModel: assert check_argument_types() if isinstance(args.token_list, str): with open(args.token_list, encoding="utf-8") as f: @@ -218,7 +218,7 @@ class PunctuationTask(AbsTask): # Assume the last-id is sos_and_eos if "punc_weight" in args.model_conf: args.model_conf.pop("punc_weight") - model = ESPnetPunctuationModel(punc_model=punc, vocab_size=vocab_size, punc_weight=punc_weight_list, **args.model_conf) + model = PunctuationModel(punc_model=punc, vocab_size=vocab_size, punc_weight=punc_weight_list, **args.model_conf) # FIXME(kamo): Should be done in model? # 3. Initialize diff --git a/funasr/punctuation/espnet_model.py b/funasr/train/abs_model.py similarity index 86% rename from funasr/punctuation/espnet_model.py rename to funasr/train/abs_model.py index 7266b387d..8bfba4513 100644 --- a/funasr/punctuation/espnet_model.py +++ b/funasr/train/abs_model.py @@ -1,3 +1,9 @@ +from abc import ABC +from abc import abstractmethod +from typing import Tuple + +import torch + from typing import Dict from typing import Optional from typing import Tuple @@ -7,13 +13,34 @@ import torch.nn.functional as F from typeguard import check_argument_types from funasr.modules.nets_utils import make_pad_mask -from funasr.punctuation.abs_model import AbsPunctuation from funasr.torch_utils.device_funcs import force_gatherable from funasr.train.abs_espnet_model import AbsESPnetModel +from funasr.modules.scorers.scorer_interface import BatchScorerInterface -class ESPnetPunctuationModel(AbsESPnetModel): +class AbsPunctuation(torch.nn.Module, BatchScorerInterface, ABC): + """The abstract class + + To share the loss calculation way among different models, + We uses delegate pattern here: + The instance of this class should be passed to "LanguageModel" + + This "model" is one of mediator objects for "Task" class. + + """ + + @abstractmethod + def forward(self, input: torch.Tensor, hidden: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + raise NotImplementedError + + @abstractmethod + def with_vad(self) -> bool: + raise NotImplementedError + + +class PunctuationModel(AbsESPnetModel): + def __init__(self, punc_model: AbsPunctuation, vocab_size: int, ignore_id: int = 0, punc_weight: list = None): assert check_argument_types() super().__init__() @@ -21,12 +48,12 @@ class ESPnetPunctuationModel(AbsESPnetModel): self.punc_weight = torch.Tensor(punc_weight) self.sos = 1 self.eos = 2 - + # ignore_id may be assumed as 0, shared with CTC-blank symbol for ASR. self.ignore_id = ignore_id - #if self.punc_model.with_vad(): + # if self.punc_model.with_vad(): # print("This is a vad puncuation model.") - + def nll( self, text: torch.Tensor, @@ -54,7 +81,7 @@ class ESPnetPunctuationModel(AbsESPnetModel): else: text = text[:, :max_length] punc = punc[:, :max_length] - + if self.punc_model.with_vad(): # Should be VadRealtimeTransformer assert vad_indexes is not None @@ -62,7 +89,7 @@ class ESPnetPunctuationModel(AbsESPnetModel): else: # Should be TargetDelayTransformer, y, _ = self.punc_model(text, text_lengths) - + # Calc negative log likelihood # nll: (BxL,) if self.training == False: @@ -75,7 +102,8 @@ class ESPnetPunctuationModel(AbsESPnetModel): return nll, text_lengths else: self.punc_weight = self.punc_weight.to(punc.device) - nll = F.cross_entropy(y.view(-1, y.shape[-1]), punc.view(-1), self.punc_weight, reduction="none", ignore_index=self.ignore_id) + nll = F.cross_entropy(y.view(-1, y.shape[-1]), punc.view(-1), self.punc_weight, reduction="none", + ignore_index=self.ignore_id) # nll: (BxL,) -> (BxL,) if max_length is None: nll.masked_fill_(make_pad_mask(text_lengths).to(nll.device).view(-1), 0.0) @@ -87,7 +115,7 @@ class ESPnetPunctuationModel(AbsESPnetModel): # nll: (BxL,) -> (B, L) nll = nll.view(batch_size, -1) return nll, text_lengths - + def batchify_nll(self, text: torch.Tensor, punc: torch.Tensor, @@ -113,7 +141,7 @@ class ESPnetPunctuationModel(AbsESPnetModel): nlls = [] x_lengths = [] max_length = text_lengths.max() - + start_idx = 0 while True: end_idx = min(start_idx + batch_size, total_num) @@ -132,7 +160,7 @@ class ESPnetPunctuationModel(AbsESPnetModel): assert nll.size(0) == total_num assert x_lengths.size(0) == total_num return nll, x_lengths - + def forward( self, text: torch.Tensor, @@ -146,15 +174,15 @@ class ESPnetPunctuationModel(AbsESPnetModel): ntokens = y_lengths.sum() loss = nll.sum() / ntokens stats = dict(loss=loss.detach()) - + # force_gatherable: to-device and to-tensor if scalar for DataParallel loss, stats, weight = force_gatherable((loss, stats, ntokens), loss.device) return loss, stats, weight - + def collect_feats(self, text: torch.Tensor, punc: torch.Tensor, text_lengths: torch.Tensor) -> Dict[str, torch.Tensor]: return {} - + def inference(self, text: torch.Tensor, text_lengths: torch.Tensor,