This commit is contained in:
游雁 2023-03-31 15:05:37 +08:00
parent 3cd71a385a
commit d0cd484fdc
16 changed files with 309 additions and 762 deletions

View File

@ -23,7 +23,7 @@ from funasr.torch_utils.set_all_random_seed import set_all_random_seed
from funasr.utils import config_argparse from funasr.utils import config_argparse
from funasr.utils.types import str2triple_str from funasr.utils.types import str2triple_str
from funasr.utils.types import str_or_none from funasr.utils.types import str_or_none
from funasr.punctuation.text_preprocessor import split_to_mini_sentence from funasr.datasets.preprocessor import split_to_mini_sentence
class Text2Punc: class Text2Punc:

View File

@ -23,7 +23,7 @@ from funasr.torch_utils.set_all_random_seed import set_all_random_seed
from funasr.utils import config_argparse from funasr.utils import config_argparse
from funasr.utils.types import str2triple_str from funasr.utils.types import str2triple_str
from funasr.utils.types import str_or_none from funasr.utils.types import str_or_none
from funasr.punctuation.text_preprocessor import split_to_mini_sentence from funasr.datasets.preprocessor import split_to_mini_sentence
class Text2Punc: class Text2Punc:

View File

@ -800,3 +800,17 @@ class PuncTrainTokenizerCommonPreprocessor(CommonPreprocessor):
data[self.vad_name] = np.array([vad], dtype=np.int64) data[self.vad_name] = np.array([vad], dtype=np.int64)
text_ints = self.token_id_converter[i].tokens2ids(tokens) text_ints = self.token_id_converter[i].tokens2ids(tokens)
data[text_name] = np.array(text_ints, dtype=np.int64) data[text_name] = np.array(text_ints, dtype=np.int64)
def split_to_mini_sentence(words: list, word_limit: int = 20):
assert word_limit > 1
if len(words) <= word_limit:
return [words]
sentences = []
length = len(words)
sentence_len = length // word_limit
for i in range(sentence_len):
sentences.append(words[i * word_limit:(i + 1) * word_limit])
if length % word_limit > 0:
sentences.append(words[sentence_len * word_limit:])
return sentences

View File

@ -3,10 +3,10 @@ from funasr.export.models.e2e_asr_paraformer import Paraformer as Paraformer_exp
from funasr.export.models.e2e_asr_paraformer import BiCifParaformer as BiCifParaformer_export from funasr.export.models.e2e_asr_paraformer import BiCifParaformer as BiCifParaformer_export
from funasr.models.e2e_vad import E2EVadModel from funasr.models.e2e_vad import E2EVadModel
from funasr.export.models.e2e_vad import E2EVadModel as E2EVadModel_export from funasr.export.models.e2e_vad import E2EVadModel as E2EVadModel_export
from funasr.punctuation.target_delay_transformer import TargetDelayTransformer from funasr.models.target_delay_transformer import TargetDelayTransformer
from funasr.export.models.target_delay_transformer import TargetDelayTransformer as TargetDelayTransformer_export from funasr.export.models.target_delay_transformer import TargetDelayTransformer as TargetDelayTransformer_export
from funasr.punctuation.espnet_model import ESPnetPunctuationModel from funasr.train.abs_model import PunctuationModel
from funasr.punctuation.vad_realtime_transformer import VadRealtimeTransformer from funasr.models.vad_realtime_transformer import VadRealtimeTransformer
from funasr.export.models.vad_realtime_transformer import VadRealtimeTransformer as VadRealtimeTransformer_export from funasr.export.models.vad_realtime_transformer import VadRealtimeTransformer as VadRealtimeTransformer_export
def get_model(model, export_config=None): def get_model(model, export_config=None):
@ -16,7 +16,7 @@ def get_model(model, export_config=None):
return Paraformer_export(model, **export_config) return Paraformer_export(model, **export_config)
elif isinstance(model, E2EVadModel): elif isinstance(model, E2EVadModel):
return E2EVadModel_export(model, **export_config) return E2EVadModel_export(model, **export_config)
elif isinstance(model, ESPnetPunctuationModel): elif isinstance(model, PunctuationModel):
if isinstance(model.punc_model, TargetDelayTransformer): if isinstance(model.punc_model, TargetDelayTransformer):
return TargetDelayTransformer_export(model.punc_model, **export_config) return TargetDelayTransformer_export(model.punc_model, **export_config)
elif isinstance(model.punc_model, VadRealtimeTransformer): elif isinstance(model.punc_model, VadRealtimeTransformer):

View File

