diff --git a/funasr/export/models/__init__.py b/funasr/export/models/__init__.py index 27a65af9b..0012377a5 100644 --- a/funasr/export/models/__init__.py +++ b/funasr/export/models/__init__.py @@ -1,10 +1,13 @@ -from funasr.models.e2e_asr_paraformer import Paraformer +from funasr.models.e2e_asr_paraformer import Paraformer, BiCifParaformer from funasr.export.models.e2e_asr_paraformer import Paraformer as Paraformer_export +from funasr.export.models.e2e_asr_paraformer import BiCifParaformer as BiCifParaformer_export from funasr.models.e2e_uni_asr import UniASR -def get_model(model, export_config=None): - if isinstance(model, Paraformer): +def get_model(model, export_config=None): + if isinstance(model, BiCifParaformer): + return BiCifParaformer_export(model, **export_config) + elif isinstance(model, Paraformer): return Paraformer_export(model, **export_config) else: - raise "The model is not exist!" \ No newline at end of file + raise "Funasr does not support the given model type currently." \ No newline at end of file diff --git a/funasr/export/models/decoder/transformer_decoder.py b/funasr/export/models/decoder/transformer_decoder.py new file mode 100644 index 000000000..d70a3c750 --- /dev/null +++ b/funasr/export/models/decoder/transformer_decoder.py @@ -0,0 +1,143 @@ +import os +from funasr.export import models + +import torch +import torch.nn as nn + + +from funasr.export.utils.torch_function import MakePadMask +from funasr.export.utils.torch_function import sequence_mask + +from funasr.modules.attention import MultiHeadedAttentionSANMDecoder +from funasr.export.models.modules.multihead_att import MultiHeadedAttentionSANMDecoder as MultiHeadedAttentionSANMDecoder_export +from funasr.modules.attention import MultiHeadedAttentionCrossAtt, MultiHeadedAttention +from funasr.export.models.modules.multihead_att import MultiHeadedAttentionCrossAtt as MultiHeadedAttentionCrossAtt_export +from funasr.export.models.modules.multihead_att import OnnxMultiHeadedAttention +from funasr.modules.positionwise_feed_forward import PositionwiseFeedForwardDecoderSANM +from funasr.export.models.modules.feedforward import PositionwiseFeedForwardDecoderSANM as PositionwiseFeedForwardDecoderSANM_export +from funasr.export.models.modules.decoder_layer import DecoderLayer as DecoderLayer_export + + +class ParaformerDecoderSAN(nn.Module): + def __init__(self, model, + max_seq_len=512, + model_name='decoder', + onnx: bool = True,): + super().__init__() + # self.embed = model.embed #Embedding(model.embed, max_seq_len) + self.model = model + if onnx: + self.make_pad_mask = MakePadMask(max_seq_len, flip=False) + else: + self.make_pad_mask = sequence_mask(max_seq_len, flip=False) + + for i, d in enumerate(self.model.decoders): + if isinstance(d.feed_forward, PositionwiseFeedForwardDecoderSANM): + d.feed_forward = PositionwiseFeedForwardDecoderSANM_export(d.feed_forward) + if isinstance(d.self_attn, MultiHeadedAttentionSANMDecoder): + d.self_attn = MultiHeadedAttentionSANMDecoder_export(d.self_attn) + # if isinstance(d.src_attn, MultiHeadedAttentionCrossAtt): + # d.src_attn = MultiHeadedAttentionCrossAtt_export(d.src_attn) + if isinstance(d.src_attn, MultiHeadedAttention): + d.src_attn = OnnxMultiHeadedAttention(d.src_attn) + self.model.decoders[i] = DecoderLayer_export(d) + + self.output_layer = model.output_layer + self.after_norm = model.after_norm + self.model_name = model_name + + + def prepare_mask(self, mask): + mask_3d_btd = mask[:, :, None] + if len(mask.shape) == 2: + mask_4d_bhlt = 1 - mask[:, None, None, :] + elif len(mask.shape) == 3: + mask_4d_bhlt = 1 - mask[:, None, :] + mask_4d_bhlt = mask_4d_bhlt * -10000.0 + + return mask_3d_btd, mask_4d_bhlt + + def forward( + self, + hs_pad: torch.Tensor, + hlens: torch.Tensor, + ys_in_pad: torch.Tensor, + ys_in_lens: torch.Tensor, + ): + + tgt = ys_in_pad + tgt_mask = self.make_pad_mask(ys_in_lens) + tgt_mask, _ = self.prepare_mask(tgt_mask) + # tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None] + + memory = hs_pad + memory_mask = self.make_pad_mask(hlens) + _, memory_mask = self.prepare_mask(memory_mask) + # memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :] + + x = tgt + x, tgt_mask, memory, memory_mask = self.model.decoders( + x, tgt_mask, memory, memory_mask + ) + x = self.after_norm(x) + x = self.output_layer(x) + + return x, ys_in_lens + + + def get_dummy_inputs(self, enc_size): + tgt = torch.LongTensor([0]).unsqueeze(0) + memory = torch.randn(1, 100, enc_size) + pre_acoustic_embeds = torch.randn(1, 1, enc_size) + cache_num = len(self.model.decoders) + len(self.model.decoders2) + cache = [ + torch.zeros((1, self.model.decoders[0].size, self.model.decoders[0].self_attn.kernel_size)) + for _ in range(cache_num) + ] + return (tgt, memory, pre_acoustic_embeds, cache) + + def is_optimizable(self): + return True + + def get_input_names(self): + cache_num = len(self.model.decoders) + len(self.model.decoders2) + return ['tgt', 'memory', 'pre_acoustic_embeds'] \ + + ['cache_%d' % i for i in range(cache_num)] + + def get_output_names(self): + cache_num = len(self.model.decoders) + len(self.model.decoders2) + return ['y'] \ + + ['out_cache_%d' % i for i in range(cache_num)] + + def get_dynamic_axes(self): + ret = { + 'tgt': { + 0: 'tgt_batch', + 1: 'tgt_length' + }, + 'memory': { + 0: 'memory_batch', + 1: 'memory_length' + }, + 'pre_acoustic_embeds': { + 0: 'acoustic_embeds_batch', + 1: 'acoustic_embeds_length', + } + } + cache_num = len(self.model.decoders) + len(self.model.decoders2) + ret.update({ + 'cache_%d' % d: { + 0: 'cache_%d_batch' % d, + 2: 'cache_%d_length' % d + } + for d in range(cache_num) + }) + return ret + + def get_model_config(self, path): + return { + "dec_type": "XformerDecoder", + "model_path": os.path.join(path, f'{self.model_name}.onnx'), + "n_layers": len(self.model.decoders) + len(self.model.decoders2), + "odim": self.model.decoders[0].size + } \ No newline at end of file diff --git a/funasr/export/models/e2e_asr_paraformer.py b/funasr/export/models/e2e_asr_paraformer.py index 5424a0a94..0db61e0c5 100644 --- a/funasr/export/models/e2e_asr_paraformer.py +++ b/funasr/export/models/e2e_asr_paraformer.py @@ -1,17 +1,21 @@ import logging - - import torch import torch.nn as nn from funasr.export.utils.torch_function import MakePadMask from funasr.export.utils.torch_function import sequence_mask from funasr.models.encoder.sanm_encoder import SANMEncoder +from funasr.models.encoder.conformer_encoder import ConformerEncoder from funasr.export.models.encoder.sanm_encoder import SANMEncoder as SANMEncoder_export -from funasr.models.predictor.cif import CifPredictorV2 +from funasr.export.models.encoder.conformer_encoder import ConformerEncoder as ConformerEncoder_export +from funasr.models.predictor.cif import CifPredictorV2, CifPredictorV3 from funasr.export.models.predictor.cif import CifPredictorV2 as CifPredictorV2_export +from funasr.export.models.predictor.cif import CifPredictorV3 as CifPredictorV3_export from funasr.models.decoder.sanm_decoder import ParaformerSANMDecoder +from funasr.models.decoder.transformer_decoder import ParaformerDecoderSAN from funasr.export.models.decoder.sanm_decoder import ParaformerSANMDecoder as ParaformerSANMDecoder_export +from funasr.export.models.decoder.transformer_decoder import ParaformerDecoderSAN as ParaformerDecoderSAN_export + class Paraformer(nn.Module): """ @@ -34,10 +38,14 @@ class Paraformer(nn.Module): onnx = kwargs["onnx"] if isinstance(model.encoder, SANMEncoder): self.encoder = SANMEncoder_export(model.encoder, onnx=onnx) + elif isinstance(model.encoder, ConformerEncoder): + self.encoder = ConformerEncoder_export(model.encoder, onnx=onnx) if isinstance(model.predictor, CifPredictorV2): self.predictor = CifPredictorV2_export(model.predictor) if isinstance(model.decoder, ParaformerSANMDecoder): self.decoder = ParaformerSANMDecoder_export(model.decoder, onnx=onnx) + elif isinstance(model.decoder, ParaformerDecoderSAN): + self.decoder = ParaformerDecoderSAN_export(model.decoder, onnx=onnx) self.feats_dim = feats_dim self.model_name = model_name @@ -99,4 +107,113 @@ class Paraformer(nn.Module): 0: 'batch_size', 1: 'logits_length' }, + } + + +class BiCifParaformer(nn.Module): + """ + Author: Speech Lab, Alibaba Group, China + Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition + https://arxiv.org/abs/2206.08317 + """ + + def __init__( + self, + model, + max_seq_len=512, + feats_dim=560, + model_name='model', + **kwargs, + ): + super().__init__() + onnx = False + if "onnx" in kwargs: + onnx = kwargs["onnx"] + if isinstance(model.encoder, SANMEncoder): + self.encoder = SANMEncoder_export(model.encoder, onnx=onnx) + elif isinstance(model.encoder, ConformerEncoder): + self.encoder = ConformerEncoder_export(model.encoder, onnx=onnx) + else: + logging.warning("Unsupported encoder type to export.") + if isinstance(model.predictor, CifPredictorV3): + self.predictor = CifPredictorV3_export(model.predictor) + else: + logging.warning("Wrong predictor type to export.") + if isinstance(model.decoder, ParaformerSANMDecoder): + self.decoder = ParaformerSANMDecoder_export(model.decoder, onnx=onnx) + elif isinstance(model.decoder, ParaformerDecoderSAN): + self.decoder = ParaformerDecoderSAN_export(model.decoder, onnx=onnx) + else: + logging.warning("Unsupported decoder type to export.") + + self.feats_dim = feats_dim + self.model_name = model_name + + if onnx: + self.make_pad_mask = MakePadMask(max_seq_len, flip=False) + else: + self.make_pad_mask = sequence_mask(max_seq_len, flip=False) + + def forward( + self, + speech: torch.Tensor, + speech_lengths: torch.Tensor, + ): + # a. To device + batch = {"speech": speech, "speech_lengths": speech_lengths} + # batch = to_device(batch, device=self.device) + + enc, enc_len = self.encoder(**batch) + mask = self.make_pad_mask(enc_len)[:, None, :] + pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = self.predictor(enc, mask) + pre_token_length = pre_token_length.round().type(torch.int32) + + decoder_out, _ = self.decoder(enc, enc_len, pre_acoustic_embeds, pre_token_length) + decoder_out = torch.log_softmax(decoder_out, dim=-1) + + # get predicted timestamps + us_alphas, us_cif_peak = self.predictor.get_upsample_timestmap(enc, mask, pre_token_length) + + return decoder_out, pre_token_length, us_alphas, us_cif_peak + + def get_dummy_inputs(self): + speech = torch.randn(2, 30, self.feats_dim) + speech_lengths = torch.tensor([6, 30], dtype=torch.int32) + return (speech, speech_lengths) + + def get_dummy_inputs_txt(self, txt_file: str = "/mnt/workspace/data_fbank/0207/12345.wav.fea.txt"): + import numpy as np + fbank = np.loadtxt(txt_file) + fbank_lengths = np.array([fbank.shape[0], ], dtype=np.int32) + speech = torch.from_numpy(fbank[None, :, :].astype(np.float32)) + speech_lengths = torch.from_numpy(fbank_lengths.astype(np.int32)) + return (speech, speech_lengths) + + def get_input_names(self): + return ['speech', 'speech_lengths'] + + def get_output_names(self): + return ['logits', 'token_num', 'us_alphas', 'us_cif_peak'] + + def get_dynamic_axes(self): + return { + 'speech': { + 0: 'batch_size', + 1: 'feats_length' + }, + 'speech_lengths': { + 0: 'batch_size', + }, + 'logits': { + 0: 'batch_size', + 1: 'logits_length' + }, + 'us_alphas': { + 0: 'batch_size', + 1: 'alphas_length' + }, + 'us_cif_peak': { + 0: 'batch_size', + 1: 'alphas_length' + }, } \ No newline at end of file diff --git a/funasr/export/models/encoder/conformer_encoder.py b/funasr/export/models/encoder/conformer_encoder.py new file mode 100644 index 000000000..9f2257462 --- /dev/null +++ b/funasr/export/models/encoder/conformer_encoder.py @@ -0,0 +1,106 @@ +import torch +import torch.nn as nn + +from funasr.export.utils.torch_function import MakePadMask +from funasr.export.utils.torch_function import sequence_mask +from funasr.modules.attention import MultiHeadedAttentionSANM +from funasr.export.models.modules.multihead_att import MultiHeadedAttentionSANM as MultiHeadedAttentionSANM_export +from funasr.export.models.modules.encoder_layer import EncoderLayerSANM as EncoderLayerSANM_export +from funasr.export.models.modules.encoder_layer import EncoderLayerConformer as EncoderLayerConformer_export +from funasr.modules.positionwise_feed_forward import PositionwiseFeedForward +from funasr.export.models.modules.feedforward import PositionwiseFeedForward as PositionwiseFeedForward_export +from funasr.export.models.encoder.sanm_encoder import SANMEncoder +from funasr.modules.attention import RelPositionMultiHeadedAttention +# from funasr.export.models.modules.multihead_att import RelPositionMultiHeadedAttention as RelPositionMultiHeadedAttention_export +from funasr.export.models.modules.multihead_att import OnnxRelPosMultiHeadedAttention as RelPositionMultiHeadedAttention_export + + +class ConformerEncoder(nn.Module): + def __init__( + self, + model, + max_seq_len=512, + feats_dim=560, + model_name='encoder', + onnx: bool = True, + ): + super().__init__() + self.embed = model.embed + self.model = model + self.feats_dim = feats_dim + self._output_size = model._output_size + + if onnx: + self.make_pad_mask = MakePadMask(max_seq_len, flip=False) + else: + self.make_pad_mask = sequence_mask(max_seq_len, flip=False) + + for i, d in enumerate(self.model.encoders): + if isinstance(d.self_attn, MultiHeadedAttentionSANM): + d.self_attn = MultiHeadedAttentionSANM_export(d.self_attn) + if isinstance(d.self_attn, RelPositionMultiHeadedAttention): + d.self_attn = RelPositionMultiHeadedAttention_export(d.self_attn) + if isinstance(d.feed_forward, PositionwiseFeedForward): + d.feed_forward = PositionwiseFeedForward_export(d.feed_forward) + self.model.encoders[i] = EncoderLayerConformer_export(d) + + self.model_name = model_name + self.num_heads = model.encoders[0].self_attn.h + self.hidden_size = model.encoders[0].self_attn.linear_out.out_features + + + def prepare_mask(self, mask): + if len(mask.shape) == 2: + mask = 1 - mask[:, None, None, :] + elif len(mask.shape) == 3: + mask = 1 - mask[:, None, :] + + return mask * -10000.0 + + def forward(self, + speech: torch.Tensor, + speech_lengths: torch.Tensor, + ): + speech = speech * self._output_size ** 0.5 + mask = self.make_pad_mask(speech_lengths) + mask = self.prepare_mask(mask) + if self.embed is None: + xs_pad = speech + else: + xs_pad = self.embed(speech) + + encoder_outs = self.model.encoders(xs_pad, mask) + xs_pad, masks = encoder_outs[0], encoder_outs[1] + + if isinstance(xs_pad, tuple): + xs_pad = xs_pad[0] + xs_pad = self.model.after_norm(xs_pad) + + return xs_pad, speech_lengths + + def get_output_size(self): + return self.model.encoders[0].size + + def get_dummy_inputs(self): + feats = torch.randn(1, 100, self.feats_dim) + return (feats) + + def get_input_names(self): + return ['feats'] + + def get_output_names(self): + return ['encoder_out', 'encoder_out_lens', 'predictor_weight'] + + def get_dynamic_axes(self): + return { + 'feats': { + 1: 'feats_length' + }, + 'encoder_out': { + 1: 'enc_out_length' + }, + 'predictor_weight':{ + 1: 'pre_out_length' + } + + } diff --git a/funasr/export/models/modules/decoder_layer.py b/funasr/export/models/modules/decoder_layer.py index bc306b1fd..f5394523d 100644 --- a/funasr/export/models/modules/decoder_layer.py +++ b/funasr/export/models/modules/decoder_layer.py @@ -41,3 +41,30 @@ class DecoderLayerSANM(nn.Module): return x, tgt_mask, memory, memory_mask, cache + +class DecoderLayer(nn.Module): + def __init__(self, model): + super().__init__() + self.self_attn = model.self_attn + self.src_attn = model.src_attn + self.feed_forward = model.feed_forward + self.norm1 = model.norm1 + self.norm2 = model.norm2 + self.norm3 = model.norm3 + + def forward(self, tgt, tgt_mask, memory, memory_mask, cache=None): + residual = tgt + tgt_q = tgt + tgt_q_mask = tgt_mask + x = residual + self.self_attn(tgt_q, tgt, tgt, tgt_q_mask) + + residual = x + x = self.norm2(x) + + x = residual + self.src_attn(x, memory, memory, memory_mask) + + residual = x + x = self.norm3(x) + x = residual + self.feed_forward(x) + + return x, tgt_mask, memory, memory_mask diff --git a/funasr/export/models/modules/encoder_layer.py b/funasr/export/models/modules/encoder_layer.py index 800a4f784..622b109d3 100644 --- a/funasr/export/models/modules/encoder_layer.py +++ b/funasr/export/models/modules/encoder_layer.py @@ -34,4 +34,58 @@ class EncoderLayerSANM(nn.Module): return x, mask +class EncoderLayerConformer(nn.Module): + def __init__( + self, + model, + ): + """Construct an EncoderLayer object.""" + super().__init__() + self.self_attn = model.self_attn + self.feed_forward = model.feed_forward + self.feed_forward_macaron = model.feed_forward_macaron + self.conv_module = model.conv_module + self.norm_ff = model.norm_ff + self.norm_mha = model.norm_mha + self.norm_ff_macaron = model.norm_ff_macaron + self.norm_conv = model.norm_conv + self.norm_final = model.norm_final + self.size = model.size + def forward(self, x, mask): + if isinstance(x, tuple): + x, pos_emb = x[0], x[1] + else: + x, pos_emb = x, None + + if self.feed_forward_macaron is not None: + residual = x + x = self.norm_ff_macaron(x) + x = residual + self.feed_forward_macaron(x) + + residual = x + x = self.norm_mha(x) + + x_q = x + + if pos_emb is not None: + x_att = self.self_attn(x_q, x, x, pos_emb, mask) + else: + x_att = self.self_attn(x_q, x, x, mask) + x = residual + x_att + + if self.conv_module is not None: + residual = x + x = self.norm_conv(x) + x = residual + self.conv_module(x) + + residual = x + x = self.norm_ff(x) + x = residual + self.feed_forward(x) + + x = self.norm_final(x) + + if pos_emb is not None: + return (x, pos_emb), mask + + return x, mask diff --git a/funasr/export/models/modules/multihead_att.py b/funasr/export/models/modules/multihead_att.py index 377b979d4..7d685f588 100644 --- a/funasr/export/models/modules/multihead_att.py +++ b/funasr/export/models/modules/multihead_att.py @@ -4,6 +4,7 @@ import math import torch import torch.nn as nn + class MultiHeadedAttentionSANM(nn.Module): def __init__(self, model): super().__init__() @@ -32,7 +33,6 @@ class MultiHeadedAttentionSANM(nn.Module): return x.permute(0, 2, 1, 3) def forward_qkv(self, x): - q_k_v = self.linear_q_k_v(x) q, k, v = torch.split(q_k_v, int(self.h * self.d_k), dim=-1) q_h = self.transpose_for_scores(q) @@ -41,7 +41,6 @@ class MultiHeadedAttentionSANM(nn.Module): return q_h, k_h, v_h, v def forward_fsmn(self, inputs, mask): - # b, t, d = inputs.size() # mask = torch.reshape(mask, (b, -1, 1)) inputs = inputs * mask @@ -53,7 +52,6 @@ class MultiHeadedAttentionSANM(nn.Module): x = x * mask return x - def forward_attention(self, value, scores, mask): scores = scores + mask @@ -65,6 +63,7 @@ class MultiHeadedAttentionSANM(nn.Module): context_layer = context_layer.view(new_context_layer_shape) return self.linear_out(context_layer) # (batch, time1, d_model) + class MultiHeadedAttentionSANMDecoder(nn.Module): def __init__(self, model): super().__init__() @@ -74,7 +73,6 @@ class MultiHeadedAttentionSANMDecoder(nn.Module): self.attn = None def forward(self, inputs, mask, cache=None): - # b, t, d = inputs.size() # mask = torch.reshape(mask, (b, -1, 1)) inputs = inputs * mask @@ -92,6 +90,7 @@ class MultiHeadedAttentionSANMDecoder(nn.Module): x = x * mask return x, cache + class MultiHeadedAttentionCrossAtt(nn.Module): def __init__(self, model): super().__init__() @@ -133,3 +132,104 @@ class MultiHeadedAttentionCrossAtt(nn.Module): new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) context_layer = context_layer.view(new_context_layer_shape) return self.linear_out(context_layer) # (batch, time1, d_model) + + +class OnnxMultiHeadedAttention(nn.Module): + def __init__(self, model): + super().__init__() + self.d_k = model.d_k + self.h = model.h + self.linear_q = model.linear_q + self.linear_k = model.linear_k + self.linear_v = model.linear_v + self.linear_out = model.linear_out + self.attn = None + self.all_head_size = self.h * self.d_k + + def forward(self, query, key, value, mask): + q, k, v = self.forward_qkv(query, key, value) + scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k) + return self.forward_attention(v, scores, mask) + + def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: + new_x_shape = x.size()[:-1] + (self.h, self.d_k) + x = x.view(new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward_qkv(self, query, key, value): + q = self.linear_q(query) + k = self.linear_k(key) + v = self.linear_v(value) + q = self.transpose_for_scores(q) + k = self.transpose_for_scores(k) + v = self.transpose_for_scores(v) + return q, k, v + + def forward_attention(self, value, scores, mask): + scores = scores + mask + + self.attn = torch.softmax(scores, dim=-1) + context_layer = torch.matmul(self.attn, value) # (batch, head, time1, d_k) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(new_context_layer_shape) + return self.linear_out(context_layer) # (batch, time1, d_model) + + +class OnnxRelPosMultiHeadedAttention(OnnxMultiHeadedAttention): + def __init__(self, model): + super().__init__(model) + self.linear_pos = model.linear_pos + self.pos_bias_u = model.pos_bias_u + self.pos_bias_v = model.pos_bias_v + + def forward(self, query, key, value, pos_emb, mask): + q, k, v = self.forward_qkv(query, key, value) + q = q.transpose(1, 2) # (batch, time1, head, d_k) + + p = self.transpose_for_scores(self.linear_pos(pos_emb)) # (batch, head, time1, d_k) + + # (batch, head, time1, d_k) + q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2) + # (batch, head, time1, d_k) + q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2) + + # compute attention score + # first compute matrix a and matrix c + # as described in https://arxiv.org/abs/1901.02860 Section 3.3 + # (batch, head, time1, time2) + matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1)) + + # compute matrix b and matrix d + # (batch, head, time1, time1) + matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1)) + matrix_bd = self.rel_shift(matrix_bd) + + scores = (matrix_ac + matrix_bd) / math.sqrt( + self.d_k + ) # (batch, head, time1, time2) + + return self.forward_attention(v, scores, mask) + + def rel_shift(self, x): + zero_pad = torch.zeros((*x.size()[:3], 1), device=x.device, dtype=x.dtype) + x_padded = torch.cat([zero_pad, x], dim=-1) + + x_padded = x_padded.view(*x.size()[:2], x.size(3) + 1, x.size(2)) + x = x_padded[:, :, 1:].view_as(x)[ + :, :, :, : x.size(-1) // 2 + 1 + ] # only keep the positions from 0 to time2 + return x + + def forward_attention(self, value, scores, mask): + scores = scores + mask + + self.attn = torch.softmax(scores, dim=-1) + context_layer = torch.matmul(self.attn, value) # (batch, head, time1, d_k) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(new_context_layer_shape) + return self.linear_out(context_layer) # (batch, time1, d_model) + \ No newline at end of file diff --git a/funasr/export/models/predictor/cif.py b/funasr/export/models/predictor/cif.py index 6f4601dfa..5ea4a34e2 100644 --- a/funasr/export/models/predictor/cif.py +++ b/funasr/export/models/predictor/cif.py @@ -1,9 +1,8 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- + import torch from torch import nn -import logging -import numpy as np def sequence_mask(lengths, maxlen=None, dtype=torch.float32, device=None): @@ -175,3 +174,115 @@ def cif(hidden, alphas, threshold: float): max_label_len = frame_len frame_fires = frame_fires[:, :max_label_len, :] return frame_fires, fires + + +class CifPredictorV3(nn.Module): + def __init__(self, model): + super().__init__() + + self.pad = model.pad + self.cif_conv1d = model.cif_conv1d + self.cif_output = model.cif_output + self.threshold = model.threshold + self.smooth_factor = model.smooth_factor + self.noise_threshold = model.noise_threshold + self.tail_threshold = model.tail_threshold + + self.upsample_times = model.upsample_times + self.upsample_cnn = model.upsample_cnn + self.blstm = model.blstm + self.cif_output2 = model.cif_output2 + self.smooth_factor2 = model.smooth_factor2 + self.noise_threshold2 = model.noise_threshold2 + + def forward(self, hidden: torch.Tensor, + mask: torch.Tensor, + ): + h = hidden + context = h.transpose(1, 2) + queries = self.pad(context) + output = torch.relu(self.cif_conv1d(queries)) + output = output.transpose(1, 2) + + output = self.cif_output(output) + alphas = torch.sigmoid(output) + alphas = torch.nn.functional.relu(alphas * self.smooth_factor - self.noise_threshold) + mask = mask.transpose(-1, -2).float() + alphas = alphas * mask + alphas = alphas.squeeze(-1) + token_num = alphas.sum(-1) + + mask = mask.squeeze(-1) + hidden, alphas, token_num = self.tail_process_fn(hidden, alphas, mask=mask) + acoustic_embeds, cif_peak = cif(hidden, alphas, self.threshold) + + return acoustic_embeds, token_num, alphas, cif_peak + + def get_upsample_timestmap(self, hidden, mask=None, token_num=None): + h = hidden + b = hidden.shape[0] + context = h.transpose(1, 2) + + # generate alphas2 + _output = context + output2 = self.upsample_cnn(_output) + output2 = output2.transpose(1, 2) + output2, (_, _) = self.blstm(output2) + alphas2 = torch.sigmoid(self.cif_output2(output2)) + alphas2 = torch.nn.functional.relu(alphas2 * self.smooth_factor2 - self.noise_threshold2) + + mask = mask.repeat(1, self.upsample_times, 1).transpose(-1, -2).reshape(alphas2.shape[0], -1) + mask = mask.unsqueeze(-1) + alphas2 = alphas2 * mask + alphas2 = alphas2.squeeze(-1) + _token_num = alphas2.sum(-1) + alphas2 *= (token_num / _token_num)[:, None].repeat(1, alphas2.size(1)) + # upsampled alphas and cif_peak + us_alphas = alphas2 + us_cif_peak = cif_wo_hidden(us_alphas, self.threshold - 1e-4) + return us_alphas, us_cif_peak + + def tail_process_fn(self, hidden, alphas, token_num=None, mask=None): + b, t, d = hidden.size() + tail_threshold = self.tail_threshold + + zeros_t = torch.zeros((b, 1), dtype=torch.float32, device=alphas.device) + ones_t = torch.ones_like(zeros_t) + + mask_1 = torch.cat([mask, zeros_t], dim=1) + mask_2 = torch.cat([ones_t, mask], dim=1) + mask = mask_2 - mask_1 + tail_threshold = mask * tail_threshold + alphas = torch.cat([alphas, zeros_t], dim=1) + alphas = torch.add(alphas, tail_threshold) + + zeros = torch.zeros((b, 1, d), dtype=hidden.dtype).to(hidden.device) + hidden = torch.cat([hidden, zeros], dim=1) + token_num = alphas.sum(dim=-1) + token_num_floor = torch.floor(token_num) + + return hidden, alphas, token_num_floor + + +@torch.jit.script +def cif_wo_hidden(alphas, threshold: float): + batch_size, len_time = alphas.size() + + # loop varss + integrate = torch.zeros([batch_size], dtype=alphas.dtype, device=alphas.device) + # intermediate vars along time + list_fires = [] + + for t in range(len_time): + alpha = alphas[:, t] + + integrate += alpha + list_fires.append(integrate) + + fire_place = integrate >= threshold + integrate = torch.where(fire_place, + integrate - torch.ones([batch_size], device=alphas.device), + integrate) + + fires = torch.stack(list_fires, 1) + return fires \ No newline at end of file diff --git a/funasr/runtime/python/onnxruntime/demo.py b/funasr/runtime/python/onnxruntime/demo.py index 9c7f2f450..b4a03f3d0 100644 --- a/funasr/runtime/python/onnxruntime/demo.py +++ b/funasr/runtime/python/onnxruntime/demo.py @@ -1,8 +1,10 @@ from rapid_paraformer import Paraformer +from rapid_paraformer import BiCifParaformer -model_dir = "/Users/shixian/code/funasr2/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch" -model = Paraformer(model_dir, batch_size=1) +model_dir = "/Users/shixian/code/funasr2/export/damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch" +# model = Paraformer(model_dir, batch_size=1) +model = BiCifParaformer(model_dir, batch_size=1) wav_path = ['/Users/shixian/code/funasr2/export/damo/speech_paraformer-tiny-commandword_asr_nat-zh-cn-16k-vocab544-pytorch/example/asr_example.wav'] diff --git a/funasr/runtime/python/onnxruntime/rapid_paraformer/__init__.py b/funasr/runtime/python/onnxruntime/rapid_paraformer/__init__.py index f1b5c29b8..64e0a16f2 100644 --- a/funasr/runtime/python/onnxruntime/rapid_paraformer/__init__.py +++ b/funasr/runtime/python/onnxruntime/rapid_paraformer/__init__.py @@ -2,3 +2,4 @@ # @Author: SWHL # @Contact: liekkaskono@163.com from .paraformer_onnx import Paraformer +from .paraformer_onnx import BiCifParaformer diff --git a/funasr/runtime/python/onnxruntime/rapid_paraformer/paraformer_onnx.py b/funasr/runtime/python/onnxruntime/rapid_paraformer/paraformer_onnx.py index a786ef0c6..d77bcf724 100644 --- a/funasr/runtime/python/onnxruntime/rapid_paraformer/paraformer_onnx.py +++ b/funasr/runtime/python/onnxruntime/rapid_paraformer/paraformer_onnx.py @@ -5,6 +5,7 @@ import os.path from pathlib import Path from typing import List, Union, Tuple +import copy import librosa import numpy as np @@ -13,6 +14,7 @@ from .utils.utils import (CharTokenizer, Hypothesis, ONNXRuntimeError, read_yaml) from .utils.postprocess_utils import sentence_postprocess from .utils.frontend import WavFrontend +from funasr.utils.timestamp_tools import time_stamp_lfr6_pl logging = get_logger() @@ -134,8 +136,67 @@ class Paraformer(): # Change integer-ids to tokens token = self.converter.ids2tokens(token_int) - token = token[:valid_token_num-1] + # token = token[:valid_token_num-1] texts = sentence_postprocess(token) text = texts[0] # text = self.tokenizer.tokens2text(token) return text + + +class BiCifParaformer(Paraformer): + def infer(self, feats: np.ndarray, + feats_len: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: + am_scores, token_nums, us_alphas, us_cif_peak = self.ort_infer([feats, feats_len]) + return am_scores, token_nums, us_alphas, us_cif_peak + def __call__(self, wav_content: Union[str, np.ndarray, List[str]], **kwargs) -> List: + waveform_list = self.load_data(wav_content, self.frontend.opts.frame_opts.samp_freq) + waveform_nums = len(waveform_list) + + asr_res = [] + for beg_idx in range(0, waveform_nums, self.batch_size): + res = {} + end_idx = min(waveform_nums, beg_idx + self.batch_size) + feats, feats_len = self.extract_feat(waveform_list[beg_idx:end_idx]) + am_scores, valid_token_lens, us_alphas, us_cif_peak = self.infer(feats, feats_len) + + try: + am_scores, valid_token_lens, us_alphas, us_cif_peak = self.infer(feats, feats_len) + except ONNXRuntimeError: + #logging.warning(traceback.format_exc()) + logging.warning("input wav is silence or noise") + preds = [''] + else: + token = self.decode(am_scores, valid_token_lens) + timestamp = time_stamp_lfr6_pl(us_alphas, us_cif_peak, copy.copy(token[0]), log=False) + texts = sentence_postprocess(token[0], timestamp) + # texts = sentence_postprocess(token[0]) + text = texts[0] + res['text'] = text + res['timestamp'] = timestamp + asr_res.append(res) + + return asr_res + + def decode_one(self, + am_score: np.ndarray, + valid_token_num: int) -> List[str]: + yseq = am_score.argmax(axis=-1) + score = am_score.max(axis=-1) + score = np.sum(score, axis=-1) + + # pad with mask tokens to ensure compatibility with sos/eos tokens + # asr_model.sos:1 asr_model.eos:2 + yseq = np.array([1] + yseq.tolist() + [2]) + hyp = Hypothesis(yseq=yseq, score=score) + + # remove sos/eos and get results + last_pos = -1 + token_int = hyp.yseq[1:last_pos].tolist() + + # remove blank symbol id, which is assumed to be 0 + token_int = list(filter(lambda x: x not in (0, 2), token_int)) + + # Change integer-ids to tokens + token = self.converter.ids2tokens(token_int) + # token = token[:valid_token_num-1] + return token \ No newline at end of file diff --git a/funasr/utils/timestamp_tools.py b/funasr/utils/timestamp_tools.py index f6a6e989b..b82c74ad7 100644 --- a/funasr/utils/timestamp_tools.py +++ b/funasr/utils/timestamp_tools.py @@ -4,6 +4,7 @@ import logging import numpy as np from typing import Any, List, Tuple, Union + def time_stamp_lfr6_pl(us_alphas, us_cif_peak, char_list, begin_time=0.0, end_time=None): if not len(char_list): return []