From a030ff0f85fd6b1cc2a1d443d2fcfb11ccb1aa8f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B8=B8=E9=9B=81?= Date: Wed, 29 Mar 2023 21:15:55 +0800 Subject: [PATCH] export --- funasr/export/models/__init__.py | 4 + funasr/export/models/encoder/sanm_encoder.py | 99 +++++++++++++ .../export/models/target_delay_transformer.py | 132 +++++++++--------- .../export/models/vad_realtime_transformer.py | 79 +++++++++++ 4 files changed, 248 insertions(+), 66 deletions(-) create mode 100644 funasr/export/models/vad_realtime_transformer.py diff --git a/funasr/export/models/__init__.py b/funasr/export/models/__init__.py index a34133841..62ee72354 100644 --- a/funasr/export/models/__init__.py +++ b/funasr/export/models/__init__.py @@ -6,6 +6,8 @@ from funasr.export.models.e2e_vad import E2EVadModel as E2EVadModel_export from funasr.punctuation.target_delay_transformer import TargetDelayTransformer from funasr.export.models.target_delay_transformer import TargetDelayTransformer as TargetDelayTransformer_export from funasr.punctuation.espnet_model import ESPnetPunctuationModel +from funasr.punctuation.vad_realtime_transformer import VadRealtimeTransformer +from funasr.export.models.vad_realtime_transformer import VadRealtimeTransformer as VadRealtimeTransformer_export def get_model(model, export_config=None): if isinstance(model, BiCifParaformer): @@ -17,5 +19,7 @@ def get_model(model, export_config=None): elif isinstance(model, ESPnetPunctuationModel): if isinstance(model.punc_model, TargetDelayTransformer): return TargetDelayTransformer_export(model.punc_model, **export_config) + elif isinstance(model.punc_model, VadRealtimeTransformer): + return VadRealtimeTransformer_export(model.punc_model, **export_config) else: raise "Funasr does not support the given model type currently." diff --git a/funasr/export/models/encoder/sanm_encoder.py b/funasr/export/models/encoder/sanm_encoder.py index 8a5053870..3b7b4143f 100644 --- a/funasr/export/models/encoder/sanm_encoder.py +++ b/funasr/export/models/encoder/sanm_encoder.py @@ -107,3 +107,102 @@ class SANMEncoder(nn.Module): } } + + +class SANMVadEncoder(nn.Module): + def __init__( + self, + model, + max_seq_len=512, + feats_dim=560, + model_name='encoder', + onnx: bool = True, + ): + super().__init__() + self.embed = model.embed + self.model = model + self.feats_dim = feats_dim + self._output_size = model._output_size + + if onnx: + self.make_pad_mask = MakePadMask(max_seq_len, flip=False) + else: + self.make_pad_mask = sequence_mask(max_seq_len, flip=False) + + if hasattr(model, 'encoders0'): + for i, d in enumerate(self.model.encoders0): + if isinstance(d.self_attn, MultiHeadedAttentionSANM): + d.self_attn = MultiHeadedAttentionSANM_export(d.self_attn) + if isinstance(d.feed_forward, PositionwiseFeedForward): + d.feed_forward = PositionwiseFeedForward_export(d.feed_forward) + self.model.encoders0[i] = EncoderLayerSANM_export(d) + + for i, d in enumerate(self.model.encoders): + if isinstance(d.self_attn, MultiHeadedAttentionSANM): + d.self_attn = MultiHeadedAttentionSANM_export(d.self_attn) + if isinstance(d.feed_forward, PositionwiseFeedForward): + d.feed_forward = PositionwiseFeedForward_export(d.feed_forward) + self.model.encoders[i] = EncoderLayerSANM_export(d) + + self.model_name = model_name + self.num_heads = model.encoders[0].self_attn.h + self.hidden_size = model.encoders[0].self_attn.linear_out.out_features + + def prepare_mask(self, mask): + mask_3d_btd = mask[:, :, None] + if len(mask.shape) == 2: + mask_4d_bhlt = 1 - mask[:, None, None, :] + elif len(mask.shape) == 3: + mask_4d_bhlt = 1 - mask[:, None, :] + mask_4d_bhlt = mask_4d_bhlt * -10000.0 + + return mask_3d_btd, mask_4d_bhlt + + def forward(self, + speech: torch.Tensor, + speech_lengths: torch.Tensor, + ): + speech = speech * self._output_size ** 0.5 + mask = self.make_pad_mask(speech_lengths) + mask = self.prepare_mask(mask) + if self.embed is None: + xs_pad = speech + else: + xs_pad = self.embed(speech) + + encoder_outs = self.model.encoders0(xs_pad, mask) + xs_pad, masks = encoder_outs[0], encoder_outs[1] + + encoder_outs = self.model.encoders(xs_pad, mask) + xs_pad, masks = encoder_outs[0], encoder_outs[1] + + xs_pad = self.model.after_norm(xs_pad) + + return xs_pad, speech_lengths + + def get_output_size(self): + return self.model.encoders[0].size + + def get_dummy_inputs(self): + feats = torch.randn(1, 100, self.feats_dim) + return (feats) + + def get_input_names(self): + return ['feats'] + + def get_output_names(self): + return ['encoder_out', 'encoder_out_lens', 'predictor_weight'] + + def get_dynamic_axes(self): + return { + 'feats': { + 1: 'feats_length' + }, + 'encoder_out': { + 1: 'enc_out_length' + }, + 'predictor_weight': { + 1: 'pre_out_length' + } + + } diff --git a/funasr/export/models/target_delay_transformer.py b/funasr/export/models/target_delay_transformer.py index 0a2586c93..fd90835c9 100644 --- a/funasr/export/models/target_delay_transformer.py +++ b/funasr/export/models/target_delay_transformer.py @@ -28,7 +28,7 @@ class TargetDelayTransformer(nn.Module): onnx = kwargs["onnx"] self.embed = model.embed self.decoder = model.decoder - self.model = model + # self.model = model self.feats_dim = self.embed.embedding_dim self.num_embeddings = self.embed.num_embeddings self.model_name = model_name @@ -46,71 +46,71 @@ class TargetDelayTransformer(nn.Module): from funasr.export.models.encoder.sanm_encoder import SANMEncoder as SANMEncoder_export from funasr.punctuation.abs_model import AbsPunctuation - class TargetDelayTransformer(nn.Module): - - def __init__( - self, - model, - max_seq_len=512, - model_name='punc_model', - **kwargs, - ): - super().__init__() - onnx = False - if "onnx" in kwargs: - onnx = kwargs["onnx"] - self.embed = model.embed - self.decoder = model.decoder - self.model = model - self.feats_dim = self.embed.embedding_dim - self.num_embeddings = self.embed.num_embeddings - self.model_name = model_name - - if isinstance(model.encoder, SANMEncoder): - self.encoder = SANMEncoder_export(model.encoder, onnx=onnx) - else: - assert False, "Only support samn encode." - - def forward(self, input: torch.Tensor, text_lengths: torch.Tensor) -> Tuple[torch.Tensor, None]: - """Compute loss value from buffer sequences. - - Args: - input (torch.Tensor): Input ids. (batch, len) - hidden (torch.Tensor): Target ids. (batch, len) - - """ - x = self.embed(input) - # mask = self._target_mask(input) - h, _ = self.encoder(x, text_lengths) - y = self.decoder(h) - return y - - def get_dummy_inputs(self): - length = 120 - text_indexes = torch.randint(0, self.embed.num_embeddings, (2, length)) - text_lengths = torch.tensor([length - 20, length], dtype=torch.int32) - return (text_indexes, text_lengths) - - def get_input_names(self): - return ['input', 'text_lengths'] - - def get_output_names(self): - return ['logits'] - - def get_dynamic_axes(self): - return { - 'input': { - 0: 'batch_size', - 1: 'feats_length' - }, - 'text_lengths': { - 0: 'batch_size', - }, - 'logits': { - 0: 'batch_size', - 1: 'logits_length' - }, - } + # class TargetDelayTransformer(nn.Module): + # + # def __init__( + # self, + # model, + # max_seq_len=512, + # model_name='punc_model', + # **kwargs, + # ): + # super().__init__() + # onnx = False + # if "onnx" in kwargs: + # onnx = kwargs["onnx"] + # self.embed = model.embed + # self.decoder = model.decoder + # self.model = model + # self.feats_dim = self.embed.embedding_dim + # self.num_embeddings = self.embed.num_embeddings + # self.model_name = model_name + # + # if isinstance(model.encoder, SANMEncoder): + # self.encoder = SANMEncoder_export(model.encoder, onnx=onnx) + # else: + # assert False, "Only support samn encode." + # + # def forward(self, input: torch.Tensor, text_lengths: torch.Tensor) -> Tuple[torch.Tensor, None]: + # """Compute loss value from buffer sequences. + # + # Args: + # input (torch.Tensor): Input ids. (batch, len) + # hidden (torch.Tensor): Target ids. (batch, len) + # + # """ + # x = self.embed(input) + # # mask = self._target_mask(input) + # h, _ = self.encoder(x, text_lengths) + # y = self.decoder(h) + # return y + # + # def get_dummy_inputs(self): + # length = 120 + # text_indexes = torch.randint(0, self.embed.num_embeddings, (2, length)) + # text_lengths = torch.tensor([length - 20, length], dtype=torch.int32) + # return (text_indexes, text_lengths) + # + # def get_input_names(self): + # return ['input', 'text_lengths'] + # + # def get_output_names(self): + # return ['logits'] + # + # def get_dynamic_axes(self): + # return { + # 'input': { + # 0: 'batch_size', + # 1: 'feats_length' + # }, + # 'text_lengths': { + # 0: 'batch_size', + # }, + # 'logits': { + # 0: 'batch_size', + # 1: 'logits_length' + # }, + # } if isinstance(model.encoder, SANMEncoder): self.encoder = SANMEncoder_export(model.encoder, onnx=onnx) diff --git a/funasr/export/models/vad_realtime_transformer.py b/funasr/export/models/vad_realtime_transformer.py new file mode 100644 index 000000000..44583d853 --- /dev/null +++ b/funasr/export/models/vad_realtime_transformer.py @@ -0,0 +1,79 @@ +from typing import Any +from typing import List +from typing import Tuple + +import torch +import torch.nn as nn + +from funasr.modules.embedding import SinusoidalPositionEncoder +from funasr.punctuation.sanm_encoder import SANMVadEncoder as Encoder +from funasr.punctuation.abs_model import AbsPunctuation +from funasr.punctuation.sanm_encoder import SANMVadEncoder +from funasr.export.models.encoder.sanm_encoder import SANMVadEncoder as SANMVadEncoder_export + +class VadRealtimeTransformer(AbsPunctuation): + + def __init__( + self, + model, + max_seq_len=512, + model_name='punc_model', + **kwargs, + ): + super().__init__() + + + self.embed = model.embed + if isinstance(model.encoder, SANMVadEncoder): + self.encoder = SANMVadEncoder_export(model.encoder, onnx=onnx) + else: + assert False, "Only support samn encode." + # self.encoder = model.encoder + self.decoder = model.decoder + + + + def forward(self, input: torch.Tensor, text_lengths: torch.Tensor, + vad_indexes: torch.Tensor) -> Tuple[torch.Tensor, None]: + """Compute loss value from buffer sequences. + + Args: + input (torch.Tensor): Input ids. (batch, len) + hidden (torch.Tensor): Target ids. (batch, len) + + """ + x = self.embed(input) + # mask = self._target_mask(input) + h, _, _ = self.encoder(x, text_lengths, vad_indexes) + y = self.decoder(h) + return y + + def with_vad(self): + return True + + def get_dummy_inputs(self): + length = 120 + text_indexes = torch.randint(0, self.embed.num_embeddings, (2, length)) + text_lengths = torch.tensor([length-20, length], dtype=torch.int32) + return (text_indexes, text_lengths) + + def get_input_names(self): + return ['input', 'text_lengths'] + + def get_output_names(self): + return ['logits'] + + def get_dynamic_axes(self): + return { + 'input': { + 0: 'batch_size', + 1: 'feats_length' + }, + 'text_lengths': { + 0: 'batch_size', + }, + 'logits': { + 0: 'batch_size', + 1: 'logits_length' + }, + }