@ -1,18 +1,8 @@
from typing import Any
from typing import List
from typing import Tuple from typing import Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
from funasr.export.utils.torch_function import MakePadMask
from funasr.export.utils.torch_function import sequence_mask
#from funasr.models.encoder.sanm_encoder import SANMEncoder as Encoder
from funasr.punctuation.sanm_encoder import SANMEncoder
from funasr.export.models.encoder.sanm_encoder import SANMEncoder as SANMEncoder_export
from funasr.punctuation.abs_model import AbsPunctuation
class TargetDelayTransformer(nn.Module): class TargetDelayTransformer(nn.Module):
def __init__( def __init__(
@ -32,85 +22,10 @@ class TargetDelayTransformer(nn.Module):
self.feats_dim = self.embed.embedding_dim self.feats_dim = self.embed.embedding_dim
self.num_embeddings = self.embed.num_embeddings self.num_embeddings = self.embed.num_embeddings
self.model_name = model_name self.model_name = model_name
from typing import Any
from typing import List
from typing import Tuple
import torch
import torch.nn as nn
from funasr.export.utils.torch_function import MakePadMask
from funasr.export.utils.torch_function import sequence_mask
# from funasr.models.encoder.sanm_encoder import SANMEncoder as Encoder # from funasr.models.encoder.sanm_encoder import SANMEncoder as Encoder
from funasr.punctuation.sanm_encoder import SANMEncoder from funasr.models.encoder.sanm_encoder import SANMEncoder
from funasr.export.models.encoder.sanm_encoder import SANMEncoder as SANMEncoder_export from funasr.export.models.encoder.sanm_encoder import SANMEncoder as SANMEncoder_export
from funasr.punctuation.abs_model import AbsPunctuation
# class TargetDelayTransformer(nn.Module):
#
# def __init__(
# self,
# model,
# max_seq_len=512,
# model_name='punc_model',
# **kwargs,
# ):
# super().__init__()
# onnx = False
# if "onnx" in kwargs:
# onnx = kwargs["onnx"]
# self.embed = model.embed
# self.decoder = model.decoder
# self.model = model
# self.feats_dim = self.embed.embedding_dim
# self.num_embeddings = self.embed.num_embeddings
# self.model_name = model_name
#
# if isinstance(model.encoder, SANMEncoder):
# self.encoder = SANMEncoder_export(model.encoder, onnx=onnx)
# else:
# assert False, "Only support samn encode."
#
# def forward(self, input: torch.Tensor, text_lengths: torch.Tensor) -> Tuple[torch.Tensor, None]:
# """Compute loss value from buffer sequences.
#
# Args:
# input (torch.Tensor): Input ids. (batch, len)
# hidden (torch.Tensor): Target ids. (batch, len)
#
# """
# x = self.embed(input)
# # mask = self._target_mask(input)
# h, _ = self.encoder(x, text_lengths)
# y = self.decoder(h)
# return y
#
# def get_dummy_inputs(self):
# length = 120
# text_indexes = torch.randint(0, self.embed.num_embeddings, (2, length))
# text_lengths = torch.tensor([length - 20, length], dtype=torch.int32)
# return (text_indexes, text_lengths)
#
# def get_input_names(self):
# return ['input', 'text_lengths']
#
# def get_output_names(self):
# return ['logits']
#
# def get_dynamic_axes(self):
# return {
# 'input': {
# 0: 'batch_size',
# 1: 'feats_length'
# },
# 'text_lengths': {
# 0: 'batch_size',
# },
# 'logits': {
# 0: 'batch_size',
# 1: 'logits_length'
# },
# }
if isinstance(model.encoder, SANMEncoder): if isinstance(model.encoder, SANMEncoder):
self.encoder = SANMEncoder_export(model.encoder, onnx=onnx) self.encoder = SANMEncoder_export(model.encoder, onnx=onnx)

View File

@ -1,14 +1,9 @@
from typing import Any
from typing import List
from typing import Tuple from typing import Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
from funasr.modules.embedding import SinusoidalPositionEncoder from funasr.models.encoder.sanm_encoder import SANMVadEncoder
from funasr.punctuation.sanm_encoder import SANMVadEncoder as Encoder
from funasr.punctuation.abs_model import AbsPunctuation
from funasr.punctuation.sanm_encoder import SANMVadEncoder
from funasr.export.models.encoder.sanm_encoder import SANMVadEncoder as SANMVadEncoder_export from funasr.export.models.encoder.sanm_encoder import SANMVadEncoder as SANMVadEncoder_export
class VadRealtimeTransformer(nn.Module): class VadRealtimeTransformer(nn.Module):

View File

@ -12,7 +12,7 @@ from funasr.torch_utils.device_funcs import force_gatherable
from funasr.train.abs_espnet_model import AbsESPnetModel from funasr.train.abs_espnet_model import AbsESPnetModel
class ESPnetLanguageModel(AbsESPnetModel): class LanguageModel(AbsESPnetModel):
def __init__(self, lm: AbsLM, vocab_size: int, ignore_id: int = 0): def __init__(self, lm: AbsLM, vocab_size: int, ignore_id: int = 0):
assert check_argument_types() assert check_argument_types()
super().__init__() super().__init__()

View File

@ -10,7 +10,7 @@ from funasr.modules.streaming_utils.chunk_utilis import overlap_chunk
from typeguard import check_argument_types from typeguard import check_argument_types
import numpy as np import numpy as np
from funasr.modules.nets_utils import make_pad_mask from funasr.modules.nets_utils import make_pad_mask
from funasr.modules.attention import MultiHeadedAttention, MultiHeadedAttentionSANM from funasr.modules.attention import MultiHeadedAttention, MultiHeadedAttentionSANM, MultiHeadedAttentionSANMwithMask
from funasr.modules.embedding import SinusoidalPositionEncoder from funasr.modules.embedding import SinusoidalPositionEncoder
from funasr.modules.layer_norm import LayerNorm from funasr.modules.layer_norm import LayerNorm
from funasr.modules.multi_layer_conv import Conv1dLinear from funasr.modules.multi_layer_conv import Conv1dLinear
@ -27,7 +27,7 @@ from funasr.modules.subsampling import TooShortUttError
from funasr.modules.subsampling import check_short_utt from funasr.modules.subsampling import check_short_utt
from funasr.models.ctc import CTC from funasr.models.ctc import CTC
from funasr.models.encoder.abs_encoder import AbsEncoder from funasr.models.encoder.abs_encoder import AbsEncoder
from funasr.modules.mask import subsequent_mask, vad_mask
class EncoderLayerSANM(nn.Module): class EncoderLayerSANM(nn.Module):
def __init__( def __init__(
@ -958,3 +958,231 @@ class SANMEncoderChunkOpt(AbsEncoder):
var_dict_tf[name_tf].shape)) var_dict_tf[name_tf].shape))
return var_dict_torch_update return var_dict_torch_update
class SANMVadEncoder(AbsEncoder):
"""
author: Speech Lab, Alibaba Group, China
"""
def __init__(
self,
input_size: int,
output_size: int = 256,
attention_heads: int = 4,
linear_units: int = 2048,
num_blocks: int = 6,
dropout_rate: float = 0.1,
positional_dropout_rate: float = 0.1,
attention_dropout_rate: float = 0.0,
input_layer: Optional[str] = "conv2d",
pos_enc_class=SinusoidalPositionEncoder,
normalize_before: bool = True,
concat_after: bool = False,
positionwise_layer_type: str = "linear",
positionwise_conv_kernel_size: int = 1,
padding_idx: int = -1,
interctc_layer_idx: List[int] = [],
interctc_use_conditioning: bool = False,
kernel_size : int = 11,
sanm_shfit : int = 0,
selfattention_layer_type: str = "sanm",
):
assert check_argument_types()
super().__init__()
self._output_size = output_size
if input_layer == "linear":
self.embed = torch.nn.Sequential(
torch.nn.Linear(input_size, output_size),
torch.nn.LayerNorm(output_size),
torch.nn.Dropout(dropout_rate),
torch.nn.ReLU(),
pos_enc_class(output_size, positional_dropout_rate),
)
elif input_layer == "conv2d":
self.embed = Conv2dSubsampling(input_size, output_size, dropout_rate)
elif input_layer == "conv2d2":
self.embed = Conv2dSubsampling2(input_size, output_size, dropout_rate)
elif input_layer == "conv2d6":
self.embed = Conv2dSubsampling6(input_size, output_size, dropout_rate)
elif input_layer == "conv2d8":
self.embed = Conv2dSubsampling8(input_size, output_size, dropout_rate)
elif input_layer == "embed":
self.embed = torch.nn.Sequential(
torch.nn.Embedding(input_size, output_size, padding_idx=padding_idx),
SinusoidalPositionEncoder(),
)
elif input_layer is None:
if input_size == output_size:
self.embed = None
else:
self.embed = torch.nn.Linear(input_size, output_size)
elif input_layer == "pe":
self.embed = SinusoidalPositionEncoder()
else:
raise ValueError("unknown input_layer: " + input_layer)
self.normalize_before = normalize_before
if positionwise_layer_type == "linear":
positionwise_layer = PositionwiseFeedForward
positionwise_layer_args = (
output_size,
linear_units,
dropout_rate,
)
elif positionwise_layer_type == "conv1d":
positionwise_layer = MultiLayeredConv1d
positionwise_layer_args = (
output_size,
linear_units,
positionwise_conv_kernel_size,
dropout_rate,
)
elif positionwise_layer_type == "conv1d-linear":
positionwise_layer = Conv1dLinear
positionwise_layer_args = (
output_size,
linear_units,
positionwise_conv_kernel_size,
dropout_rate,
)
else:
raise NotImplementedError("Support only linear or conv1d.")
if selfattention_layer_type == "selfattn":
encoder_selfattn_layer = MultiHeadedAttention
encoder_selfattn_layer_args = (
attention_heads,
output_size,
attention_dropout_rate,
)
elif selfattention_layer_type == "sanm":
self.encoder_selfattn_layer = MultiHeadedAttentionSANMwithMask
encoder_selfattn_layer_args0 = (
attention_heads,
input_size,
output_size,
attention_dropout_rate,
kernel_size,
sanm_shfit,
)
encoder_selfattn_layer_args = (
attention_heads,
output_size,
output_size,
attention_dropout_rate,
kernel_size,
sanm_shfit,
)
self.encoders0 = repeat(
1,
lambda lnum: EncoderLayerSANM(
input_size,
output_size,
self.encoder_selfattn_layer(*encoder_selfattn_layer_args0),
positionwise_layer(*positionwise_layer_args),
dropout_rate,
normalize_before,
concat_after,
),
)
self.encoders = repeat(
num_blocks-1,
lambda lnum: EncoderLayerSANM(
output_size,
output_size,
self.encoder_selfattn_layer(*encoder_selfattn_layer_args),
positionwise_layer(*positionwise_layer_args),
dropout_rate,
normalize_before,
concat_after,
),
)
if self.normalize_before:
self.after_norm = LayerNorm(output_size)
self.interctc_layer_idx = interctc_layer_idx
if len(interctc_layer_idx) > 0:
assert 0 < min(interctc_layer_idx) and max(interctc_layer_idx) < num_blocks
self.interctc_use_conditioning = interctc_use_conditioning
self.conditioning_layer = None
self.dropout = nn.Dropout(dropout_rate)
def output_size(self) -> int:
return self._output_size
def forward(
self,
xs_pad: torch.Tensor,
ilens: torch.Tensor,
vad_indexes: torch.Tensor,
prev_states: torch.Tensor = None,
ctc: CTC = None,
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
"""Embed positions in tensor.
Args:
xs_pad: input tensor (B, L, D)
ilens: input length (B)
prev_states: Not to be used now.
Returns:
position embedded tensor and mask
"""
masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
sub_masks = subsequent_mask(masks.size(-1), device=xs_pad.device).unsqueeze(0)
no_future_masks = masks & sub_masks
xs_pad *= self.output_size()**0.5
if self.embed is None:
xs_pad = xs_pad
elif (isinstance(self.embed, Conv2dSubsampling) or isinstance(self.embed, Conv2dSubsampling2)
or isinstance(self.embed, Conv2dSubsampling6) or isinstance(self.embed, Conv2dSubsampling8)):
short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
if short_status:
raise TooShortUttError(
f"has {xs_pad.size(1)} frames and is too short for subsampling " +
f"(it needs more than {limit_size} frames), return empty results",
xs_pad.size(1),
limit_size,
)
xs_pad, masks = self.embed(xs_pad, masks)
else:
xs_pad = self.embed(xs_pad)
# xs_pad = self.dropout(xs_pad)
mask_tup0 = [masks, no_future_masks]
encoder_outs = self.encoders0(xs_pad, mask_tup0)
xs_pad, _ = encoder_outs[0], encoder_outs[1]
intermediate_outs = []
for layer_idx, encoder_layer in enumerate(self.encoders):
if layer_idx + 1 == len(self.encoders):
# This is last layer.
coner_mask = torch.ones(masks.size(0),
masks.size(-1),
masks.size(-1),
device=xs_pad.device,
dtype=torch.bool)
for word_index, length in enumerate(ilens):
coner_mask[word_index, :, :] = vad_mask(masks.size(-1),
vad_indexes[word_index],
device=xs_pad.device)
layer_mask = masks & coner_mask
else:
layer_mask = no_future_masks
mask_tup1 = [masks, layer_mask]
encoder_outs = encoder_layer(xs_pad, mask_tup1)
xs_pad, layer_mask = encoder_outs[0], encoder_outs[1]
if self.normalize_before:
xs_pad = self.after_norm(xs_pad)
olens = masks.squeeze(1).sum(1)
if len(intermediate_outs) > 0:
return (xs_pad, intermediate_outs), olens, None
return xs_pad, olens, None

View File

@ -5,12 +5,11 @@ from typing import Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
from funasr.modules.embedding import PositionalEncoding
from funasr.modules.embedding import SinusoidalPositionEncoder from funasr.modules.embedding import SinusoidalPositionEncoder
#from funasr.models.encoder.transformer_encoder import TransformerEncoder as Encoder #from funasr.models.encoder.transformer_encoder import TransformerEncoder as Encoder
from funasr.punctuation.sanm_encoder import SANMEncoder as Encoder from funasr.punctuation.sanm_encoder import SANMEncoder as Encoder
#from funasr.modules.mask import subsequent_n_mask #from funasr.modules.mask import subsequent_n_mask
from funasr.punctuation.abs_model import AbsPunctuation from funasr.train.abs_model import AbsPunctuation
class TargetDelayTransformer(AbsPunctuation): class TargetDelayTransformer(AbsPunctuation):

View File

@ -7,7 +7,7 @@ import torch.nn as nn
from funasr.modules.embedding import SinusoidalPositionEncoder from funasr.modules.embedding import SinusoidalPositionEncoder
from funasr.punctuation.sanm_encoder import SANMVadEncoder as Encoder from funasr.punctuation.sanm_encoder import SANMVadEncoder as Encoder
from funasr.punctuation.abs_model import AbsPunctuation from funasr.train.abs_model import AbsPunctuation
class VadRealtimeTransformer(AbsPunctuation): class VadRealtimeTransformer(AbsPunctuation):

View File

@ -1,31 +0,0 @@
from abc import ABC
from abc import abstractmethod
from typing import Tuple
import torch
from funasr.modules.scorers.scorer_interface import BatchScorerInterface
class AbsPunctuation(torch.nn.Module, BatchScorerInterface, ABC):
"""The abstract class
To share the loss calculation way among different models,
We uses delegate pattern here:
The instance of this class should be passed to "LanguageModel"
>>> from funasr.punctuation.abs_model import AbsPunctuation
>>> punc = AbsPunctuation()
>>> model = ESPnetPunctuationModel(punc=punc)
This "model" is one of mediator objects for "Task" class.
"""
@abstractmethod
def forward(self, input: torch.Tensor, hidden: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
raise NotImplementedError
@abstractmethod
def with_vad(self) -> bool:
raise NotImplementedError

View File

@ -1,590 +0,0 @@
from typing import List
from typing import Optional
from typing import Sequence
from typing import Tuple
from typing import Union
import logging
import torch
import torch.nn as nn
from funasr.modules.streaming_utils.chunk_utilis import overlap_chunk
from typeguard import check_argument_types
import numpy as np
from funasr.modules.nets_utils import make_pad_mask
from funasr.modules.attention import MultiHeadedAttention, MultiHeadedAttentionSANM, MultiHeadedAttentionSANMwithMask
from funasr.modules.embedding import SinusoidalPositionEncoder
from funasr.modules.layer_norm import LayerNorm
from funasr.modules.multi_layer_conv import Conv1dLinear
from funasr.modules.multi_layer_conv import MultiLayeredConv1d
from funasr.modules.positionwise_feed_forward import (
PositionwiseFeedForward, # noqa: H301
)
from funasr.modules.repeat import repeat
from funasr.modules.subsampling import Conv2dSubsampling
from funasr.modules.subsampling import Conv2dSubsampling2
from funasr.modules.subsampling import Conv2dSubsampling6
from funasr.modules.subsampling import Conv2dSubsampling8
from funasr.modules.subsampling import TooShortUttError
from funasr.modules.subsampling import check_short_utt
from funasr.models.ctc import CTC
from funasr.models.encoder.abs_encoder import AbsEncoder
from funasr.modules.nets_utils import make_pad_mask
from funasr.modules.mask import subsequent_mask, vad_mask
class EncoderLayerSANM(nn.Module):
def __init__(
self,
in_size,
size,
self_attn,
feed_forward,
dropout_rate,
normalize_before=True,
concat_after=False,
stochastic_depth_rate=0.0,
):
"""Construct an EncoderLayer object."""
super(EncoderLayerSANM, self).__init__()
self.self_attn = self_attn
self.feed_forward = feed_forward
self.norm1 = LayerNorm(in_size)
self.norm2 = LayerNorm(size)
self.dropout = nn.Dropout(dropout_rate)
self.in_size = in_size
self.size = size
self.normalize_before = normalize_before
self.concat_after = concat_after
if self.concat_after:
self.concat_linear = nn.Linear(size + size, size)
self.stochastic_depth_rate = stochastic_depth_rate
self.dropout_rate = dropout_rate
def forward(self, x, mask, cache=None, mask_shfit_chunk=None, mask_att_chunk_encoder=None):
"""Compute encoded features.
Args:
x_input (torch.Tensor): Input tensor (#batch, time, size).
mask (torch.Tensor): Mask tensor for the input (#batch, time).
cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size).
Returns:
torch.Tensor: Output tensor (#batch, time, size).
torch.Tensor: Mask tensor (#batch, time).
"""
skip_layer = False
# with stochastic depth, residual connection `x + f(x)` becomes
# `x <- x + 1 / (1 - p) * f(x)` at training time.
stoch_layer_coeff = 1.0
if self.training and self.stochastic_depth_rate > 0:
skip_layer = torch.rand(1).item() < self.stochastic_depth_rate
stoch_layer_coeff = 1.0 / (1 - self.stochastic_depth_rate)
if skip_layer:
if cache is not None:
x = torch.cat([cache, x], dim=1)
return x, mask
residual = x
if self.normalize_before:
x = self.norm1(x)
if self.concat_after:
x_concat = torch.cat((x, self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder)), dim=-1)
if self.in_size == self.size:
x = residual + stoch_layer_coeff * self.concat_linear(x_concat)
else:
x = stoch_layer_coeff * self.concat_linear(x_concat)
else:
if self.in_size == self.size:
x = residual + stoch_layer_coeff * self.dropout(
self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder)
)
else:
x = stoch_layer_coeff * self.dropout(
self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder)
)
if not self.normalize_before:
x = self.norm1(x)
residual = x
if self.normalize_before:
x = self.norm2(x)
x = residual + stoch_layer_coeff * self.dropout(self.feed_forward(x))
if not self.normalize_before:
x = self.norm2(x)
return x, mask, cache, mask_shfit_chunk, mask_att_chunk_encoder
class SANMEncoder(AbsEncoder):
"""
author: Speech Lab, Alibaba Group, China
"""
def __init__(
self,
input_size: int,
output_size: int = 256,
attention_heads: int = 4,
linear_units: int = 2048,
num_blocks: int = 6,
dropout_rate: float = 0.1,
positional_dropout_rate: float = 0.1,
attention_dropout_rate: float = 0.0,
input_layer: Optional[str] = "conv2d",
pos_enc_class=SinusoidalPositionEncoder,
normalize_before: bool = True,
concat_after: bool = False,
positionwise_layer_type: str = "linear",
positionwise_conv_kernel_size: int = 1,
padding_idx: int = -1,
interctc_layer_idx: List[int] = [],
interctc_use_conditioning: bool = False,
kernel_size : int = 11,
sanm_shfit : int = 0,
selfattention_layer_type: str = "sanm",
):
assert check_argument_types()
super().__init__()
self._output_size = output_size
if input_layer == "linear":
self.embed = torch.nn.Sequential(
torch.nn.Linear(input_size, output_size),
torch.nn.LayerNorm(output_size),
torch.nn.Dropout(dropout_rate),
torch.nn.ReLU(),
pos_enc_class(output_size, positional_dropout_rate),
)
elif input_layer == "conv2d":
self.embed = Conv2dSubsampling(input_size, output_size, dropout_rate)
elif input_layer == "conv2d2":
self.embed = Conv2dSubsampling2(input_size, output_size, dropout_rate)
elif input_layer == "conv2d6":
self.embed = Conv2dSubsampling6(input_size, output_size, dropout_rate)
elif input_layer == "conv2d8":
self.embed = Conv2dSubsampling8(input_size, output_size, dropout_rate)
elif input_layer == "embed":
self.embed = torch.nn.Sequential(
torch.nn.Embedding(input_size, output_size, padding_idx=padding_idx),
SinusoidalPositionEncoder(),
)
elif input_layer is None:
if input_size == output_size:
self.embed = None
else:
self.embed = torch.nn.Linear(input_size, output_size)
elif input_layer == "pe":
self.embed = SinusoidalPositionEncoder()
else:
raise ValueError("unknown input_layer: " + input_layer)
self.normalize_before = normalize_before
if positionwise_layer_type == "linear":
positionwise_layer = PositionwiseFeedForward
positionwise_layer_args = (
output_size,
linear_units,
dropout_rate,
)
elif positionwise_layer_type == "conv1d":
positionwise_layer = MultiLayeredConv1d
positionwise_layer_args = (
output_size,
linear_units,
positionwise_conv_kernel_size,
dropout_rate,
)
elif positionwise_layer_type == "conv1d-linear":
positionwise_layer = Conv1dLinear
positionwise_layer_args = (
output_size,
linear_units,
positionwise_conv_kernel_size,
dropout_rate,
)
else:
raise NotImplementedError("Support only linear or conv1d.")
if selfattention_layer_type == "selfattn":
encoder_selfattn_layer = MultiHeadedAttention
encoder_selfattn_layer_args = (
attention_heads,
output_size,
attention_dropout_rate,
)
elif selfattention_layer_type == "sanm":
self.encoder_selfattn_layer = MultiHeadedAttentionSANM
encoder_selfattn_layer_args0 = (
attention_heads,
input_size,
output_size,
attention_dropout_rate,
kernel_size,
sanm_shfit,
)
encoder_selfattn_layer_args = (
attention_heads,
output_size,
output_size,
attention_dropout_rate,
kernel_size,
sanm_shfit,
)
self.encoders0 = repeat(
1,
lambda lnum: EncoderLayerSANM(
input_size,
output_size,
self.encoder_selfattn_layer(*encoder_selfattn_layer_args0),
positionwise_layer(*positionwise_layer_args),
dropout_rate,
normalize_before,
concat_after,
),
)
self.encoders = repeat(
num_blocks-1,
lambda lnum: EncoderLayerSANM(
output_size,
output_size,
self.encoder_selfattn_layer(*encoder_selfattn_layer_args),
positionwise_layer(*positionwise_layer_args),
dropout_rate,
normalize_before,
concat_after,
),
)
if self.normalize_before:
self.after_norm = LayerNorm(output_size)
self.interctc_layer_idx = interctc_layer_idx
if len(interctc_layer_idx) > 0:
assert 0 < min(interctc_layer_idx) and max(interctc_layer_idx) < num_blocks
self.interctc_use_conditioning = interctc_use_conditioning
self.conditioning_layer = None
self.dropout = nn.Dropout(dropout_rate)
def output_size(self) -> int:
return self._output_size
def forward(
self,
xs_pad: torch.Tensor,
ilens: torch.Tensor,
prev_states: torch.Tensor = None,
ctc: CTC = None,
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
"""Embed positions in tensor.
Args:
xs_pad: input tensor (B, L, D)
ilens: input length (B)
prev_states: Not to be used now.
Returns:
position embedded tensor and mask
"""
masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
xs_pad *= self.output_size()**0.5
if self.embed is None:
xs_pad = xs_pad
elif (
isinstance(self.embed, Conv2dSubsampling)
or isinstance(self.embed, Conv2dSubsampling2)
or isinstance(self.embed, Conv2dSubsampling6)
or isinstance(self.embed, Conv2dSubsampling8)
):
short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
if short_status:
raise TooShortUttError(
f"has {xs_pad.size(1)} frames and is too short for subsampling "
+ f"(it needs more than {limit_size} frames), return empty results",
xs_pad.size(1),
limit_size,
)
xs_pad, masks = self.embed(xs_pad, masks)
else:
xs_pad = self.embed(xs_pad)
# xs_pad = self.dropout(xs_pad)
encoder_outs = self.encoders0(xs_pad, masks)
xs_pad, masks = encoder_outs[0], encoder_outs[1]
intermediate_outs = []
if len(self.interctc_layer_idx) == 0:
encoder_outs = self.encoders(xs_pad, masks)
xs_pad, masks = encoder_outs[0], encoder_outs[1]
else:
for layer_idx, encoder_layer in enumerate(self.encoders):
encoder_outs = encoder_layer(xs_pad, masks)
xs_pad, masks = encoder_outs[0], encoder_outs[1]
if layer_idx + 1 in self.interctc_layer_idx:
encoder_out = xs_pad
# intermediate outputs are also normalized
if self.normalize_before:
encoder_out = self.after_norm(encoder_out)
intermediate_outs.append((layer_idx + 1, encoder_out))
if self.interctc_use_conditioning:
ctc_out = ctc.softmax(encoder_out)
xs_pad = xs_pad + self.conditioning_layer(ctc_out)
if self.normalize_before:
xs_pad = self.after_norm(xs_pad)
olens = masks.squeeze(1).sum(1)
if len(intermediate_outs) > 0:
return (xs_pad, intermediate_outs), olens, None
return xs_pad, olens, None
class SANMVadEncoder(AbsEncoder):
"""
author: Speech Lab, Alibaba Group, China
"""
def __init__(
self,
input_size: int,
output_size: int = 256,
attention_heads: int = 4,
linear_units: int = 2048,
num_blocks: int = 6,
dropout_rate: float = 0.1,
positional_dropout_rate: float = 0.1,
attention_dropout_rate: float = 0.0,
input_layer: Optional[str] = "conv2d",
pos_enc_class=SinusoidalPositionEncoder,
normalize_before: bool = True,
concat_after: bool = False,
positionwise_layer_type: str = "linear",
positionwise_conv_kernel_size: int = 1,
padding_idx: int = -1,
interctc_layer_idx: List[int] = [],
interctc_use_conditioning: bool = False,
kernel_size : int = 11,
sanm_shfit : int = 0,
selfattention_layer_type: str = "sanm",
):
assert check_argument_types()
super().__init__()
self._output_size = output_size
if input_layer == "linear":
self.embed = torch.nn.Sequential(
torch.nn.Linear(input_size, output_size),
torch.nn.LayerNorm(output_size),
torch.nn.Dropout(dropout_rate),
torch.nn.ReLU(),
pos_enc_class(output_size, positional_dropout_rate),
)
elif input_layer == "conv2d":
self.embed = Conv2dSubsampling(input_size, output_size, dropout_rate)
elif input_layer == "conv2d2":
self.embed = Conv2dSubsampling2(input_size, output_size, dropout_rate)
elif input_layer == "conv2d6":
self.embed = Conv2dSubsampling6(input_size, output_size, dropout_rate)
elif input_layer == "conv2d8":
self.embed = Conv2dSubsampling8(input_size, output_size, dropout_rate)
elif input_layer == "embed":
self.embed = torch.nn.Sequential(
torch.nn.Embedding(input_size, output_size, padding_idx=padding_idx),
SinusoidalPositionEncoder(),
)
elif input_layer is None:
if input_size == output_size:
self.embed = None
else:
self.embed = torch.nn.Linear(input_size, output_size)
elif input_layer == "pe":
self.embed = SinusoidalPositionEncoder()
else:
raise ValueError("unknown input_layer: " + input_layer)
self.normalize_before = normalize_before
if positionwise_layer_type == "linear":
positionwise_layer = PositionwiseFeedForward
positionwise_layer_args = (
output_size,
linear_units,
dropout_rate,
)
elif positionwise_layer_type == "conv1d":
positionwise_layer = MultiLayeredConv1d
positionwise_layer_args = (
output_size,
linear_units,
positionwise_conv_kernel_size,
dropout_rate,
)
elif positionwise_layer_type == "conv1d-linear":
positionwise_layer = Conv1dLinear
positionwise_layer_args = (
output_size,
linear_units,
positionwise_conv_kernel_size,
dropout_rate,
)
else:
raise NotImplementedError("Support only linear or conv1d.")
if selfattention_layer_type == "selfattn":
encoder_selfattn_layer = MultiHeadedAttention
encoder_selfattn_layer_args = (
attention_heads,
output_size,
attention_dropout_rate,
)
elif selfattention_layer_type == "sanm":
self.encoder_selfattn_layer = MultiHeadedAttentionSANMwithMask
encoder_selfattn_layer_args0 = (
attention_heads,
input_size,
output_size,
attention_dropout_rate,
kernel_size,
sanm_shfit,
)
encoder_selfattn_layer_args = (
attention_heads,
output_size,
output_size,
attention_dropout_rate,
kernel_size,
sanm_shfit,
)
self.encoders0 = repeat(
1,
lambda lnum: EncoderLayerSANM(
input_size,
output_size,
self.encoder_selfattn_layer(*encoder_selfattn_layer_args0),
positionwise_layer(*positionwise_layer_args),
dropout_rate,
normalize_before,
concat_after,
),
)
self.encoders = repeat(
num_blocks-1,
lambda lnum: EncoderLayerSANM(
output_size,
output_size,
self.encoder_selfattn_layer(*encoder_selfattn_layer_args),
positionwise_layer(*positionwise_layer_args),
dropout_rate,
normalize_before,
concat_after,
),
)
if self.normalize_before:
self.after_norm = LayerNorm(output_size)
self.interctc_layer_idx = interctc_layer_idx
if len(interctc_layer_idx) > 0:
assert 0 < min(interctc_layer_idx) and max(interctc_layer_idx) < num_blocks
self.interctc_use_conditioning = interctc_use_conditioning
self.conditioning_layer = None
self.dropout = nn.Dropout(dropout_rate)
def output_size(self) -> int:
return self._output_size
def forward(
self,
xs_pad: torch.Tensor,
ilens: torch.Tensor,
vad_indexes: torch.Tensor,
prev_states: torch.Tensor = None,
ctc: CTC = None,
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
"""Embed positions in tensor.
Args:
xs_pad: input tensor (B, L, D)
ilens: input length (B)
prev_states: Not to be used now.
Returns:
position embedded tensor and mask
"""
masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
sub_masks = subsequent_mask(masks.size(-1), device=xs_pad.device).unsqueeze(0)
no_future_masks = masks & sub_masks
xs_pad *= self.output_size()**0.5
if self.embed is None:
xs_pad = xs_pad
elif (isinstance(self.embed, Conv2dSubsampling) or isinstance(self.embed, Conv2dSubsampling2)
or isinstance(self.embed, Conv2dSubsampling6) or isinstance(self.embed, Conv2dSubsampling8)):
short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
if short_status:
raise TooShortUttError(
f"has {xs_pad.size(1)} frames and is too short for subsampling " +
f"(it needs more than {limit_size} frames), return empty results",
xs_pad.size(1),
limit_size,
)
xs_pad, masks = self.embed(xs_pad, masks)
else:
xs_pad = self.embed(xs_pad)
# xs_pad = self.dropout(xs_pad)
mask_tup0 = [masks, no_future_masks]
encoder_outs = self.encoders0(xs_pad, mask_tup0)
xs_pad, _ = encoder_outs[0], encoder_outs[1]
intermediate_outs = []
#if len(self.interctc_layer_idx) == 0:
if False:
# Here, we should not use the repeat operation to do it for all layers.
encoder_outs = self.encoders(xs_pad, masks)
xs_pad, masks = encoder_outs[0], encoder_outs[1]
else:
for layer_idx, encoder_layer in enumerate(self.encoders):
if layer_idx + 1 == len(self.encoders):
# This is last layer.
coner_mask = torch.ones(masks.size(0),
masks.size(-1),
masks.size(-1),
device=xs_pad.device,
dtype=torch.bool)
for word_index, length in enumerate(ilens):
coner_mask[word_index, :, :] = vad_mask(masks.size(-1),
vad_indexes[word_index],
device=xs_pad.device)
layer_mask = masks & coner_mask
else:
layer_mask = no_future_masks
mask_tup1 = [masks, layer_mask]
encoder_outs = encoder_layer(xs_pad, mask_tup1)
xs_pad, layer_mask = encoder_outs[0], encoder_outs[1]
if layer_idx + 1 in self.interctc_layer_idx:
encoder_out = xs_pad
# intermediate outputs are also normalized
if self.normalize_before:
encoder_out = self.after_norm(encoder_out)
intermediate_outs.append((layer_idx + 1, encoder_out))
if self.interctc_use_conditioning:
ctc_out = ctc.softmax(encoder_out)
xs_pad = xs_pad + self.conditioning_layer(ctc_out)
if self.normalize_before:
xs_pad = self.after_norm(xs_pad)
olens = masks.squeeze(1).sum(1)
if len(intermediate_outs) > 0:
return (xs_pad, intermediate_outs), olens, None
return xs_pad, olens, None

View File

@ -1,12 +1 @@
def split_to_mini_sentence(words: list, word_limit: int = 20):
assert word_limit > 1
if len(words) <= word_limit:
return [words]
sentences = []
length = len(words)
sentence_len = length // word_limit
for i in range(sentence_len):
sentences.append(words[i * word_limit:(i + 1) * word_limit])
if length % word_limit > 0:
sentences.append(words[sentence_len * word_limit:])
return sentences

View File

@ -15,7 +15,7 @@ from typeguard import check_return_type
from funasr.datasets.collate_fn import CommonCollateFn from funasr.datasets.collate_fn import CommonCollateFn
from funasr.datasets.preprocessor import CommonPreprocessor from funasr.datasets.preprocessor import CommonPreprocessor
from funasr.lm.abs_model import AbsLM from funasr.lm.abs_model import AbsLM
from funasr.lm.espnet_model import ESPnetLanguageModel from funasr.lm.espnet_model import LanguageModel
from funasr.lm.seq_rnn_lm import SequentialRNNLM from funasr.lm.seq_rnn_lm import SequentialRNNLM
from funasr.lm.transformer_lm import TransformerLM from funasr.lm.transformer_lm import TransformerLM
from funasr.tasks.abs_task import AbsTask from funasr.tasks.abs_task import AbsTask
@ -83,7 +83,7 @@ class LMTask(AbsTask):
group.add_argument( group.add_argument(
"--model_conf", "--model_conf",
action=NestedDictAction, action=NestedDictAction,
default=get_default_kwargs(ESPnetLanguageModel), default=get_default_kwargs(LanguageModel),
help="The keyword arguments for model class.", help="The keyword arguments for model class.",
) )
@ -178,7 +178,7 @@ class LMTask(AbsTask):
return retval return retval
@classmethod @classmethod
def build_model(cls, args: argparse.Namespace) -> ESPnetLanguageModel: def build_model(cls, args: argparse.Namespace) -> LanguageModel:
assert check_argument_types() assert check_argument_types()
if isinstance(args.token_list, str): if isinstance(args.token_list, str):
with open(args.token_list, encoding="utf-8") as f: with open(args.token_list, encoding="utf-8") as f:
@ -201,7 +201,7 @@ class LMTask(AbsTask):
# 2. Build ESPnetModel # 2. Build ESPnetModel
# Assume the last-id is sos_and_eos # Assume the last-id is sos_and_eos
model = ESPnetLanguageModel(lm=lm, vocab_size=vocab_size, **args.model_conf) model = LanguageModel(lm=lm, vocab_size=vocab_size, **args.model_conf)
# 3. Initialize # 3. Initialize
if args.init is not None: if args.init is not None:

View File

@ -14,10 +14,10 @@ from typeguard import check_return_type
from funasr.datasets.collate_fn import CommonCollateFn from funasr.datasets.collate_fn import CommonCollateFn
from funasr.datasets.preprocessor import PuncTrainTokenizerCommonPreprocessor from funasr.datasets.preprocessor import PuncTrainTokenizerCommonPreprocessor
from funasr.punctuation.abs_model import AbsPunctuation from funasr.train.abs_model import AbsPunctuation
from funasr.punctuation.espnet_model import ESPnetPunctuationModel from funasr.train.abs_model import PunctuationModel
from funasr.punctuation.target_delay_transformer import TargetDelayTransformer from funasr.models.target_delay_transformer import TargetDelayTransformer
from funasr.punctuation.vad_realtime_transformer import VadRealtimeTransformer from funasr.models.vad_realtime_transformer import VadRealtimeTransformer
from funasr.tasks.abs_task import AbsTask from funasr.tasks.abs_task import AbsTask
from funasr.text.phoneme_tokenizer import g2p_choices from funasr.text.phoneme_tokenizer import g2p_choices
from funasr.torch_utils.initialize import initialize from funasr.torch_utils.initialize import initialize
@ -79,7 +79,7 @@ class PunctuationTask(AbsTask):
group.add_argument( group.add_argument(
"--model_conf", "--model_conf",
action=NestedDictAction, action=NestedDictAction,
default=get_default_kwargs(ESPnetPunctuationModel), default=get_default_kwargs(PunctuationModel),
help="The keyword arguments for model class.", help="The keyword arguments for model class.",
) )
@ -183,7 +183,7 @@ class PunctuationTask(AbsTask):
return retval return retval
@classmethod @classmethod
def build_model(cls, args: argparse.Namespace) -> ESPnetPunctuationModel: def build_model(cls, args: argparse.Namespace) -> PunctuationModel:
assert check_argument_types() assert check_argument_types()
if isinstance(args.token_list, str): if isinstance(args.token_list, str):
with open(args.token_list, encoding="utf-8") as f: with open(args.token_list, encoding="utf-8") as f:
@ -218,7 +218,7 @@ class PunctuationTask(AbsTask):
# Assume the last-id is sos_and_eos # Assume the last-id is sos_and_eos
if "punc_weight" in args.model_conf: if "punc_weight" in args.model_conf:
args.model_conf.pop("punc_weight") args.model_conf.pop("punc_weight")
model = ESPnetPunctuationModel(punc_model=punc, vocab_size=vocab_size, punc_weight=punc_weight_list, **args.model_conf) model = PunctuationModel(punc_model=punc, vocab_size=vocab_size, punc_weight=punc_weight_list, **args.model_conf)
# FIXME(kamo): Should be done in model? # FIXME(kamo): Should be done in model?
# 3. Initialize # 3. Initialize

View File

@ -1,3 +1,9 @@
from abc import ABC
from abc import abstractmethod
from typing import Tuple
import torch
from typing import Dict from typing import Dict
from typing import Optional from typing import Optional
from typing import Tuple from typing import Tuple
@ -7,13 +13,34 @@ import torch.nn.functional as F
from typeguard import check_argument_types from typeguard import check_argument_types
from funasr.modules.nets_utils import make_pad_mask from funasr.modules.nets_utils import make_pad_mask
from funasr.punctuation.abs_model import AbsPunctuation
from funasr.torch_utils.device_funcs import force_gatherable from funasr.torch_utils.device_funcs import force_gatherable
from funasr.train.abs_espnet_model import AbsESPnetModel from funasr.train.abs_espnet_model import AbsESPnetModel
from funasr.modules.scorers.scorer_interface import BatchScorerInterface
class ESPnetPunctuationModel(AbsESPnetModel):
class AbsPunctuation(torch.nn.Module, BatchScorerInterface, ABC):
"""The abstract class
To share the loss calculation way among different models,
We uses delegate pattern here:
The instance of this class should be passed to "LanguageModel"
This "model" is one of mediator objects for "Task" class.
"""
@abstractmethod
def forward(self, input: torch.Tensor, hidden: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
raise NotImplementedError
@abstractmethod
def with_vad(self) -> bool:
raise NotImplementedError
class PunctuationModel(AbsESPnetModel):
def __init__(self, punc_model: AbsPunctuation, vocab_size: int, ignore_id: int = 0, punc_weight: list = None): def __init__(self, punc_model: AbsPunctuation, vocab_size: int, ignore_id: int = 0, punc_weight: list = None):
assert check_argument_types() assert check_argument_types()
super().__init__() super().__init__()
@ -21,12 +48,12 @@ class ESPnetPunctuationModel(AbsESPnetModel):
self.punc_weight = torch.Tensor(punc_weight) self.punc_weight = torch.Tensor(punc_weight)
self.sos = 1 self.sos = 1
self.eos = 2 self.eos = 2
# ignore_id may be assumed as 0, shared with CTC-blank symbol for ASR. # ignore_id may be assumed as 0, shared with CTC-blank symbol for ASR.
self.ignore_id = ignore_id self.ignore_id = ignore_id
#if self.punc_model.with_vad(): # if self.punc_model.with_vad():
# print("This is a vad puncuation model.") # print("This is a vad puncuation model.")
def nll( def nll(
self, self,
text: torch.Tensor, text: torch.Tensor,
@ -54,7 +81,7 @@ class ESPnetPunctuationModel(AbsESPnetModel):
else: else:
text = text[:, :max_length] text = text[:, :max_length]
punc = punc[:, :max_length] punc = punc[:, :max_length]
if self.punc_model.with_vad(): if self.punc_model.with_vad():
# Should be VadRealtimeTransformer # Should be VadRealtimeTransformer
assert vad_indexes is not None assert vad_indexes is not None
@ -62,7 +89,7 @@ class ESPnetPunctuationModel(AbsESPnetModel):
else: else:
# Should be TargetDelayTransformer, # Should be TargetDelayTransformer,
y, _ = self.punc_model(text, text_lengths) y, _ = self.punc_model(text, text_lengths)
# Calc negative log likelihood # Calc negative log likelihood
# nll: (BxL,) # nll: (BxL,)
if self.training == False: if self.training == False:
@ -75,7 +102,8 @@ class ESPnetPunctuationModel(AbsESPnetModel):
return nll, text_lengths return nll, text_lengths
else: else:
self.punc_weight = self.punc_weight.to(punc.device) self.punc_weight = self.punc_weight.to(punc.device)
nll = F.cross_entropy(y.view(-1, y.shape[-1]), punc.view(-1), self.punc_weight, reduction="none", ignore_index=self.ignore_id) nll = F.cross_entropy(y.view(-1, y.shape[-1]), punc.view(-1), self.punc_weight, reduction="none",
ignore_index=self.ignore_id)
# nll: (BxL,) -> (BxL,) # nll: (BxL,) -> (BxL,)
if max_length is None: if max_length is None:
nll.masked_fill_(make_pad_mask(text_lengths).to(nll.device).view(-1), 0.0) nll.masked_fill_(make_pad_mask(text_lengths).to(nll.device).view(-1), 0.0)
@ -87,7 +115,7 @@ class ESPnetPunctuationModel(AbsESPnetModel):
# nll: (BxL,) -> (B, L) # nll: (BxL,) -> (B, L)
nll = nll.view(batch_size, -1) nll = nll.view(batch_size, -1)
return nll, text_lengths return nll, text_lengths
def batchify_nll(self, def batchify_nll(self,
text: torch.Tensor, text: torch.Tensor,
punc: torch.Tensor, punc: torch.Tensor,
@ -113,7 +141,7 @@ class ESPnetPunctuationModel(AbsESPnetModel):
nlls = [] nlls = []
x_lengths = [] x_lengths = []
max_length = text_lengths.max() max_length = text_lengths.max()
start_idx = 0 start_idx = 0
while True: while True:
end_idx = min(start_idx + batch_size, total_num) end_idx = min(start_idx + batch_size, total_num)
@ -132,7 +160,7 @@ class ESPnetPunctuationModel(AbsESPnetModel):
assert nll.size(0) == total_num assert nll.size(0) == total_num
assert x_lengths.size(0) == total_num assert x_lengths.size(0) == total_num
return nll, x_lengths return nll, x_lengths
def forward( def forward(
self, self,
text: torch.Tensor, text: torch.Tensor,
@ -146,15 +174,15 @@ class ESPnetPunctuationModel(AbsESPnetModel):
ntokens = y_lengths.sum() ntokens = y_lengths.sum()
loss = nll.sum() / ntokens loss = nll.sum() / ntokens
stats = dict(loss=loss.detach()) stats = dict(loss=loss.detach())
# force_gatherable: to-device and to-tensor if scalar for DataParallel # force_gatherable: to-device and to-tensor if scalar for DataParallel
loss, stats, weight = force_gatherable((loss, stats, ntokens), loss.device) loss, stats, weight = force_gatherable((loss, stats, ntokens), loss.device)
return loss, stats, weight return loss, stats, weight
def collect_feats(self, text: torch.Tensor, punc: torch.Tensor, def collect_feats(self, text: torch.Tensor, punc: torch.Tensor,
text_lengths: torch.Tensor) -> Dict[str, torch.Tensor]: text_lengths: torch.Tensor) -> Dict[str, torch.Tensor]:
return {} return {}
def inference(self, def inference(self,
text: torch.Tensor, text: torch.Tensor,
text_lengths: torch.Tensor, text_lengths: torch.Tensor,