This commit is contained in:
游雁 2023-03-29 21:15:55 +08:00
parent bf918fe311
commit a030ff0f85
4 changed files with 248 additions and 66 deletions

View File

@ -6,6 +6,8 @@ from funasr.export.models.e2e_vad import E2EVadModel as E2EVadModel_export
from funasr.punctuation.target_delay_transformer import TargetDelayTransformer
from funasr.export.models.target_delay_transformer import TargetDelayTransformer as TargetDelayTransformer_export
from funasr.punctuation.espnet_model import ESPnetPunctuationModel
from funasr.punctuation.vad_realtime_transformer import VadRealtimeTransformer
from funasr.export.models.vad_realtime_transformer import VadRealtimeTransformer as VadRealtimeTransformer_export
def get_model(model, export_config=None):
if isinstance(model, BiCifParaformer):
@ -17,5 +19,7 @@ def get_model(model, export_config=None):
elif isinstance(model, ESPnetPunctuationModel):
if isinstance(model.punc_model, TargetDelayTransformer):
return TargetDelayTransformer_export(model.punc_model, **export_config)
elif isinstance(model.punc_model, VadRealtimeTransformer):
return VadRealtimeTransformer_export(model.punc_model, **export_config)
else:
raise "Funasr does not support the given model type currently."

View File

@ -107,3 +107,102 @@ class SANMEncoder(nn.Module):
}
}
class SANMVadEncoder(nn.Module):
def __init__(
self,
model,
max_seq_len=512,
feats_dim=560,
model_name='encoder',
onnx: bool = True,
):
super().__init__()
self.embed = model.embed
self.model = model
self.feats_dim = feats_dim
self._output_size = model._output_size
if onnx:
self.make_pad_mask = MakePadMask(max_seq_len, flip=False)
else:
self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
if hasattr(model, 'encoders0'):
for i, d in enumerate(self.model.encoders0):
if isinstance(d.self_attn, MultiHeadedAttentionSANM):
d.self_attn = MultiHeadedAttentionSANM_export(d.self_attn)
if isinstance(d.feed_forward, PositionwiseFeedForward):
d.feed_forward = PositionwiseFeedForward_export(d.feed_forward)
self.model.encoders0[i] = EncoderLayerSANM_export(d)
for i, d in enumerate(self.model.encoders):
if isinstance(d.self_attn, MultiHeadedAttentionSANM):
d.self_attn = MultiHeadedAttentionSANM_export(d.self_attn)
if isinstance(d.feed_forward, PositionwiseFeedForward):
d.feed_forward = PositionwiseFeedForward_export(d.feed_forward)
self.model.encoders[i] = EncoderLayerSANM_export(d)
self.model_name = model_name
self.num_heads = model.encoders[0].self_attn.h
self.hidden_size = model.encoders[0].self_attn.linear_out.out_features
def prepare_mask(self, mask):
mask_3d_btd = mask[:, :, None]
if len(mask.shape) == 2:
mask_4d_bhlt = 1 - mask[:, None, None, :]
elif len(mask.shape) == 3:
mask_4d_bhlt = 1 - mask[:, None, :]
mask_4d_bhlt = mask_4d_bhlt * -10000.0
return mask_3d_btd, mask_4d_bhlt
def forward(self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
):
speech = speech * self._output_size ** 0.5
mask = self.make_pad_mask(speech_lengths)
mask = self.prepare_mask(mask)
if self.embed is None:
xs_pad = speech
else:
xs_pad = self.embed(speech)
encoder_outs = self.model.encoders0(xs_pad, mask)
xs_pad, masks = encoder_outs[0], encoder_outs[1]
encoder_outs = self.model.encoders(xs_pad, mask)
xs_pad, masks = encoder_outs[0], encoder_outs[1]
xs_pad = self.model.after_norm(xs_pad)
return xs_pad, speech_lengths
def get_output_size(self):
return self.model.encoders[0].size
def get_dummy_inputs(self):
feats = torch.randn(1, 100, self.feats_dim)
return (feats)
def get_input_names(self):
return ['feats']
def get_output_names(self):
return ['encoder_out', 'encoder_out_lens', 'predictor_weight']
def get_dynamic_axes(self):
return {
'feats': {
1: 'feats_length'
},
'encoder_out': {
1: 'enc_out_length'
},
'predictor_weight': {
1: 'pre_out_length'
}
}

View File

