mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
rnnt reorg
This commit is contained in:
parent
d46a542fae
commit
7d1efe158e
@ -1,13 +1,13 @@
|
||||
encoder_conf:
|
||||
main_conf:
|
||||
pos_wise_act_type: swish
|
||||
pos_enc_dropout_rate: 0.3
|
||||
pos_enc_dropout_rate: 0.5
|
||||
conv_mod_act_type: swish
|
||||
time_reduction_factor: 2
|
||||
unified_model_training: true
|
||||
default_chunk_size: 16
|
||||
jitter_range: 4
|
||||
left_chunk_size: 1
|
||||
left_chunk_size: 0
|
||||
input_conf:
|
||||
block_type: conv2d
|
||||
conv_size: 512
|
||||
@ -18,9 +18,9 @@ encoder_conf:
|
||||
linear_size: 2048
|
||||
hidden_size: 512
|
||||
heads: 8
|
||||
dropout_rate: 0.3
|
||||
pos_wise_dropout_rate: 0.3
|
||||
att_dropout_rate: 0.3
|
||||
dropout_rate: 0.5
|
||||
pos_wise_dropout_rate: 0.5
|
||||
att_dropout_rate: 0.5
|
||||
conv_mod_kernel_size: 15
|
||||
num_blocks: 12
|
||||
|
||||
@ -29,8 +29,8 @@ decoder: rnn
|
||||
decoder_conf:
|
||||
embed_size: 512
|
||||
hidden_size: 512
|
||||
embed_dropout_rate: 0.2
|
||||
dropout_rate: 0.1
|
||||
embed_dropout_rate: 0.5
|
||||
dropout_rate: 0.5
|
||||
|
||||
joint_network_conf:
|
||||
joint_space_size: 512
|
||||
@ -41,14 +41,14 @@ model_conf:
|
||||
|
||||
# minibatch related
|
||||
use_amp: true
|
||||
batch_type: numel
|
||||
batch_bins: 1600000
|
||||
batch_type: unsorted
|
||||
batch_size: 16
|
||||
num_workers: 16
|
||||
|
||||
# optimization related
|
||||
accum_grad: 1
|
||||
grad_clip: 5
|
||||
max_epoch: 80
|
||||
max_epoch: 200
|
||||
val_scheduler_criterion:
|
||||
- valid
|
||||
- loss
|
||||
@ -56,11 +56,11 @@ best_model_criterion:
|
||||
- - valid
|
||||
- cer_transducer_chunk
|
||||
- min
|
||||
keep_nbest_models: 5
|
||||
keep_nbest_models: 10
|
||||
|
||||
optim: adam
|
||||
optim_conf:
|
||||
lr: 0.0003
|
||||
lr: 0.001
|
||||
scheduler: warmuplr
|
||||
scheduler_conf:
|
||||
warmup_steps: 25000
|
||||
@ -75,10 +75,12 @@ specaug_conf:
|
||||
apply_freq_mask: true
|
||||
freq_mask_width_range:
|
||||
- 0
|
||||
- 30
|
||||
- 40
|
||||
num_freq_mask: 2
|
||||
apply_time_mask: true
|
||||
time_mask_width_range:
|
||||
- 0
|
||||
- 40
|
||||
num_time_mask: 2
|
||||
- 50
|
||||
num_time_mask: 5
|
||||
|
||||
log_interval: 50
|
||||
|
||||
@ -16,11 +16,11 @@ import torch
|
||||
from packaging.version import parse as V
|
||||
from typeguard import check_argument_types, check_return_type
|
||||
|
||||
from funasr.models_transducer.beam_search_transducer import (
|
||||
from funasr.modules.beam_search.beam_search_transducer import (
|
||||
BeamSearchTransducer,
|
||||
Hypothesis,
|
||||
)
|
||||
from funasr.models_transducer.utils import TooShortUttError
|
||||
from funasr.modules.nets_utils import TooShortUttError
|
||||
from funasr.fileio.datadir_writer import DatadirWriter
|
||||
from funasr.tasks.asr_transducer import ASRTransducerTask
|
||||
from funasr.tasks.lm import LMTask
|
||||
@ -500,7 +500,6 @@ def inference(
|
||||
|
||||
_bs = len(next(iter(batch.values())))
|
||||
assert len(keys) == _bs, f"{len(keys)} != {_bs}"
|
||||
<<<<<<< HEAD
|
||||
batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")}
|
||||
assert len(batch.keys()) == 1
|
||||
|
||||
@ -541,59 +540,6 @@ def inference(
|
||||
|
||||
if text is not None:
|
||||
ibest_writer["text"][key] = text
|
||||
=======
|
||||
# batch = {k: v for k, v in batch.items() if not k.endswith("_lengths")}
|
||||
|
||||
logging.info("decoding, utt_id: {}".format(keys))
|
||||
# N-best list of (text, token, token_int, hyp_object)
|
||||
|
||||
time_beg = time.time()
|
||||
results = speech2text(cache=cache, **batch)
|
||||
if len(results) < 1:
|
||||
hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[])
|
||||
results = [[" ", ["sil"], [2], hyp, 10, 6]] * nbest
|
||||
time_end = time.time()
|
||||
forward_time = time_end - time_beg
|
||||
lfr_factor = results[0][-1]
|
||||
length = results[0][-2]
|
||||
forward_time_total += forward_time
|
||||
length_total += length
|
||||
rtf_cur = "decoding, feature length: {}, forward_time: {:.4f}, rtf: {:.4f}".format(length, forward_time, 100 * forward_time / (length * lfr_factor))
|
||||
logging.info(rtf_cur)
|
||||
|
||||
for batch_id in range(_bs):
|
||||
result = [results[batch_id][:-2]]
|
||||
|
||||
key = keys[batch_id]
|
||||
for n, (text, token, token_int, hyp) in zip(range(1, nbest + 1), result):
|
||||
# Create a directory: outdir/{n}best_recog
|
||||
if writer is not None:
|
||||
ibest_writer = writer[f"{n}best_recog"]
|
||||
|
||||
# Write the result to each file
|
||||
ibest_writer["token"][key] = " ".join(token)
|
||||
# ibest_writer["token_int"][key] = " ".join(map(str, token_int))
|
||||
ibest_writer["score"][key] = str(hyp.score)
|
||||
ibest_writer["rtf"][key] = rtf_cur
|
||||
|
||||
if text is not None:
|
||||
text_postprocessed, word_lists = postprocess_utils.sentence_postprocess(token)
|
||||
item = {'key': key, 'value': text_postprocessed}
|
||||
asr_result_list.append(item)
|
||||
finish_count += 1
|
||||
# asr_utils.print_progress(finish_count / file_count)
|
||||
if writer is not None:
|
||||
ibest_writer["text"][key] = " ".join(word_lists)
|
||||
|
||||
logging.info("decoding, utt: {}, predictions: {}".format(key, text))
|
||||
rtf_avg = "decoding, feature length total: {}, forward_time total: {:.4f}, rtf avg: {:.4f}".format(length_total, forward_time_total, 100 * forward_time_total / (length_total * lfr_factor))
|
||||
logging.info(rtf_avg)
|
||||
if writer is not None:
|
||||
ibest_writer["rtf"]["rtf_avf"] = rtf_avg
|
||||
return asr_result_list
|
||||
|
||||
return _forward
|
||||
>>>>>>> main
|
||||
|
||||
|
||||
def get_parser():
|
||||
|
||||
@ -10,11 +10,11 @@ from typeguard import check_argument_types
|
||||
|
||||
from funasr.models.frontend.abs_frontend import AbsFrontend
|
||||
from funasr.models.specaug.abs_specaug import AbsSpecAug
|
||||
from funasr.models_transducer.decoder.abs_decoder import AbsDecoder
|
||||
from funasr.models.rnnt_decoder.abs_decoder import AbsDecoder
|
||||
from funasr.models.decoder.abs_decoder import AbsDecoder as AbsAttDecoder
|
||||
from funasr.models_transducer.encoder.encoder import Encoder
|
||||
from funasr.models_transducer.joint_network import JointNetwork
|
||||
from funasr.models_transducer.utils import get_transducer_task_io
|
||||
from funasr.models.encoder.chunk_encoder import ChunkEncoder as Encoder
|
||||
from funasr.models.joint_network import JointNetwork
|
||||
from funasr.modules.nets_utils import get_transducer_task_io
|
||||
from funasr.layers.abs_normalize import AbsNormalize
|
||||
from funasr.torch_utils.device_funcs import force_gatherable
|
||||
from funasr.train.abs_espnet_model import AbsESPnetModel
|
||||
@ -28,7 +28,7 @@ else:
|
||||
yield
|
||||
|
||||
|
||||
class ESPnetASRTransducerModel(AbsESPnetModel):
|
||||
class TransducerModel(AbsESPnetModel):
|
||||
"""ESPnet2ASRTransducerModel module definition.
|
||||
|
||||
Args:
|
||||
@ -10,10 +10,10 @@ from typeguard import check_argument_types
|
||||
|
||||
from funasr.models.frontend.abs_frontend import AbsFrontend
|
||||
from funasr.models.specaug.abs_specaug import AbsSpecAug
|
||||
from funasr.models_transducer.decoder.abs_decoder import AbsDecoder
|
||||
from funasr.models_transducer.encoder.encoder import Encoder
|
||||
from funasr.models_transducer.joint_network import JointNetwork
|
||||
from funasr.models_transducer.utils import get_transducer_task_io
|
||||
from funasr.models.rnnt_decoder.abs_decoder import AbsDecoder
|
||||
from funasr.models.encoder.chunk_encoder import ChunkEncoder as Encoder
|
||||
from funasr.models.joint_network import JointNetwork
|
||||
from funasr.modules.nets_utils import get_transducer_task_io
|
||||
from funasr.layers.abs_normalize import AbsNormalize
|
||||
from funasr.torch_utils.device_funcs import force_gatherable
|
||||
from funasr.train.abs_espnet_model import AbsESPnetModel
|
||||
@ -23,7 +23,7 @@ from funasr.modules.nets_utils import th_accuracy
|
||||
from funasr.losses.label_smoothing_loss import ( # noqa: H301
|
||||
LabelSmoothingLoss,
|
||||
)
|
||||
from funasr.models_transducer.error_calculator import ErrorCalculator
|
||||
from funasr.modules.e2e_asr_common import ErrorCalculatorTransducer as ErrorCalculator
|
||||
if V(torch.__version__) >= V("1.6.0"):
|
||||
from torch.cuda.amp import autocast
|
||||
else:
|
||||
@ -33,7 +33,7 @@ else:
|
||||
yield
|
||||
|
||||
|
||||
class ESPnetASRUnifiedTransducerModel(AbsESPnetModel):
|
||||
class UnifiedTransducerModel(AbsESPnetModel):
|
||||
"""ESPnet2ASRTransducerModel module definition.
|
||||
|
||||
Args:
|
||||
@ -289,7 +289,6 @@ class ESPnetASRUnifiedTransducerModel(AbsESPnetModel):
|
||||
|
||||
# force_gatherable: to-device and to-tensor if scalar for DataParallel
|
||||
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
|
||||
|
||||
return loss, stats, weight
|
||||
|
||||
def collect_feats(
|
||||
@ -1,26 +1,23 @@
|
||||
"""Encoder for Transducer model."""
|
||||
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
import torch
|
||||
from typeguard import check_argument_types
|
||||
|
||||
from funasr.models_transducer.encoder.building import (
|
||||
from funasr.models.encoder.chunk_encoder_utils.building import (
|
||||
build_body_blocks,
|
||||
build_input_block,
|
||||
build_main_parameters,
|
||||
build_positional_encoding,
|
||||
)
|
||||
from funasr.models_transducer.encoder.validation import validate_architecture
|
||||
from funasr.models_transducer.utils import (
|
||||
from funasr.models.encoder.chunk_encoder_utils.validation import validate_architecture
|
||||
from funasr.modules.nets_utils import (
|
||||
TooShortUttError,
|
||||
check_short_utt,
|
||||
make_chunk_mask,
|
||||
make_source_mask,
|
||||
)
|
||||
|
||||
|
||||
class Encoder(torch.nn.Module):
|
||||
class ChunkEncoder(torch.nn.Module):
|
||||
"""Encoder module definition.
|
||||
|
||||
Args:
|
||||
@ -61,10 +58,9 @@ class Encoder(torch.nn.Module):
|
||||
|
||||
self.unified_model_training = main_params["unified_model_training"]
|
||||
self.default_chunk_size = main_params["default_chunk_size"]
|
||||
self.jitter_range = main_params["jitter_range"]
|
||||
|
||||
self.time_reduction_factor = main_params["time_reduction_factor"]
|
||||
self.jitter_range = main_params["jitter_range"]
|
||||
|
||||
self.time_reduction_factor = main_params["time_reduction_factor"]
|
||||
def get_encoder_input_raw_size(self, size: int, hop_length: int) -> int:
|
||||
"""Return the corresponding number of sample for a given chunk size, in frames.
|
||||
|
||||
@ -79,7 +75,7 @@ class Encoder(torch.nn.Module):
|
||||
|
||||
"""
|
||||
return self.embed.get_size_before_subsampling(size) * hop_length
|
||||
|
||||
|
||||
def get_encoder_input_size(self, size: int) -> int:
|
||||
"""Return the corresponding number of sample for a given chunk size, in frames.
|
||||
|
||||
@ -157,7 +153,7 @@ class Encoder(torch.nn.Module):
|
||||
mask,
|
||||
chunk_mask=chunk_mask,
|
||||
)
|
||||
|
||||
|
||||
olens = mask.eq(0).sum(1)
|
||||
if self.time_reduction_factor > 1:
|
||||
x_utt = x_utt[:,::self.time_reduction_factor,:]
|
||||
@ -194,14 +190,14 @@ class Encoder(torch.nn.Module):
|
||||
mask,
|
||||
chunk_mask=chunk_mask,
|
||||
)
|
||||
|
||||
|
||||
olens = mask.eq(0).sum(1)
|
||||
if self.time_reduction_factor > 1:
|
||||
x = x[:,::self.time_reduction_factor,:]
|
||||
olens = torch.floor_divide(olens-1, self.time_reduction_factor) + 1
|
||||
|
||||
return x, olens
|
||||
|
||||
|
||||
def simu_chunk_forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
@ -290,7 +286,7 @@ class Encoder(torch.nn.Module):
|
||||
|
||||
if right_context > 0:
|
||||
x = x[:, 0:-right_context, :]
|
||||
|
||||
|
||||
if self.time_reduction_factor > 1:
|
||||
x = x[:,::self.time_reduction_factor,:]
|
||||
return x
|
||||
@ -5,7 +5,7 @@ from typing import Optional, Tuple, Union
|
||||
import torch
|
||||
import math
|
||||
|
||||
from funasr.models_transducer.utils import sub_factor_to_params, pad_to_len
|
||||
from funasr.modules.nets_utils import sub_factor_to_params, pad_to_len
|
||||
|
||||
|
||||
class ConvInput(torch.nn.Module):
|
||||
@ -2,22 +2,22 @@
|
||||
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from funasr.models_transducer.activation import get_activation
|
||||
from funasr.models_transducer.encoder.blocks.branchformer import Branchformer
|
||||
from funasr.models_transducer.encoder.blocks.conformer import Conformer
|
||||
from funasr.models_transducer.encoder.blocks.conv1d import Conv1d
|
||||
from funasr.models_transducer.encoder.blocks.conv_input import ConvInput
|
||||
from funasr.models_transducer.encoder.blocks.linear_input import LinearInput
|
||||
from funasr.models_transducer.encoder.modules.attention import ( # noqa: H301
|
||||
from funasr.modules.activation import get_activation
|
||||
from funasr.models.encoder.chunk_encoder_blocks.branchformer import Branchformer
|
||||
from funasr.models.encoder.chunk_encoder_blocks.conformer import Conformer
|
||||
from funasr.models.encoder.chunk_encoder_blocks.conv1d import Conv1d
|
||||
from funasr.models.encoder.chunk_encoder_blocks.conv_input import ConvInput
|
||||
from funasr.models.encoder.chunk_encoder_blocks.linear_input import LinearInput
|
||||
from funasr.models.encoder.chunk_encoder_modules.attention import ( # noqa: H301
|
||||
RelPositionMultiHeadedAttention,
|
||||
)
|
||||
from funasr.models_transducer.encoder.modules.convolution import ( # noqa: H301
|
||||
from funasr.models.encoder.chunk_encoder_modules.convolution import ( # noqa: H301
|
||||
ConformerConvolution,
|
||||
ConvolutionalSpatialGatingUnit,
|
||||
)
|
||||
from funasr.models_transducer.encoder.modules.multi_blocks import MultiBlocks
|
||||
from funasr.models_transducer.encoder.modules.normalization import get_normalization
|
||||
from funasr.models_transducer.encoder.modules.positional_encoding import ( # noqa: H301
|
||||
from funasr.models.encoder.chunk_encoder_modules.multi_blocks import MultiBlocks
|
||||
from funasr.models.encoder.chunk_encoder_modules.normalization import get_normalization
|
||||
from funasr.models.encoder.chunk_encoder_modules.positional_encoding import ( # noqa: H301
|
||||
RelPositionalEncoding,
|
||||
)
|
||||
from funasr.modules.positionwise_feed_forward import (
|
||||
@ -2,7 +2,7 @@
|
||||
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
from funasr.models_transducer.utils import sub_factor_to_params
|
||||
from funasr.modules.nets_utils import sub_factor_to_params
|
||||
|
||||
|
||||
def validate_block_arguments(
|
||||
@ -2,7 +2,7 @@
|
||||
|
||||
import torch
|
||||
|
||||
from funasr.models_transducer.activation import get_activation
|
||||
from funasr.modules.activation import get_activation
|
||||
|
||||
|
||||
class JointNetwork(torch.nn.Module):
|
||||
@ -5,8 +5,8 @@ from typing import List, Optional, Tuple
|
||||
import torch
|
||||
from typeguard import check_argument_types
|
||||
|
||||
from funasr.models_transducer.beam_search_transducer import Hypothesis
|
||||
from funasr.models_transducer.decoder.abs_decoder import AbsDecoder
|
||||
from funasr.modules.beam_search.beam_search_transducer import Hypothesis
|
||||
from funasr.models.rnnt_decoder.abs_decoder import AbsDecoder
|
||||
from funasr.models.specaug.specaug import SpecAug
|
||||
|
||||
class RNNDecoder(AbsDecoder):
|
||||
@ -5,8 +5,8 @@ from typing import List, Optional, Tuple
|
||||
import torch
|
||||
from typeguard import check_argument_types
|
||||
|
||||
from funasr.models_transducer.beam_search_transducer import Hypothesis
|
||||
from funasr.models_transducer.decoder.abs_decoder import AbsDecoder
|
||||
from funasr.modules.beam_search.beam_search_transducer import Hypothesis
|
||||
from funasr.models.rnnt_decoder.abs_decoder import AbsDecoder
|
||||
from funasr.models.specaug.specaug import SpecAug
|
||||
|
||||
class StatelessDecoder(AbsDecoder):
|
||||
@ -26,7 +26,6 @@ class StatelessDecoder(AbsDecoder):
|
||||
embed_size: int = 256,
|
||||
embed_dropout_rate: float = 0.0,
|
||||
embed_pad: int = 0,
|
||||
use_embed_mask: bool = False,
|
||||
) -> None:
|
||||
"""Construct a StatelessDecoder object."""
|
||||
super().__init__()
|
||||
@ -42,14 +41,6 @@ class StatelessDecoder(AbsDecoder):
|
||||
self.device = next(self.parameters()).device
|
||||
self.score_cache = {}
|
||||
|
||||
self.use_embed_mask = use_embed_mask
|
||||
if self.use_embed_mask:
|
||||
self._embed_mask = SpecAug(
|
||||
time_mask_width_range=3,
|
||||
num_time_mask=1,
|
||||
apply_freq_mask=False,
|
||||
apply_time_warp=False
|
||||
)
|
||||
|
||||
|
||||
def forward(
|
||||
@ -69,9 +60,6 @@ class StatelessDecoder(AbsDecoder):
|
||||
|
||||
"""
|
||||
dec_embed = self.embed_dropout_rate(self.embed(labels))
|
||||
if self.use_embed_mask and self.training:
|
||||
dec_embed = self._embed_mask(dec_embed, label_lens)[0]
|
||||
|
||||
return dec_embed
|
||||
|
||||
def score(
|
||||
@ -1,835 +0,0 @@
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import Sequence
|
||||
from typing import Tuple
|
||||
from typing import Union
|
||||
import logging
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from funasr.modules.streaming_utils.chunk_utilis import overlap_chunk
|
||||
from typeguard import check_argument_types
|
||||
import numpy as np
|
||||
from funasr.modules.nets_utils import make_pad_mask
|
||||
from funasr.modules.attention import MultiHeadedAttention, MultiHeadedAttentionSANM
|
||||
from funasr.modules.embedding import SinusoidalPositionEncoder
|
||||
from funasr.modules.layer_norm import LayerNorm
|
||||
from funasr.modules.multi_layer_conv import Conv1dLinear
|
||||
from funasr.modules.multi_layer_conv import MultiLayeredConv1d
|
||||
from funasr.modules.positionwise_feed_forward import (
|
||||
PositionwiseFeedForward, # noqa: H301
|
||||
)
|
||||
from funasr.modules.repeat import repeat
|
||||
from funasr.modules.subsampling import Conv2dSubsampling
|
||||
from funasr.modules.subsampling import Conv2dSubsampling2
|
||||
from funasr.modules.subsampling import Conv2dSubsampling6
|
||||
from funasr.modules.subsampling import Conv2dSubsampling8
|
||||
from funasr.modules.subsampling import TooShortUttError
|
||||
from funasr.modules.subsampling import check_short_utt
|
||||
from funasr.models.ctc import CTC
|
||||
from funasr.models.encoder.abs_encoder import AbsEncoder
|
||||
|
||||
|
||||
class EncoderLayerSANM(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_size,
|
||||
size,
|
||||
self_attn,
|
||||
feed_forward,
|
||||
dropout_rate,
|
||||
normalize_before=True,
|
||||
concat_after=False,
|
||||
stochastic_depth_rate=0.0,
|
||||
):
|
||||
"""Construct an EncoderLayer object."""
|
||||
super(EncoderLayerSANM, self).__init__()
|
||||
self.self_attn = self_attn
|
||||
self.feed_forward = feed_forward
|
||||
self.norm1 = LayerNorm(in_size)
|
||||
self.norm2 = LayerNorm(size)
|
||||
self.dropout = nn.Dropout(dropout_rate)
|
||||
self.in_size = in_size
|
||||
self.size = size
|
||||
self.normalize_before = normalize_before
|
||||
self.concat_after = concat_after
|
||||
if self.concat_after:
|
||||
self.concat_linear = nn.Linear(size + size, size)
|
||||
self.stochastic_depth_rate = stochastic_depth_rate
|
||||
self.dropout_rate = dropout_rate
|
||||
|
||||
def forward(self, x, mask, cache=None, mask_shfit_chunk=None, mask_att_chunk_encoder=None):
|
||||
"""Compute encoded features.
|
||||
Args:
|
||||
x_input (torch.Tensor): Input tensor (#batch, time, size).
|
||||
mask (torch.Tensor): Mask tensor for the input (#batch, time).
|
||||
cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size).
|
||||
Returns:
|
||||
torch.Tensor: Output tensor (#batch, time, size).
|
||||
torch.Tensor: Mask tensor (#batch, time).
|
||||
"""
|
||||
skip_layer = False
|
||||
# with stochastic depth, residual connection `x + f(x)` becomes
|
||||
# `x <- x + 1 / (1 - p) * f(x)` at training time.
|
||||
stoch_layer_coeff = 1.0
|
||||
if self.training and self.stochastic_depth_rate > 0:
|
||||
skip_layer = torch.rand(1).item() < self.stochastic_depth_rate
|
||||
stoch_layer_coeff = 1.0 / (1 - self.stochastic_depth_rate)
|
||||
|
||||
if skip_layer:
|
||||
if cache is not None:
|
||||
x = torch.cat([cache, x], dim=1)
|
||||
return x, mask
|
||||
|
||||
residual = x
|
||||
if self.normalize_before:
|
||||
x = self.norm1(x)
|
||||
|
||||
if self.concat_after:
|
||||
x_concat = torch.cat((x, self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder)), dim=-1)
|
||||
if self.in_size == self.size:
|
||||
x = residual + stoch_layer_coeff * self.concat_linear(x_concat)
|
||||
else:
|
||||
x = stoch_layer_coeff * self.concat_linear(x_concat)
|
||||
else:
|
||||
if self.in_size == self.size:
|
||||
x = residual + stoch_layer_coeff * self.dropout(
|
||||
self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder)
|
||||
)
|
||||
else:
|
||||
x = stoch_layer_coeff * self.dropout(
|
||||
self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder)
|
||||
)
|
||||
if not self.normalize_before:
|
||||
x = self.norm1(x)
|
||||
|
||||
residual = x
|
||||
if self.normalize_before:
|
||||
x = self.norm2(x)
|
||||
x = residual + stoch_layer_coeff * self.dropout(self.feed_forward(x))
|
||||
if not self.normalize_before:
|
||||
x = self.norm2(x)
|
||||
|
||||
|
||||
return x, mask, cache, mask_shfit_chunk, mask_att_chunk_encoder
|
||||
|
||||
class SANMEncoder(AbsEncoder):
|
||||
"""
|
||||
author: Speech Lab, Alibaba Group, China
|
||||
San-m: Memory equipped self-attention for end-to-end speech recognition
|
||||
https://arxiv.org/abs/2006.01713
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size: int,
|
||||
output_size: int = 256,
|
||||
attention_heads: int = 4,
|
||||
linear_units: int = 2048,
|
||||
num_blocks: int = 6,
|
||||
dropout_rate: float = 0.1,
|
||||
positional_dropout_rate: float = 0.1,
|
||||
attention_dropout_rate: float = 0.0,
|
||||
pos_enc_class=SinusoidalPositionEncoder,
|
||||
normalize_before: bool = True,
|
||||
concat_after: bool = False,
|
||||
positionwise_layer_type: str = "linear",
|
||||
positionwise_conv_kernel_size: int = 1,
|
||||
padding_idx: int = -1,
|
||||
interctc_layer_idx: List[int] = [],
|
||||
interctc_use_conditioning: bool = False,
|
||||
kernel_size : int = 11,
|
||||
sanm_shfit : int = 0,
|
||||
tf2torch_tensor_name_prefix_torch: str = "encoder",
|
||||
tf2torch_tensor_name_prefix_tf: str = "seq2seq/encoder",
|
||||
):
|
||||
assert check_argument_types()
|
||||
super().__init__()
|
||||
|
||||
self.embed = SinusoidalPositionEncoder()
|
||||
self.normalize_before = normalize_before
|
||||
if positionwise_layer_type == "linear":
|
||||
positionwise_layer = PositionwiseFeedForward
|
||||
positionwise_layer_args = (
|
||||
output_size,
|
||||
linear_units,
|
||||
dropout_rate,
|
||||
)
|
||||
elif positionwise_layer_type == "conv1d":
|
||||
positionwise_layer = MultiLayeredConv1d
|
||||
positionwise_layer_args = (
|
||||
output_size,
|
||||
linear_units,
|
||||
positionwise_conv_kernel_size,
|
||||
dropout_rate,
|
||||
)
|
||||
elif positionwise_layer_type == "conv1d-linear":
|
||||
positionwise_layer = Conv1dLinear
|
||||
positionwise_layer_args = (
|
||||
output_size,
|
||||
linear_units,
|
||||
positionwise_conv_kernel_size,
|
||||
dropout_rate,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError("Support only linear or conv1d.")
|
||||
|
||||
encoder_selfattn_layer = MultiHeadedAttentionSANM
|
||||
encoder_selfattn_layer_args0 = (
|
||||
attention_heads,
|
||||
input_size,
|
||||
output_size,
|
||||
attention_dropout_rate,
|
||||
kernel_size,
|
||||
sanm_shfit,
|
||||
)
|
||||
|
||||
encoder_selfattn_layer_args = (
|
||||
attention_heads,
|
||||
output_size,
|
||||
output_size,
|
||||
attention_dropout_rate,
|
||||
kernel_size,
|
||||
sanm_shfit,
|
||||
)
|
||||
self.encoders0 = repeat(
|
||||
1,
|
||||
lambda lnum: EncoderLayerSANM(
|
||||
input_size,
|
||||
output_size,
|
||||
encoder_selfattn_layer(*encoder_selfattn_layer_args0),
|
||||
positionwise_layer(*positionwise_layer_args),
|
||||
dropout_rate,
|
||||
normalize_before,
|
||||
concat_after,
|
||||
),
|
||||
)
|
||||
|
||||
self.encoders = repeat(
|
||||
num_blocks-1,
|
||||
lambda lnum: EncoderLayerSANM(
|
||||
output_size,
|
||||
output_size,
|
||||
encoder_selfattn_layer(*encoder_selfattn_layer_args),
|
||||
positionwise_layer(*positionwise_layer_args),
|
||||
dropout_rate,
|
||||
normalize_before,
|
||||
concat_after,
|
||||
),
|
||||
)
|
||||
if self.normalize_before:
|
||||
self.after_norm = LayerNorm(output_size)
|
||||
|
||||
self.interctc_layer_idx = interctc_layer_idx
|
||||
if len(interctc_layer_idx) > 0:
|
||||
assert 0 < min(interctc_layer_idx) and max(interctc_layer_idx) < num_blocks
|
||||
self.interctc_use_conditioning = interctc_use_conditioning
|
||||
self.conditioning_layer = None
|
||||
self.dropout = nn.Dropout(dropout_rate)
|
||||
self.tf2torch_tensor_name_prefix_torch = tf2torch_tensor_name_prefix_torch
|
||||
self.tf2torch_tensor_name_prefix_tf = tf2torch_tensor_name_prefix_tf
|
||||
|
||||
def forward(
|
||||
self,
|
||||
xs_pad: torch.Tensor,
|
||||
ilens: torch.Tensor,
|
||||
prev_states: torch.Tensor = None,
|
||||
ctc: CTC = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
||||
"""Embed positions in tensor.
|
||||
Args:
|
||||
xs_pad: input tensor (B, L, D)
|
||||
ilens: input length (B)
|
||||
prev_states: Not to be used now.
|
||||
Returns:
|
||||
position embedded tensor and mask
|
||||
"""
|
||||
masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
|
||||
xs_pad = xs_pad * self.output_size**0.5
|
||||
if self.embed is None:
|
||||
xs_pad = xs_pad
|
||||
elif (
|
||||
isinstance(self.embed, Conv2dSubsampling)
|
||||
or isinstance(self.embed, Conv2dSubsampling2)
|
||||
or isinstance(self.embed, Conv2dSubsampling6)
|
||||
or isinstance(self.embed, Conv2dSubsampling8)
|
||||
):
|
||||
short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
|
||||
if short_status:
|
||||
raise TooShortUttError(
|
||||
f"has {xs_pad.size(1)} frames and is too short for subsampling "
|
||||
+ f"(it needs more than {limit_size} frames), return empty results",
|
||||
xs_pad.size(1),
|
||||
limit_size,
|
||||
)
|
||||
xs_pad, masks = self.embed(xs_pad, masks)
|
||||
else:
|
||||
xs_pad = self.embed(xs_pad)
|
||||
|
||||
# xs_pad = self.dropout(xs_pad)
|
||||
encoder_outs = self.encoders0(xs_pad, masks)
|
||||
xs_pad, masks = encoder_outs[0], encoder_outs[1]
|
||||
intermediate_outs = []
|
||||
if len(self.interctc_layer_idx) == 0:
|
||||
encoder_outs = self.encoders(xs_pad, masks)
|
||||
xs_pad, masks = encoder_outs[0], encoder_outs[1]
|
||||
else:
|
||||
for layer_idx, encoder_layer in enumerate(self.encoders):
|
||||
encoder_outs = encoder_layer(xs_pad, masks)
|
||||
xs_pad, masks = encoder_outs[0], encoder_outs[1]
|
||||
|
||||
if layer_idx + 1 in self.interctc_layer_idx:
|
||||
encoder_out = xs_pad
|
||||
|
||||
# intermediate outputs are also normalized
|
||||
if self.normalize_before:
|
||||
encoder_out = self.after_norm(encoder_out)
|
||||
|
||||
intermediate_outs.append((layer_idx + 1, encoder_out))
|
||||
|
||||
if self.interctc_use_conditioning:
|
||||
ctc_out = ctc.softmax(encoder_out)
|
||||
xs_pad = xs_pad + self.conditioning_layer(ctc_out)
|
||||
|
||||
if self.normalize_before:
|
||||
xs_pad = self.after_norm(xs_pad)
|
||||
|
||||
olens = masks.squeeze(1).sum(1)
|
||||
if len(intermediate_outs) > 0:
|
||||
return (xs_pad, intermediate_outs), olens, None
|
||||
return xs_pad, olens
|
||||
|
||||
def gen_tf2torch_map_dict(self):
|
||||
tensor_name_prefix_torch = self.tf2torch_tensor_name_prefix_torch
|
||||
tensor_name_prefix_tf = self.tf2torch_tensor_name_prefix_tf
|
||||
map_dict_local = {
|
||||
## encoder
|
||||
# cicd
|
||||
"{}.encoders.layeridx.norm1.weight".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/layer_layeridx/multi_head/LayerNorm/gamma".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
}, # (256,),(256,)
|
||||
"{}.encoders.layeridx.norm1.bias".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/layer_layeridx/multi_head/LayerNorm/beta".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
}, # (256,),(256,)
|
||||
"{}.encoders.layeridx.self_attn.linear_q_k_v.weight".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/layer_layeridx/multi_head/conv1d/kernel".format(tensor_name_prefix_tf),
|
||||
"squeeze": 0,
|
||||
"transpose": (1, 0),
|
||||
}, # (768,256),(1,256,768)
|
||||
"{}.encoders.layeridx.self_attn.linear_q_k_v.bias".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/layer_layeridx/multi_head/conv1d/bias".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
}, # (768,),(768,)
|
||||
"{}.encoders.layeridx.self_attn.fsmn_block.weight".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/layer_layeridx/multi_head/depth_conv_w".format(tensor_name_prefix_tf),
|
||||
"squeeze": 0,
|
||||
"transpose": (1, 2, 0),
|
||||
}, # (256,1,31),(1,31,256,1)
|
||||
"{}.encoders.layeridx.self_attn.linear_out.weight".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/layer_layeridx/multi_head/conv1d_1/kernel".format(tensor_name_prefix_tf),
|
||||
"squeeze": 0,
|
||||
"transpose": (1, 0),
|
||||
}, # (256,256),(1,256,256)
|
||||
"{}.encoders.layeridx.self_attn.linear_out.bias".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/layer_layeridx/multi_head/conv1d_1/bias".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
}, # (256,),(256,)
|
||||
# ffn
|
||||
"{}.encoders.layeridx.norm2.weight".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/layer_layeridx/ffn/LayerNorm/gamma".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
}, # (256,),(256,)
|
||||
"{}.encoders.layeridx.norm2.bias".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/layer_layeridx/ffn/LayerNorm/beta".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
}, # (256,),(256,)
|
||||
"{}.encoders.layeridx.feed_forward.w_1.weight".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/layer_layeridx/ffn/conv1d/kernel".format(tensor_name_prefix_tf),
|
||||
"squeeze": 0,
|
||||
"transpose": (1, 0),
|
||||
}, # (1024,256),(1,256,1024)
|
||||
"{}.encoders.layeridx.feed_forward.w_1.bias".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/layer_layeridx/ffn/conv1d/bias".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
}, # (1024,),(1024,)
|
||||
"{}.encoders.layeridx.feed_forward.w_2.weight".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/layer_layeridx/ffn/conv1d_1/kernel".format(tensor_name_prefix_tf),
|
||||
"squeeze": 0,
|
||||
"transpose": (1, 0),
|
||||
}, # (256,1024),(1,1024,256)
|
||||
"{}.encoders.layeridx.feed_forward.w_2.bias".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/layer_layeridx/ffn/conv1d_1/bias".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
}, # (256,),(256,)
|
||||
# out norm
|
||||
"{}.after_norm.weight".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/LayerNorm/gamma".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
}, # (256,),(256,)
|
||||
"{}.after_norm.bias".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/LayerNorm/beta".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
}, # (256,),(256,)
|
||||
|
||||
}
|
||||
|
||||
return map_dict_local
|
||||
|
||||
def convert_tf2torch(self,
|
||||
var_dict_tf,
|
||||
var_dict_torch,
|
||||
):
|
||||
|
||||
map_dict = self.gen_tf2torch_map_dict()
|
||||
|
||||
var_dict_torch_update = dict()
|
||||
for name in sorted(var_dict_torch.keys(), reverse=False):
|
||||
names = name.split('.')
|
||||
if names[0] == self.tf2torch_tensor_name_prefix_torch:
|
||||
if names[1] == "encoders0":
|
||||
layeridx = int(names[2])
|
||||
name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
|
||||
|
||||
name_q = name_q.replace("encoders0", "encoders")
|
||||
layeridx_bias = 0
|
||||
layeridx += layeridx_bias
|
||||
if name_q in map_dict.keys():
|
||||
name_v = map_dict[name_q]["name"]
|
||||
name_tf = name_v.replace("layeridx", "{}".format(layeridx))
|
||||
data_tf = var_dict_tf[name_tf]
|
||||
if map_dict[name_q]["squeeze"] is not None:
|
||||
data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
|
||||
if map_dict[name_q]["transpose"] is not None:
|
||||
data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
|
||||
data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
|
||||
assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
|
||||
var_dict_torch[
|
||||
name].size(),
|
||||
data_tf.size())
|
||||
var_dict_torch_update[name] = data_tf
|
||||
logging.info(
|
||||
"torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
|
||||
var_dict_tf[name_tf].shape))
|
||||
elif names[1] == "encoders":
|
||||
layeridx = int(names[2])
|
||||
name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
|
||||
layeridx_bias = 1
|
||||
layeridx += layeridx_bias
|
||||
if name_q in map_dict.keys():
|
||||
name_v = map_dict[name_q]["name"]
|
||||
name_tf = name_v.replace("layeridx", "{}".format(layeridx))
|
||||
data_tf = var_dict_tf[name_tf]
|
||||
if map_dict[name_q]["squeeze"] is not None:
|
||||
data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
|
||||
if map_dict[name_q]["transpose"] is not None:
|
||||
data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
|
||||
data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
|
||||
assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
|
||||
var_dict_torch[
|
||||
name].size(),
|
||||
data_tf.size())
|
||||
var_dict_torch_update[name] = data_tf
|
||||
logging.info(
|
||||
"torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
|
||||
var_dict_tf[name_tf].shape))
|
||||
|
||||
elif names[1] == "after_norm":
|
||||
name_tf = map_dict[name]["name"]
|
||||
data_tf = var_dict_tf[name_tf]
|
||||
data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
|
||||
var_dict_torch_update[name] = data_tf
|
||||
logging.info(
|
||||
"torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf,
|
||||
var_dict_tf[name_tf].shape))
|
||||
|
||||
return var_dict_torch_update
|
||||
|
||||
|
||||
class SANMEncoderChunkOpt(AbsEncoder):
|
||||
"""
|
||||
author: Speech Lab, Alibaba Group, China
|
||||
SCAMA: Streaming chunk-aware multihead attention for online end-to-end speech recognition
|
||||
https://arxiv.org/abs/2006.01713
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size: int,
|
||||
output_size: int = 256,
|
||||
attention_heads: int = 4,
|
||||
linear_units: int = 2048,
|
||||
num_blocks: int = 6,
|
||||
dropout_rate: float = 0.1,
|
||||
positional_dropout_rate: float = 0.1,
|
||||
attention_dropout_rate: float = 0.0,
|
||||
pos_enc_class=SinusoidalPositionEncoder,
|
||||
normalize_before: bool = True,
|
||||
concat_after: bool = False,
|
||||
positionwise_layer_type: str = "linear",
|
||||
positionwise_conv_kernel_size: int = 1,
|
||||
padding_idx: int = -1,
|
||||
interctc_layer_idx: List[int] = [],
|
||||
interctc_use_conditioning: bool = False,
|
||||
kernel_size: int = 11,
|
||||
sanm_shfit: int = 0,
|
||||
chunk_size: Union[int, Sequence[int]] = (16,),
|
||||
stride: Union[int, Sequence[int]] = (10,),
|
||||
pad_left: Union[int, Sequence[int]] = (0,),
|
||||
time_reduction_factor: int = 1,
|
||||
encoder_att_look_back_factor: Union[int, Sequence[int]] = (1,),
|
||||
decoder_att_look_back_factor: Union[int, Sequence[int]] = (1,),
|
||||
tf2torch_tensor_name_prefix_torch: str = "encoder",
|
||||
tf2torch_tensor_name_prefix_tf: str = "seq2seq/encoder",
|
||||
):
|
||||
assert check_argument_types()
|
||||
super().__init__()
|
||||
self.output_size = output_size
|
||||
|
||||
self.embed = SinusoidalPositionEncoder()
|
||||
|
||||
self.normalize_before = normalize_before
|
||||
if positionwise_layer_type == "linear":
|
||||
positionwise_layer = PositionwiseFeedForward
|
||||
positionwise_layer_args = (
|
||||
output_size,
|
||||
linear_units,
|
||||
dropout_rate,
|
||||
)
|
||||
elif positionwise_layer_type == "conv1d":
|
||||
positionwise_layer = MultiLayeredConv1d
|
||||
positionwise_layer_args = (
|
||||
output_size,
|
||||
linear_units,
|
||||
positionwise_conv_kernel_size,
|
||||
dropout_rate,
|
||||
)
|
||||
elif positionwise_layer_type == "conv1d-linear":
|
||||
positionwise_layer = Conv1dLinear
|
||||
positionwise_layer_args = (
|
||||
output_size,
|
||||
linear_units,
|
||||
positionwise_conv_kernel_size,
|
||||
dropout_rate,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError("Support only linear or conv1d.")
|
||||
|
||||
encoder_selfattn_layer = MultiHeadedAttentionSANM
|
||||
encoder_selfattn_layer_args0 = (
|
||||
attention_heads,
|
||||
input_size,
|
||||
output_size,
|
||||
attention_dropout_rate,
|
||||
kernel_size,
|
||||
sanm_shfit,
|
||||
)
|
||||
|
||||
encoder_selfattn_layer_args = (
|
||||
attention_heads,
|
||||
output_size,
|
||||
output_size,
|
||||
attention_dropout_rate,
|
||||
kernel_size,
|
||||
sanm_shfit,
|
||||
)
|
||||
self.encoders0 = repeat(
|
||||
1,
|
||||
lambda lnum: EncoderLayerSANM(
|
||||
input_size,
|
||||
output_size,
|
||||
encoder_selfattn_layer(*encoder_selfattn_layer_args0),
|
||||
positionwise_layer(*positionwise_layer_args),
|
||||
dropout_rate,
|
||||
normalize_before,
|
||||
concat_after,
|
||||
),
|
||||
)
|
||||
|
||||
self.encoders = repeat(
|
||||
num_blocks - 1,
|
||||
lambda lnum: EncoderLayerSANM(
|
||||
output_size,
|
||||
output_size,
|
||||
encoder_selfattn_layer(*encoder_selfattn_layer_args),
|
||||
positionwise_layer(*positionwise_layer_args),
|
||||
dropout_rate,
|
||||
normalize_before,
|
||||
concat_after,
|
||||
),
|
||||
)
|
||||
if self.normalize_before:
|
||||
self.after_norm = LayerNorm(output_size)
|
||||
|
||||
self.interctc_layer_idx = interctc_layer_idx
|
||||
if len(interctc_layer_idx) > 0:
|
||||
assert 0 < min(interctc_layer_idx) and max(interctc_layer_idx) < num_blocks
|
||||
self.interctc_use_conditioning = interctc_use_conditioning
|
||||
self.conditioning_layer = None
|
||||
shfit_fsmn = (kernel_size - 1) // 2
|
||||
self.overlap_chunk_cls = overlap_chunk(
|
||||
chunk_size=chunk_size,
|
||||
stride=stride,
|
||||
pad_left=pad_left,
|
||||
shfit_fsmn=shfit_fsmn,
|
||||
encoder_att_look_back_factor=encoder_att_look_back_factor,
|
||||
decoder_att_look_back_factor=decoder_att_look_back_factor,
|
||||
)
|
||||
self.time_reduction_factor = time_reduction_factor
|
||||
self.tf2torch_tensor_name_prefix_torch = tf2torch_tensor_name_prefix_torch
|
||||
self.tf2torch_tensor_name_prefix_tf = tf2torch_tensor_name_prefix_tf
|
||||
|
||||
def forward(
|
||||
self,
|
||||
xs_pad: torch.Tensor,
|
||||
ilens: torch.Tensor,
|
||||
prev_states: torch.Tensor = None,
|
||||
ctc: CTC = None,
|
||||
ind: int = 0,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
||||
"""Embed positions in tensor.
|
||||
Args:
|
||||
xs_pad: input tensor (B, L, D)
|
||||
ilens: input length (B)
|
||||
prev_states: Not to be used now.
|
||||
Returns:
|
||||
position embedded tensor and mask
|
||||
"""
|
||||
masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
|
||||
xs_pad *= self.output_size ** 0.5
|
||||
if self.embed is None:
|
||||
xs_pad = xs_pad
|
||||
elif (
|
||||
isinstance(self.embed, Conv2dSubsampling)
|
||||
or isinstance(self.embed, Conv2dSubsampling2)
|
||||
or isinstance(self.embed, Conv2dSubsampling6)
|
||||
or isinstance(self.embed, Conv2dSubsampling8)
|
||||
):
|
||||
short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
|
||||
if short_status:
|
||||
raise TooShortUttError(
|
||||
f"has {xs_pad.size(1)} frames and is too short for subsampling "
|
||||
+ f"(it needs more than {limit_size} frames), return empty results",
|
||||
xs_pad.size(1),
|
||||
limit_size,
|
||||
)
|
||||
xs_pad, masks = self.embed(xs_pad, masks)
|
||||
else:
|
||||
xs_pad = self.embed(xs_pad)
|
||||
|
||||
mask_shfit_chunk, mask_att_chunk_encoder = None, None
|
||||
if self.overlap_chunk_cls is not None:
|
||||
ilens = masks.squeeze(1).sum(1)
|
||||
chunk_outs = self.overlap_chunk_cls.gen_chunk_mask(ilens, ind)
|
||||
xs_pad, ilens = self.overlap_chunk_cls.split_chunk(xs_pad, ilens, chunk_outs=chunk_outs)
|
||||
masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
|
||||
mask_shfit_chunk = self.overlap_chunk_cls.get_mask_shfit_chunk(chunk_outs, xs_pad.device, xs_pad.size(0),
|
||||
dtype=xs_pad.dtype)
|
||||
mask_att_chunk_encoder = self.overlap_chunk_cls.get_mask_att_chunk_encoder(chunk_outs, xs_pad.device,
|
||||
xs_pad.size(0),
|
||||
dtype=xs_pad.dtype)
|
||||
|
||||
encoder_outs = self.encoders0(xs_pad, masks, None, mask_shfit_chunk, mask_att_chunk_encoder)
|
||||
xs_pad, masks = encoder_outs[0], encoder_outs[1]
|
||||
intermediate_outs = []
|
||||
if len(self.interctc_layer_idx) == 0:
|
||||
encoder_outs = self.encoders(xs_pad, masks, None, mask_shfit_chunk, mask_att_chunk_encoder)
|
||||
xs_pad, masks = encoder_outs[0], encoder_outs[1]
|
||||
else:
|
||||
for layer_idx, encoder_layer in enumerate(self.encoders):
|
||||
encoder_outs = encoder_layer(xs_pad, masks, None, mask_shfit_chunk, mask_att_chunk_encoder)
|
||||
xs_pad, masks = encoder_outs[0], encoder_outs[1]
|
||||
if layer_idx + 1 in self.interctc_layer_idx:
|
||||
encoder_out = xs_pad
|
||||
|
||||
# intermediate outputs are also normalized
|
||||
if self.normalize_before:
|
||||
encoder_out = self.after_norm(encoder_out)
|
||||
|
||||
intermediate_outs.append((layer_idx + 1, encoder_out))
|
||||
|
||||
if self.interctc_use_conditioning:
|
||||
ctc_out = ctc.softmax(encoder_out)
|
||||
xs_pad = xs_pad + self.conditioning_layer(ctc_out)
|
||||
|
||||
if self.normalize_before:
|
||||
xs_pad = self.after_norm(xs_pad)
|
||||
|
||||
olens = masks.squeeze(1).sum(1)
|
||||
|
||||
xs_pad, olens = self.overlap_chunk_cls.remove_chunk(xs_pad, olens, chunk_outs=None)
|
||||
|
||||
if self.time_reduction_factor > 1:
|
||||
xs_pad = xs_pad[:,::self.time_reduction_factor,:]
|
||||
olens = torch.floor_divide(olens-1, self.time_reduction_factor) + 1
|
||||
|
||||
if len(intermediate_outs) > 0:
|
||||
return (xs_pad, intermediate_outs), olens, None
|
||||
return xs_pad, olens
|
||||
|
||||
def gen_tf2torch_map_dict(self):
|
||||
tensor_name_prefix_torch = self.tf2torch_tensor_name_prefix_torch
|
||||
tensor_name_prefix_tf = self.tf2torch_tensor_name_prefix_tf
|
||||
map_dict_local = {
|
||||
## encoder
|
||||
# cicd
|
||||
"{}.encoders.layeridx.norm1.weight".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/layer_layeridx/multi_head/LayerNorm/gamma".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
}, # (256,),(256,)
|
||||
"{}.encoders.layeridx.norm1.bias".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/layer_layeridx/multi_head/LayerNorm/beta".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
}, # (256,),(256,)
|
||||
"{}.encoders.layeridx.self_attn.linear_q_k_v.weight".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/layer_layeridx/multi_head/conv1d/kernel".format(tensor_name_prefix_tf),
|
||||
"squeeze": 0,
|
||||
"transpose": (1, 0),
|
||||
}, # (768,256),(1,256,768)
|
||||
"{}.encoders.layeridx.self_attn.linear_q_k_v.bias".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/layer_layeridx/multi_head/conv1d/bias".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
}, # (768,),(768,)
|
||||
"{}.encoders.layeridx.self_attn.fsmn_block.weight".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/layer_layeridx/multi_head/depth_conv_w".format(tensor_name_prefix_tf),
|
||||
"squeeze": 0,
|
||||
"transpose": (1, 2, 0),
|
||||
}, # (256,1,31),(1,31,256,1)
|
||||
"{}.encoders.layeridx.self_attn.linear_out.weight".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/layer_layeridx/multi_head/conv1d_1/kernel".format(tensor_name_prefix_tf),
|
||||
"squeeze": 0,
|
||||
"transpose": (1, 0),
|
||||
}, # (256,256),(1,256,256)
|
||||
"{}.encoders.layeridx.self_attn.linear_out.bias".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/layer_layeridx/multi_head/conv1d_1/bias".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
}, # (256,),(256,)
|
||||
# ffn
|
||||
"{}.encoders.layeridx.norm2.weight".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/layer_layeridx/ffn/LayerNorm/gamma".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
}, # (256,),(256,)
|
||||
"{}.encoders.layeridx.norm2.bias".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/layer_layeridx/ffn/LayerNorm/beta".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
}, # (256,),(256,)
|
||||
"{}.encoders.layeridx.feed_forward.w_1.weight".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/layer_layeridx/ffn/conv1d/kernel".format(tensor_name_prefix_tf),
|
||||
"squeeze": 0,
|
||||
"transpose": (1, 0),
|
||||
}, # (1024,256),(1,256,1024)
|
||||
"{}.encoders.layeridx.feed_forward.w_1.bias".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/layer_layeridx/ffn/conv1d/bias".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
}, # (1024,),(1024,)
|
||||
"{}.encoders.layeridx.feed_forward.w_2.weight".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/layer_layeridx/ffn/conv1d_1/kernel".format(tensor_name_prefix_tf),
|
||||
"squeeze": 0,
|
||||
"transpose": (1, 0),
|
||||
}, # (256,1024),(1,1024,256)
|
||||
"{}.encoders.layeridx.feed_forward.w_2.bias".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/layer_layeridx/ffn/conv1d_1/bias".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
}, # (256,),(256,)
|
||||
# out norm
|
||||
"{}.after_norm.weight".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/LayerNorm/gamma".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
}, # (256,),(256,)
|
||||
"{}.after_norm.bias".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/LayerNorm/beta".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
}, # (256,),(256,)
|
||||
|
||||
}
|
||||
|
||||
return map_dict_local
|
||||
|
||||
def convert_tf2torch(self,
|
||||
var_dict_tf,
|
||||
var_dict_torch,
|
||||
):
|
||||
|
||||
map_dict = self.gen_tf2torch_map_dict()
|
||||
|
||||
var_dict_torch_update = dict()
|
||||
for name in sorted(var_dict_torch.keys(), reverse=False):
|
||||
names = name.split('.')
|
||||
if names[0] == self.tf2torch_tensor_name_prefix_torch:
|
||||
if names[1] == "encoders0":
|
||||
layeridx = int(names[2])
|
||||
name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
|
||||
|
||||
name_q = name_q.replace("encoders0", "encoders")
|
||||
layeridx_bias = 0
|
||||
layeridx += layeridx_bias
|
||||
if name_q in map_dict.keys():
|
||||
name_v = map_dict[name_q]["name"]
|
||||
name_tf = name_v.replace("layeridx", "{}".format(layeridx))
|
||||
data_tf = var_dict_tf[name_tf]
|
||||
if map_dict[name_q]["squeeze"] is not None:
|
||||
data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
|
||||
if map_dict[name_q]["transpose"] is not None:
|
||||
data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
|
||||
data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
|
||||
assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
|
||||
var_dict_torch[
|
||||
name].size(),
|
||||
data_tf.size())
|
||||
var_dict_torch_update[name] = data_tf
|
||||
logging.info(
|
||||
"torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
|
||||
var_dict_tf[name_tf].shape))
|
||||
elif names[1] == "encoders":
|
||||
layeridx = int(names[2])
|
||||
name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
|
||||
layeridx_bias = 1
|
||||
layeridx += layeridx_bias
|
||||
if name_q in map_dict.keys():
|
||||
name_v = map_dict[name_q]["name"]
|
||||
name_tf = name_v.replace("layeridx", "{}".format(layeridx))
|
||||
data_tf = var_dict_tf[name_tf]
|
||||
if map_dict[name_q]["squeeze"] is not None:
|
||||
data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
|
||||
if map_dict[name_q]["transpose"] is not None:
|
||||
data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
|
||||
data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
|
||||
assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
|
||||
var_dict_torch[
|
||||
name].size(),
|
||||
data_tf.size())
|
||||
var_dict_torch_update[name] = data_tf
|
||||
logging.info(
|
||||
"torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
|
||||
var_dict_tf[name_tf].shape))
|
||||
|
||||
elif names[1] == "after_norm":
|
||||
name_tf = map_dict[name]["name"]
|
||||
data_tf = var_dict_tf[name_tf]
|
||||
data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
|
||||
var_dict_torch_update[name] = data_tf
|
||||
logging.info(
|
||||
"torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf,
|
||||
var_dict_tf[name_tf].shape))
|
||||
|
||||
return var_dict_torch_update
|
||||
@ -1,169 +0,0 @@
|
||||
"""Error Calculator module for Transducer."""
|
||||
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from funasr.models_transducer.beam_search_transducer import BeamSearchTransducer
|
||||
from funasr.models_transducer.decoder.abs_decoder import AbsDecoder
|
||||
from funasr.models_transducer.joint_network import JointNetwork
|
||||
|
||||
|
||||
class ErrorCalculator:
|
||||
"""Calculate CER and WER for transducer models.
|
||||
|
||||
Args:
|
||||
decoder: Decoder module.
|
||||
joint_network: Joint Network module.
|
||||
token_list: List of token units.
|
||||
sym_space: Space symbol.
|
||||
sym_blank: Blank symbol.
|
||||
report_cer: Whether to compute CER.
|
||||
report_wer: Whether to compute WER.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
decoder: AbsDecoder,
|
||||
joint_network: JointNetwork,
|
||||
token_list: List[int],
|
||||
sym_space: str,
|
||||
sym_blank: str,
|
||||
report_cer: bool = False,
|
||||
report_wer: bool = False,
|
||||
) -> None:
|
||||
"""Construct an ErrorCalculatorTransducer object."""
|
||||
super().__init__()
|
||||
|
||||
self.beam_search = BeamSearchTransducer(
|
||||
decoder=decoder,
|
||||
joint_network=joint_network,
|
||||
beam_size=1,
|
||||
search_type="default",
|
||||
score_norm=False,
|
||||
)
|
||||
|
||||
self.decoder = decoder
|
||||
|
||||
self.token_list = token_list
|
||||
self.space = sym_space
|
||||
self.blank = sym_blank
|
||||
|
||||
self.report_cer = report_cer
|
||||
self.report_wer = report_wer
|
||||
|
||||
def __call__(
|
||||
self, encoder_out: torch.Tensor, target: torch.Tensor
|
||||
) -> Tuple[Optional[float], Optional[float]]:
|
||||
"""Calculate sentence-level WER or/and CER score for Transducer model.
|
||||
|
||||
Args:
|
||||
encoder_out: Encoder output sequences. (B, T, D_enc)
|
||||
target: Target label ID sequences. (B, L)
|
||||
|
||||
Returns:
|
||||
: Sentence-level CER score.
|
||||
: Sentence-level WER score.
|
||||
|
||||
"""
|
||||
cer, wer = None, None
|
||||
|
||||
batchsize = int(encoder_out.size(0))
|
||||
|
||||
encoder_out = encoder_out.to(next(self.decoder.parameters()).device)
|
||||
|
||||
batch_nbest = [self.beam_search(encoder_out[b]) for b in range(batchsize)]
|
||||
pred = [nbest_hyp[0].yseq[1:] for nbest_hyp in batch_nbest]
|
||||
|
||||
char_pred, char_target = self.convert_to_char(pred, target)
|
||||
|
||||
if self.report_cer:
|
||||
cer = self.calculate_cer(char_pred, char_target)
|
||||
|
||||
if self.report_wer:
|
||||
wer = self.calculate_wer(char_pred, char_target)
|
||||
|
||||
return cer, wer
|
||||
|
||||
def convert_to_char(
|
||||
self, pred: torch.Tensor, target: torch.Tensor
|
||||
) -> Tuple[List, List]:
|
||||
"""Convert label ID sequences to character sequences.
|
||||
|
||||
Args:
|
||||
pred: Prediction label ID sequences. (B, U)
|
||||
target: Target label ID sequences. (B, L)
|
||||
|
||||
Returns:
|
||||
char_pred: Prediction character sequences. (B, ?)
|
||||
char_target: Target character sequences. (B, ?)
|
||||
|
||||
"""
|
||||
char_pred, char_target = [], []
|
||||
|
||||
for i, pred_i in enumerate(pred):
|
||||
char_pred_i = [self.token_list[int(h)] for h in pred_i]
|
||||
char_target_i = [self.token_list[int(r)] for r in target[i]]
|
||||
|
||||
char_pred_i = "".join(char_pred_i).replace(self.space, " ")
|
||||
char_pred_i = char_pred_i.replace(self.blank, "")
|
||||
|
||||
char_target_i = "".join(char_target_i).replace(self.space, " ")
|
||||
char_target_i = char_target_i.replace(self.blank, "")
|
||||
|
||||
char_pred.append(char_pred_i)
|
||||
char_target.append(char_target_i)
|
||||
|
||||
return char_pred, char_target
|
||||
|
||||
def calculate_cer(
|
||||
self, char_pred: torch.Tensor, char_target: torch.Tensor
|
||||
) -> float:
|
||||
"""Calculate sentence-level CER score.
|
||||
|
||||
Args:
|
||||
char_pred: Prediction character sequences. (B, ?)
|
||||
char_target: Target character sequences. (B, ?)
|
||||
|
||||
Returns:
|
||||
: Average sentence-level CER score.
|
||||
|
||||
"""
|
||||
import editdistance
|
||||
|
||||
distances, lens = [], []
|
||||
|
||||
for i, char_pred_i in enumerate(char_pred):
|
||||
pred = char_pred_i.replace(" ", "")
|
||||
target = char_target[i].replace(" ", "")
|
||||
distances.append(editdistance.eval(pred, target))
|
||||
lens.append(len(target))
|
||||
|
||||
return float(sum(distances)) / sum(lens)
|
||||
|
||||
def calculate_wer(
|
||||
self, char_pred: torch.Tensor, char_target: torch.Tensor
|
||||
) -> float:
|
||||
"""Calculate sentence-level WER score.
|
||||
|
||||
Args:
|
||||
char_pred: Prediction character sequences. (B, ?)
|
||||
char_target: Target character sequences. (B, ?)
|
||||
|
||||
Returns:
|
||||
: Average sentence-level WER score
|
||||
|
||||
"""
|
||||
import editdistance
|
||||
|
||||
distances, lens = [], []
|
||||
|
||||
for i, char_pred_i in enumerate(char_pred):
|
||||
pred = char_pred_i.replace("▁", " ").split()
|
||||
target = char_target[i].replace("▁", " ").split()
|
||||
|
||||
distances.append(editdistance.eval(pred, target))
|
||||
lens.append(len(target))
|
||||
|
||||
return float(sum(distances)) / sum(lens)
|
||||
@ -1,485 +0,0 @@
|
||||
"""ESPnet2 ASR Transducer model."""
|
||||
|
||||
import logging
|
||||
from contextlib import contextmanager
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from packaging.version import parse as V
|
||||
from typeguard import check_argument_types
|
||||
|
||||
from funasr.models.frontend.abs_frontend import AbsFrontend
|
||||
from funasr.models.specaug.abs_specaug import AbsSpecAug
|
||||
from funasr.models_transducer.decoder.abs_decoder import AbsDecoder
|
||||
from funasr.models.decoder.abs_decoder import AbsDecoder as AbsAttDecoder
|
||||
from funasr.models_transducer.encoder.encoder import Encoder
|
||||
from funasr.models_transducer.joint_network import JointNetwork
|
||||
from funasr.models_transducer.utils import get_transducer_task_io
|
||||
from funasr.layers.abs_normalize import AbsNormalize
|
||||
from funasr.torch_utils.device_funcs import force_gatherable
|
||||
from funasr.train.abs_espnet_model import AbsESPnetModel
|
||||
|
||||
if V(torch.__version__) >= V("1.6.0"):
|
||||
from torch.cuda.amp import autocast
|
||||
else:
|
||||
|
||||
@contextmanager
|
||||
def autocast(enabled=True):
|
||||
yield
|
||||
|
||||
|
||||
class UniASRTransducerModel(AbsESPnetModel):
|
||||
"""ESPnet2ASRTransducerModel module definition.
|
||||
|
||||
Args:
|
||||
vocab_size: Size of complete vocabulary (w/ EOS and blank included).
|
||||
token_list: List of token
|
||||
frontend: Frontend module.
|
||||
specaug: SpecAugment module.
|
||||
normalize: Normalization module.
|
||||
encoder: Encoder module.
|
||||
decoder: Decoder module.
|
||||
joint_network: Joint Network module.
|
||||
transducer_weight: Weight of the Transducer loss.
|
||||
fastemit_lambda: FastEmit lambda value.
|
||||
auxiliary_ctc_weight: Weight of auxiliary CTC loss.
|
||||
auxiliary_ctc_dropout_rate: Dropout rate for auxiliary CTC loss inputs.
|
||||
auxiliary_lm_loss_weight: Weight of auxiliary LM loss.
|
||||
auxiliary_lm_loss_smoothing: Smoothing rate for LM loss' label smoothing.
|
||||
ignore_id: Initial padding ID.
|
||||
sym_space: Space symbol.
|
||||
sym_blank: Blank Symbol
|
||||
report_cer: Whether to report Character Error Rate during validation.
|
||||
report_wer: Whether to report Word Error Rate during validation.
|
||||
extract_feats_in_collect_stats: Whether to use extract_feats stats collection.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size: int,
|
||||
token_list: Union[Tuple[str, ...], List[str]],
|
||||
frontend: Optional[AbsFrontend],
|
||||
specaug: Optional[AbsSpecAug],
|
||||
normalize: Optional[AbsNormalize],
|
||||
encoder,
|
||||
decoder: AbsDecoder,
|
||||
att_decoder: Optional[AbsAttDecoder],
|
||||
joint_network: JointNetwork,
|
||||
transducer_weight: float = 1.0,
|
||||
fastemit_lambda: float = 0.0,
|
||||
auxiliary_ctc_weight: float = 0.0,
|
||||
auxiliary_ctc_dropout_rate: float = 0.0,
|
||||
auxiliary_lm_loss_weight: float = 0.0,
|
||||
auxiliary_lm_loss_smoothing: float = 0.0,
|
||||
ignore_id: int = -1,
|
||||
sym_space: str = "<space>",
|
||||
sym_blank: str = "<blank>",
|
||||
report_cer: bool = True,
|
||||
report_wer: bool = True,
|
||||
extract_feats_in_collect_stats: bool = True,
|
||||
) -> None:
|
||||
"""Construct an ESPnetASRTransducerModel object."""
|
||||
super().__init__()
|
||||
|
||||
assert check_argument_types()
|
||||
|
||||
# The following labels ID are reserved: 0 (blank) and vocab_size - 1 (sos/eos)
|
||||
self.blank_id = 0
|
||||
self.vocab_size = vocab_size
|
||||
self.ignore_id = ignore_id
|
||||
self.token_list = token_list.copy()
|
||||
|
||||
self.sym_space = sym_space
|
||||
self.sym_blank = sym_blank
|
||||
|
||||
self.frontend = frontend
|
||||
self.specaug = specaug
|
||||
self.normalize = normalize
|
||||
|
||||
self.encoder = encoder
|
||||
self.decoder = decoder
|
||||
self.joint_network = joint_network
|
||||
|
||||
self.criterion_transducer = None
|
||||
self.error_calculator = None
|
||||
|
||||
self.use_auxiliary_ctc = auxiliary_ctc_weight > 0
|
||||
self.use_auxiliary_lm_loss = auxiliary_lm_loss_weight > 0
|
||||
|
||||
if self.use_auxiliary_ctc:
|
||||
self.ctc_lin = torch.nn.Linear(encoder.output_size, vocab_size)
|
||||
self.ctc_dropout_rate = auxiliary_ctc_dropout_rate
|
||||
|
||||
if self.use_auxiliary_lm_loss:
|
||||
self.lm_lin = torch.nn.Linear(decoder.output_size, vocab_size)
|
||||
self.lm_loss_smoothing = auxiliary_lm_loss_smoothing
|
||||
|
||||
self.transducer_weight = transducer_weight
|
||||
self.fastemit_lambda = fastemit_lambda
|
||||
|
||||
self.auxiliary_ctc_weight = auxiliary_ctc_weight
|
||||
self.auxiliary_lm_loss_weight = auxiliary_lm_loss_weight
|
||||
|
||||
self.report_cer = report_cer
|
||||
self.report_wer = report_wer
|
||||
|
||||
self.extract_feats_in_collect_stats = extract_feats_in_collect_stats
|
||||
|
||||
def forward(
|
||||
self,
|
||||
speech: torch.Tensor,
|
||||
speech_lengths: torch.Tensor,
|
||||
text: torch.Tensor,
|
||||
text_lengths: torch.Tensor,
|
||||
decoding_ind: int = None,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
|
||||
"""Forward architecture and compute loss(es).
|
||||
|
||||
Args:
|
||||
speech: Speech sequences. (B, S)
|
||||
speech_lengths: Speech sequences lengths. (B,)
|
||||
text: Label ID sequences. (B, L)
|
||||
text_lengths: Label ID sequences lengths. (B,)
|
||||
kwargs: Contains "utts_id".
|
||||
|
||||
Return:
|
||||
loss: Main loss value.
|
||||
stats: Task statistics.
|
||||
weight: Task weights.
|
||||
|
||||
"""
|
||||
assert text_lengths.dim() == 1, text_lengths.shape
|
||||
assert (
|
||||
speech.shape[0]
|
||||
== speech_lengths.shape[0]
|
||||
== text.shape[0]
|
||||
== text_lengths.shape[0]
|
||||
), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
|
||||
|
||||
batch_size = speech.shape[0]
|
||||
text = text[:, : text_lengths.max()]
|
||||
|
||||
# 1. Encoder
|
||||
ind = self.encoder.overlap_chunk_cls.random_choice(self.training, decoding_ind)
|
||||
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths, ind=ind)
|
||||
# 2. Transducer-related I/O preparation
|
||||
decoder_in, target, t_len, u_len = get_transducer_task_io(
|
||||
text,
|
||||
encoder_out_lens,
|
||||
ignore_id=self.ignore_id,
|
||||
)
|
||||
|
||||
# 3. Decoder
|
||||
self.decoder.set_device(encoder_out.device)
|
||||
decoder_out = self.decoder(decoder_in, u_len)
|
||||
|
||||
# 4. Joint Network
|
||||
joint_out = self.joint_network(
|
||||
encoder_out.unsqueeze(2), decoder_out.unsqueeze(1)
|
||||
)
|
||||
|
||||
# 5. Losses
|
||||
loss_trans, cer_trans, wer_trans = self._calc_transducer_loss(
|
||||
encoder_out,
|
||||
joint_out,
|
||||
target,
|
||||
t_len,
|
||||
u_len,
|
||||
)
|
||||
|
||||
loss_ctc, loss_lm = 0.0, 0.0
|
||||
|
||||
if self.use_auxiliary_ctc:
|
||||
loss_ctc = self._calc_ctc_loss(
|
||||
encoder_out,
|
||||
target,
|
||||
t_len,
|
||||
u_len,
|
||||
)
|
||||
|
||||
if self.use_auxiliary_lm_loss:
|
||||
loss_lm = self._calc_lm_loss(decoder_out, target)
|
||||
|
||||
loss = (
|
||||
self.transducer_weight * loss_trans
|
||||
+ self.auxiliary_ctc_weight * loss_ctc
|
||||
+ self.auxiliary_lm_loss_weight * loss_lm
|
||||
)
|
||||
|
||||
stats = dict(
|
||||
loss=loss.detach(),
|
||||
loss_transducer=loss_trans.detach(),
|
||||
aux_ctc_loss=loss_ctc.detach() if loss_ctc > 0.0 else None,
|
||||
aux_lm_loss=loss_lm.detach() if loss_lm > 0.0 else None,
|
||||
cer_transducer=cer_trans,
|
||||
wer_transducer=wer_trans,
|
||||
)
|
||||
|
||||
# force_gatherable: to-device and to-tensor if scalar for DataParallel
|
||||
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
|
||||
|
||||
return loss, stats, weight
|
||||
|
||||
def collect_feats(
|
||||
self,
|
||||
speech: torch.Tensor,
|
||||
speech_lengths: torch.Tensor,
|
||||
text: torch.Tensor,
|
||||
text_lengths: torch.Tensor,
|
||||
**kwargs,
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
"""Collect features sequences and features lengths sequences.
|
||||
|
||||
Args:
|
||||
speech: Speech sequences. (B, S)
|
||||
speech_lengths: Speech sequences lengths. (B,)
|
||||
text: Label ID sequences. (B, L)
|
||||
text_lengths: Label ID sequences lengths. (B,)
|
||||
kwargs: Contains "utts_id".
|
||||
|
||||
Return:
|
||||
{}: "feats": Features sequences. (B, T, D_feats),
|
||||
"feats_lengths": Features sequences lengths. (B,)
|
||||
|
||||
"""
|
||||
if self.extract_feats_in_collect_stats:
|
||||
feats, feats_lengths = self._extract_feats(speech, speech_lengths)
|
||||
else:
|
||||
# Generate dummy stats if extract_feats_in_collect_stats is False
|
||||
logging.warning(
|
||||
"Generating dummy stats for feats and feats_lengths, "
|
||||
"because encoder_conf.extract_feats_in_collect_stats is "
|
||||
f"{self.extract_feats_in_collect_stats}"
|
||||
)
|
||||
|
||||
feats, feats_lengths = speech, speech_lengths
|
||||
|
||||
return {"feats": feats, "feats_lengths": feats_lengths}
|
||||
|
||||
def encode(
|
||||
self,
|
||||
speech: torch.Tensor,
|
||||
speech_lengths: torch.Tensor,
|
||||
ind: int,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Encoder speech sequences.
|
||||
|
||||
Args:
|
||||
speech: Speech sequences. (B, S)
|
||||
speech_lengths: Speech sequences lengths. (B,)
|
||||
|
||||
Return:
|
||||
encoder_out: Encoder outputs. (B, T, D_enc)
|
||||
encoder_out_lens: Encoder outputs lengths. (B,)
|
||||
|
||||
"""
|
||||
with autocast(False):
|
||||
# 1. Extract feats
|
||||
feats, feats_lengths = self._extract_feats(speech, speech_lengths)
|
||||
|
||||
# 2. Data augmentation
|
||||
if self.specaug is not None and self.training:
|
||||
feats, feats_lengths = self.specaug(feats, feats_lengths)
|
||||
|
||||
# 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
|
||||
if self.normalize is not None:
|
||||
feats, feats_lengths = self.normalize(feats, feats_lengths)
|
||||
|
||||
# 4. Forward encoder
|
||||
encoder_out, encoder_out_lens = self.encoder(feats, feats_lengths, ind=ind)
|
||||
|
||||
assert encoder_out.size(0) == speech.size(0), (
|
||||
encoder_out.size(),
|
||||
speech.size(0),
|
||||
)
|
||||
assert encoder_out.size(1) <= encoder_out_lens.max(), (
|
||||
encoder_out.size(),
|
||||
encoder_out_lens.max(),
|
||||
)
|
||||
|
||||
return encoder_out, encoder_out_lens
|
||||
|
||||
def _extract_feats(
|
||||
self, speech: torch.Tensor, speech_lengths: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Extract features sequences and features sequences lengths.
|
||||
|
||||
Args:
|
||||
speech: Speech sequences. (B, S)
|
||||
speech_lengths: Speech sequences lengths. (B,)
|
||||
|
||||
Return:
|
||||
feats: Features sequences. (B, T, D_feats)
|
||||
feats_lengths: Features sequences lengths. (B,)
|
||||
|
||||
"""
|
||||
assert speech_lengths.dim() == 1, speech_lengths.shape
|
||||
|
||||
# for data-parallel
|
||||
speech = speech[:, : speech_lengths.max()]
|
||||
|
||||
if self.frontend is not None:
|
||||
feats, feats_lengths = self.frontend(speech, speech_lengths)
|
||||
else:
|
||||
feats, feats_lengths = speech, speech_lengths
|
||||
|
||||
return feats, feats_lengths
|
||||
|
||||
def _calc_transducer_loss(
|
||||
self,
|
||||
encoder_out: torch.Tensor,
|
||||
joint_out: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
t_len: torch.Tensor,
|
||||
u_len: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, Optional[float], Optional[float]]:
|
||||
"""Compute Transducer loss.
|
||||
|
||||
Args:
|
||||
encoder_out: Encoder output sequences. (B, T, D_enc)
|
||||
joint_out: Joint Network output sequences (B, T, U, D_joint)
|
||||
target: Target label ID sequences. (B, L)
|
||||
t_len: Encoder output sequences lengths. (B,)
|
||||
u_len: Target label ID sequences lengths. (B,)
|
||||
|
||||
Return:
|
||||
loss_transducer: Transducer loss value.
|
||||
cer_transducer: Character error rate for Transducer.
|
||||
wer_transducer: Word Error Rate for Transducer.
|
||||
|
||||
"""
|
||||
if self.criterion_transducer is None:
|
||||
try:
|
||||
# from warprnnt_pytorch import RNNTLoss
|
||||
# self.criterion_transducer = RNNTLoss(
|
||||
# reduction="mean",
|
||||
# fastemit_lambda=self.fastemit_lambda,
|
||||
# )
|
||||
from warp_rnnt import rnnt_loss as RNNTLoss
|
||||
self.criterion_transducer = RNNTLoss
|
||||
|
||||
except ImportError:
|
||||
logging.error(
|
||||
"warp-rnnt was not installed."
|
||||
"Please consult the installation documentation."
|
||||
)
|
||||
exit(1)
|
||||
|
||||
# loss_transducer = self.criterion_transducer(
|
||||
# joint_out,
|
||||
# target,
|
||||
# t_len,
|
||||
# u_len,
|
||||
# )
|
||||
log_probs = torch.log_softmax(joint_out, dim=-1)
|
||||
|
||||
loss_transducer = self.criterion_transducer(
|
||||
log_probs,
|
||||
target,
|
||||
t_len,
|
||||
u_len,
|
||||
reduction="mean",
|
||||
blank=self.blank_id,
|
||||
gather=True,
|
||||
)
|
||||
|
||||
if not self.training and (self.report_cer or self.report_wer):
|
||||
if self.error_calculator is None:
|
||||
from espnet2.asr_transducer.error_calculator import ErrorCalculator
|
||||
|
||||
self.error_calculator = ErrorCalculator(
|
||||
self.decoder,
|
||||
self.joint_network,
|
||||
self.token_list,
|
||||
self.sym_space,
|
||||
self.sym_blank,
|
||||
report_cer=self.report_cer,
|
||||
report_wer=self.report_wer,
|
||||
)
|
||||
|
||||
cer_transducer, wer_transducer = self.error_calculator(encoder_out, target)
|
||||
|
||||
return loss_transducer, cer_transducer, wer_transducer
|
||||
|
||||
return loss_transducer, None, None
|
||||
|
||||
def _calc_ctc_loss(
|
||||
self,
|
||||
encoder_out: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
t_len: torch.Tensor,
|
||||
u_len: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""Compute CTC loss.
|
||||
|
||||
Args:
|
||||
encoder_out: Encoder output sequences. (B, T, D_enc)
|
||||
target: Target label ID sequences. (B, L)
|
||||
t_len: Encoder output sequences lengths. (B,)
|
||||
u_len: Target label ID sequences lengths. (B,)
|
||||
|
||||
Return:
|
||||
loss_ctc: CTC loss value.
|
||||
|
||||
"""
|
||||
ctc_in = self.ctc_lin(
|
||||
torch.nn.functional.dropout(encoder_out, p=self.ctc_dropout_rate)
|
||||
)
|
||||
ctc_in = torch.log_softmax(ctc_in.transpose(0, 1), dim=-1)
|
||||
|
||||
target_mask = target != 0
|
||||
ctc_target = target[target_mask].cpu()
|
||||
|
||||
with torch.backends.cudnn.flags(deterministic=True):
|
||||
loss_ctc = torch.nn.functional.ctc_loss(
|
||||
ctc_in,
|
||||
ctc_target,
|
||||
t_len,
|
||||
u_len,
|
||||
zero_infinity=True,
|
||||
reduction="sum",
|
||||
)
|
||||
loss_ctc /= target.size(0)
|
||||
|
||||
return loss_ctc
|
||||
|
||||
def _calc_lm_loss(
|
||||
self,
|
||||
decoder_out: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""Compute LM loss.
|
||||
|
||||
Args:
|
||||
decoder_out: Decoder output sequences. (B, U, D_dec)
|
||||
target: Target label ID sequences. (B, L)
|
||||
|
||||
Return:
|
||||
loss_lm: LM loss value.
|
||||
|
||||
"""
|
||||
lm_loss_in = self.lm_lin(decoder_out[:, :-1, :]).view(-1, self.vocab_size)
|
||||
lm_target = target.view(-1).type(torch.int64)
|
||||
|
||||
with torch.no_grad():
|
||||
true_dist = lm_loss_in.clone()
|
||||
true_dist.fill_(self.lm_loss_smoothing / (self.vocab_size - 1))
|
||||
|
||||
# Ignore blank ID (0)
|
||||
ignore = lm_target == 0
|
||||
lm_target = lm_target.masked_fill(ignore, 0)
|
||||
|
||||
true_dist.scatter_(1, lm_target.unsqueeze(1), (1 - self.lm_loss_smoothing))
|
||||
|
||||
loss_lm = torch.nn.functional.kl_div(
|
||||
torch.log_softmax(lm_loss_in, dim=1),
|
||||
true_dist,
|
||||
reduction="none",
|
||||
)
|
||||
loss_lm = loss_lm.masked_fill(ignore.unsqueeze(1), 0).sum() / decoder_out.size(
|
||||
0
|
||||
)
|
||||
|
||||
return loss_lm
|
||||
@ -1,200 +0,0 @@
|
||||
"""Utility functions for Transducer models."""
|
||||
|
||||
from typing import List, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class TooShortUttError(Exception):
|
||||
"""Raised when the utt is too short for subsampling.
|
||||
|
||||
Args:
|
||||
message: Error message to display.
|
||||
actual_size: The size that cannot pass the subsampling.
|
||||
limit: The size limit for subsampling.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, message: str, actual_size: int, limit: int) -> None:
|
||||
"""Construct a TooShortUttError module."""
|
||||
super().__init__(message)
|
||||
|
||||
self.actual_size = actual_size
|
||||
self.limit = limit
|
||||
|
||||
|
||||
def check_short_utt(sub_factor: int, size: int) -> Tuple[bool, int]:
|
||||
"""Check if the input is too short for subsampling.
|
||||
|
||||
Args:
|
||||
sub_factor: Subsampling factor for Conv2DSubsampling.
|
||||
size: Input size.
|
||||
|
||||
Returns:
|
||||
: Whether an error should be sent.
|
||||
: Size limit for specified subsampling factor.
|
||||
|
||||
"""
|
||||
if sub_factor == 2 and size < 3:
|
||||
return True, 7
|
||||
elif sub_factor == 4 and size < 7:
|
||||
return True, 7
|
||||
elif sub_factor == 6 and size < 11:
|
||||
return True, 11
|
||||
|
||||
return False, -1
|
||||
|
||||
|
||||
def sub_factor_to_params(sub_factor: int, input_size: int) -> Tuple[int, int, int]:
|
||||
"""Get conv2D second layer parameters for given subsampling factor.
|
||||
|
||||
Args:
|
||||
sub_factor: Subsampling factor (1/X).
|
||||
input_size: Input size.
|
||||
|
||||
Returns:
|
||||
: Kernel size for second convolution.
|
||||
: Stride for second convolution.
|
||||
: Conv2DSubsampling output size.
|
||||
|
||||
"""
|
||||
if sub_factor == 2:
|
||||
return 3, 1, (((input_size - 1) // 2 - 2))
|
||||
elif sub_factor == 4:
|
||||
return 3, 2, (((input_size - 1) // 2 - 1) // 2)
|
||||
elif sub_factor == 6:
|
||||
return 5, 3, (((input_size - 1) // 2 - 2) // 3)
|
||||
else:
|
||||
raise ValueError(
|
||||
"subsampling_factor parameter should be set to either 2, 4 or 6."
|
||||
)
|
||||
|
||||
|
||||
def make_chunk_mask(
|
||||
size: int,
|
||||
chunk_size: int,
|
||||
left_chunk_size: int = 0,
|
||||
device: torch.device = None,
|
||||
) -> torch.Tensor:
|
||||
"""Create chunk mask for the subsequent steps (size, size).
|
||||
|
||||
Reference: https://github.com/k2-fsa/icefall/blob/master/icefall/utils.py
|
||||
|
||||
Args:
|
||||
size: Size of the source mask.
|
||||
chunk_size: Number of frames in chunk.
|
||||
left_chunk_size: Size of the left context in chunks (0 means full context).
|
||||
device: Device for the mask tensor.
|
||||
|
||||
Returns:
|
||||
mask: Chunk mask. (size, size)
|
||||
|
||||
"""
|
||||
mask = torch.zeros(size, size, device=device, dtype=torch.bool)
|
||||
|
||||
for i in range(size):
|
||||
if left_chunk_size <= 0:
|
||||
start = 0
|
||||
else:
|
||||
start = max((i // chunk_size - left_chunk_size) * chunk_size, 0)
|
||||
|
||||
end = min((i // chunk_size + 1) * chunk_size, size)
|
||||
mask[i, start:end] = True
|
||||
|
||||
return ~mask
|
||||
|
||||
|
||||
def make_source_mask(lengths: torch.Tensor) -> torch.Tensor:
|
||||
"""Create source mask for given lengths.
|
||||
|
||||
Reference: https://github.com/k2-fsa/icefall/blob/master/icefall/utils.py
|
||||
|
||||
Args:
|
||||
lengths: Sequence lengths. (B,)
|
||||
|
||||
Returns:
|
||||
: Mask for the sequence lengths. (B, max_len)
|
||||
|
||||
"""
|
||||
max_len = lengths.max()
|
||||
batch_size = lengths.size(0)
|
||||
|
||||
expanded_lengths = torch.arange(max_len).expand(batch_size, max_len).to(lengths)
|
||||
|
||||
return expanded_lengths >= lengths.unsqueeze(1)
|
||||
|
||||
|
||||
def get_transducer_task_io(
|
||||
labels: torch.Tensor,
|
||||
encoder_out_lens: torch.Tensor,
|
||||
ignore_id: int = -1,
|
||||
blank_id: int = 0,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""Get Transducer loss I/O.
|
||||
|
||||
Args:
|
||||
labels: Label ID sequences. (B, L)
|
||||
encoder_out_lens: Encoder output lengths. (B,)
|
||||
ignore_id: Padding symbol ID.
|
||||
blank_id: Blank symbol ID.
|
||||
|
||||
Returns:
|
||||
decoder_in: Decoder inputs. (B, U)
|
||||
target: Target label ID sequences. (B, U)
|
||||
t_len: Time lengths. (B,)
|
||||
u_len: Label lengths. (B,)
|
||||
|
||||
"""
|
||||
|
||||
def pad_list(labels: List[torch.Tensor], padding_value: int = 0):
|
||||
"""Create padded batch of labels from a list of labels sequences.
|
||||
|
||||
Args:
|
||||
labels: Labels sequences. [B x (?)]
|
||||
padding_value: Padding value.
|
||||
|
||||
Returns:
|
||||
labels: Batch of padded labels sequences. (B,)
|
||||
|
||||
"""
|
||||
batch_size = len(labels)
|
||||
|
||||
padded = (
|
||||
labels[0]
|
||||
.new(batch_size, max(x.size(0) for x in labels), *labels[0].size()[1:])
|
||||
.fill_(padding_value)
|
||||
)
|
||||
|
||||
for i in range(batch_size):
|
||||
padded[i, : labels[i].size(0)] = labels[i]
|
||||
|
||||
return padded
|
||||
|
||||
device = labels.device
|
||||
|
||||
labels_unpad = [y[y != ignore_id] for y in labels]
|
||||
blank = labels[0].new([blank_id])
|
||||
|
||||
decoder_in = pad_list(
|
||||
[torch.cat([blank, label], dim=0) for label in labels_unpad], blank_id
|
||||
).to(device)
|
||||
|
||||
target = pad_list(labels_unpad, blank_id).type(torch.int32).to(device)
|
||||
|
||||
encoder_out_lens = list(map(int, encoder_out_lens))
|
||||
t_len = torch.IntTensor(encoder_out_lens).to(device)
|
||||
|
||||
u_len = torch.IntTensor([y.size(0) for y in labels_unpad]).to(device)
|
||||
|
||||
return decoder_in, target, t_len, u_len
|
||||
|
||||
def pad_to_len(t: torch.Tensor, pad_len: int, dim: int):
|
||||
"""Pad the tensor `t` at `dim` to the length `pad_len` with right padding zeros."""
|
||||
if t.size(dim) == pad_len:
|
||||
return t
|
||||
else:
|
||||
pad_size = list(t.shape)
|
||||
pad_size[dim] = pad_len - t.size(dim)
|
||||
return torch.cat(
|
||||
[t, torch.zeros(*pad_size, dtype=t.dtype, device=t.device)], dim=dim
|
||||
)
|
||||
@ -6,8 +6,8 @@ from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from funasr.models_transducer.decoder.abs_decoder import AbsDecoder
|
||||
from funasr.models_transducer.joint_network import JointNetwork
|
||||
from funasr.models.rnnt_decoder.abs_decoder import AbsDecoder
|
||||
from funasr.models.joint_network import JointNetwork
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -6,6 +6,8 @@
|
||||
|
||||
"""Common functions for ASR."""
|
||||
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import json
|
||||
import logging
|
||||
import sys
|
||||
@ -13,7 +15,11 @@ import sys
|
||||
from itertools import groupby
|
||||
import numpy as np
|
||||
import six
|
||||
import torch
|
||||
|
||||
from funasr.modules.beam_search.beam_search_transducer import BeamSearchTransducer
|
||||
from funasr.models.rnnt_decoder.abs_decoder import AbsDecoder
|
||||
from funasr.models.joint_network import JointNetwork
|
||||
|
||||
def end_detect(ended_hyps, i, M=3, D_end=np.log(1 * np.exp(-10))):
|
||||
"""End detection.
|
||||
@ -247,3 +253,148 @@ class ErrorCalculator(object):
|
||||
word_eds.append(editdistance.eval(hyp_words, ref_words))
|
||||
word_ref_lens.append(len(ref_words))
|
||||
return float(sum(word_eds)) / sum(word_ref_lens)
|
||||
|
||||
class ErrorCalculatorTransducer:
|
||||
"""Calculate CER and WER for transducer models.
|
||||
Args:
|
||||
decoder: Decoder module.
|
||||
joint_network: Joint Network module.
|
||||
token_list: List of token units.
|
||||
sym_space: Space symbol.
|
||||
sym_blank: Blank symbol.
|
||||
report_cer: Whether to compute CER.
|
||||
report_wer: Whether to compute WER.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
decoder: AbsDecoder,
|
||||
joint_network: JointNetwork,
|
||||
token_list: List[int],
|
||||
sym_space: str,
|
||||
sym_blank: str,
|
||||
report_cer: bool = False,
|
||||
report_wer: bool = False,
|
||||
) -> None:
|
||||
"""Construct an ErrorCalculatorTransducer object."""
|
||||
super().__init__()
|
||||
|
||||
self.beam_search = BeamSearchTransducer(
|
||||
decoder=decoder,
|
||||
joint_network=joint_network,
|
||||
beam_size=1,
|
||||
search_type="default",
|
||||
score_norm=False,
|
||||
)
|
||||
|
||||
self.decoder = decoder
|
||||
|
||||
self.token_list = token_list
|
||||
self.space = sym_space
|
||||
self.blank = sym_blank
|
||||
|
||||
self.report_cer = report_cer
|
||||
self.report_wer = report_wer
|
||||
|
||||
def __call__(
|
||||
self, encoder_out: torch.Tensor, target: torch.Tensor
|
||||
) -> Tuple[Optional[float], Optional[float]]:
|
||||
"""Calculate sentence-level WER or/and CER score for Transducer model.
|
||||
Args:
|
||||
encoder_out: Encoder output sequences. (B, T, D_enc)
|
||||
target: Target label ID sequences. (B, L)
|
||||
Returns:
|
||||
: Sentence-level CER score.
|
||||
: Sentence-level WER score.
|
||||
"""
|
||||
cer, wer = None, None
|
||||
|
||||
batchsize = int(encoder_out.size(0))
|
||||
|
||||
encoder_out = encoder_out.to(next(self.decoder.parameters()).device)
|
||||
|
||||
batch_nbest = [self.beam_search(encoder_out[b]) for b in range(batchsize)]
|
||||
pred = [nbest_hyp[0].yseq[1:] for nbest_hyp in batch_nbest]
|
||||
|
||||
char_pred, char_target = self.convert_to_char(pred, target)
|
||||
|
||||
if self.report_cer:
|
||||
cer = self.calculate_cer(char_pred, char_target)
|
||||
|
||||
if self.report_wer:
|
||||
wer = self.calculate_wer(char_pred, char_target)
|
||||
|
||||
return cer, wer
|
||||
|
||||
def convert_to_char(
|
||||
self, pred: torch.Tensor, target: torch.Tensor
|
||||
) -> Tuple[List, List]:
|
||||
"""Convert label ID sequences to character sequences.
|
||||
Args:
|
||||
pred: Prediction label ID sequences. (B, U)
|
||||
target: Target label ID sequences. (B, L)
|
||||
Returns:
|
||||
char_pred: Prediction character sequences. (B, ?)
|
||||
char_target: Target character sequences. (B, ?)
|
||||
"""
|
||||
char_pred, char_target = [], []
|
||||
|
||||
for i, pred_i in enumerate(pred):
|
||||
char_pred_i = [self.token_list[int(h)] for h in pred_i]
|
||||
char_target_i = [self.token_list[int(r)] for r in target[i]]
|
||||
|
||||
char_pred_i = "".join(char_pred_i).replace(self.space, " ")
|
||||
char_pred_i = char_pred_i.replace(self.blank, "")
|
||||
|
||||
char_target_i = "".join(char_target_i).replace(self.space, " ")
|
||||
char_target_i = char_target_i.replace(self.blank, "")
|
||||
|
||||
char_pred.append(char_pred_i)
|
||||
char_target.append(char_target_i)
|
||||
|
||||
return char_pred, char_target
|
||||
|
||||
def calculate_cer(
|
||||
self, char_pred: torch.Tensor, char_target: torch.Tensor
|
||||
) -> float:
|
||||
"""Calculate sentence-level CER score.
|
||||
Args:
|
||||
char_pred: Prediction character sequences. (B, ?)
|
||||
char_target: Target character sequences. (B, ?)
|
||||
Returns:
|
||||
: Average sentence-level CER score.
|
||||
"""
|
||||
import editdistance
|
||||
|
||||
distances, lens = [], []
|
||||
|
||||
for i, char_pred_i in enumerate(char_pred):
|
||||
pred = char_pred_i.replace(" ", "")
|
||||
target = char_target[i].replace(" ", "")
|
||||
distances.append(editdistance.eval(pred, target))
|
||||
lens.append(len(target))
|
||||
|
||||
return float(sum(distances)) / sum(lens)
|
||||
|
||||
def calculate_wer(
|
||||
self, char_pred: torch.Tensor, char_target: torch.Tensor
|
||||
) -> float:
|
||||
"""Calculate sentence-level WER score.
|
||||
Args:
|
||||
char_pred: Prediction character sequences. (B, ?)
|
||||
char_target: Target character sequences. (B, ?)
|
||||
Returns:
|
||||
: Average sentence-level WER score
|
||||
"""
|
||||
import editdistance
|
||||
|
||||
distances, lens = [], []
|
||||
|
||||
for i, char_pred_i in enumerate(char_pred):
|
||||
pred = char_pred_i.replace("▁", " ").split()
|
||||
target = char_target[i].replace("▁", " ").split()
|
||||
|
||||
distances.append(editdistance.eval(pred, target))
|
||||
lens.append(len(target))
|
||||
|
||||
return float(sum(distances)) / sum(lens)
|
||||
|
||||
@ -3,7 +3,7 @@
|
||||
"""Network related utility tools."""
|
||||
|
||||
import logging
|
||||
from typing import Dict
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -506,3 +506,196 @@ def get_activation(act):
|
||||
}
|
||||
|
||||
return activation_funcs[act]()
|
||||
|
||||
class TooShortUttError(Exception):
|
||||
"""Raised when the utt is too short for subsampling.
|
||||
|
||||
Args:
|
||||
message: Error message to display.
|
||||
actual_size: The size that cannot pass the subsampling.
|
||||
limit: The size limit for subsampling.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, message: str, actual_size: int, limit: int) -> None:
|
||||
"""Construct a TooShortUttError module."""
|
||||
super().__init__(message)
|
||||
|
||||
self.actual_size = actual_size
|
||||
self.limit = limit
|
||||
|
||||
|
||||
def check_short_utt(sub_factor: int, size: int) -> Tuple[bool, int]:
|
||||
"""Check if the input is too short for subsampling.
|
||||
|
||||
Args:
|
||||
sub_factor: Subsampling factor for Conv2DSubsampling.
|
||||
size: Input size.
|
||||
|
||||
Returns:
|
||||
: Whether an error should be sent.
|
||||
: Size limit for specified subsampling factor.
|
||||
|
||||
"""
|
||||
if sub_factor == 2 and size < 3:
|
||||
return True, 7
|
||||
elif sub_factor == 4 and size < 7:
|
||||
return True, 7
|
||||
elif sub_factor == 6 and size < 11:
|
||||
return True, 11
|
||||
|
||||
return False, -1
|
||||
|
||||
|
||||
def sub_factor_to_params(sub_factor: int, input_size: int) -> Tuple[int, int, int]:
|
||||
"""Get conv2D second layer parameters for given subsampling factor.
|
||||
|
||||
Args:
|
||||
sub_factor: Subsampling factor (1/X).
|
||||
input_size: Input size.
|
||||
|
||||
Returns:
|
||||
: Kernel size for second convolution.
|
||||
: Stride for second convolution.
|
||||
: Conv2DSubsampling output size.
|
||||
|
||||
"""
|
||||
if sub_factor == 2:
|
||||
return 3, 1, (((input_size - 1) // 2 - 2))
|
||||
elif sub_factor == 4:
|
||||
return 3, 2, (((input_size - 1) // 2 - 1) // 2)
|
||||
elif sub_factor == 6:
|
||||
return 5, 3, (((input_size - 1) // 2 - 2) // 3)
|
||||
else:
|
||||
raise ValueError(
|
||||
"subsampling_factor parameter should be set to either 2, 4 or 6."
|
||||
)
|
||||
|
||||
|
||||
def make_chunk_mask(
|
||||
size: int,
|
||||
chunk_size: int,
|
||||
left_chunk_size: int = 0,
|
||||
device: torch.device = None,
|
||||
) -> torch.Tensor:
|
||||
"""Create chunk mask for the subsequent steps (size, size).
|
||||
|
||||
Reference: https://github.com/k2-fsa/icefall/blob/master/icefall/utils.py
|
||||
|
||||
Args:
|
||||
size: Size of the source mask.
|
||||
chunk_size: Number of frames in chunk.
|
||||
left_chunk_size: Size of the left context in chunks (0 means full context).
|
||||
device: Device for the mask tensor.
|
||||
|
||||
Returns:
|
||||
mask: Chunk mask. (size, size)
|
||||
|
||||
"""
|
||||
mask = torch.zeros(size, size, device=device, dtype=torch.bool)
|
||||
|
||||
for i in range(size):
|
||||
if left_chunk_size <= 0:
|
||||
start = 0
|
||||
else:
|
||||
start = max((i // chunk_size - left_chunk_size) * chunk_size, 0)
|
||||
|
||||
end = min((i // chunk_size + 1) * chunk_size, size)
|
||||
mask[i, start:end] = True
|
||||
|
||||
return ~mask
|
||||
|
||||
def make_source_mask(lengths: torch.Tensor) -> torch.Tensor:
|
||||
"""Create source mask for given lengths.
|
||||
|
||||
Reference: https://github.com/k2-fsa/icefall/blob/master/icefall/utils.py
|
||||
|
||||
Args:
|
||||
lengths: Sequence lengths. (B,)
|
||||
|
||||
Returns:
|
||||
: Mask for the sequence lengths. (B, max_len)
|
||||
|
||||
"""
|
||||
max_len = lengths.max()
|
||||
batch_size = lengths.size(0)
|
||||
|
||||
expanded_lengths = torch.arange(max_len).expand(batch_size, max_len).to(lengths)
|
||||
|
||||
return expanded_lengths >= lengths.unsqueeze(1)
|
||||
|
||||
|
||||
def get_transducer_task_io(
|
||||
labels: torch.Tensor,
|
||||
encoder_out_lens: torch.Tensor,
|
||||
ignore_id: int = -1,
|
||||
blank_id: int = 0,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""Get Transducer loss I/O.
|
||||
|
||||
Args:
|
||||
labels: Label ID sequences. (B, L)
|
||||
encoder_out_lens: Encoder output lengths. (B,)
|
||||
ignore_id: Padding symbol ID.
|
||||
blank_id: Blank symbol ID.
|
||||
|
||||
Returns:
|
||||
decoder_in: Decoder inputs. (B, U)
|
||||
target: Target label ID sequences. (B, U)
|
||||
t_len: Time lengths. (B,)
|
||||
u_len: Label lengths. (B,)
|
||||
|
||||
"""
|
||||
|
||||
def pad_list(labels: List[torch.Tensor], padding_value: int = 0):
|
||||
"""Create padded batch of labels from a list of labels sequences.
|
||||
|
||||
Args:
|
||||
labels: Labels sequences. [B x (?)]
|
||||
padding_value: Padding value.
|
||||
|
||||
Returns:
|
||||
labels: Batch of padded labels sequences. (B,)
|
||||
|
||||
"""
|
||||
batch_size = len(labels)
|
||||
|
||||
padded = (
|
||||
labels[0]
|
||||
.new(batch_size, max(x.size(0) for x in labels), *labels[0].size()[1:])
|
||||
.fill_(padding_value)
|
||||
)
|
||||
|
||||
for i in range(batch_size):
|
||||
padded[i, : labels[i].size(0)] = labels[i]
|
||||
|
||||
return padded
|
||||
|
||||
device = labels.device
|
||||
|
||||
labels_unpad = [y[y != ignore_id] for y in labels]
|
||||
blank = labels[0].new([blank_id])
|
||||
|
||||
decoder_in = pad_list(
|
||||
[torch.cat([blank, label], dim=0) for label in labels_unpad], blank_id
|
||||
).to(device)
|
||||
|
||||
target = pad_list(labels_unpad, blank_id).type(torch.int32).to(device)
|
||||
|
||||
encoder_out_lens = list(map(int, encoder_out_lens))
|
||||
t_len = torch.IntTensor(encoder_out_lens).to(device)
|
||||
|
||||
u_len = torch.IntTensor([y.size(0) for y in labels_unpad]).to(device)
|
||||
|
||||
return decoder_in, target, t_len, u_len
|
||||
|
||||
def pad_to_len(t: torch.Tensor, pad_len: int, dim: int):
|
||||
"""Pad the tensor `t` at `dim` to the length `pad_len` with right padding zeros."""
|
||||
if t.size(dim) == pad_len:
|
||||
return t
|
||||
else:
|
||||
pad_size = list(t.shape)
|
||||
pad_size[dim] = pad_len - t.size(dim)
|
||||
return torch.cat(
|
||||
[t, torch.zeros(*pad_size, dtype=t.dtype, device=t.device)], dim=dim
|
||||
)
|
||||
|
||||
@ -21,15 +21,13 @@ from funasr.models.decoder.transformer_decoder import (
|
||||
LightweightConvolutionTransformerDecoder,
|
||||
TransformerDecoder,
|
||||
)
|
||||
from funasr.models_transducer.decoder.abs_decoder import AbsDecoder
|
||||
from funasr.models_transducer.decoder.rnn_decoder import RNNDecoder
|
||||
from funasr.models_transducer.decoder.stateless_decoder import StatelessDecoder
|
||||
from funasr.models_transducer.encoder.encoder import Encoder
|
||||
from funasr.models_transducer.encoder.sanm_encoder import SANMEncoderChunkOpt
|
||||
from funasr.models_transducer.espnet_transducer_model import ESPnetASRTransducerModel
|
||||
from funasr.models_transducer.espnet_transducer_model_unified import ESPnetASRUnifiedTransducerModel
|
||||
from funasr.models_transducer.espnet_transducer_model_uni_asr import UniASRTransducerModel
|
||||
from funasr.models_transducer.joint_network import JointNetwork
|
||||
from funasr.models.rnnt_decoder.abs_decoder import AbsDecoder
|
||||
from funasr.models.rnnt_decoder.rnn_decoder import RNNDecoder
|
||||
from funasr.models.rnnt_decoder.stateless_decoder import StatelessDecoder
|
||||
from funasr.models.encoder.chunk_encoder import ChunkEncoder as Encoder
|
||||
from funasr.models.e2e_transducer import TransducerModel
|
||||
from funasr.models.e2e_transducer_unified import UnifiedTransducerModel
|
||||
from funasr.models.joint_network import JointNetwork
|
||||
from funasr.layers.abs_normalize import AbsNormalize
|
||||
from funasr.layers.global_mvn import GlobalMVN
|
||||
from funasr.layers.utterance_mvn import UtteranceMVN
|
||||
@ -75,7 +73,6 @@ encoder_choices = ClassChoices(
|
||||
"encoder",
|
||||
classes=dict(
|
||||
encoder=Encoder,
|
||||
sanm_chunk_opt=SANMEncoderChunkOpt,
|
||||
),
|
||||
default="encoder",
|
||||
)
|
||||
@ -158,7 +155,7 @@ class ASRTransducerTask(AbsTask):
|
||||
group.add_argument(
|
||||
"--model_conf",
|
||||
action=NestedDictAction,
|
||||
default=get_default_kwargs(ESPnetASRTransducerModel),
|
||||
default=get_default_kwargs(TransducerModel),
|
||||
help="The keyword arguments for the model class.",
|
||||
)
|
||||
# group.add_argument(
|
||||
@ -354,7 +351,7 @@ class ASRTransducerTask(AbsTask):
|
||||
return retval
|
||||
|
||||
@classmethod
|
||||
def build_model(cls, args: argparse.Namespace) -> ESPnetASRTransducerModel:
|
||||
def build_model(cls, args: argparse.Namespace) -> TransducerModel:
|
||||
"""Required data depending on task mode.
|
||||
Args:
|
||||
cls: ASRTransducerTask object.
|
||||
@ -440,22 +437,8 @@ class ASRTransducerTask(AbsTask):
|
||||
|
||||
# 7. Build model
|
||||
|
||||
if getattr(args, "encoder", None) is not None and args.encoder == 'sanm_chunk_opt':
|
||||
model = UniASRTransducerModel(
|
||||
vocab_size=vocab_size,
|
||||
token_list=token_list,
|
||||
frontend=frontend,
|
||||
specaug=specaug,
|
||||
normalize=normalize,
|
||||
encoder=encoder,
|
||||
decoder=decoder,
|
||||
att_decoder=att_decoder,
|
||||
joint_network=joint_network,
|
||||
**args.model_conf,
|
||||
)
|
||||
|
||||
elif encoder.unified_model_training:
|
||||
model = ESPnetASRUnifiedTransducerModel(
|
||||
if encoder.unified_model_training:
|
||||
model = UnifiedTransducerModel(
|
||||
vocab_size=vocab_size,
|
||||
token_list=token_list,
|
||||
frontend=frontend,
|
||||
@ -469,7 +452,7 @@ class ASRTransducerTask(AbsTask):
|
||||
)
|
||||
|
||||
else:
|
||||
model = ESPnetASRTransducerModel(
|
||||
model = TransducerModel(
|
||||
vocab_size=vocab_size,
|
||||
token_list=token_list,
|
||||
frontend=frontend,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user