diff --git a/funasr/bin/asr_inference_paraformer.py b/funasr/bin/asr_inference_paraformer.py index 3769b6cc5..709c5bfb4 100644 --- a/funasr/bin/asr_inference_paraformer.py +++ b/funasr/bin/asr_inference_paraformer.py @@ -3,6 +3,9 @@ import argparse import logging import sys import time +import copy +import os +import codecs from pathlib import Path from typing import Optional from typing import Sequence @@ -35,6 +38,8 @@ from funasr.utils.types import str2triple_str from funasr.utils.types import str_or_none from funasr.utils import asr_utils, wav_utils, postprocess_utils from funasr.models.frontend.wav_frontend import WavFrontend +from funasr.models.e2e_asr_paraformer import BiCifParaformer, ContextualParaformer + header_colors = '\033[95m' end_colors = '\033[0m' @@ -78,6 +83,7 @@ class Speech2Text: penalty: float = 0.0, nbest: int = 1, frontend_conf: dict = None, + hotword_list_or_file: str = None, **kwargs, ): assert check_argument_types() @@ -168,6 +174,34 @@ class Speech2Text: self.asr_train_args = asr_train_args self.converter = converter self.tokenizer = tokenizer + + # 6. [Optional] Build hotword list from file or str + if hotword_list_or_file is None: + self.hotword_list = None + elif os.path.exists(hotword_list_or_file): + self.hotword_list = [] + hotword_str_list = [] + with codecs.open(hotword_list_or_file, 'r') as fin: + for line in fin.readlines(): + hw = line.strip() + hotword_str_list.append(hw) + self.hotword_list.append(self.converter.tokens2ids([i for i in hw])) + self.hotword_list.append([1]) + hotword_str_list.append('') + logging.info("Initialized hotword list from file: {}, hotword list: {}." + .format(hotword_list_or_file, hotword_str_list)) + else: + logging.info("Attempting to parse hotwords as str...") + self.hotword_list = [] + hotword_str_list = [] + for hw in hotword_list_or_file.strip().split(): + hotword_str_list.append(hw) + self.hotword_list.append(self.converter.tokens2ids([i for i in hw])) + self.hotword_list.append([1]) + hotword_str_list.append('') + logging.info("Hotword list: {}.".format(hotword_str_list)) + + is_use_lm = lm_weight != 0.0 and lm_file is not None if (ctc_weight == 0.0 or asr_model.ctc == None) and not is_use_lm: beam_search = None @@ -229,8 +263,14 @@ class Speech2Text: pre_token_length = pre_token_length.round().long() if torch.max(pre_token_length) < 1: return [] - decoder_outs = self.asr_model.cal_decoder_with_predictor(enc, enc_len, pre_acoustic_embeds, pre_token_length) - decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1] + if not isinstance(self.asr_model, ContextualParaformer): + if self.hotword_list: + logging.warning("Hotword is given but asr model is not a ContextualParaformer.") + decoder_outs = self.asr_model.cal_decoder_with_predictor(enc, enc_len, pre_acoustic_embeds, pre_token_length) + decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1] + else: + decoder_outs = self.asr_model.cal_decoder_with_predictor(enc, enc_len, pre_acoustic_embeds, pre_token_length, hw_list=self.hotword_list) + decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1] results = [] b, n, d = decoder_out.size() @@ -388,6 +428,7 @@ def inference_modelscope( format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", ) + hotword_list_or_file = param_dict['hotword'] if ngpu >= 1 and torch.cuda.is_available(): device = "cuda" else: @@ -416,6 +457,7 @@ def inference_modelscope( ngram_weight=ngram_weight, penalty=penalty, nbest=nbest, + hotword_list_or_file=hotword_list_or_file, ) speech2text = Speech2Text(**speech2text_kwargs) @@ -551,7 +593,12 @@ def get_parser(): default=1, help="The number of workers used for DataLoader", ) - + parser.add_argument( + "--hotword", + type=str_or_none, + default=None, + help="hotword file path or hotwords seperated by space" + ) group = parser.add_argument_group("Input data related") group.add_argument( "--data_path_and_name_and_type", @@ -679,8 +726,10 @@ def main(cmd=None): print(get_commandline_args(), file=sys.stderr) parser = get_parser() args = parser.parse_args(cmd) + param_dict = {'hotword': args.hotword} kwargs = vars(args) kwargs.pop("config", None) + kwargs['param_dict'] = param_dict inference(**kwargs) diff --git a/funasr/bin/asr_inference_paraformer_vad_punc.py b/funasr/bin/asr_inference_paraformer_vad_punc.py index 1d09c790a..7d18e0218 100644 --- a/funasr/bin/asr_inference_paraformer_vad_punc.py +++ b/funasr/bin/asr_inference_paraformer_vad_punc.py @@ -14,6 +14,7 @@ from typing import Dict from typing import Any from typing import List import math +import copy import numpy as np import torch from typeguard import check_argument_types @@ -38,8 +39,9 @@ from funasr.utils.types import str_or_none from funasr.utils import asr_utils, wav_utils, postprocess_utils from funasr.models.frontend.wav_frontend import WavFrontend from funasr.tasks.vad import VADTask -from funasr.utils.timestamp_tools import time_stamp_lfr6 +from funasr.utils.timestamp_tools import time_stamp_lfr6, time_stamp_lfr6_pl from funasr.bin.punctuation_infer import Text2Punc +from funasr.models.e2e_asr_paraformer import BiCifParaformer header_colors = '\033[95m' end_colors = '\033[0m' @@ -234,6 +236,10 @@ class Speech2Text: decoder_outs = self.asr_model.cal_decoder_with_predictor(enc, enc_len, pre_acoustic_embeds, pre_token_length) decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1] + if isinstance(self.asr_model, BiCifParaformer): + _, _, us_alphas, us_cif_peak = self.asr_model.calc_predictor_timestamp(enc, enc_len, + pre_token_length) # test no bias cif2 + results = [] b, n, d = decoder_out.size() for i in range(b): @@ -276,9 +282,12 @@ class Speech2Text: else: text = None - time_stamp = time_stamp_lfr6(alphas[i:i+1,], enc_len[i:i+1,], token, begin_time, end_time) - - results.append((text, token, token_int, time_stamp, enc_len_batch_total, lfr_factor)) + if isinstance(self.asr_model, BiCifParaformer): + timestamp = time_stamp_lfr6_pl(us_alphas[i], us_cif_peak[i], copy.copy(token), begin_time, end_time) + results.append((text, token, token_int, timestamp, enc_len_batch_total, lfr_factor)) + else: + time_stamp = time_stamp_lfr6(alphas[i:i + 1, ], enc_len[i:i + 1, ], copy.copy(token), begin_time, end_time) + results.append((text, token, token_int, time_stamp, enc_len_batch_total, lfr_factor)) # assert check_return_type(results) return results diff --git a/funasr/models/decoder/contextual_decoder.py b/funasr/models/decoder/contextual_decoder.py new file mode 100644 index 000000000..32f550a71 --- /dev/null +++ b/funasr/models/decoder/contextual_decoder.py @@ -0,0 +1,776 @@ +from typing import List +from typing import Tuple +import logging +import torch +import torch.nn as nn +import numpy as np + +from funasr.modules.streaming_utils import utils as myutils +from funasr.models.decoder.transformer_decoder import BaseTransformerDecoder +from typeguard import check_argument_types + +from funasr.modules.attention import MultiHeadedAttentionSANMDecoder, MultiHeadedAttentionCrossAtt +from funasr.modules.embedding import PositionalEncoding +from funasr.modules.layer_norm import LayerNorm +from funasr.modules.positionwise_feed_forward import PositionwiseFeedForwardDecoderSANM +from funasr.modules.repeat import repeat +from funasr.models.decoder.sanm_decoder import DecoderLayerSANM, ParaformerSANMDecoder + + +class ContextualDecoderLayer(nn.Module): + def __init__( + self, + size, + self_attn, + src_attn, + feed_forward, + dropout_rate, + normalize_before=True, + concat_after=False, + ): + """Construct an DecoderLayer object.""" + super(ContextualDecoderLayer, self).__init__() + self.size = size + self.self_attn = self_attn + self.src_attn = src_attn + self.feed_forward = feed_forward + self.norm1 = LayerNorm(size) + if self_attn is not None: + self.norm2 = LayerNorm(size) + if src_attn is not None: + self.norm3 = LayerNorm(size) + self.dropout = nn.Dropout(dropout_rate) + self.normalize_before = normalize_before + self.concat_after = concat_after + if self.concat_after: + self.concat_linear1 = nn.Linear(size + size, size) + self.concat_linear2 = nn.Linear(size + size, size) + + def forward(self, tgt, tgt_mask, memory, memory_mask, cache=None,): + # tgt = self.dropout(tgt) + if isinstance(tgt, Tuple): + tgt, _ = tgt + residual = tgt + if self.normalize_before: + tgt = self.norm1(tgt) + tgt = self.feed_forward(tgt) + + x = tgt + if self.normalize_before: + tgt = self.norm2(tgt) + if self.training: + cache = None + x, cache = self.self_attn(tgt, tgt_mask, cache=cache) + x = residual + self.dropout(x) + x_self_attn = x + + residual = x + if self.normalize_before: + x = self.norm3(x) + x = self.src_attn(x, memory, memory_mask) + x_src_attn = x + + x = residual + self.dropout(x) + return x, tgt_mask, x_self_attn, x_src_attn + + +class ContexutalBiasDecoder(nn.Module): + def __init__( + self, + size, + src_attn, + dropout_rate, + normalize_before=True, + ): + """Construct an DecoderLayer object.""" + super(ContexutalBiasDecoder, self).__init__() + self.size = size + self.src_attn = src_attn + if src_attn is not None: + self.norm3 = LayerNorm(size) + self.dropout = nn.Dropout(dropout_rate) + self.normalize_before = normalize_before + + def forward(self, tgt, tgt_mask, memory, memory_mask=None, cache=None): + x = tgt + if self.src_attn is not None: + if self.normalize_before: + x = self.norm3(x) + x = self.dropout(self.src_attn(x, memory, memory_mask)) + return x, tgt_mask, memory, memory_mask, cache + + +class ContextualParaformerDecoder(ParaformerSANMDecoder): + """ + author: Speech Lab, Alibaba Group, China + Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition + https://arxiv.org/abs/2006.01713 + """ + def __init__( + self, + vocab_size: int, + encoder_output_size: int, + attention_heads: int = 4, + linear_units: int = 2048, + num_blocks: int = 6, + dropout_rate: float = 0.1, + positional_dropout_rate: float = 0.1, + self_attention_dropout_rate: float = 0.0, + src_attention_dropout_rate: float = 0.0, + input_layer: str = "embed", + use_output_layer: bool = True, + pos_enc_class=PositionalEncoding, + normalize_before: bool = True, + concat_after: bool = False, + att_layer_num: int = 6, + kernel_size: int = 21, + sanm_shfit: int = 0, + ): + assert check_argument_types() + super().__init__( + vocab_size=vocab_size, + encoder_output_size=encoder_output_size, + dropout_rate=dropout_rate, + positional_dropout_rate=positional_dropout_rate, + input_layer=input_layer, + use_output_layer=use_output_layer, + pos_enc_class=pos_enc_class, + normalize_before=normalize_before, + ) + + attention_dim = encoder_output_size + if input_layer == 'none': + self.embed = None + if input_layer == "embed": + self.embed = torch.nn.Sequential( + torch.nn.Embedding(vocab_size, attention_dim), + # pos_enc_class(attention_dim, positional_dropout_rate), + ) + elif input_layer == "linear": + self.embed = torch.nn.Sequential( + torch.nn.Linear(vocab_size, attention_dim), + torch.nn.LayerNorm(attention_dim), + torch.nn.Dropout(dropout_rate), + torch.nn.ReLU(), + pos_enc_class(attention_dim, positional_dropout_rate), + ) + else: + raise ValueError(f"only 'embed' or 'linear' is supported: {input_layer}") + + self.normalize_before = normalize_before + if self.normalize_before: + self.after_norm = LayerNorm(attention_dim) + if use_output_layer: + self.output_layer = torch.nn.Linear(attention_dim, vocab_size) + else: + self.output_layer = None + + self.att_layer_num = att_layer_num + self.num_blocks = num_blocks + if sanm_shfit is None: + sanm_shfit = (kernel_size - 1) // 2 + self.decoders = repeat( + att_layer_num - 1, + lambda lnum: DecoderLayerSANM( + attention_dim, + MultiHeadedAttentionSANMDecoder( + attention_dim, self_attention_dropout_rate, kernel_size, sanm_shfit=sanm_shfit + ), + MultiHeadedAttentionCrossAtt( + attention_heads, attention_dim, src_attention_dropout_rate + ), + PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate), + dropout_rate, + normalize_before, + concat_after, + ), + ) + self.dropout = nn.Dropout(dropout_rate) + self.bias_decoder = ContexutalBiasDecoder( + size=attention_dim, + src_attn=MultiHeadedAttentionCrossAtt( + attention_heads, attention_dim, src_attention_dropout_rate + ), + dropout_rate=dropout_rate, + normalize_before=True, + ) + self.bias_output = torch.nn.Conv1d(attention_dim*2, attention_dim, 1, bias=False) + self.last_decoder = ContextualDecoderLayer( + attention_dim, + MultiHeadedAttentionSANMDecoder( + attention_dim, self_attention_dropout_rate, kernel_size, sanm_shfit=sanm_shfit + ), + MultiHeadedAttentionCrossAtt( + attention_heads, attention_dim, src_attention_dropout_rate + ), + PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate), + dropout_rate, + normalize_before, + concat_after, + ) + if num_blocks - att_layer_num <= 0: + self.decoders2 = None + else: + self.decoders2 = repeat( + num_blocks - att_layer_num, + lambda lnum: DecoderLayerSANM( + attention_dim, + MultiHeadedAttentionSANMDecoder( + attention_dim, self_attention_dropout_rate, kernel_size, sanm_shfit=0 + ), + None, + PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate), + dropout_rate, + normalize_before, + concat_after, + ), + ) + + self.decoders3 = repeat( + 1, + lambda lnum: DecoderLayerSANM( + attention_dim, + None, + None, + PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate), + dropout_rate, + normalize_before, + concat_after, + ), + ) + + def forward( + self, + hs_pad: torch.Tensor, + hlens: torch.Tensor, + ys_in_pad: torch.Tensor, + ys_in_lens: torch.Tensor, + contextual_info: torch.Tensor, + return_hidden: bool = False, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Forward decoder. + + Args: + hs_pad: encoded memory, float32 (batch, maxlen_in, feat) + hlens: (batch) + ys_in_pad: + input token ids, int64 (batch, maxlen_out) + if input_layer == "embed" + input tensor (batch, maxlen_out, #mels) in the other cases + ys_in_lens: (batch) + Returns: + (tuple): tuple containing: + + x: decoded token score before softmax (batch, maxlen_out, token) + if use_output_layer is True, + olens: (batch, ) + """ + tgt = ys_in_pad + tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None] + + memory = hs_pad + memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :] + + x = tgt + x, tgt_mask, memory, memory_mask, _ = self.decoders( + x, tgt_mask, memory, memory_mask + ) + _, _, x_self_attn, x_src_attn = self.last_decoder( + x, tgt_mask, memory, memory_mask + ) + + # contextual paraformer related + contextual_length = torch.Tensor([contextual_info.shape[1]]).int().repeat(hs_pad.shape[0]) + contextual_mask = myutils.sequence_mask(contextual_length, device=memory.device)[:, None, :] + cx, tgt_mask, _, _, _ = self.bias_decoder(x_self_attn, tgt_mask, contextual_info, memory_mask=contextual_mask) + + if self.bias_output is not None: + x = torch.cat([x_src_attn, cx], dim=2) + x = self.bias_output(x.transpose(1, 2)).transpose(1, 2) # 2D -> D + x = x_self_attn + self.dropout(x) + + if self.decoders2 is not None: + x, tgt_mask, memory, memory_mask, _ = self.decoders2( + x, tgt_mask, memory, memory_mask + ) + + x, tgt_mask, memory, memory_mask, _ = self.decoders3( + x, tgt_mask, memory, memory_mask + ) + if self.normalize_before: + x = self.after_norm(x) + olens = tgt_mask.sum(1) + if self.output_layer is not None and return_hidden is False: + x = self.output_layer(x) + return x, olens + + def gen_tf2torch_map_dict(self): + + tensor_name_prefix_torch = self.tf2torch_tensor_name_prefix_torch + tensor_name_prefix_tf = self.tf2torch_tensor_name_prefix_tf + map_dict_local = { + + ## decoder + # ffn + "{}.decoders.layeridx.norm1.weight".format(tensor_name_prefix_torch): + {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/LayerNorm/gamma".format(tensor_name_prefix_tf), + "squeeze": None, + "transpose": None, + }, # (256,),(256,) + "{}.decoders.layeridx.norm1.bias".format(tensor_name_prefix_torch): + {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/LayerNorm/beta".format(tensor_name_prefix_tf), + "squeeze": None, + "transpose": None, + }, # (256,),(256,) + "{}.decoders.layeridx.feed_forward.w_1.weight".format(tensor_name_prefix_torch): + {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/conv1d/kernel".format(tensor_name_prefix_tf), + "squeeze": 0, + "transpose": (1, 0), + }, # (1024,256),(1,256,1024) + "{}.decoders.layeridx.feed_forward.w_1.bias".format(tensor_name_prefix_torch): + {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/conv1d/bias".format(tensor_name_prefix_tf), + "squeeze": None, + "transpose": None, + }, # (1024,),(1024,) + "{}.decoders.layeridx.feed_forward.norm.weight".format(tensor_name_prefix_torch): + {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/LayerNorm_1/gamma".format(tensor_name_prefix_tf), + "squeeze": None, + "transpose": None, + }, # (1024,),(1024,) + "{}.decoders.layeridx.feed_forward.norm.bias".format(tensor_name_prefix_torch): + {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/LayerNorm_1/beta".format(tensor_name_prefix_tf), + "squeeze": None, + "transpose": None, + }, # (1024,),(1024,) + "{}.decoders.layeridx.feed_forward.w_2.weight".format(tensor_name_prefix_torch): + {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/conv1d_1/kernel".format(tensor_name_prefix_tf), + "squeeze": 0, + "transpose": (1, 0), + }, # (256,1024),(1,1024,256) + + # fsmn + "{}.decoders.layeridx.norm2.weight".format(tensor_name_prefix_torch): + {"name": "{}/decoder_fsmn_layer_layeridx/decoder_memory_block/LayerNorm/gamma".format( + tensor_name_prefix_tf), + "squeeze": None, + "transpose": None, + }, # (256,),(256,) + "{}.decoders.layeridx.norm2.bias".format(tensor_name_prefix_torch): + {"name": "{}/decoder_fsmn_layer_layeridx/decoder_memory_block/LayerNorm/beta".format( + tensor_name_prefix_tf), + "squeeze": None, + "transpose": None, + }, # (256,),(256,) + "{}.decoders.layeridx.self_attn.fsmn_block.weight".format(tensor_name_prefix_torch): + {"name": "{}/decoder_fsmn_layer_layeridx/decoder_memory_block/depth_conv_w".format( + tensor_name_prefix_tf), + "squeeze": 0, + "transpose": (1, 2, 0), + }, # (256,1,31),(1,31,256,1) + # src att + "{}.decoders.layeridx.norm3.weight".format(tensor_name_prefix_torch): + {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/LayerNorm/gamma".format(tensor_name_prefix_tf), + "squeeze": None, + "transpose": None, + }, # (256,),(256,) + "{}.decoders.layeridx.norm3.bias".format(tensor_name_prefix_torch): + {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/LayerNorm/beta".format(tensor_name_prefix_tf), + "squeeze": None, + "transpose": None, + }, # (256,),(256,) + "{}.decoders.layeridx.src_attn.linear_q.weight".format(tensor_name_prefix_torch): + {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d/kernel".format(tensor_name_prefix_tf), + "squeeze": 0, + "transpose": (1, 0), + }, # (256,256),(1,256,256) + "{}.decoders.layeridx.src_attn.linear_q.bias".format(tensor_name_prefix_torch): + {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d/bias".format(tensor_name_prefix_tf), + "squeeze": None, + "transpose": None, + }, # (256,),(256,) + "{}.decoders.layeridx.src_attn.linear_k_v.weight".format(tensor_name_prefix_torch): + {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d_1/kernel".format(tensor_name_prefix_tf), + "squeeze": 0, + "transpose": (1, 0), + }, # (1024,256),(1,256,1024) + "{}.decoders.layeridx.src_attn.linear_k_v.bias".format(tensor_name_prefix_torch): + {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d_1/bias".format(tensor_name_prefix_tf), + "squeeze": None, + "transpose": None, + }, # (1024,),(1024,) + "{}.decoders.layeridx.src_attn.linear_out.weight".format(tensor_name_prefix_torch): + {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d_2/kernel".format(tensor_name_prefix_tf), + "squeeze": 0, + "transpose": (1, 0), + }, # (256,256),(1,256,256) + "{}.decoders.layeridx.src_attn.linear_out.bias".format(tensor_name_prefix_torch): + {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d_2/bias".format(tensor_name_prefix_tf), + "squeeze": None, + "transpose": None, + }, # (256,),(256,) + # dnn + "{}.decoders3.layeridx.norm1.weight".format(tensor_name_prefix_torch): + {"name": "{}/decoder_dnn_layer_layeridx/LayerNorm/gamma".format(tensor_name_prefix_tf), + "squeeze": None, + "transpose": None, + }, # (256,),(256,) + "{}.decoders3.layeridx.norm1.bias".format(tensor_name_prefix_torch): + {"name": "{}/decoder_dnn_layer_layeridx/LayerNorm/beta".format(tensor_name_prefix_tf), + "squeeze": None, + "transpose": None, + }, # (256,),(256,) + "{}.decoders3.layeridx.feed_forward.w_1.weight".format(tensor_name_prefix_torch): + {"name": "{}/decoder_dnn_layer_layeridx/conv1d/kernel".format(tensor_name_prefix_tf), + "squeeze": 0, + "transpose": (1, 0), + }, # (1024,256),(1,256,1024) + "{}.decoders3.layeridx.feed_forward.w_1.bias".format(tensor_name_prefix_torch): + {"name": "{}/decoder_dnn_layer_layeridx/conv1d/bias".format(tensor_name_prefix_tf), + "squeeze": None, + "transpose": None, + }, # (1024,),(1024,) + "{}.decoders3.layeridx.feed_forward.norm.weight".format(tensor_name_prefix_torch): + {"name": "{}/decoder_dnn_layer_layeridx/LayerNorm_1/gamma".format(tensor_name_prefix_tf), + "squeeze": None, + "transpose": None, + }, # (1024,),(1024,) + "{}.decoders3.layeridx.feed_forward.norm.bias".format(tensor_name_prefix_torch): + {"name": "{}/decoder_dnn_layer_layeridx/LayerNorm_1/beta".format(tensor_name_prefix_tf), + "squeeze": None, + "transpose": None, + }, # (1024,),(1024,) + "{}.decoders3.layeridx.feed_forward.w_2.weight".format(tensor_name_prefix_torch): + {"name": "{}/decoder_dnn_layer_layeridx/conv1d_1/kernel".format(tensor_name_prefix_tf), + "squeeze": 0, + "transpose": (1, 0), + }, # (256,1024),(1,1024,256) + + # embed_concat_ffn + "{}.embed_concat_ffn.layeridx.norm1.weight".format(tensor_name_prefix_torch): + {"name": "{}/cif_concat/LayerNorm/gamma".format(tensor_name_prefix_tf), + "squeeze": None, + "transpose": None, + }, # (256,),(256,) + "{}.embed_concat_ffn.layeridx.norm1.bias".format(tensor_name_prefix_torch): + {"name": "{}/cif_concat/LayerNorm/beta".format(tensor_name_prefix_tf), + "squeeze": None, + "transpose": None, + }, # (256,),(256,) + "{}.embed_concat_ffn.layeridx.feed_forward.w_1.weight".format(tensor_name_prefix_torch): + {"name": "{}/cif_concat/conv1d/kernel".format(tensor_name_prefix_tf), + "squeeze": 0, + "transpose": (1, 0), + }, # (1024,256),(1,256,1024) + "{}.embed_concat_ffn.layeridx.feed_forward.w_1.bias".format(tensor_name_prefix_torch): + {"name": "{}/cif_concat/conv1d/bias".format(tensor_name_prefix_tf), + "squeeze": None, + "transpose": None, + }, # (1024,),(1024,) + "{}.embed_concat_ffn.layeridx.feed_forward.norm.weight".format(tensor_name_prefix_torch): + {"name": "{}/cif_concat/LayerNorm_1/gamma".format(tensor_name_prefix_tf), + "squeeze": None, + "transpose": None, + }, # (1024,),(1024,) + "{}.embed_concat_ffn.layeridx.feed_forward.norm.bias".format(tensor_name_prefix_torch): + {"name": "{}/cif_concat/LayerNorm_1/beta".format(tensor_name_prefix_tf), + "squeeze": None, + "transpose": None, + }, # (1024,),(1024,) + "{}.embed_concat_ffn.layeridx.feed_forward.w_2.weight".format(tensor_name_prefix_torch): + {"name": "{}/cif_concat/conv1d_1/kernel".format(tensor_name_prefix_tf), + "squeeze": 0, + "transpose": (1, 0), + }, # (256,1024),(1,1024,256) + + # out norm + "{}.after_norm.weight".format(tensor_name_prefix_torch): + {"name": "{}/LayerNorm/gamma".format(tensor_name_prefix_tf), + "squeeze": None, + "transpose": None, + }, # (256,),(256,) + "{}.after_norm.bias".format(tensor_name_prefix_torch): + {"name": "{}/LayerNorm/beta".format(tensor_name_prefix_tf), + "squeeze": None, + "transpose": None, + }, # (256,),(256,) + + # in embed + "{}.embed.0.weight".format(tensor_name_prefix_torch): + {"name": "{}/w_embs".format(tensor_name_prefix_tf), + "squeeze": None, + "transpose": None, + }, # (4235,256),(4235,256) + + # out layer + "{}.output_layer.weight".format(tensor_name_prefix_torch): + {"name": ["{}/dense/kernel".format(tensor_name_prefix_tf), "{}/w_embs".format(tensor_name_prefix_tf)], + "squeeze": [None, None], + "transpose": [(1, 0), None], + }, # (4235,256),(256,4235) + "{}.output_layer.bias".format(tensor_name_prefix_torch): + {"name": ["{}/dense/bias".format(tensor_name_prefix_tf), + "seq2seq/2bias" if tensor_name_prefix_tf == "seq2seq/decoder/inputter_1" else "seq2seq/bias"], + "squeeze": [None, None], + "transpose": [None, None], + }, # (4235,),(4235,) + + ## clas decoder + # src att + "{}.bias_decoder.norm3.weight".format(tensor_name_prefix_torch): + {"name": "{}/decoder_fsmn_layer_15/multi_head_1/LayerNorm/gamma".format(tensor_name_prefix_tf), + "squeeze": None, + "transpose": None, + }, # (256,),(256,) + "{}.bias_decoder.norm3.bias".format(tensor_name_prefix_torch): + {"name": "{}/decoder_fsmn_layer_15/multi_head_1/LayerNorm/beta".format(tensor_name_prefix_tf), + "squeeze": None, + "transpose": None, + }, # (256,),(256,) + "{}.bias_decoder.src_attn.linear_q.weight".format(tensor_name_prefix_torch): + {"name": "{}/decoder_fsmn_layer_15/multi_head_1/conv1d/kernel".format(tensor_name_prefix_tf), + "squeeze": 0, + "transpose": (1, 0), + }, # (256,256),(1,256,256) + "{}.bias_decoder.src_attn.linear_q.bias".format(tensor_name_prefix_torch): + {"name": "{}/decoder_fsmn_layer_15/multi_head_1/conv1d/bias".format(tensor_name_prefix_tf), + "squeeze": None, + "transpose": None, + }, # (256,),(256,) + "{}.bias_decoder.src_attn.linear_k_v.weight".format(tensor_name_prefix_torch): + {"name": "{}/decoder_fsmn_layer_15/multi_head_1/conv1d_1/kernel".format(tensor_name_prefix_tf), + "squeeze": 0, + "transpose": (1, 0), + }, # (1024,256),(1,256,1024) + "{}.bias_decoder.src_attn.linear_k_v.bias".format(tensor_name_prefix_torch): + {"name": "{}/decoder_fsmn_layer_15/multi_head_1/conv1d_1/bias".format(tensor_name_prefix_tf), + "squeeze": None, + "transpose": None, + }, # (1024,),(1024,) + "{}.bias_decoder.src_attn.linear_out.weight".format(tensor_name_prefix_torch): + {"name": "{}/decoder_fsmn_layer_15/multi_head_1/conv1d_2/kernel".format(tensor_name_prefix_tf), + "squeeze": 0, + "transpose": (1, 0), + }, # (256,256),(1,256,256) + "{}.bias_decoder.src_attn.linear_out.bias".format(tensor_name_prefix_torch): + {"name": "{}/decoder_fsmn_layer_15/multi_head_1/conv1d_2/bias".format(tensor_name_prefix_tf), + "squeeze": None, + "transpose": None, + }, # (256,),(256,) + # dnn + "{}.bias_output.weight".format(tensor_name_prefix_torch): + {"name": "{}/decoder_fsmn_layer_15/conv1d/kernel".format(tensor_name_prefix_tf), + "squeeze": None, + "transpose": (2, 1, 0), + }, # (1024,256),(1,256,1024) + + } + return map_dict_local + + def convert_tf2torch(self, + var_dict_tf, + var_dict_torch, + ): + map_dict = self.gen_tf2torch_map_dict() + var_dict_torch_update = dict() + decoder_layeridx_sets = set() + for name in sorted(var_dict_torch.keys(), reverse=False): + names = name.split('.') + if names[0] == self.tf2torch_tensor_name_prefix_torch: + if names[1] == "decoders": + layeridx = int(names[2]) + name_q = name.replace(".{}.".format(layeridx), ".layeridx.") + layeridx_bias = 0 + layeridx += layeridx_bias + decoder_layeridx_sets.add(layeridx) + if name_q in map_dict.keys(): + name_v = map_dict[name_q]["name"] + name_tf = name_v.replace("layeridx", "{}".format(layeridx)) + data_tf = var_dict_tf[name_tf] + if map_dict[name_q]["squeeze"] is not None: + data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"]) + if map_dict[name_q]["transpose"] is not None: + data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"]) + data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu") + assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf, + var_dict_torch[ + name].size(), + data_tf.size()) + var_dict_torch_update[name] = data_tf + logging.info( + "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v, + var_dict_tf[name_tf].shape)) + elif names[1] == "last_decoder": + layeridx = 15 + name_q = name.replace("last_decoder", "decoders.layeridx") + layeridx_bias = 0 + layeridx += layeridx_bias + decoder_layeridx_sets.add(layeridx) + if name_q in map_dict.keys(): + name_v = map_dict[name_q]["name"] + name_tf = name_v.replace("layeridx", "{}".format(layeridx)) + data_tf = var_dict_tf[name_tf] + if map_dict[name_q]["squeeze"] is not None: + data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"]) + if map_dict[name_q]["transpose"] is not None: + data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"]) + data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu") + assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf, + var_dict_torch[ + name].size(), + data_tf.size()) + var_dict_torch_update[name] = data_tf + logging.info( + "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v, + var_dict_tf[name_tf].shape)) + + + elif names[1] == "decoders2": + layeridx = int(names[2]) + name_q = name.replace(".{}.".format(layeridx), ".layeridx.") + name_q = name_q.replace("decoders2", "decoders") + layeridx_bias = len(decoder_layeridx_sets) + + layeridx += layeridx_bias + if "decoders." in name: + decoder_layeridx_sets.add(layeridx) + if name_q in map_dict.keys(): + name_v = map_dict[name_q]["name"] + name_tf = name_v.replace("layeridx", "{}".format(layeridx)) + data_tf = var_dict_tf[name_tf] + if map_dict[name_q]["squeeze"] is not None: + data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"]) + if map_dict[name_q]["transpose"] is not None: + data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"]) + data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu") + assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf, + var_dict_torch[ + name].size(), + data_tf.size()) + var_dict_torch_update[name] = data_tf + logging.info( + "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v, + var_dict_tf[name_tf].shape)) + + elif names[1] == "decoders3": + layeridx = int(names[2]) + name_q = name.replace(".{}.".format(layeridx), ".layeridx.") + + layeridx_bias = 0 + layeridx += layeridx_bias + if "decoders." in name: + decoder_layeridx_sets.add(layeridx) + if name_q in map_dict.keys(): + name_v = map_dict[name_q]["name"] + name_tf = name_v.replace("layeridx", "{}".format(layeridx)) + data_tf = var_dict_tf[name_tf] + if map_dict[name_q]["squeeze"] is not None: + data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"]) + if map_dict[name_q]["transpose"] is not None: + data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"]) + data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu") + assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf, + var_dict_torch[ + name].size(), + data_tf.size()) + var_dict_torch_update[name] = data_tf + logging.info( + "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v, + var_dict_tf[name_tf].shape)) + elif names[1] == "bias_decoder": + name_q = name + + if name_q in map_dict.keys(): + name_v = map_dict[name_q]["name"] + name_tf = name_v + data_tf = var_dict_tf[name_tf] + if map_dict[name_q]["squeeze"] is not None: + data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"]) + if map_dict[name_q]["transpose"] is not None: + data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"]) + data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu") + assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf, + var_dict_torch[ + name].size(), + data_tf.size()) + var_dict_torch_update[name] = data_tf + logging.info( + "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v, + var_dict_tf[name_tf].shape)) + + + elif names[1] == "embed" or names[1] == "output_layer" or names[1] == "bias_output": + name_tf = map_dict[name]["name"] + if isinstance(name_tf, list): + idx_list = 0 + if name_tf[idx_list] in var_dict_tf.keys(): + pass + else: + idx_list = 1 + data_tf = var_dict_tf[name_tf[idx_list]] + if map_dict[name]["squeeze"][idx_list] is not None: + data_tf = np.squeeze(data_tf, axis=map_dict[name]["squeeze"][idx_list]) + if map_dict[name]["transpose"][idx_list] is not None: + data_tf = np.transpose(data_tf, map_dict[name]["transpose"][idx_list]) + data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu") + assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf, + var_dict_torch[ + name].size(), + data_tf.size()) + var_dict_torch_update[name] = data_tf + logging.info("torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), + name_tf[idx_list], + var_dict_tf[name_tf[ + idx_list]].shape)) + + else: + data_tf = var_dict_tf[name_tf] + if map_dict[name]["squeeze"] is not None: + data_tf = np.squeeze(data_tf, axis=map_dict[name]["squeeze"]) + if map_dict[name]["transpose"] is not None: + data_tf = np.transpose(data_tf, map_dict[name]["transpose"]) + data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu") + assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf, + var_dict_torch[ + name].size(), + data_tf.size()) + var_dict_torch_update[name] = data_tf + logging.info( + "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf, + var_dict_tf[name_tf].shape)) + + elif names[1] == "after_norm": + name_tf = map_dict[name]["name"] + data_tf = var_dict_tf[name_tf] + data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu") + var_dict_torch_update[name] = data_tf + logging.info( + "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf, + var_dict_tf[name_tf].shape)) + + elif names[1] == "embed_concat_ffn": + layeridx = int(names[2]) + name_q = name.replace(".{}.".format(layeridx), ".layeridx.") + + layeridx_bias = 0 + layeridx += layeridx_bias + if "decoders." in name: + decoder_layeridx_sets.add(layeridx) + if name_q in map_dict.keys(): + name_v = map_dict[name_q]["name"] + name_tf = name_v.replace("layeridx", "{}".format(layeridx)) + data_tf = var_dict_tf[name_tf] + if map_dict[name_q]["squeeze"] is not None: + data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"]) + if map_dict[name_q]["transpose"] is not None: + data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"]) + data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu") + assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf, + var_dict_torch[ + name].size(), + data_tf.size()) + var_dict_torch_update[name] = data_tf + logging.info( + "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v, + var_dict_tf[name_tf].shape)) + + return var_dict_torch_update diff --git a/funasr/models/e2e_asr_paraformer.py b/funasr/models/e2e_asr_paraformer.py index 759689629..5786bc46e 100644 --- a/funasr/models/e2e_asr_paraformer.py +++ b/funasr/models/e2e_asr_paraformer.py @@ -8,6 +8,8 @@ from typing import Tuple from typing import Union import torch +import random +import numpy as np from typeguard import check_argument_types from funasr.layers.abs_normalize import AbsNormalize @@ -24,7 +26,7 @@ from funasr.models.predictor.cif import mae_loss from funasr.models.preencoder.abs_preencoder import AbsPreEncoder from funasr.models.specaug.abs_specaug import AbsSpecAug from funasr.modules.add_sos_eos import add_sos_eos -from funasr.modules.nets_utils import make_pad_mask +from funasr.modules.nets_utils import make_pad_mask, pad_list from funasr.modules.nets_utils import th_accuracy from funasr.torch_utils.device_funcs import force_gatherable from funasr.train.abs_espnet_model import AbsESPnetModel @@ -824,7 +826,10 @@ class ParaformerBert(Paraformer): class BiCifParaformer(Paraformer): - """CTC-attention hybrid Encoder-Decoder model""" + """ + Paraformer model with an extra cif predictor + to conduct accurate timestamp prediction + """ def __init__( self, @@ -891,7 +896,7 @@ class BiCifParaformer(Paraformer): ) assert isinstance(self.predictor, CifPredictorV3), "BiCifParaformer should use CIFPredictorV3" - def _calc_att_loss( + def _calc_pre2_loss( self, encoder_out: torch.Tensor, encoder_out_lens: torch.Tensor, @@ -903,47 +908,12 @@ class BiCifParaformer(Paraformer): if self.predictor_bias == 1: _, ys_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id) ys_pad_lens = ys_pad_lens + self.predictor_bias - pre_acoustic_embeds, pre_token_length, _, pre_peak_index, pre_token_length2 = self.predictor(encoder_out, ys_pad, encoder_out_mask, - ignore_id=self.ignore_id) + _, _, _, _, pre_token_length2 = self.predictor(encoder_out, ys_pad, encoder_out_mask, ignore_id=self.ignore_id) - # 0. sampler - decoder_out_1st = None - if self.sampling_ratio > 0.0: - if self.step_cur < 2: - logging.info("enable sampler in paraformer, sampling_ratio: {}".format(self.sampling_ratio)) - sematic_embeds, decoder_out_1st = self.sampler(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens, - pre_acoustic_embeds) - else: - if self.step_cur < 2: - logging.info("disable sampler in paraformer, sampling_ratio: {}".format(self.sampling_ratio)) - sematic_embeds = pre_acoustic_embeds + # loss_pre = self.criterion_pre(ys_pad_lens.type_as(pre_token_length), pre_token_length) + loss_pre2 = self.criterion_pre(ys_pad_lens.type_as(pre_token_length2), pre_token_length2) - # 1. Forward decoder - decoder_outs = self.decoder( - encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens - ) - decoder_out, _ = decoder_outs[0], decoder_outs[1] - - if decoder_out_1st is None: - decoder_out_1st = decoder_out - # 2. Compute attention loss - loss_att = self.criterion_att(decoder_out, ys_pad) - acc_att = th_accuracy( - decoder_out_1st.view(-1, self.vocab_size), - ys_pad, - ignore_label=self.ignore_id, - ) - loss_pre = self.criterion_pre(ys_pad_lens.type_as(pre_token_length), pre_token_length) - loss_pre2 = self.criterion_pre(ys_pad_lens.type_as(pre_token_length), pre_token_length2) - - # Compute cer/wer using attention-decoder - if self.training or self.error_calculator is None: - cer_att, wer_att = None, None - else: - ys_hat = decoder_out_1st.argmax(dim=-1) - cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu()) - - return loss_att, acc_att, cer_att, wer_att, loss_pre, loss_pre2 + return loss_pre2 def calc_predictor(self, encoder_out, encoder_out_lens): @@ -956,11 +926,155 @@ class BiCifParaformer(Paraformer): def calc_predictor_timestamp(self, encoder_out, encoder_out_lens, token_num): encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to( encoder_out.device) - ds_alphas, ds_cif_peak, us_alphas, us_cif_peak = self.predictor.get_upsample_timestamp(encoder_out, None, encoder_out_mask, token_num=token_num, - ignore_id=self.ignore_id) - import pdb; pdb.set_trace() + ds_alphas, ds_cif_peak, us_alphas, us_cif_peak = self.predictor.get_upsample_timestamp(encoder_out, + encoder_out_mask, + token_num) return ds_alphas, ds_cif_peak, us_alphas, us_cif_peak + def forward( + self, + speech: torch.Tensor, + speech_lengths: torch.Tensor, + text: torch.Tensor, + text_lengths: torch.Tensor, + ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: + """Frontend + Encoder + Decoder + Calc loss + + Args: + speech: (Batch, Length, ...) + speech_lengths: (Batch, ) + text: (Batch, Length) + text_lengths: (Batch,) + """ + assert text_lengths.dim() == 1, text_lengths.shape + # Check that batch_size is unified + assert ( + speech.shape[0] + == speech_lengths.shape[0] + == text.shape[0] + == text_lengths.shape[0] + ), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape) + batch_size = speech.shape[0] + self.step_cur += 1 + # for data-parallel + text = text[:, : text_lengths.max()] + speech = speech[:, :speech_lengths.max()] + + # 1. Encoder + encoder_out, encoder_out_lens = self.encode(speech, speech_lengths) + + stats = dict() + + loss_pre2 = self._calc_pre2_loss( + encoder_out, encoder_out_lens, text, text_lengths + ) + + loss = loss_pre2 + + stats["loss_pre2"] = loss_pre2.detach().cpu() + stats["loss"] = torch.clone(loss.detach()) + + # force_gatherable: to-device and to-tensor if scalar for DataParallel + loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) + return loss, stats, weight + +class ContextualParaformer(Paraformer): + """ + Paraformer model with contextual hotword + """ + + def __init__( + self, + vocab_size: int, + token_list: Union[Tuple[str, ...], List[str]], + frontend: Optional[AbsFrontend], + specaug: Optional[AbsSpecAug], + normalize: Optional[AbsNormalize], + preencoder: Optional[AbsPreEncoder], + encoder: AbsEncoder, + postencoder: Optional[AbsPostEncoder], + decoder: AbsDecoder, + ctc: CTC, + ctc_weight: float = 0.5, + interctc_weight: float = 0.0, + ignore_id: int = -1, + blank_id: int = 0, + sos: int = 1, + eos: int = 2, + lsm_weight: float = 0.0, + length_normalized_loss: bool = False, + report_cer: bool = True, + report_wer: bool = True, + sym_space: str = "", + sym_blank: str = "", + extract_feats_in_collect_stats: bool = True, + predictor=None, + predictor_weight: float = 0.0, + predictor_bias: int = 0, + sampling_ratio: float = 0.2, + min_hw_length: int = 2, + max_hw_length: int = 4, + sample_rate: float = 0.6, + batch_rate: float = 0.5, + double_rate: float = -1.0, + target_buffer_length: int = -1, + inner_dim: int = 256, + bias_encoder_type: str = 'lstm', + label_bracket: bool = False, + ): + assert check_argument_types() + assert 0.0 <= ctc_weight <= 1.0, ctc_weight + assert 0.0 <= interctc_weight < 1.0, interctc_weight + + super().__init__( + vocab_size=vocab_size, + token_list=token_list, + frontend=frontend, + specaug=specaug, + normalize=normalize, + preencoder=preencoder, + encoder=encoder, + postencoder=postencoder, + decoder=decoder, + ctc=ctc, + ctc_weight=ctc_weight, + interctc_weight=interctc_weight, + ignore_id=ignore_id, + blank_id=blank_id, + sos=sos, + eos=eos, + lsm_weight=lsm_weight, + length_normalized_loss=length_normalized_loss, + report_cer=report_cer, + report_wer=report_wer, + sym_space=sym_space, + sym_blank=sym_blank, + extract_feats_in_collect_stats=extract_feats_in_collect_stats, + predictor=predictor, + predictor_weight=predictor_weight, + predictor_bias=predictor_bias, + sampling_ratio=sampling_ratio, + ) + + if bias_encoder_type == 'lstm': + logging.warning("enable bias encoder sampling and contextual training") + self.bias_encoder = torch.nn.LSTM(inner_dim, inner_dim, 1, batch_first=True, dropout=0) + self.bias_embed = torch.nn.Embedding(vocab_size, inner_dim) + else: + logging.error("Unsupport bias encoder type") + + self.min_hw_length = min_hw_length + self.max_hw_length = max_hw_length + self.sample_rate = sample_rate + self.batch_rate = batch_rate + self.target_buffer_length = target_buffer_length + self.double_rate = double_rate + + if self.target_buffer_length > 0: + self.hotword_buffer = None + self.length_record = [] + self.current_buffer_length = 0 + def forward( self, speech: torch.Tensor, @@ -1038,17 +1152,17 @@ class BiCifParaformer(Paraformer): # 2b. Attention decoder branch if self.ctc_weight != 1.0: - loss_att, acc_att, cer_att, wer_att, loss_pre, loss_pre2 = self._calc_att_loss( + loss_att, acc_att, cer_att, wer_att, loss_pre = self._calc_att_loss( encoder_out, encoder_out_lens, text, text_lengths ) # 3. CTC-Att loss definition if self.ctc_weight == 0.0: - loss = loss_att + loss_pre * self.predictor_weight + loss_pre2 * self.predictor_weight + loss = loss_att + loss_pre * self.predictor_weight elif self.ctc_weight == 1.0: loss = loss_ctc else: - loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att + loss_pre * self.predictor_weight + loss_pre2 * self.predictor_weight + loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att + loss_pre * self.predictor_weight # Collect Attn branch stats stats["loss_att"] = loss_att.detach() if loss_att is not None else None @@ -1056,10 +1170,292 @@ class BiCifParaformer(Paraformer): stats["cer"] = cer_att stats["wer"] = wer_att stats["loss_pre"] = loss_pre.detach().cpu() if loss_pre is not None else None - stats["loss_pre2"] = loss_pre2.detach().cpu() if loss_pre is not None else None stats["loss"] = torch.clone(loss.detach()) # force_gatherable: to-device and to-tensor if scalar for DataParallel loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) - return loss, stats, weight \ No newline at end of file + return loss, stats, weight + + def _sample_hot_word(self, ys_pad, ys_pad_lens): + hw_list = [torch.Tensor([0]).long().to(ys_pad.device)] + hw_lengths = [0] # this length is actually for indice, so -1 + for i, length in enumerate(ys_pad_lens): + if length < 2: + continue + if length > self.min_hw_length + self.max_hw_length + 2 and random.random() < self.double_rate: + # sample double hotword + _max_hw_length = min(self.max_hw_length, length // 2) + # first hotword + start1 = random.randint(0, length // 3) + end1 = random.randint(start1 + self.min_hw_length - 1, start1 + _max_hw_length - 1) + hw_tokens1 = ys_pad[i][start1:end1 + 1] + hw_lengths.append(len(hw_tokens1) - 1) + hw_list.append(hw_tokens1) + # second hotword + start2 = random.randint(end1 + 1, length - self.min_hw_length) + end2 = random.randint(min(length - 1, start2 + self.min_hw_length - 1), + min(length - 1, start2 + self.max_hw_length - 1)) + hw_tokens2 = ys_pad[i][start2:end2 + 1] + hw_lengths.append(len(hw_tokens2) - 1) + hw_list.append(hw_tokens2) + continue + if random.random() < self.sample_rate: + if length == 2: + hw_tokens = ys_pad[i][:2] + hw_lengths.append(1) + hw_list.append(hw_tokens) + else: + start = random.randint(0, length - self.min_hw_length) + end = random.randint(min(length - 1, start + self.min_hw_length - 1), + min(length - 1, start + self.max_hw_length - 1)) + 1 + # print(start, end) + hw_tokens = ys_pad[i][start:end] + hw_lengths.append(len(hw_tokens) - 1) + hw_list.append(hw_tokens) + # padding + hw_list_pad = pad_list(hw_list, 0) + hw_embed = self.decoder.embed(hw_list_pad) + hw_embed, (_, _) = self.bias_encoder(hw_embed) + _ind = np.arange(0, len(hw_list)).tolist() + # update self.hotword_buffer, throw a part if oversize + selected = hw_embed[_ind, hw_lengths] + if self.target_buffer_length > 0: + _b = selected.shape[0] + if self.hotword_buffer is None: + self.hotword_buffer = selected + self.length_record.append(selected.shape[0]) + self.current_buffer_length = _b + elif self.current_buffer_length + _b < self.target_buffer_length: + self.hotword_buffer = torch.cat([self.hotword_buffer.detach(), selected], dim=0) + self.current_buffer_length += _b + selected = self.hotword_buffer + else: + self.hotword_buffer = torch.cat([self.hotword_buffer.detach(), selected], dim=0) + random_throw = random.randint(self.target_buffer_length // 2, self.target_buffer_length) + 10 + self.hotword_buffer = self.hotword_buffer[-1 * random_throw:] + selected = self.hotword_buffer + self.current_buffer_length = selected.shape[0] + return selected.squeeze(0).repeat(ys_pad.shape[0], 1, 1).to(ys_pad.device) + + def sampler(self, encoder_out, encoder_out_lens, ys_pad, ys_pad_lens, pre_acoustic_embeds, contextual_info): + + tgt_mask = (~make_pad_mask(ys_pad_lens, maxlen=ys_pad_lens.max())[:, :, None]).to(ys_pad.device) + ys_pad = ys_pad * tgt_mask[:, :, 0] + if self.share_embedding: + ys_pad_embed = self.decoder.output_layer.weight[ys_pad] + else: + ys_pad_embed = self.decoder.embed(ys_pad) + with torch.no_grad(): + decoder_outs = self.decoder( + encoder_out, encoder_out_lens, pre_acoustic_embeds, ys_pad_lens, contextual_info=contextual_info + ) + decoder_out, _ = decoder_outs[0], decoder_outs[1] + pred_tokens = decoder_out.argmax(-1) + nonpad_positions = ys_pad.ne(self.ignore_id) + seq_lens = (nonpad_positions).sum(1) + same_num = ((pred_tokens == ys_pad) & nonpad_positions).sum(1) + input_mask = torch.ones_like(nonpad_positions) + bsz, seq_len = ys_pad.size() + for li in range(bsz): + target_num = (((seq_lens[li] - same_num[li].sum()).float()) * self.sampling_ratio).long() + if target_num > 0: + input_mask[li].scatter_(dim=0, index=torch.randperm(seq_lens[li])[:target_num].cuda(), value=0) + input_mask = input_mask.eq(1) + input_mask = input_mask.masked_fill(~nonpad_positions, False) + input_mask_expand_dim = input_mask.unsqueeze(2).to(pre_acoustic_embeds.device) + + sematic_embeds = pre_acoustic_embeds.masked_fill(~input_mask_expand_dim, 0) + ys_pad_embed.masked_fill( + input_mask_expand_dim, 0) + return sematic_embeds * tgt_mask, decoder_out * tgt_mask + + def _calc_att_loss( + self, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + ys_pad: torch.Tensor, + ys_pad_lens: torch.Tensor, + ): + encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to( + encoder_out.device) + if self.predictor_bias == 1: + _, ys_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id) + ys_pad_lens = ys_pad_lens + self.predictor_bias + pre_acoustic_embeds, pre_token_length, _, pre_peak_index = self.predictor(encoder_out, ys_pad, + encoder_out_mask, + ignore_id=self.ignore_id) + + # sample hot word + contextual_info = self._sample_hot_word(ys_pad, ys_pad_lens) + + # 0. sampler + decoder_out_1st = None + if self.sampling_ratio > 0.0: + if self.step_cur < 2: + logging.info("enable sampler in paraformer, sampling_ratio: {}".format(self.sampling_ratio)) + sematic_embeds, decoder_out_1st = self.sampler(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens, + pre_acoustic_embeds, contextual_info) + else: + if self.step_cur < 2: + logging.info("disable sampler in paraformer, sampling_ratio: {}".format(self.sampling_ratio)) + sematic_embeds = pre_acoustic_embeds + + # 1. Forward decoder + decoder_outs = self.decoder( + encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, contextual_info=contextual_info + ) + decoder_out, _ = decoder_outs[0], decoder_outs[1] + + if decoder_out_1st is None: + decoder_out_1st = decoder_out + # 2. Compute attention loss + loss_att = self.criterion_att(decoder_out, ys_pad) + acc_att = th_accuracy( + decoder_out_1st.view(-1, self.vocab_size), + ys_pad, + ignore_label=self.ignore_id, + ) + loss_pre = self.criterion_pre(ys_pad_lens.type_as(pre_token_length), pre_token_length) + + # Compute cer/wer using attention-decoder + if self.training or self.error_calculator is None: + cer_att, wer_att = None, None + else: + ys_hat = decoder_out_1st.argmax(dim=-1) + cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu()) + + return loss_att, acc_att, cer_att, wer_att, loss_pre + + def cal_decoder_with_predictor(self, encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, hw_list=None): + if hw_list is None: + # default hotword list + hw_list = [torch.Tensor([self.sos]).long().to(encoder_out.device)] # empty hotword list + hw_list_pad = pad_list(hw_list, 0) + hw_embed = self.bias_embed(hw_list_pad) + _, (h_n, _) = self.bias_encoder(hw_embed) + contextual_info = h_n.squeeze(0).repeat(encoder_out.shape[0], 1, 1) + else: + hw_lengths = [len(i) for i in hw_list] + hw_list_pad = pad_list([torch.Tensor(i).long() for i in hw_list], 0).to(encoder_out.device) + hw_embed = self.bias_embed(hw_list_pad) + hw_embed = torch.nn.utils.rnn.pack_padded_sequence(hw_embed, hw_lengths, batch_first=True, + enforce_sorted=False) + _, (h_n, _) = self.bias_encoder(hw_embed) + # hw_embed, _ = torch.nn.utils.rnn.pad_packed_sequence(hw_embed, batch_first=True) + contextual_info = h_n.squeeze(0).repeat(encoder_out.shape[0], 1, 1) + decoder_outs = self.decoder( + encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, contextual_info=contextual_info + ) + decoder_out = decoder_outs[0] + decoder_out = torch.log_softmax(decoder_out, dim=-1) + return decoder_out, ys_pad_lens + + def gen_clas_tf2torch_map_dict(self): + tensor_name_prefix_torch = "bias_encoder" + tensor_name_prefix_tf = "seq2seq/clas_charrnn" + + tensor_name_prefix_torch_emb = "bias_embed" + tensor_name_prefix_tf_emb = "seq2seq" + + map_dict_local = { + # in lstm + "{}.weight_ih_l0".format(tensor_name_prefix_torch): + {"name": "{}/rnn/lstm_cell/kernel".format(tensor_name_prefix_tf), + "squeeze": None, + "transpose": (1, 0), + "slice": (0, 512), + "unit_k": 512, + }, # (1024, 2048),(2048,512) + "{}.weight_hh_l0".format(tensor_name_prefix_torch): + {"name": "{}/rnn/lstm_cell/kernel".format(tensor_name_prefix_tf), + "squeeze": None, + "transpose": (1, 0), + "slice": (512, 1024), + "unit_k": 512, + }, # (1024, 2048),(2048,512) + "{}.bias_ih_l0".format(tensor_name_prefix_torch): + {"name": "{}/rnn/lstm_cell/bias".format(tensor_name_prefix_tf), + "squeeze": None, + "transpose": None, + "scale": 0.5, + "unit_b": 512, + }, # (2048,),(2048,) + "{}.bias_hh_l0".format(tensor_name_prefix_torch): + {"name": "{}/rnn/lstm_cell/bias".format(tensor_name_prefix_tf), + "squeeze": None, + "transpose": None, + "scale": 0.5, + "unit_b": 512, + }, # (2048,),(2048,) + + # in embed + "{}.weight".format(tensor_name_prefix_torch_emb): + {"name": "{}/contextual_encoder/w_char_embs".format(tensor_name_prefix_tf_emb), + "squeeze": None, + "transpose": None, + }, # (4235,256),(4235,256) + } + return map_dict_local + + def clas_convert_tf2torch(self, + var_dict_tf, + var_dict_torch): + map_dict = self.gen_clas_tf2torch_map_dict() + var_dict_torch_update = dict() + for name in sorted(var_dict_torch.keys(), reverse=False): + names = name.split('.') + if names[0] == "bias_encoder": + name_q = name + if name_q in map_dict.keys(): + name_v = map_dict[name_q]["name"] + name_tf = name_v + data_tf = var_dict_tf[name_tf] + if map_dict[name_q].get("unit_k") is not None: + dim = map_dict[name_q]["unit_k"] + i = data_tf[:, 0:dim].copy() + f = data_tf[:, dim:2 * dim].copy() + o = data_tf[:, 2 * dim:3 * dim].copy() + g = data_tf[:, 3 * dim:4 * dim].copy() + data_tf = np.concatenate([i, o, f, g], axis=1) + if map_dict[name_q].get("unit_b") is not None: + dim = map_dict[name_q]["unit_b"] + i = data_tf[0:dim].copy() + f = data_tf[dim:2 * dim].copy() + o = data_tf[2 * dim:3 * dim].copy() + g = data_tf[3 * dim:4 * dim].copy() + data_tf = np.concatenate([i, o, f, g], axis=0) + if map_dict[name_q]["squeeze"] is not None: + data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"]) + if map_dict[name_q].get("slice") is not None: + data_tf = data_tf[map_dict[name_q]["slice"][0]:map_dict[name_q]["slice"][1]] + if map_dict[name_q].get("scale") is not None: + data_tf = data_tf * map_dict[name_q]["scale"] + if map_dict[name_q]["transpose"] is not None: + data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"]) + data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu") + assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf, + var_dict_torch[ + name].size(), + data_tf.size()) + var_dict_torch_update[name] = data_tf + logging.info( + "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v, + var_dict_tf[name_tf].shape)) + elif names[0] == "bias_embed": + name_tf = map_dict[name]["name"] + data_tf = var_dict_tf[name_tf] + if map_dict[name]["squeeze"] is not None: + data_tf = np.squeeze(data_tf, axis=map_dict[name]["squeeze"]) + if map_dict[name]["transpose"] is not None: + data_tf = np.transpose(data_tf, map_dict[name]["transpose"]) + data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu") + assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf, + var_dict_torch[ + name].size(), + data_tf.size()) + var_dict_torch_update[name] = data_tf + logging.info( + "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf, + var_dict_tf[name_tf].shape)) + + return var_dict_torch_update \ No newline at end of file diff --git a/funasr/models/predictor/cif.py b/funasr/models/predictor/cif.py index c34759d0d..561537323 100644 --- a/funasr/models/predictor/cif.py +++ b/funasr/models/predictor/cif.py @@ -544,9 +544,8 @@ class CifPredictorV3(nn.Module): token_num_int = torch.max(token_num).type(torch.int32).item() acoustic_embeds = acoustic_embeds[:, :token_num_int, :] return acoustic_embeds, token_num, alphas, cif_peak, token_num2 - - def get_upsample_timestamp(self, hidden, target_label=None, mask=None, ignore_id=-1, mask_chunk_predictor=None, - target_label_length=None, token_num=None): + + def get_upsample_timestamp(self, hidden, mask=None, token_num=None): h = hidden b = hidden.shape[0] context = h.transpose(1, 2) diff --git a/funasr/tasks/asr.py b/funasr/tasks/asr.py index 1b7f152a8..e62a74820 100644 --- a/funasr/tasks/asr.py +++ b/funasr/tasks/asr.py @@ -37,8 +37,9 @@ from funasr.models.decoder.transformer_decoder import ( ) from funasr.models.decoder.transformer_decoder import ParaformerDecoderSAN from funasr.models.decoder.transformer_decoder import TransformerDecoder +from funasr.models.decoder.contextual_decoder import ContextualParaformerDecoder from funasr.models.e2e_asr import ESPnetASRModel -from funasr.models.e2e_asr_paraformer import Paraformer, ParaformerBert, BiCifParaformer +from funasr.models.e2e_asr_paraformer import Paraformer, ParaformerBert, BiCifParaformer, ContextualParaformer from funasr.models.e2e_uni_asr import UniASR from funasr.models.encoder.abs_encoder import AbsEncoder from funasr.models.encoder.conformer_encoder import ConformerEncoder @@ -117,6 +118,7 @@ model_choices = ClassChoices( paraformer=Paraformer, paraformer_bert=ParaformerBert, bicif_paraformer=BiCifParaformer, + contextual_paraformer=ContextualParaformer, ), type_check=AbsESPnetModel, default="asr", @@ -177,6 +179,7 @@ decoder_choices = ClassChoices( fsmn_scama_opt=FsmnDecoderSCAMAOpt, paraformer_decoder_sanm=ParaformerSANMDecoder, paraformer_decoder_san=ParaformerDecoderSAN, + contextual_paraformer_decoder=ContextualParaformerDecoder, ), type_check=AbsDecoder, default="rnn", @@ -1098,5 +1101,8 @@ class ASRTaskParaformer(ASRTask): # decoder var_dict_torch_update_local = model.decoder.convert_tf2torch(var_dict_tf, var_dict_torch) var_dict_torch_update.update(var_dict_torch_update_local) + # bias_encoder + var_dict_torch_update_local = model.clas_convert_tf2torch(var_dict_tf, var_dict_torch) + var_dict_torch_update.update(var_dict_torch_update_local) return var_dict_torch_update diff --git a/funasr/utils/timestamp_tools.py b/funasr/utils/timestamp_tools.py index 3afaa4049..33d1255cc 100644 --- a/funasr/utils/timestamp_tools.py +++ b/funasr/utils/timestamp_tools.py @@ -86,14 +86,51 @@ def time_stamp_lfr6(alphas: torch.Tensor, speech_lengths: torch.Tensor, raw_text else: return time_stamp_list - -def time_stamp_lfr6_advance(tst: List, text: str): - # advanced timestamp prediction for BiCIF_Paraformer using upsampled alphas - ds_alphas, ds_cif_peak, us_alphas, us_cif_peak = tst - if text.endswith(''): - text = text[:-4] +def time_stamp_lfr6_pl(us_alphas, us_cif_peak, char_list, begin_time=0.0, end_time=None): + START_END_THRESHOLD = 5 + TIME_RATE = 10.0 * 6 / 1000 / 3 # 3 times upsampled + if len(us_alphas.shape) == 3: + alphas, cif_peak = us_alphas[0], us_cif_peak[0] # support inference batch_size=1 only else: - text = text[:-1] - logging.warning("found text does not end with ") - assert int(ds_alphas.sum() + 1e-4) - 1 == len(text) - + alphas, cif_peak = us_alphas, us_cif_peak + num_frames = cif_peak.shape[0] + if char_list[-1] == '': + char_list = char_list[:-1] + # char_list = [i for i in text] + timestamp_list = [] + # for bicif model trained with large data, cif2 actually fires when a character starts + # so treat the frames between two peaks as the duration of the former token + fire_place = torch.where(cif_peak>1.0-1e-4)[0].cpu().numpy() - 1.5 + num_peak = len(fire_place) + assert num_peak == len(char_list) + 1 # number of peaks is supposed to be number of tokens + 1 + # begin silence + if fire_place[0] > START_END_THRESHOLD: + char_list.insert(0, '') + timestamp_list.append([0.0, fire_place[0]*TIME_RATE]) + # tokens timestamp + for i in range(len(fire_place)-1): + # the peak is always a little ahead of the start time + # timestamp_list.append([(fire_place[i]-1.2)*TIME_RATE, fire_place[i+1]*TIME_RATE]) + timestamp_list.append([(fire_place[i])*TIME_RATE, fire_place[i+1]*TIME_RATE]) + # cut the duration to token and sil of the 0-weight frames last long + # tail token and end silence + if num_frames - fire_place[-1] > START_END_THRESHOLD: + _end = (num_frames + fire_place[-1]) / 2 + timestamp_list[-1][1] = _end*TIME_RATE + timestamp_list.append([_end*TIME_RATE, num_frames*TIME_RATE]) + char_list.append("") + else: + timestamp_list[-1][1] = num_frames*TIME_RATE + if begin_time: # add offset time in model with vad + for i in range(len(timestamp_list)): + timestamp_list[i][0] = timestamp_list[i][0] + begin_time / 1000.0 + timestamp_list[i][1] = timestamp_list[i][1] + begin_time / 1000.0 + res_txt = "" + for char, timestamp in zip(char_list, timestamp_list): + res_txt += "{} {} {};".format(char, timestamp[0], timestamp[1]) + res = [] + for char, timestamp in zip(char_list, timestamp_list): + if char != '': + res.append([int(timestamp[0] * 1000), int(timestamp[1] * 1000)]) + return res +