@ -28,7 +28,7 @@ class TargetDelayTransformer(nn.Module):
onnx = kwargs["onnx"]
self.embed = model.embed
self.decoder = model.decoder
self.model = model
# self.model = model
self.feats_dim = self.embed.embedding_dim
self.num_embeddings = self.embed.num_embeddings
self.model_name = model_name
@ -46,71 +46,71 @@ class TargetDelayTransformer(nn.Module):
from funasr.export.models.encoder.sanm_encoder import SANMEncoder as SANMEncoder_export
from funasr.punctuation.abs_model import AbsPunctuation
class TargetDelayTransformer(nn.Module):
def __init__(
self,
model,
max_seq_len=512,
model_name='punc_model',
**kwargs,
):
super().__init__()
onnx = False
if "onnx" in kwargs:
onnx = kwargs["onnx"]
self.embed = model.embed
self.decoder = model.decoder
self.model = model
self.feats_dim = self.embed.embedding_dim
self.num_embeddings = self.embed.num_embeddings
self.model_name = model_name
if isinstance(model.encoder, SANMEncoder):
self.encoder = SANMEncoder_export(model.encoder, onnx=onnx)
else:
assert False, "Only support samn encode."
def forward(self, input: torch.Tensor, text_lengths: torch.Tensor) -> Tuple[torch.Tensor, None]:
"""Compute loss value from buffer sequences.
Args:
input (torch.Tensor): Input ids. (batch, len)
hidden (torch.Tensor): Target ids. (batch, len)
"""
x = self.embed(input)
# mask = self._target_mask(input)
h, _ = self.encoder(x, text_lengths)
y = self.decoder(h)
return y
def get_dummy_inputs(self):
length = 120
text_indexes = torch.randint(0, self.embed.num_embeddings, (2, length))
text_lengths = torch.tensor([length - 20, length], dtype=torch.int32)
return (text_indexes, text_lengths)
def get_input_names(self):
return ['input', 'text_lengths']
def get_output_names(self):
return ['logits']
def get_dynamic_axes(self):
return {
'input': {
0: 'batch_size',
1: 'feats_length'
},
'text_lengths': {
0: 'batch_size',
},
'logits': {
0: 'batch_size',
1: 'logits_length'
},
}
# class TargetDelayTransformer(nn.Module):
#
# def __init__(
# self,
# model,
# max_seq_len=512,
# model_name='punc_model',
# **kwargs,
# ):
# super().__init__()
# onnx = False
# if "onnx" in kwargs:
# onnx = kwargs["onnx"]
# self.embed = model.embed
# self.decoder = model.decoder
# self.model = model
# self.feats_dim = self.embed.embedding_dim
# self.num_embeddings = self.embed.num_embeddings
# self.model_name = model_name
#
# if isinstance(model.encoder, SANMEncoder):
# self.encoder = SANMEncoder_export(model.encoder, onnx=onnx)
# else:
# assert False, "Only support samn encode."
#
# def forward(self, input: torch.Tensor, text_lengths: torch.Tensor) -> Tuple[torch.Tensor, None]:
# """Compute loss value from buffer sequences.
#
# Args:
# input (torch.Tensor): Input ids. (batch, len)
# hidden (torch.Tensor): Target ids. (batch, len)
#
# """
# x = self.embed(input)
# # mask = self._target_mask(input)
# h, _ = self.encoder(x, text_lengths)
# y = self.decoder(h)
# return y
#
# def get_dummy_inputs(self):
# length = 120
# text_indexes = torch.randint(0, self.embed.num_embeddings, (2, length))
# text_lengths = torch.tensor([length - 20, length], dtype=torch.int32)
# return (text_indexes, text_lengths)
#
# def get_input_names(self):
# return ['input', 'text_lengths']
#
# def get_output_names(self):
# return ['logits']
#
# def get_dynamic_axes(self):
# return {
# 'input': {
# 0: 'batch_size',
# 1: 'feats_length'
# },
# 'text_lengths': {
# 0: 'batch_size',
# },
# 'logits': {
# 0: 'batch_size',
# 1: 'logits_length'
# },
# }
if isinstance(model.encoder, SANMEncoder):
self.encoder = SANMEncoder_export(model.encoder, onnx=onnx)

View File

@ -0,0 +1,79 @@
from typing import Any
from typing import List
from typing import Tuple
import torch
import torch.nn as nn
from funasr.modules.embedding import SinusoidalPositionEncoder
from funasr.punctuation.sanm_encoder import SANMVadEncoder as Encoder
from funasr.punctuation.abs_model import AbsPunctuation
from funasr.punctuation.sanm_encoder import SANMVadEncoder
from funasr.export.models.encoder.sanm_encoder import SANMVadEncoder as SANMVadEncoder_export
class VadRealtimeTransformer(AbsPunctuation):
def __init__(
self,
model,
max_seq_len=512,
model_name='punc_model',
**kwargs,
):
super().__init__()
self.embed = model.embed
if isinstance(model.encoder, SANMVadEncoder):
self.encoder = SANMVadEncoder_export(model.encoder, onnx=onnx)
else:
assert False, "Only support samn encode."
# self.encoder = model.encoder
self.decoder = model.decoder
def forward(self, input: torch.Tensor, text_lengths: torch.Tensor,
vad_indexes: torch.Tensor) -> Tuple[torch.Tensor, None]:
"""Compute loss value from buffer sequences.
Args:
input (torch.Tensor): Input ids. (batch, len)
hidden (torch.Tensor): Target ids. (batch, len)
"""
x = self.embed(input)
# mask = self._target_mask(input)
h, _, _ = self.encoder(x, text_lengths, vad_indexes)
y = self.decoder(h)
return y
def with_vad(self):
return True
def get_dummy_inputs(self):
length = 120
text_indexes = torch.randint(0, self.embed.num_embeddings, (2, length))
text_lengths = torch.tensor([length-20, length], dtype=torch.int32)
return (text_indexes, text_lengths)
def get_input_names(self):
return ['input', 'text_lengths']
def get_output_names(self):
return ['logits']
def get_dynamic_axes(self):
return {
'input': {
0: 'batch_size',
1: 'feats_length'
},
'text_lengths': {
0: 'batch_size',
},
'logits': {
0: 'batch_size',
1: 'logits_length'
},
}