export model

This commit is contained in:
游雁 2023-02-07 22:51:39 +08:00
parent 88c4f4a25d
commit 87bff7ae59
5 changed files with 32 additions and 10 deletions

View File

@ -1,7 +1,12 @@
environment: ubuntu20.04-py37-torch1.11.0-tf1.15.5-1.2.0
Export onnx files from modelscope
## install modelscope and funasr
The install is the same as [funasr](../../README.md)
## export onnx format model
Export model modelscope
```python
from funasr.export.export_model import ASRModelExportParaformer
@ -11,7 +16,26 @@ export_model.export_from_modelscope('damo/speech_paraformer-large_asr_nat-zh-cn-
```
Export onnx files from local path
Export model from local path
```python
from funasr.export.export_model import ASRModelExportParaformer
output_dir = "../export"
export_model = ASRModelExportParaformer(cache_dir=output_dir, onnx=True)
export_model.export_from_local('/root/cache/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch')
```
## export torchscripts format model
Export model modelscope
```python
from funasr.export.export_model import ASRModelExportParaformer
output_dir = "../export"
export_model = ASRModelExportParaformer(cache_dir=output_dir, onnx=False)
export_model.export_from_modelscope('damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch')
```
Export model from local path
```python
from funasr.export.export_model import ASRModelExportParaformer

View File

@ -20,7 +20,7 @@ class ASRModelExportParaformer:
self.cache_dir = Path(cache_dir)
self.export_config = dict(
feats_dim=560,
onnx=onnx,
onnx=False,
)
logging.info("output dir: {}".format(self.cache_dir))
self.onnx = onnx

View File

@ -63,12 +63,9 @@ class Paraformer(nn.Module):
decoder_out, _ = self.decoder(enc, enc_len, pre_acoustic_embeds, pre_token_length)
decoder_out = torch.log_softmax(decoder_out, dim=-1)
sample_ids = decoder_out.argmax(dim=-1)
# sample_ids = decoder_out.argmax(dim=-1)
return decoder_out, sample_ids
# def get_output_size(self):
# return self.model.encoders[0].size
return decoder_out, pre_token_length
def get_dummy_inputs(self):
speech = torch.randn(2, 30, self.feats_dim)

View File

@ -22,6 +22,7 @@ class SANMEncoder(nn.Module):
self.embed = model.embed
self.model = model
self.feats_dim = feats_dim
self._output_size = model._output_size
if onnx:
self.make_pad_mask = MakePadMask(max_seq_len, flip=False)
@ -62,7 +63,7 @@ class SANMEncoder(nn.Module):
speech: torch.Tensor,
speech_lengths: torch.Tensor,
):
speech = speech * self._output_size ** 0.5
mask = self.make_pad_mask(speech_lengths)
mask = self.prepare_mask(mask)
if self.embed is None:

View File

@ -293,7 +293,7 @@ class SANMEncoder(AbsEncoder):
position embedded tensor and mask
"""
masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
xs_pad *= self.output_size()**0.5
xs_pad = xs_pad * self.output_size()**0.5
if self.embed is None:
xs_pad = xs_pad
elif (