diff --git a/funasr/export/models/__init__.py b/funasr/export/models/__init__.py index 4ac0456b9..f81ff6454 100644 --- a/funasr/export/models/__init__.py +++ b/funasr/export/models/__init__.py @@ -4,10 +4,10 @@ from funasr.export.models.e2e_asr_paraformer import BiCifParaformer as BiCifPara from funasr.models.e2e_vad import E2EVadModel from funasr.export.models.e2e_vad import E2EVadModel as E2EVadModel_export from funasr.models.target_delay_transformer import TargetDelayTransformer -from funasr.export.models.target_delay_transformer import TargetDelayTransformer as TargetDelayTransformer_export +from funasr.export.models.target_delay_transformer import CT_Transformer as CT_Transformer_export from funasr.train.abs_model import PunctuationModel from funasr.models.vad_realtime_transformer import VadRealtimeTransformer -from funasr.export.models.vad_realtime_transformer import VadRealtimeTransformer as VadRealtimeTransformer_export +from funasr.export.models.target_delay_transformer import CT_Transformer_VadRealtime as CT_Transformer_VadRealtime_export def get_model(model, export_config=None): if isinstance(model, BiCifParaformer): @@ -18,8 +18,8 @@ def get_model(model, export_config=None): return E2EVadModel_export(model, **export_config) elif isinstance(model, PunctuationModel): if isinstance(model.punc_model, TargetDelayTransformer): - return TargetDelayTransformer_export(model.punc_model, **export_config) + return CT_Transformer_export(model.punc_model, **export_config) elif isinstance(model.punc_model, VadRealtimeTransformer): - return VadRealtimeTransformer_export(model.punc_model, **export_config) + return CT_Transformer_VadRealtime_export(model.punc_model, **export_config) else: raise "Funasr does not support the given model type currently." diff --git a/funasr/export/models/target_delay_transformer.py b/funasr/export/models/target_delay_transformer.py index bfe3ec423..2780d8275 100644 --- a/funasr/export/models/target_delay_transformer.py +++ b/funasr/export/models/target_delay_transformer.py @@ -3,7 +3,12 @@ from typing import Tuple import torch import torch.nn as nn -class TargetDelayTransformer(nn.Module): +from funasr.models.encoder.sanm_encoder import SANMEncoder +from funasr.export.models.encoder.sanm_encoder import SANMEncoder as SANMEncoder_export +from funasr.models.encoder.sanm_encoder import SANMVadEncoder +from funasr.export.models.encoder.sanm_encoder import SANMVadEncoder as SANMVadEncoder_export + +class CT_Transformer(nn.Module): def __init__( self, @@ -23,16 +28,12 @@ class TargetDelayTransformer(nn.Module): self.num_embeddings = self.embed.num_embeddings self.model_name = model_name - # from funasr.models.encoder.sanm_encoder import SANMEncoder as Encoder - from funasr.models.encoder.sanm_encoder import SANMEncoder - from funasr.export.models.encoder.sanm_encoder import SANMEncoder as SANMEncoder_export - if isinstance(model.encoder, SANMEncoder): self.encoder = SANMEncoder_export(model.encoder, onnx=onnx) else: assert False, "Only support samn encode." - def forward(self, input: torch.Tensor, text_lengths: torch.Tensor) -> Tuple[torch.Tensor, None]: + def forward(self, inputs: torch.Tensor, text_lengths: torch.Tensor) -> Tuple[torch.Tensor, None]: """Compute loss value from buffer sequences. Args: @@ -40,7 +41,7 @@ class TargetDelayTransformer(nn.Module): hidden (torch.Tensor): Target ids. (batch, len) """ - x = self.embed(input) + x = self.embed(inputs) # mask = self._target_mask(input) h, _ = self.encoder(x, text_lengths) y = self.decoder(h) @@ -53,14 +54,14 @@ class TargetDelayTransformer(nn.Module): return (text_indexes, text_lengths) def get_input_names(self): - return ['input', 'text_lengths'] + return ['inputs', 'text_lengths'] def get_output_names(self): return ['logits'] def get_dynamic_axes(self): return { - 'input': { + 'inputs': { 0: 'batch_size', 1: 'feats_length' }, @@ -73,3 +74,81 @@ class TargetDelayTransformer(nn.Module): }, } + +class CT_Transformer_VadRealtime(nn.Module): + + def __init__( + self, + model, + max_seq_len=512, + model_name='punc_model', + **kwargs, + ): + super().__init__() + onnx = False + if "onnx" in kwargs: + onnx = kwargs["onnx"] + + self.embed = model.embed + if isinstance(model.encoder, SANMVadEncoder): + self.encoder = SANMVadEncoder_export(model.encoder, onnx=onnx) + else: + assert False, "Only support samn encode." + self.decoder = model.decoder + self.model_name = model_name + + + + def forward(self, inputs: torch.Tensor, + text_lengths: torch.Tensor, + vad_indexes: torch.Tensor, + sub_masks: torch.Tensor, + ) -> Tuple[torch.Tensor, None]: + """Compute loss value from buffer sequences. + + Args: + input (torch.Tensor): Input ids. (batch, len) + hidden (torch.Tensor): Target ids. (batch, len) + + """ + x = self.embed(inputs) + # mask = self._target_mask(input) + h, _ = self.encoder(x, text_lengths, vad_indexes, sub_masks) + y = self.decoder(h) + return y + + def with_vad(self): + return True + + def get_dummy_inputs(self): + length = 120 + text_indexes = torch.randint(0, self.embed.num_embeddings, (1, length)) + text_lengths = torch.tensor([length], dtype=torch.int32) + vad_mask = torch.ones(length, length, dtype=torch.float32)[None, None, :, :] + sub_masks = torch.ones(length, length, dtype=torch.float32) + sub_masks = torch.tril(sub_masks).type(torch.float32) + return (text_indexes, text_lengths, vad_mask, sub_masks[None, None, :, :]) + + def get_input_names(self): + return ['inputs', 'text_lengths', 'vad_masks', 'sub_masks'] + + def get_output_names(self): + return ['logits'] + + def get_dynamic_axes(self): + return { + 'inputs': { + 1: 'feats_length' + }, + 'vad_masks': { + 2: 'feats_length1', + 3: 'feats_length2' + }, + 'sub_masks': { + 2: 'feats_length1', + 3: 'feats_length2' + }, + 'logits': { + 1: 'logits_length' + }, + } diff --git a/funasr/export/models/vad_realtime_transformer.py b/funasr/export/models/vad_realtime_transformer.py deleted file mode 100644 index 24a8e7247..000000000 --- a/funasr/export/models/vad_realtime_transformer.py +++ /dev/null @@ -1,96 +0,0 @@ -from typing import Tuple - -import torch -import torch.nn as nn - -from funasr.models.encoder.sanm_encoder import SANMVadEncoder -from funasr.export.models.encoder.sanm_encoder import SANMVadEncoder as SANMVadEncoder_export - -class VadRealtimeTransformer(nn.Module): - - def __init__( - self, - model, - max_seq_len=512, - model_name='punc_model', - **kwargs, - ): - super().__init__() - onnx = False - if "onnx" in kwargs: - onnx = kwargs["onnx"] - - self.embed = model.embed - if isinstance(model.encoder, SANMVadEncoder): - self.encoder = SANMVadEncoder_export(model.encoder, onnx=onnx) - else: - assert False, "Only support samn encode." - # self.encoder = model.encoder - self.decoder = model.decoder - self.model_name = model_name - - - - def forward(self, input: torch.Tensor, - text_lengths: torch.Tensor, - vad_indexes: torch.Tensor, - sub_masks: torch.Tensor, - ) -> Tuple[torch.Tensor, None]: - """Compute loss value from buffer sequences. - - Args: - input (torch.Tensor): Input ids. (batch, len) - hidden (torch.Tensor): Target ids. (batch, len) - - """ - x = self.embed(input) - # mask = self._target_mask(input) - h, _ = self.encoder(x, text_lengths, vad_indexes, sub_masks) - y = self.decoder(h) - return y - - def with_vad(self): - return True - - # def get_dummy_inputs(self): - # length = 120 - # text_indexes = torch.randint(0, self.embed.num_embeddings, (1, length)) - # text_lengths = torch.tensor([length], dtype=torch.int32) - # vad_mask = torch.ones(length, length, dtype=torch.float32)[None, None, :, :] - # sub_masks = torch.ones(length, length, dtype=torch.float32) - # sub_masks = torch.tril(sub_masks).type(torch.float32) - # return (text_indexes, text_lengths, vad_mask, sub_masks[None, None, :, :]) - - def get_dummy_inputs(self, txt_dir=None): - from funasr.modules.mask import vad_mask - length = 10 - text_indexes = torch.tensor([[266757, 266757, 266757, 266757, 266757, 266757, 266757, 266757, 266757, 266757]], dtype=torch.int32) - text_lengths = torch.tensor([length], dtype=torch.int32) - vad_masks = vad_mask(10, 2, dtype=torch.float32)[None, None, :, :] - sub_masks = torch.ones(length, length, dtype=torch.float32) - sub_masks = torch.tril(sub_masks).type(torch.float32) - return (text_indexes, text_lengths, vad_masks, sub_masks[None, None, :, :]) - - def get_input_names(self): - return ['input', 'text_lengths', 'vad_masks', 'sub_masks'] - - def get_output_names(self): - return ['logits'] - - def get_dynamic_axes(self): - return { - 'input': { - 1: 'feats_length' - }, - 'vad_masks': { - 2: 'feats_length1', - 3: 'feats_length2' - }, - 'sub_masks': { - 2: 'feats_length1', - 3: 'feats_length2' - }, - 'logits': { - 1: 'logits_length' - }, - } diff --git a/funasr/export/test/test_onnx_punc.py b/funasr/export/test/test_onnx_punc.py index 62689a904..39f85f457 100644 --- a/funasr/export/test/test_onnx_punc.py +++ b/funasr/export/test/test_onnx_punc.py @@ -9,7 +9,7 @@ if __name__ == '__main__': output_name = [nd.name for nd in sess.get_outputs()] def _get_feed_dict(text_length): - return {'input': np.ones((1, text_length), dtype=np.int64), 'text_lengths': np.array([text_length,], dtype=np.int32)} + return {'inputs': np.ones((1, text_length), dtype=np.int64), 'text_lengths': np.array([text_length,], dtype=np.int32)} def _run(feed_dict): output = sess.run(output_name, input_feed=feed_dict) diff --git a/funasr/export/test/test_onnx_punc_vadrealtime.py b/funasr/export/test/test_onnx_punc_vadrealtime.py index 54f85f194..86be026dc 100644 --- a/funasr/export/test/test_onnx_punc_vadrealtime.py +++ b/funasr/export/test/test_onnx_punc_vadrealtime.py @@ -9,9 +9,9 @@ if __name__ == '__main__': output_name = [nd.name for nd in sess.get_outputs()] def _get_feed_dict(text_length): - return {'input': np.ones((1, text_length), dtype=np.int64), + return {'inputs': np.ones((1, text_length), dtype=np.int64), 'text_lengths': np.array([text_length,], dtype=np.int32), - 'vad_mask': np.ones((1, 1, text_length, text_length), dtype=np.float32), + 'vad_masks': np.ones((1, 1, text_length, text_length), dtype=np.float32), 'sub_masks': np.tril(np.ones((text_length, text_length), dtype=np.float32))[None, None, :, :].astype(np.float32) }