mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
export
This commit is contained in:
parent
3cd71a385a
commit
d0cd484fdc
@ -23,7 +23,7 @@ from funasr.torch_utils.set_all_random_seed import set_all_random_seed
|
||||
from funasr.utils import config_argparse
|
||||
from funasr.utils.types import str2triple_str
|
||||
from funasr.utils.types import str_or_none
|
||||
from funasr.punctuation.text_preprocessor import split_to_mini_sentence
|
||||
from funasr.datasets.preprocessor import split_to_mini_sentence
|
||||
|
||||
|
||||
class Text2Punc:
|
||||
|
||||
@ -23,7 +23,7 @@ from funasr.torch_utils.set_all_random_seed import set_all_random_seed
|
||||
from funasr.utils import config_argparse
|
||||
from funasr.utils.types import str2triple_str
|
||||
from funasr.utils.types import str_or_none
|
||||
from funasr.punctuation.text_preprocessor import split_to_mini_sentence
|
||||
from funasr.datasets.preprocessor import split_to_mini_sentence
|
||||
|
||||
|
||||
class Text2Punc:
|
||||
|
||||
@ -800,3 +800,17 @@ class PuncTrainTokenizerCommonPreprocessor(CommonPreprocessor):
|
||||
data[self.vad_name] = np.array([vad], dtype=np.int64)
|
||||
text_ints = self.token_id_converter[i].tokens2ids(tokens)
|
||||
data[text_name] = np.array(text_ints, dtype=np.int64)
|
||||
|
||||
|
||||
def split_to_mini_sentence(words: list, word_limit: int = 20):
|
||||
assert word_limit > 1
|
||||
if len(words) <= word_limit:
|
||||
return [words]
|
||||
sentences = []
|
||||
length = len(words)
|
||||
sentence_len = length // word_limit
|
||||
for i in range(sentence_len):
|
||||
sentences.append(words[i * word_limit:(i + 1) * word_limit])
|
||||
if length % word_limit > 0:
|
||||
sentences.append(words[sentence_len * word_limit:])
|
||||
return sentences
|
||||
@ -3,10 +3,10 @@ from funasr.export.models.e2e_asr_paraformer import Paraformer as Paraformer_exp
|
||||
from funasr.export.models.e2e_asr_paraformer import BiCifParaformer as BiCifParaformer_export
|
||||
from funasr.models.e2e_vad import E2EVadModel
|
||||
from funasr.export.models.e2e_vad import E2EVadModel as E2EVadModel_export
|
||||
from funasr.punctuation.target_delay_transformer import TargetDelayTransformer
|
||||
from funasr.models.target_delay_transformer import TargetDelayTransformer
|
||||
from funasr.export.models.target_delay_transformer import TargetDelayTransformer as TargetDelayTransformer_export
|
||||
from funasr.punctuation.espnet_model import ESPnetPunctuationModel
|
||||
from funasr.punctuation.vad_realtime_transformer import VadRealtimeTransformer
|
||||
from funasr.train.abs_model import PunctuationModel
|
||||
from funasr.models.vad_realtime_transformer import VadRealtimeTransformer
|
||||
from funasr.export.models.vad_realtime_transformer import VadRealtimeTransformer as VadRealtimeTransformer_export
|
||||
|
||||
def get_model(model, export_config=None):
|
||||
@ -16,7 +16,7 @@ def get_model(model, export_config=None):
|
||||
return Paraformer_export(model, **export_config)
|
||||
elif isinstance(model, E2EVadModel):
|
||||
return E2EVadModel_export(model, **export_config)
|
||||
elif isinstance(model, ESPnetPunctuationModel):
|
||||
elif isinstance(model, PunctuationModel):
|
||||
if isinstance(model.punc_model, TargetDelayTransformer):
|
||||
return TargetDelayTransformer_export(model.punc_model, **export_config)
|
||||
elif isinstance(model.punc_model, VadRealtimeTransformer):
|
||||
|
||||
@ -1,18 +1,8 @@
|
||||
from typing import Any
|
||||
from typing import List
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from funasr.export.utils.torch_function import MakePadMask
|
||||
from funasr.export.utils.torch_function import sequence_mask
|
||||
#from funasr.models.encoder.sanm_encoder import SANMEncoder as Encoder
|
||||
from funasr.punctuation.sanm_encoder import SANMEncoder
|
||||
from funasr.export.models.encoder.sanm_encoder import SANMEncoder as SANMEncoder_export
|
||||
from funasr.punctuation.abs_model import AbsPunctuation
|
||||
|
||||
|
||||
class TargetDelayTransformer(nn.Module):
|
||||
|
||||
def __init__(
|
||||
@ -32,85 +22,10 @@ class TargetDelayTransformer(nn.Module):
|
||||
self.feats_dim = self.embed.embedding_dim
|
||||
self.num_embeddings = self.embed.num_embeddings
|
||||
self.model_name = model_name
|
||||
from typing import Any
|
||||
from typing import List
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from funasr.export.utils.torch_function import MakePadMask
|
||||
from funasr.export.utils.torch_function import sequence_mask
|
||||
# from funasr.models.encoder.sanm_encoder import SANMEncoder as Encoder
|
||||
from funasr.punctuation.sanm_encoder import SANMEncoder
|
||||
from funasr.models.encoder.sanm_encoder import SANMEncoder
|
||||
from funasr.export.models.encoder.sanm_encoder import SANMEncoder as SANMEncoder_export
|
||||
from funasr.punctuation.abs_model import AbsPunctuation
|
||||
|
||||
# class TargetDelayTransformer(nn.Module):
|
||||
#
|
||||
# def __init__(
|
||||
# self,
|
||||
# model,
|
||||
# max_seq_len=512,
|
||||
# model_name='punc_model',
|
||||
# **kwargs,
|
||||
# ):
|
||||
# super().__init__()
|
||||
# onnx = False
|
||||
# if "onnx" in kwargs:
|
||||
# onnx = kwargs["onnx"]
|
||||
# self.embed = model.embed
|
||||
# self.decoder = model.decoder
|
||||
# self.model = model
|
||||
# self.feats_dim = self.embed.embedding_dim
|
||||
# self.num_embeddings = self.embed.num_embeddings
|
||||
# self.model_name = model_name
|
||||
#
|
||||
# if isinstance(model.encoder, SANMEncoder):
|
||||
# self.encoder = SANMEncoder_export(model.encoder, onnx=onnx)
|
||||
# else:
|
||||
# assert False, "Only support samn encode."
|
||||
#
|
||||
# def forward(self, input: torch.Tensor, text_lengths: torch.Tensor) -> Tuple[torch.Tensor, None]:
|
||||
# """Compute loss value from buffer sequences.
|
||||
#
|
||||
# Args:
|
||||
# input (torch.Tensor): Input ids. (batch, len)
|
||||
# hidden (torch.Tensor): Target ids. (batch, len)
|
||||
#
|
||||
# """
|
||||
# x = self.embed(input)
|
||||
# # mask = self._target_mask(input)
|
||||
# h, _ = self.encoder(x, text_lengths)
|
||||
# y = self.decoder(h)
|
||||
# return y
|
||||
#
|
||||
# def get_dummy_inputs(self):
|
||||
# length = 120
|
||||
# text_indexes = torch.randint(0, self.embed.num_embeddings, (2, length))
|
||||
# text_lengths = torch.tensor([length - 20, length], dtype=torch.int32)
|
||||
# return (text_indexes, text_lengths)
|
||||
#
|
||||
# def get_input_names(self):
|
||||
# return ['input', 'text_lengths']
|
||||
#
|
||||
# def get_output_names(self):
|
||||
# return ['logits']
|
||||
#
|
||||
# def get_dynamic_axes(self):
|
||||
# return {
|
||||
# 'input': {
|
||||
# 0: 'batch_size',
|
||||
# 1: 'feats_length'
|
||||
# },
|
||||
# 'text_lengths': {
|
||||
# 0: 'batch_size',
|
||||
# },
|
||||
# 'logits': {
|
||||
# 0: 'batch_size',
|
||||
# 1: 'logits_length'
|
||||
# },
|
||||
# }
|
||||
|
||||
if isinstance(model.encoder, SANMEncoder):
|
||||
self.encoder = SANMEncoder_export(model.encoder, onnx=onnx)
|
||||
|
||||
@ -1,14 +1,9 @@
|
||||
from typing import Any
|
||||
from typing import List
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from funasr.modules.embedding import SinusoidalPositionEncoder
|
||||
from funasr.punctuation.sanm_encoder import SANMVadEncoder as Encoder
|
||||
from funasr.punctuation.abs_model import AbsPunctuation
|
||||
from funasr.punctuation.sanm_encoder import SANMVadEncoder
|
||||
from funasr.models.encoder.sanm_encoder import SANMVadEncoder
|
||||
from funasr.export.models.encoder.sanm_encoder import SANMVadEncoder as SANMVadEncoder_export
|
||||
|
||||
class VadRealtimeTransformer(nn.Module):
|
||||
|
||||
@ -12,7 +12,7 @@ from funasr.torch_utils.device_funcs import force_gatherable
|
||||
from funasr.train.abs_espnet_model import AbsESPnetModel
|
||||
|
||||
|
||||
class ESPnetLanguageModel(AbsESPnetModel):
|
||||
class LanguageModel(AbsESPnetModel):
|
||||
def __init__(self, lm: AbsLM, vocab_size: int, ignore_id: int = 0):
|
||||
assert check_argument_types()
|
||||
super().__init__()
|
||||
|
||||
@ -10,7 +10,7 @@ from funasr.modules.streaming_utils.chunk_utilis import overlap_chunk
|
||||
from typeguard import check_argument_types
|
||||
import numpy as np
|
||||
from funasr.modules.nets_utils import make_pad_mask
|
||||
from funasr.modules.attention import MultiHeadedAttention, MultiHeadedAttentionSANM
|
||||
from funasr.modules.attention import MultiHeadedAttention, MultiHeadedAttentionSANM, MultiHeadedAttentionSANMwithMask
|
||||
from funasr.modules.embedding import SinusoidalPositionEncoder
|
||||
from funasr.modules.layer_norm import LayerNorm
|
||||
from funasr.modules.multi_layer_conv import Conv1dLinear
|
||||
@ -27,7 +27,7 @@ from funasr.modules.subsampling import TooShortUttError
|
||||
from funasr.modules.subsampling import check_short_utt
|
||||
from funasr.models.ctc import CTC
|
||||
from funasr.models.encoder.abs_encoder import AbsEncoder
|
||||
|
||||
from funasr.modules.mask import subsequent_mask, vad_mask
|
||||
|
||||
class EncoderLayerSANM(nn.Module):
|
||||
def __init__(
|
||||
@ -958,3 +958,231 @@ class SANMEncoderChunkOpt(AbsEncoder):
|
||||
var_dict_tf[name_tf].shape))
|
||||
|
||||
return var_dict_torch_update
|
||||
|
||||
|
||||
class SANMVadEncoder(AbsEncoder):
|
||||
"""
|
||||
author: Speech Lab, Alibaba Group, China
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size: int,
|
||||
output_size: int = 256,
|
||||
attention_heads: int = 4,
|
||||
linear_units: int = 2048,
|
||||
num_blocks: int = 6,
|
||||
dropout_rate: float = 0.1,
|
||||
positional_dropout_rate: float = 0.1,
|
||||
attention_dropout_rate: float = 0.0,
|
||||
input_layer: Optional[str] = "conv2d",
|
||||
pos_enc_class=SinusoidalPositionEncoder,
|
||||
normalize_before: bool = True,
|
||||
concat_after: bool = False,
|
||||
positionwise_layer_type: str = "linear",
|
||||
positionwise_conv_kernel_size: int = 1,
|
||||
padding_idx: int = -1,
|
||||
interctc_layer_idx: List[int] = [],
|
||||
interctc_use_conditioning: bool = False,
|
||||
kernel_size : int = 11,
|
||||
sanm_shfit : int = 0,
|
||||
selfattention_layer_type: str = "sanm",
|
||||
):
|
||||
assert check_argument_types()
|
||||
super().__init__()
|
||||
self._output_size = output_size
|
||||
|
||||
if input_layer == "linear":
|
||||
self.embed = torch.nn.Sequential(
|
||||
torch.nn.Linear(input_size, output_size),
|
||||
torch.nn.LayerNorm(output_size),
|
||||
torch.nn.Dropout(dropout_rate),
|
||||
torch.nn.ReLU(),
|
||||
pos_enc_class(output_size, positional_dropout_rate),
|
||||
)
|
||||
elif input_layer == "conv2d":
|
||||
self.embed = Conv2dSubsampling(input_size, output_size, dropout_rate)
|
||||
elif input_layer == "conv2d2":
|
||||
self.embed = Conv2dSubsampling2(input_size, output_size, dropout_rate)
|
||||
elif input_layer == "conv2d6":
|
||||
self.embed = Conv2dSubsampling6(input_size, output_size, dropout_rate)
|
||||
elif input_layer == "conv2d8":
|
||||
self.embed = Conv2dSubsampling8(input_size, output_size, dropout_rate)
|
||||
elif input_layer == "embed":
|
||||
self.embed = torch.nn.Sequential(
|
||||
torch.nn.Embedding(input_size, output_size, padding_idx=padding_idx),
|
||||
SinusoidalPositionEncoder(),
|
||||
)
|
||||
elif input_layer is None:
|
||||
if input_size == output_size:
|
||||
self.embed = None
|
||||
else:
|
||||
self.embed = torch.nn.Linear(input_size, output_size)
|
||||
elif input_layer == "pe":
|
||||
self.embed = SinusoidalPositionEncoder()
|
||||
else:
|
||||
raise ValueError("unknown input_layer: " + input_layer)
|
||||
self.normalize_before = normalize_before
|
||||
if positionwise_layer_type == "linear":
|
||||
positionwise_layer = PositionwiseFeedForward
|
||||
positionwise_layer_args = (
|
||||
output_size,
|
||||
linear_units,
|
||||
dropout_rate,
|
||||
)
|
||||
elif positionwise_layer_type == "conv1d":
|
||||
positionwise_layer = MultiLayeredConv1d
|
||||
positionwise_layer_args = (
|
||||
output_size,
|
||||
linear_units,
|
||||
positionwise_conv_kernel_size,
|
||||
dropout_rate,
|
||||
)
|
||||
elif positionwise_layer_type == "conv1d-linear":
|
||||
positionwise_layer = Conv1dLinear
|
||||
positionwise_layer_args = (
|
||||
output_size,
|
||||
linear_units,
|
||||
positionwise_conv_kernel_size,
|
||||
dropout_rate,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError("Support only linear or conv1d.")
|
||||
|
||||
if selfattention_layer_type == "selfattn":
|
||||
encoder_selfattn_layer = MultiHeadedAttention
|
||||
encoder_selfattn_layer_args = (
|
||||
attention_heads,
|
||||
output_size,
|
||||
attention_dropout_rate,
|
||||
)
|
||||
|
||||
elif selfattention_layer_type == "sanm":
|
||||
self.encoder_selfattn_layer = MultiHeadedAttentionSANMwithMask
|
||||
encoder_selfattn_layer_args0 = (
|
||||
attention_heads,
|
||||
input_size,
|
||||
output_size,
|
||||
attention_dropout_rate,
|
||||
kernel_size,
|
||||
sanm_shfit,
|
||||
)
|
||||
|
||||
encoder_selfattn_layer_args = (
|
||||
attention_heads,
|
||||
output_size,
|
||||
output_size,
|
||||
attention_dropout_rate,
|
||||
kernel_size,
|
||||
sanm_shfit,
|
||||
)
|
||||
|
||||
self.encoders0 = repeat(
|
||||
1,
|
||||
lambda lnum: EncoderLayerSANM(
|
||||
input_size,
|
||||
output_size,
|
||||
self.encoder_selfattn_layer(*encoder_selfattn_layer_args0),
|
||||
positionwise_layer(*positionwise_layer_args),
|
||||
dropout_rate,
|
||||
normalize_before,
|
||||
concat_after,
|
||||
),
|
||||
)
|
||||
|
||||
self.encoders = repeat(
|
||||
num_blocks-1,
|
||||
lambda lnum: EncoderLayerSANM(
|
||||
output_size,
|
||||
output_size,
|
||||
self.encoder_selfattn_layer(*encoder_selfattn_layer_args),
|
||||
positionwise_layer(*positionwise_layer_args),
|
||||
dropout_rate,
|
||||
normalize_before,
|
||||
concat_after,
|
||||
),
|
||||
)
|
||||
if self.normalize_before:
|
||||
self.after_norm = LayerNorm(output_size)
|
||||
|
||||
self.interctc_layer_idx = interctc_layer_idx
|
||||
if len(interctc_layer_idx) > 0:
|
||||
assert 0 < min(interctc_layer_idx) and max(interctc_layer_idx) < num_blocks
|
||||
self.interctc_use_conditioning = interctc_use_conditioning
|
||||
self.conditioning_layer = None
|
||||
self.dropout = nn.Dropout(dropout_rate)
|
||||
|
||||
def output_size(self) -> int:
|
||||
return self._output_size
|
||||
|
||||
def forward(
|
||||
self,
|
||||
xs_pad: torch.Tensor,
|
||||
ilens: torch.Tensor,
|
||||
vad_indexes: torch.Tensor,
|
||||
prev_states: torch.Tensor = None,
|
||||
ctc: CTC = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
||||
"""Embed positions in tensor.
|
||||
|
||||
Args:
|
||||
xs_pad: input tensor (B, L, D)
|
||||
ilens: input length (B)
|
||||
prev_states: Not to be used now.
|
||||
Returns:
|
||||
position embedded tensor and mask
|
||||
"""
|
||||
masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
|
||||
sub_masks = subsequent_mask(masks.size(-1), device=xs_pad.device).unsqueeze(0)
|
||||
no_future_masks = masks & sub_masks
|
||||
xs_pad *= self.output_size()**0.5
|
||||
if self.embed is None:
|
||||
xs_pad = xs_pad
|
||||
elif (isinstance(self.embed, Conv2dSubsampling) or isinstance(self.embed, Conv2dSubsampling2)
|
||||
or isinstance(self.embed, Conv2dSubsampling6) or isinstance(self.embed, Conv2dSubsampling8)):
|
||||
short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
|
||||
if short_status:
|
||||
raise TooShortUttError(
|
||||
f"has {xs_pad.size(1)} frames and is too short for subsampling " +
|
||||
f"(it needs more than {limit_size} frames), return empty results",
|
||||
xs_pad.size(1),
|
||||
limit_size,
|
||||
)
|
||||
xs_pad, masks = self.embed(xs_pad, masks)
|
||||
else:
|
||||
xs_pad = self.embed(xs_pad)
|
||||
|
||||
# xs_pad = self.dropout(xs_pad)
|
||||
mask_tup0 = [masks, no_future_masks]
|
||||
encoder_outs = self.encoders0(xs_pad, mask_tup0)
|
||||
xs_pad, _ = encoder_outs[0], encoder_outs[1]
|
||||
intermediate_outs = []
|
||||
|
||||
|
||||
for layer_idx, encoder_layer in enumerate(self.encoders):
|
||||
if layer_idx + 1 == len(self.encoders):
|
||||
# This is last layer.
|
||||
coner_mask = torch.ones(masks.size(0),
|
||||
masks.size(-1),
|
||||
masks.size(-1),
|
||||
device=xs_pad.device,
|
||||
dtype=torch.bool)
|
||||
for word_index, length in enumerate(ilens):
|
||||
coner_mask[word_index, :, :] = vad_mask(masks.size(-1),
|
||||
vad_indexes[word_index],
|
||||
device=xs_pad.device)
|
||||
layer_mask = masks & coner_mask
|
||||
else:
|
||||
layer_mask = no_future_masks
|
||||
mask_tup1 = [masks, layer_mask]
|
||||
encoder_outs = encoder_layer(xs_pad, mask_tup1)
|
||||
xs_pad, layer_mask = encoder_outs[0], encoder_outs[1]
|
||||
|
||||
if self.normalize_before:
|
||||
xs_pad = self.after_norm(xs_pad)
|
||||
|
||||
olens = masks.squeeze(1).sum(1)
|
||||
if len(intermediate_outs) > 0:
|
||||
return (xs_pad, intermediate_outs), olens, None
|
||||
return xs_pad, olens, None
|
||||
|
||||
@ -5,12 +5,11 @@ from typing import Tuple
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from funasr.modules.embedding import PositionalEncoding
|
||||
from funasr.modules.embedding import SinusoidalPositionEncoder
|
||||
#from funasr.models.encoder.transformer_encoder import TransformerEncoder as Encoder
|
||||
from funasr.punctuation.sanm_encoder import SANMEncoder as Encoder
|
||||
#from funasr.modules.mask import subsequent_n_mask
|
||||
from funasr.punctuation.abs_model import AbsPunctuation
|
||||
from funasr.train.abs_model import AbsPunctuation
|
||||
|
||||
|
||||
class TargetDelayTransformer(AbsPunctuation):
|
||||
@ -7,7 +7,7 @@ import torch.nn as nn
|
||||
|
||||
from funasr.modules.embedding import SinusoidalPositionEncoder
|
||||
from funasr.punctuation.sanm_encoder import SANMVadEncoder as Encoder
|
||||
from funasr.punctuation.abs_model import AbsPunctuation
|
||||
from funasr.train.abs_model import AbsPunctuation
|
||||
|
||||
|
||||
class VadRealtimeTransformer(AbsPunctuation):
|
||||
@ -1,31 +0,0 @@
|
||||
from abc import ABC
|
||||
from abc import abstractmethod
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from funasr.modules.scorers.scorer_interface import BatchScorerInterface
|
||||
|
||||
|
||||
class AbsPunctuation(torch.nn.Module, BatchScorerInterface, ABC):
|
||||
"""The abstract class
|
||||
|
||||
To share the loss calculation way among different models,
|
||||
We uses delegate pattern here:
|
||||
The instance of this class should be passed to "LanguageModel"
|
||||
|
||||
>>> from funasr.punctuation.abs_model import AbsPunctuation
|
||||
>>> punc = AbsPunctuation()
|
||||
>>> model = ESPnetPunctuationModel(punc=punc)
|
||||
|
||||
This "model" is one of mediator objects for "Task" class.
|
||||
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def forward(self, input: torch.Tensor, hidden: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def with_vad(self) -> bool:
|
||||
raise NotImplementedError
|
||||
@ -1,590 +0,0 @@
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import Sequence
|
||||
from typing import Tuple
|
||||
from typing import Union
|
||||
import logging
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from funasr.modules.streaming_utils.chunk_utilis import overlap_chunk
|
||||
from typeguard import check_argument_types
|
||||
import numpy as np
|
||||
from funasr.modules.nets_utils import make_pad_mask
|
||||
from funasr.modules.attention import MultiHeadedAttention, MultiHeadedAttentionSANM, MultiHeadedAttentionSANMwithMask
|
||||
from funasr.modules.embedding import SinusoidalPositionEncoder
|
||||
from funasr.modules.layer_norm import LayerNorm
|
||||
from funasr.modules.multi_layer_conv import Conv1dLinear
|
||||
from funasr.modules.multi_layer_conv import MultiLayeredConv1d
|
||||
from funasr.modules.positionwise_feed_forward import (
|
||||
PositionwiseFeedForward, # noqa: H301
|
||||
)
|
||||
from funasr.modules.repeat import repeat
|
||||
from funasr.modules.subsampling import Conv2dSubsampling
|
||||
from funasr.modules.subsampling import Conv2dSubsampling2
|
||||
from funasr.modules.subsampling import Conv2dSubsampling6
|
||||
from funasr.modules.subsampling import Conv2dSubsampling8
|
||||
from funasr.modules.subsampling import TooShortUttError
|
||||
from funasr.modules.subsampling import check_short_utt
|
||||
from funasr.models.ctc import CTC
|
||||
from funasr.models.encoder.abs_encoder import AbsEncoder
|
||||
|
||||
from funasr.modules.nets_utils import make_pad_mask
|
||||
from funasr.modules.mask import subsequent_mask, vad_mask
|
||||
|
||||
class EncoderLayerSANM(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_size,
|
||||
size,
|
||||
self_attn,
|
||||
feed_forward,
|
||||
dropout_rate,
|
||||
normalize_before=True,
|
||||
concat_after=False,
|
||||
stochastic_depth_rate=0.0,
|
||||
):
|
||||
"""Construct an EncoderLayer object."""
|
||||
super(EncoderLayerSANM, self).__init__()
|
||||
self.self_attn = self_attn
|
||||
self.feed_forward = feed_forward
|
||||
self.norm1 = LayerNorm(in_size)
|
||||
self.norm2 = LayerNorm(size)
|
||||
self.dropout = nn.Dropout(dropout_rate)
|
||||
self.in_size = in_size
|
||||
self.size = size
|
||||
self.normalize_before = normalize_before
|
||||
self.concat_after = concat_after
|
||||
if self.concat_after:
|
||||
self.concat_linear = nn.Linear(size + size, size)
|
||||
self.stochastic_depth_rate = stochastic_depth_rate
|
||||
self.dropout_rate = dropout_rate
|
||||
|
||||
def forward(self, x, mask, cache=None, mask_shfit_chunk=None, mask_att_chunk_encoder=None):
|
||||
"""Compute encoded features.
|
||||
|
||||
Args:
|
||||
x_input (torch.Tensor): Input tensor (#batch, time, size).
|
||||
mask (torch.Tensor): Mask tensor for the input (#batch, time).
|
||||
cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Output tensor (#batch, time, size).
|
||||
torch.Tensor: Mask tensor (#batch, time).
|
||||
|
||||
"""
|
||||
skip_layer = False
|
||||
# with stochastic depth, residual connection `x + f(x)` becomes
|
||||
# `x <- x + 1 / (1 - p) * f(x)` at training time.
|
||||
stoch_layer_coeff = 1.0
|
||||
if self.training and self.stochastic_depth_rate > 0:
|
||||
skip_layer = torch.rand(1).item() < self.stochastic_depth_rate
|
||||
stoch_layer_coeff = 1.0 / (1 - self.stochastic_depth_rate)
|
||||
|
||||
if skip_layer:
|
||||
if cache is not None:
|
||||
x = torch.cat([cache, x], dim=1)
|
||||
return x, mask
|
||||
|
||||
residual = x
|
||||
if self.normalize_before:
|
||||
x = self.norm1(x)
|
||||
|
||||
if self.concat_after:
|
||||
x_concat = torch.cat((x, self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder)), dim=-1)
|
||||
if self.in_size == self.size:
|
||||
x = residual + stoch_layer_coeff * self.concat_linear(x_concat)
|
||||
else:
|
||||
x = stoch_layer_coeff * self.concat_linear(x_concat)
|
||||
else:
|
||||
if self.in_size == self.size:
|
||||
x = residual + stoch_layer_coeff * self.dropout(
|
||||
self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder)
|
||||
)
|
||||
else:
|
||||
x = stoch_layer_coeff * self.dropout(
|
||||
self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder)
|
||||
)
|
||||
if not self.normalize_before:
|
||||
x = self.norm1(x)
|
||||
|
||||
residual = x
|
||||
if self.normalize_before:
|
||||
x = self.norm2(x)
|
||||
x = residual + stoch_layer_coeff * self.dropout(self.feed_forward(x))
|
||||
if not self.normalize_before:
|
||||
x = self.norm2(x)
|
||||
|
||||
|
||||
return x, mask, cache, mask_shfit_chunk, mask_att_chunk_encoder
|
||||
|
||||
class SANMEncoder(AbsEncoder):
|
||||
"""
|
||||
author: Speech Lab, Alibaba Group, China
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size: int,
|
||||
output_size: int = 256,
|
||||
attention_heads: int = 4,
|
||||
linear_units: int = 2048,
|
||||
num_blocks: int = 6,
|
||||
dropout_rate: float = 0.1,
|
||||
positional_dropout_rate: float = 0.1,
|
||||
attention_dropout_rate: float = 0.0,
|
||||
input_layer: Optional[str] = "conv2d",
|
||||
pos_enc_class=SinusoidalPositionEncoder,
|
||||
normalize_before: bool = True,
|
||||
concat_after: bool = False,
|
||||
positionwise_layer_type: str = "linear",
|
||||
positionwise_conv_kernel_size: int = 1,
|
||||
padding_idx: int = -1,
|
||||
interctc_layer_idx: List[int] = [],
|
||||
interctc_use_conditioning: bool = False,
|
||||
kernel_size : int = 11,
|
||||
sanm_shfit : int = 0,
|
||||
selfattention_layer_type: str = "sanm",
|
||||
):
|
||||
assert check_argument_types()
|
||||
super().__init__()
|
||||
self._output_size = output_size
|
||||
|
||||
if input_layer == "linear":
|
||||
self.embed = torch.nn.Sequential(
|
||||
torch.nn.Linear(input_size, output_size),
|
||||
torch.nn.LayerNorm(output_size),
|
||||
torch.nn.Dropout(dropout_rate),
|
||||
torch.nn.ReLU(),
|
||||
pos_enc_class(output_size, positional_dropout_rate),
|
||||
)
|
||||
elif input_layer == "conv2d":
|
||||
self.embed = Conv2dSubsampling(input_size, output_size, dropout_rate)
|
||||
elif input_layer == "conv2d2":
|
||||
self.embed = Conv2dSubsampling2(input_size, output_size, dropout_rate)
|
||||
elif input_layer == "conv2d6":
|
||||
self.embed = Conv2dSubsampling6(input_size, output_size, dropout_rate)
|
||||
elif input_layer == "conv2d8":
|
||||
self.embed = Conv2dSubsampling8(input_size, output_size, dropout_rate)
|
||||
elif input_layer == "embed":
|
||||
self.embed = torch.nn.Sequential(
|
||||
torch.nn.Embedding(input_size, output_size, padding_idx=padding_idx),
|
||||
SinusoidalPositionEncoder(),
|
||||
)
|
||||
elif input_layer is None:
|
||||
if input_size == output_size:
|
||||
self.embed = None
|
||||
else:
|
||||
self.embed = torch.nn.Linear(input_size, output_size)
|
||||
elif input_layer == "pe":
|
||||
self.embed = SinusoidalPositionEncoder()
|
||||
else:
|
||||
raise ValueError("unknown input_layer: " + input_layer)
|
||||
self.normalize_before = normalize_before
|
||||
if positionwise_layer_type == "linear":
|
||||
positionwise_layer = PositionwiseFeedForward
|
||||
positionwise_layer_args = (
|
||||
output_size,
|
||||
linear_units,
|
||||
dropout_rate,
|
||||
)
|
||||
elif positionwise_layer_type == "conv1d":
|
||||
positionwise_layer = MultiLayeredConv1d
|
||||
positionwise_layer_args = (
|
||||
output_size,
|
||||
linear_units,
|
||||
positionwise_conv_kernel_size,
|
||||
dropout_rate,
|
||||
)
|
||||
elif positionwise_layer_type == "conv1d-linear":
|
||||
positionwise_layer = Conv1dLinear
|
||||
positionwise_layer_args = (
|
||||
output_size,
|
||||
linear_units,
|
||||
positionwise_conv_kernel_size,
|
||||
dropout_rate,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError("Support only linear or conv1d.")
|
||||
|
||||
if selfattention_layer_type == "selfattn":
|
||||
encoder_selfattn_layer = MultiHeadedAttention
|
||||
encoder_selfattn_layer_args = (
|
||||
attention_heads,
|
||||
output_size,
|
||||
attention_dropout_rate,
|
||||
)
|
||||
|
||||
elif selfattention_layer_type == "sanm":
|
||||
self.encoder_selfattn_layer = MultiHeadedAttentionSANM
|
||||
encoder_selfattn_layer_args0 = (
|
||||
attention_heads,
|
||||
input_size,
|
||||
output_size,
|
||||
attention_dropout_rate,
|
||||
kernel_size,
|
||||
sanm_shfit,
|
||||
)
|
||||
|
||||
encoder_selfattn_layer_args = (
|
||||
attention_heads,
|
||||
output_size,
|
||||
output_size,
|
||||
attention_dropout_rate,
|
||||
kernel_size,
|
||||
sanm_shfit,
|
||||
)
|
||||
|
||||
self.encoders0 = repeat(
|
||||
1,
|
||||
lambda lnum: EncoderLayerSANM(
|
||||
input_size,
|
||||
output_size,
|
||||
self.encoder_selfattn_layer(*encoder_selfattn_layer_args0),
|
||||
positionwise_layer(*positionwise_layer_args),
|
||||
dropout_rate,
|
||||
normalize_before,
|
||||
concat_after,
|
||||
),
|
||||
)
|
||||
|
||||
self.encoders = repeat(
|
||||
num_blocks-1,
|
||||
lambda lnum: EncoderLayerSANM(
|
||||
output_size,
|
||||
output_size,
|
||||
self.encoder_selfattn_layer(*encoder_selfattn_layer_args),
|
||||
positionwise_layer(*positionwise_layer_args),
|
||||
dropout_rate,
|
||||
normalize_before,
|
||||
concat_after,
|
||||
),
|
||||
)
|
||||
if self.normalize_before:
|
||||
self.after_norm = LayerNorm(output_size)
|
||||
|
||||
self.interctc_layer_idx = interctc_layer_idx
|
||||
if len(interctc_layer_idx) > 0:
|
||||
assert 0 < min(interctc_layer_idx) and max(interctc_layer_idx) < num_blocks
|
||||
self.interctc_use_conditioning = interctc_use_conditioning
|
||||
self.conditioning_layer = None
|
||||
self.dropout = nn.Dropout(dropout_rate)
|
||||
|
||||
def output_size(self) -> int:
|
||||
return self._output_size
|
||||
|
||||
def forward(
|
||||
self,
|
||||
xs_pad: torch.Tensor,
|
||||
ilens: torch.Tensor,
|
||||
prev_states: torch.Tensor = None,
|
||||
ctc: CTC = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
||||
"""Embed positions in tensor.
|
||||
|
||||
Args:
|
||||
xs_pad: input tensor (B, L, D)
|
||||
ilens: input length (B)
|
||||
prev_states: Not to be used now.
|
||||
Returns:
|
||||
position embedded tensor and mask
|
||||
"""
|
||||
masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
|
||||
xs_pad *= self.output_size()**0.5
|
||||
if self.embed is None:
|
||||
xs_pad = xs_pad
|
||||
elif (
|
||||
isinstance(self.embed, Conv2dSubsampling)
|
||||
or isinstance(self.embed, Conv2dSubsampling2)
|
||||
or isinstance(self.embed, Conv2dSubsampling6)
|
||||
or isinstance(self.embed, Conv2dSubsampling8)
|
||||
):
|
||||
short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
|
||||
if short_status:
|
||||
raise TooShortUttError(
|
||||
f"has {xs_pad.size(1)} frames and is too short for subsampling "
|
||||
+ f"(it needs more than {limit_size} frames), return empty results",
|
||||
xs_pad.size(1),
|
||||
limit_size,
|
||||
)
|
||||
xs_pad, masks = self.embed(xs_pad, masks)
|
||||
else:
|
||||
xs_pad = self.embed(xs_pad)
|
||||
|
||||
# xs_pad = self.dropout(xs_pad)
|
||||
encoder_outs = self.encoders0(xs_pad, masks)
|
||||
xs_pad, masks = encoder_outs[0], encoder_outs[1]
|
||||
intermediate_outs = []
|
||||
if len(self.interctc_layer_idx) == 0:
|
||||
encoder_outs = self.encoders(xs_pad, masks)
|
||||
xs_pad, masks = encoder_outs[0], encoder_outs[1]
|
||||
else:
|
||||
for layer_idx, encoder_layer in enumerate(self.encoders):
|
||||
encoder_outs = encoder_layer(xs_pad, masks)
|
||||
xs_pad, masks = encoder_outs[0], encoder_outs[1]
|
||||
|
||||
if layer_idx + 1 in self.interctc_layer_idx:
|
||||
encoder_out = xs_pad
|
||||
|
||||
# intermediate outputs are also normalized
|
||||
if self.normalize_before:
|
||||
encoder_out = self.after_norm(encoder_out)
|
||||
|
||||
intermediate_outs.append((layer_idx + 1, encoder_out))
|
||||
|
||||
if self.interctc_use_conditioning:
|
||||
ctc_out = ctc.softmax(encoder_out)
|
||||
xs_pad = xs_pad + self.conditioning_layer(ctc_out)
|
||||
|
||||
if self.normalize_before:
|
||||
xs_pad = self.after_norm(xs_pad)
|
||||
|
||||
olens = masks.squeeze(1).sum(1)
|
||||
if len(intermediate_outs) > 0:
|
||||
return (xs_pad, intermediate_outs), olens, None
|
||||
return xs_pad, olens, None
|
||||
|
||||
class SANMVadEncoder(AbsEncoder):
|
||||
"""
|
||||
author: Speech Lab, Alibaba Group, China
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size: int,
|
||||
output_size: int = 256,
|
||||
attention_heads: int = 4,
|
||||
linear_units: int = 2048,
|
||||
num_blocks: int = 6,
|
||||
dropout_rate: float = 0.1,
|
||||
positional_dropout_rate: float = 0.1,
|
||||
attention_dropout_rate: float = 0.0,
|
||||
input_layer: Optional[str] = "conv2d",
|
||||
pos_enc_class=SinusoidalPositionEncoder,
|
||||
normalize_before: bool = True,
|
||||
concat_after: bool = False,
|
||||
positionwise_layer_type: str = "linear",
|
||||
positionwise_conv_kernel_size: int = 1,
|
||||
padding_idx: int = -1,
|
||||
interctc_layer_idx: List[int] = [],
|
||||
interctc_use_conditioning: bool = False,
|
||||
kernel_size : int = 11,
|
||||
sanm_shfit : int = 0,
|
||||
selfattention_layer_type: str = "sanm",
|
||||
):
|
||||
assert check_argument_types()
|
||||
super().__init__()
|
||||
self._output_size = output_size
|
||||
|
||||
if input_layer == "linear":
|
||||
self.embed = torch.nn.Sequential(
|
||||
torch.nn.Linear(input_size, output_size),
|
||||
torch.nn.LayerNorm(output_size),
|
||||
torch.nn.Dropout(dropout_rate),
|
||||
torch.nn.ReLU(),
|
||||
pos_enc_class(output_size, positional_dropout_rate),
|
||||
)
|
||||
elif input_layer == "conv2d":
|
||||
self.embed = Conv2dSubsampling(input_size, output_size, dropout_rate)
|
||||
elif input_layer == "conv2d2":
|
||||
self.embed = Conv2dSubsampling2(input_size, output_size, dropout_rate)
|
||||
elif input_layer == "conv2d6":
|
||||
self.embed = Conv2dSubsampling6(input_size, output_size, dropout_rate)
|
||||
elif input_layer == "conv2d8":
|
||||
self.embed = Conv2dSubsampling8(input_size, output_size, dropout_rate)
|
||||
elif input_layer == "embed":
|
||||
self.embed = torch.nn.Sequential(
|
||||
torch.nn.Embedding(input_size, output_size, padding_idx=padding_idx),
|
||||
SinusoidalPositionEncoder(),
|
||||
)
|
||||
elif input_layer is None:
|
||||
if input_size == output_size:
|
||||
self.embed = None
|
||||
else:
|
||||
self.embed = torch.nn.Linear(input_size, output_size)
|
||||
elif input_layer == "pe":
|
||||
self.embed = SinusoidalPositionEncoder()
|
||||
else:
|
||||
raise ValueError("unknown input_layer: " + input_layer)
|
||||
self.normalize_before = normalize_before
|
||||
if positionwise_layer_type == "linear":
|
||||
positionwise_layer = PositionwiseFeedForward
|
||||
positionwise_layer_args = (
|
||||
output_size,
|
||||
linear_units,
|
||||
dropout_rate,
|
||||
)
|
||||
elif positionwise_layer_type == "conv1d":
|
||||
positionwise_layer = MultiLayeredConv1d
|
||||
positionwise_layer_args = (
|
||||
output_size,
|
||||
linear_units,
|
||||
positionwise_conv_kernel_size,
|
||||
dropout_rate,
|
||||
)
|
||||
elif positionwise_layer_type == "conv1d-linear":
|
||||
positionwise_layer = Conv1dLinear
|
||||
positionwise_layer_args = (
|
||||
output_size,
|
||||
linear_units,
|
||||
positionwise_conv_kernel_size,
|
||||
dropout_rate,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError("Support only linear or conv1d.")
|
||||
|
||||
if selfattention_layer_type == "selfattn":
|
||||
encoder_selfattn_layer = MultiHeadedAttention
|
||||
encoder_selfattn_layer_args = (
|
||||
attention_heads,
|
||||
output_size,
|
||||
attention_dropout_rate,
|
||||
)
|
||||
|
||||
elif selfattention_layer_type == "sanm":
|
||||
self.encoder_selfattn_layer = MultiHeadedAttentionSANMwithMask
|
||||
encoder_selfattn_layer_args0 = (
|
||||
attention_heads,
|
||||
input_size,
|
||||
output_size,
|
||||
attention_dropout_rate,
|
||||
kernel_size,
|
||||
sanm_shfit,
|
||||
)
|
||||
|
||||
encoder_selfattn_layer_args = (
|
||||
attention_heads,
|
||||
output_size,
|
||||
output_size,
|
||||
attention_dropout_rate,
|
||||
kernel_size,
|
||||
sanm_shfit,
|
||||
)
|
||||
|
||||
self.encoders0 = repeat(
|
||||
1,
|
||||
lambda lnum: EncoderLayerSANM(
|
||||
input_size,
|
||||
output_size,
|
||||
self.encoder_selfattn_layer(*encoder_selfattn_layer_args0),
|
||||
positionwise_layer(*positionwise_layer_args),
|
||||
dropout_rate,
|
||||
normalize_before,
|
||||
concat_after,
|
||||
),
|
||||
)
|
||||
|
||||
self.encoders = repeat(
|
||||
num_blocks-1,
|
||||
lambda lnum: EncoderLayerSANM(
|
||||
output_size,
|
||||
output_size,
|
||||
self.encoder_selfattn_layer(*encoder_selfattn_layer_args),
|
||||
positionwise_layer(*positionwise_layer_args),
|
||||
dropout_rate,
|
||||
normalize_before,
|
||||
concat_after,
|
||||
),
|
||||
)
|
||||
if self.normalize_before:
|
||||
self.after_norm = LayerNorm(output_size)
|
||||
|
||||
self.interctc_layer_idx = interctc_layer_idx
|
||||
if len(interctc_layer_idx) > 0:
|
||||
assert 0 < min(interctc_layer_idx) and max(interctc_layer_idx) < num_blocks
|
||||
self.interctc_use_conditioning = interctc_use_conditioning
|
||||
self.conditioning_layer = None
|
||||
self.dropout = nn.Dropout(dropout_rate)
|
||||
|
||||
def output_size(self) -> int:
|
||||
return self._output_size
|
||||
|
||||
def forward(
|
||||
self,
|
||||
xs_pad: torch.Tensor,
|
||||
ilens: torch.Tensor,
|
||||
vad_indexes: torch.Tensor,
|
||||
prev_states: torch.Tensor = None,
|
||||
ctc: CTC = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
||||
"""Embed positions in tensor.
|
||||
|
||||
Args:
|
||||
xs_pad: input tensor (B, L, D)
|
||||
ilens: input length (B)
|
||||
prev_states: Not to be used now.
|
||||
Returns:
|
||||
position embedded tensor and mask
|
||||
"""
|
||||
masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
|
||||
sub_masks = subsequent_mask(masks.size(-1), device=xs_pad.device).unsqueeze(0)
|
||||
no_future_masks = masks & sub_masks
|
||||
xs_pad *= self.output_size()**0.5
|
||||
if self.embed is None:
|
||||
xs_pad = xs_pad
|
||||
elif (isinstance(self.embed, Conv2dSubsampling) or isinstance(self.embed, Conv2dSubsampling2)
|
||||
or isinstance(self.embed, Conv2dSubsampling6) or isinstance(self.embed, Conv2dSubsampling8)):
|
||||
short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
|
||||
if short_status:
|
||||
raise TooShortUttError(
|
||||
f"has {xs_pad.size(1)} frames and is too short for subsampling " +
|
||||
f"(it needs more than {limit_size} frames), return empty results",
|
||||
xs_pad.size(1),
|
||||
limit_size,
|
||||
)
|
||||
xs_pad, masks = self.embed(xs_pad, masks)
|
||||
else:
|
||||
xs_pad = self.embed(xs_pad)
|
||||
|
||||
# xs_pad = self.dropout(xs_pad)
|
||||
mask_tup0 = [masks, no_future_masks]
|
||||
encoder_outs = self.encoders0(xs_pad, mask_tup0)
|
||||
xs_pad, _ = encoder_outs[0], encoder_outs[1]
|
||||
intermediate_outs = []
|
||||
#if len(self.interctc_layer_idx) == 0:
|
||||
if False:
|
||||
# Here, we should not use the repeat operation to do it for all layers.
|
||||
encoder_outs = self.encoders(xs_pad, masks)
|
||||
xs_pad, masks = encoder_outs[0], encoder_outs[1]
|
||||
else:
|
||||
for layer_idx, encoder_layer in enumerate(self.encoders):
|
||||
if layer_idx + 1 == len(self.encoders):
|
||||
# This is last layer.
|
||||
coner_mask = torch.ones(masks.size(0),
|
||||
masks.size(-1),
|
||||
masks.size(-1),
|
||||
device=xs_pad.device,
|
||||
dtype=torch.bool)
|
||||
for word_index, length in enumerate(ilens):
|
||||
coner_mask[word_index, :, :] = vad_mask(masks.size(-1),
|
||||
vad_indexes[word_index],
|
||||
device=xs_pad.device)
|
||||
layer_mask = masks & coner_mask
|
||||
else:
|
||||
layer_mask = no_future_masks
|
||||
mask_tup1 = [masks, layer_mask]
|
||||
encoder_outs = encoder_layer(xs_pad, mask_tup1)
|
||||
xs_pad, layer_mask = encoder_outs[0], encoder_outs[1]
|
||||
|
||||
if layer_idx + 1 in self.interctc_layer_idx:
|
||||
encoder_out = xs_pad
|
||||
|
||||
# intermediate outputs are also normalized
|
||||
if self.normalize_before:
|
||||
encoder_out = self.after_norm(encoder_out)
|
||||
|
||||
intermediate_outs.append((layer_idx + 1, encoder_out))
|
||||
|
||||
if self.interctc_use_conditioning:
|
||||
ctc_out = ctc.softmax(encoder_out)
|
||||
xs_pad = xs_pad + self.conditioning_layer(ctc_out)
|
||||
|
||||
if self.normalize_before:
|
||||
xs_pad = self.after_norm(xs_pad)
|
||||
|
||||
olens = masks.squeeze(1).sum(1)
|
||||
if len(intermediate_outs) > 0:
|
||||
return (xs_pad, intermediate_outs), olens, None
|
||||
return xs_pad, olens, None
|
||||
|
||||
@ -1,12 +1 @@
|
||||
def split_to_mini_sentence(words: list, word_limit: int = 20):
|
||||
assert word_limit > 1
|
||||
if len(words) <= word_limit:
|
||||
return [words]
|
||||
sentences = []
|
||||
length = len(words)
|
||||
sentence_len = length // word_limit
|
||||
for i in range(sentence_len):
|
||||
sentences.append(words[i * word_limit:(i + 1) * word_limit])
|
||||
if length % word_limit > 0:
|
||||
sentences.append(words[sentence_len * word_limit:])
|
||||
return sentences
|
||||
|
||||
|
||||
@ -15,7 +15,7 @@ from typeguard import check_return_type
|
||||
from funasr.datasets.collate_fn import CommonCollateFn
|
||||
from funasr.datasets.preprocessor import CommonPreprocessor
|
||||
from funasr.lm.abs_model import AbsLM
|
||||
from funasr.lm.espnet_model import ESPnetLanguageModel
|
||||
from funasr.lm.espnet_model import LanguageModel
|
||||
from funasr.lm.seq_rnn_lm import SequentialRNNLM
|
||||
from funasr.lm.transformer_lm import TransformerLM
|
||||
from funasr.tasks.abs_task import AbsTask
|
||||
@ -83,7 +83,7 @@ class LMTask(AbsTask):
|
||||
group.add_argument(
|
||||
"--model_conf",
|
||||
action=NestedDictAction,
|
||||
default=get_default_kwargs(ESPnetLanguageModel),
|
||||
default=get_default_kwargs(LanguageModel),
|
||||
help="The keyword arguments for model class.",
|
||||
)
|
||||
|
||||
@ -178,7 +178,7 @@ class LMTask(AbsTask):
|
||||
return retval
|
||||
|
||||
@classmethod
|
||||
def build_model(cls, args: argparse.Namespace) -> ESPnetLanguageModel:
|
||||
def build_model(cls, args: argparse.Namespace) -> LanguageModel:
|
||||
assert check_argument_types()
|
||||
if isinstance(args.token_list, str):
|
||||
with open(args.token_list, encoding="utf-8") as f:
|
||||
@ -201,7 +201,7 @@ class LMTask(AbsTask):
|
||||
|
||||
# 2. Build ESPnetModel
|
||||
# Assume the last-id is sos_and_eos
|
||||
model = ESPnetLanguageModel(lm=lm, vocab_size=vocab_size, **args.model_conf)
|
||||
model = LanguageModel(lm=lm, vocab_size=vocab_size, **args.model_conf)
|
||||
|
||||
# 3. Initialize
|
||||
if args.init is not None:
|
||||
|
||||
@ -14,10 +14,10 @@ from typeguard import check_return_type
|
||||
|
||||
from funasr.datasets.collate_fn import CommonCollateFn
|
||||
from funasr.datasets.preprocessor import PuncTrainTokenizerCommonPreprocessor
|
||||
from funasr.punctuation.abs_model import AbsPunctuation
|
||||
from funasr.punctuation.espnet_model import ESPnetPunctuationModel
|
||||
from funasr.punctuation.target_delay_transformer import TargetDelayTransformer
|
||||
from funasr.punctuation.vad_realtime_transformer import VadRealtimeTransformer
|
||||
from funasr.train.abs_model import AbsPunctuation
|
||||
from funasr.train.abs_model import PunctuationModel
|
||||
from funasr.models.target_delay_transformer import TargetDelayTransformer
|
||||
from funasr.models.vad_realtime_transformer import VadRealtimeTransformer
|
||||
from funasr.tasks.abs_task import AbsTask
|
||||
from funasr.text.phoneme_tokenizer import g2p_choices
|
||||
from funasr.torch_utils.initialize import initialize
|
||||
@ -79,7 +79,7 @@ class PunctuationTask(AbsTask):
|
||||
group.add_argument(
|
||||
"--model_conf",
|
||||
action=NestedDictAction,
|
||||
default=get_default_kwargs(ESPnetPunctuationModel),
|
||||
default=get_default_kwargs(PunctuationModel),
|
||||
help="The keyword arguments for model class.",
|
||||
)
|
||||
|
||||
@ -183,7 +183,7 @@ class PunctuationTask(AbsTask):
|
||||
return retval
|
||||
|
||||
@classmethod
|
||||
def build_model(cls, args: argparse.Namespace) -> ESPnetPunctuationModel:
|
||||
def build_model(cls, args: argparse.Namespace) -> PunctuationModel:
|
||||
assert check_argument_types()
|
||||
if isinstance(args.token_list, str):
|
||||
with open(args.token_list, encoding="utf-8") as f:
|
||||
@ -218,7 +218,7 @@ class PunctuationTask(AbsTask):
|
||||
# Assume the last-id is sos_and_eos
|
||||
if "punc_weight" in args.model_conf:
|
||||
args.model_conf.pop("punc_weight")
|
||||
model = ESPnetPunctuationModel(punc_model=punc, vocab_size=vocab_size, punc_weight=punc_weight_list, **args.model_conf)
|
||||
model = PunctuationModel(punc_model=punc, vocab_size=vocab_size, punc_weight=punc_weight_list, **args.model_conf)
|
||||
|
||||
# FIXME(kamo): Should be done in model?
|
||||
# 3. Initialize
|
||||
|
||||
@ -1,3 +1,9 @@
|
||||
from abc import ABC
|
||||
from abc import abstractmethod
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from typing import Dict
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
@ -7,13 +13,34 @@ import torch.nn.functional as F
|
||||
from typeguard import check_argument_types
|
||||
|
||||
from funasr.modules.nets_utils import make_pad_mask
|
||||
from funasr.punctuation.abs_model import AbsPunctuation
|
||||
from funasr.torch_utils.device_funcs import force_gatherable
|
||||
from funasr.train.abs_espnet_model import AbsESPnetModel
|
||||
|
||||
from funasr.modules.scorers.scorer_interface import BatchScorerInterface
|
||||
|
||||
class ESPnetPunctuationModel(AbsESPnetModel):
|
||||
|
||||
class AbsPunctuation(torch.nn.Module, BatchScorerInterface, ABC):
|
||||
"""The abstract class
|
||||
|
||||
To share the loss calculation way among different models,
|
||||
We uses delegate pattern here:
|
||||
The instance of this class should be passed to "LanguageModel"
|
||||
|
||||
This "model" is one of mediator objects for "Task" class.
|
||||
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def forward(self, input: torch.Tensor, hidden: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def with_vad(self) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class PunctuationModel(AbsESPnetModel):
|
||||
|
||||
def __init__(self, punc_model: AbsPunctuation, vocab_size: int, ignore_id: int = 0, punc_weight: list = None):
|
||||
assert check_argument_types()
|
||||
super().__init__()
|
||||
@ -21,12 +48,12 @@ class ESPnetPunctuationModel(AbsESPnetModel):
|
||||
self.punc_weight = torch.Tensor(punc_weight)
|
||||
self.sos = 1
|
||||
self.eos = 2
|
||||
|
||||
|
||||
# ignore_id may be assumed as 0, shared with CTC-blank symbol for ASR.
|
||||
self.ignore_id = ignore_id
|
||||
#if self.punc_model.with_vad():
|
||||
# if self.punc_model.with_vad():
|
||||
# print("This is a vad puncuation model.")
|
||||
|
||||
|
||||
def nll(
|
||||
self,
|
||||
text: torch.Tensor,
|
||||
@ -54,7 +81,7 @@ class ESPnetPunctuationModel(AbsESPnetModel):
|
||||
else:
|
||||
text = text[:, :max_length]
|
||||
punc = punc[:, :max_length]
|
||||
|
||||
|
||||
if self.punc_model.with_vad():
|
||||
# Should be VadRealtimeTransformer
|
||||
assert vad_indexes is not None
|
||||
@ -62,7 +89,7 @@ class ESPnetPunctuationModel(AbsESPnetModel):
|
||||
else:
|
||||
# Should be TargetDelayTransformer,
|
||||
y, _ = self.punc_model(text, text_lengths)
|
||||
|
||||
|
||||
# Calc negative log likelihood
|
||||
# nll: (BxL,)
|
||||
if self.training == False:
|
||||
@ -75,7 +102,8 @@ class ESPnetPunctuationModel(AbsESPnetModel):
|
||||
return nll, text_lengths
|
||||
else:
|
||||
self.punc_weight = self.punc_weight.to(punc.device)
|
||||
nll = F.cross_entropy(y.view(-1, y.shape[-1]), punc.view(-1), self.punc_weight, reduction="none", ignore_index=self.ignore_id)
|
||||
nll = F.cross_entropy(y.view(-1, y.shape[-1]), punc.view(-1), self.punc_weight, reduction="none",
|
||||
ignore_index=self.ignore_id)
|
||||
# nll: (BxL,) -> (BxL,)
|
||||
if max_length is None:
|
||||
nll.masked_fill_(make_pad_mask(text_lengths).to(nll.device).view(-1), 0.0)
|
||||
@ -87,7 +115,7 @@ class ESPnetPunctuationModel(AbsESPnetModel):
|
||||
# nll: (BxL,) -> (B, L)
|
||||
nll = nll.view(batch_size, -1)
|
||||
return nll, text_lengths
|
||||
|
||||
|
||||
def batchify_nll(self,
|
||||
text: torch.Tensor,
|
||||
punc: torch.Tensor,
|
||||
@ -113,7 +141,7 @@ class ESPnetPunctuationModel(AbsESPnetModel):
|
||||
nlls = []
|
||||
x_lengths = []
|
||||
max_length = text_lengths.max()
|
||||
|
||||
|
||||
start_idx = 0
|
||||
while True:
|
||||
end_idx = min(start_idx + batch_size, total_num)
|
||||
@ -132,7 +160,7 @@ class ESPnetPunctuationModel(AbsESPnetModel):
|
||||
assert nll.size(0) == total_num
|
||||
assert x_lengths.size(0) == total_num
|
||||
return nll, x_lengths
|
||||
|
||||
|
||||
def forward(
|
||||
self,
|
||||
text: torch.Tensor,
|
||||
@ -146,15 +174,15 @@ class ESPnetPunctuationModel(AbsESPnetModel):
|
||||
ntokens = y_lengths.sum()
|
||||
loss = nll.sum() / ntokens
|
||||
stats = dict(loss=loss.detach())
|
||||
|
||||
|
||||
# force_gatherable: to-device and to-tensor if scalar for DataParallel
|
||||
loss, stats, weight = force_gatherable((loss, stats, ntokens), loss.device)
|
||||
return loss, stats, weight
|
||||
|
||||
|
||||
def collect_feats(self, text: torch.Tensor, punc: torch.Tensor,
|
||||
text_lengths: torch.Tensor) -> Dict[str, torch.Tensor]:
|
||||
return {}
|
||||
|
||||
|
||||
def inference(self,
|
||||
text: torch.Tensor,
|
||||
text_lengths: torch.Tensor,
|
||||
Loading…
Reference in New Issue
Block a user