From 69ccdd35cda4c8482e189fa350fbcb83997872f2 Mon Sep 17 00:00:00 2001 From: "wanchen.swc" Date: Mon, 6 Mar 2023 18:18:31 +0800 Subject: [PATCH 01/13] [Quantization] model quantization for inference --- funasr/export/export_model.py | 51 +++++++++++++++++-- funasr/export/models/modules/encoder_layer.py | 6 +-- funasr/export/models/modules/multihead_att.py | 28 ++++++---- 3 files changed, 67 insertions(+), 18 deletions(-) diff --git a/funasr/export/export_model.py b/funasr/export/export_model.py index 3cbf6d293..1c677c929 100644 --- a/funasr/export/export_model.py +++ b/funasr/export/export_model.py @@ -15,7 +15,9 @@ import random # assert torch_version > 1.9 class ASRModelExportParaformer: - def __init__(self, cache_dir: Union[Path, str] = None, onnx: bool = True): + def __init__( + self, cache_dir: Union[Path, str] = None, onnx: bool = True, quant: bool = True + ): assert check_argument_types() self.set_all_random_seed(0) if cache_dir is None: @@ -28,6 +30,7 @@ class ASRModelExportParaformer: ) print("output dir: {}".format(self.cache_dir)) self.onnx = onnx + self.quant = quant def _export( @@ -56,6 +59,28 @@ class ASRModelExportParaformer: print("output dir: {}".format(export_dir)) + def _torch_quantize(self, model): + from torch_quant.module import ModuleFilter + from torch_quant.observer import HistogramObserver + from torch_quant.quantizer import Backend, Quantizer + from funasr.export.models.modules.decoder_layer import DecoderLayerSANM + from funasr.export.models.modules.encoder_layer import EncoderLayerSANM + module_filter = ModuleFilter(include_classes=[EncoderLayerSANM, DecoderLayerSANM]) + module_filter.exclude_op_types = [torch.nn.Conv1d] + quantizer = Quantizer( + module_filter=module_filter, + backend=Backend.FBGEMM, + act_ob_ctr=HistogramObserver, + ) + model.eval() + calib_model = quantizer.calib(model) + # run calibration data + # using dummy inputs for a example + dummy_input = model.get_dummy_inputs() + _ = calib_model(*dummy_input) + quant_model = quantizer.quantize(model) + return quant_model + def _export_torchscripts(self, model, verbose, path, enc_size=None): if enc_size: dummy_input = model.get_dummy_inputs(enc_size) @@ -66,6 +91,12 @@ class ASRModelExportParaformer: model_script = torch.jit.trace(model, dummy_input) model_script.save(os.path.join(path, f'{model.model_name}.torchscripts')) + if self.quant: + quant_model = self._torch_quantize(model) + model_script = torch.jit.trace(quant_model, dummy_input) + model_script.save(os.path.join(path, f'{model.model_name}_quant.torchscripts')) + + def set_all_random_seed(self, seed: int): random.seed(seed) np.random.seed(seed) @@ -107,11 +138,12 @@ class ASRModelExportParaformer: # model_script = torch.jit.script(model) model_script = model #torch.jit.trace(model) + model_path = os.path.join(path, f'{model.model_name}.onnx') torch.onnx.export( model_script, dummy_input, - os.path.join(path, f'{model.model_name}.onnx'), + model_path, verbose=verbose, opset_version=14, input_names=model.get_input_names(), @@ -119,6 +151,15 @@ class ASRModelExportParaformer: dynamic_axes=model.get_dynamic_axes() ) + if self.quant: + from onnxruntime.quantization import QuantType, quantize_dynamic + quant_model_path = os.path.join(path, f'{model.model_name}_quant.onnx') + quantize_dynamic( + model_input=model_path, + model_output=quant_model_path, + weight_type=QuantType.QUInt8, + ) + if __name__ == '__main__': import sys @@ -126,10 +167,12 @@ if __name__ == '__main__': model_path = sys.argv[1] output_dir = sys.argv[2] onnx = sys.argv[3] + quant = sys.argv[4] onnx = onnx.lower() onnx = onnx == 'true' + quant = quant == 'true' # model_path = 'damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch' # output_dir = "../export" - export_model = ASRModelExportParaformer(cache_dir=output_dir, onnx=onnx) + export_model = ASRModelExportParaformer(cache_dir=output_dir, onnx=onnx, quant=quant) export_model.export(model_path) - # export_model.export('/root/cache/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch') \ No newline at end of file + # export_model.export('/root/cache/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch') diff --git a/funasr/export/models/modules/encoder_layer.py b/funasr/export/models/modules/encoder_layer.py index 622b109d3..1da05f382 100644 --- a/funasr/export/models/modules/encoder_layer.py +++ b/funasr/export/models/modules/encoder_layer.py @@ -16,6 +16,7 @@ class EncoderLayerSANM(nn.Module): self.feed_forward = model.feed_forward self.norm1 = model.norm1 self.norm2 = model.norm2 + self.in_size = model.in_size self.size = model.size def forward(self, x, mask): @@ -23,13 +24,12 @@ class EncoderLayerSANM(nn.Module): residual = x x = self.norm1(x) x = self.self_attn(x, mask) - if x.size(2) == residual.size(2): + if self.in_size == self.size: x = x + residual residual = x x = self.norm2(x) x = self.feed_forward(x) - if x.size(2) == residual.size(2): - x = x + residual + x = x + residual return x, mask diff --git a/funasr/export/models/modules/multihead_att.py b/funasr/export/models/modules/multihead_att.py index 7d685f588..0a5667689 100644 --- a/funasr/export/models/modules/multihead_att.py +++ b/funasr/export/models/modules/multihead_att.py @@ -64,6 +64,21 @@ class MultiHeadedAttentionSANM(nn.Module): return self.linear_out(context_layer) # (batch, time1, d_model) +def preprocess_for_attn(x, mask, cache, pad_fn): + x = x * mask + x = x.transpose(1, 2) + if cache is None: + x = pad_fn(x) + else: + x = torch.cat((cache[:, :, 1:], x), dim=2) + cache = x + return x, cache + + +import torch.fx +torch.fx.wrap('preprocess_for_attn') + + class MultiHeadedAttentionSANMDecoder(nn.Module): def __init__(self, model): super().__init__() @@ -73,16 +88,7 @@ class MultiHeadedAttentionSANMDecoder(nn.Module): self.attn = None def forward(self, inputs, mask, cache=None): - # b, t, d = inputs.size() - # mask = torch.reshape(mask, (b, -1, 1)) - inputs = inputs * mask - - x = inputs.transpose(1, 2) - if cache is None: - x = self.pad_fn(x) - else: - x = torch.cat((cache[:, :, 1:], x), dim=2) - cache = x + x, cache = preprocess_for_attn(inputs, mask, cache, self.pad_fn) x = self.fsmn_block(x) x = x.transpose(1, 2) @@ -232,4 +238,4 @@ class OnnxRelPosMultiHeadedAttention(OnnxMultiHeadedAttention): new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) context_layer = context_layer.view(new_context_layer_shape) return self.linear_out(context_layer) # (batch, time1, d_model) - \ No newline at end of file + From 525f5d77564f016acdd03ff71197f7a4a9177840 Mon Sep 17 00:00:00 2001 From: "wanchen.swc" Date: Fri, 10 Mar 2023 17:08:04 +0800 Subject: [PATCH 02/13] [Quantization] onnx quantization --- funasr/export/README.md | 16 ++++++++++------ funasr/export/export_model.py | 8 ++++++++ 2 files changed, 18 insertions(+), 6 deletions(-) diff --git a/funasr/export/README.md b/funasr/export/README.md index c44ad3382..a1ed892d5 100644 --- a/funasr/export/README.md +++ b/funasr/export/README.md @@ -11,31 +11,35 @@ The installation is the same as [funasr](../../README.md) `Tips`: torch>=1.11.0 ```shell - python -m funasr.export.export_model [model_name] [export_dir] [onnx] + python -m funasr.export.export_model [model_name] [export_dir] [onnx] [quant] ``` `model_name`: the model is to export. It could be the models from modelscope, or local finetuned model(named: model.pb). + `export_dir`: the dir where the onnx is export. - `onnx`: `true`, export onnx format model; `false`, export torchscripts format model. + + `onnx`: `true`, export onnx format model; `false`, export torchscripts format model. + + `quant`: `true`, export quantized model at the same time; `false`, export fp32 model only. ## For example ### Export onnx format model Export model from modelscope ```shell -python -m funasr.export.export_model 'damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch' "./export" true +python -m funasr.export.export_model 'damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch' "./export" true false ``` Export model from local path, the model'name must be `model.pb`. ```shell -python -m funasr.export.export_model '/mnt/workspace/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch' "./export" true +python -m funasr.export.export_model '/mnt/workspace/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch' "./export" true false ``` ### Export torchscripts format model Export model from modelscope ```shell -python -m funasr.export.export_model 'damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch' "./export" false +python -m funasr.export.export_model 'damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch' "./export" false false ``` Export model from local path, the model'name must be `model.pb`. ```shell -python -m funasr.export.export_model '/mnt/workspace/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch' "./export" false +python -m funasr.export.export_model '/mnt/workspace/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch' "./export" false false ``` diff --git a/funasr/export/export_model.py b/funasr/export/export_model.py index 1c677c929..7370c3cff 100644 --- a/funasr/export/export_model.py +++ b/funasr/export/export_model.py @@ -153,11 +153,19 @@ class ASRModelExportParaformer: if self.quant: from onnxruntime.quantization import QuantType, quantize_dynamic + import onnx quant_model_path = os.path.join(path, f'{model.model_name}_quant.onnx') + onnx_model = onnx.load(model_path) + nodes = [n.name for n in onnx_model.graph.node] + nodes_to_exclude = [m for m in nodes if 'output' in m] quantize_dynamic( model_input=model_path, model_output=quant_model_path, + op_types_to_quantize=['MatMul'], + per_channel=True, + reduce_range=False, weight_type=QuantType.QUInt8, + nodes_to_exclude=nodes_to_exclude, ) From 4bdce2285b42b3ac445a60721b3e7e26f78f4ad3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B8=B8=E9=9B=81?= Date: Tue, 14 Mar 2023 19:46:14 +0800 Subject: [PATCH 03/13] rtf --- funasr/bin/vad_inference_launch.py | 3 +- funasr/runtime/python/utils/test_rtf.py | 47 ++++++++++++++++ funasr/runtime/python/utils/test_rtf.sh | 74 +++++++++++++++++++++++++ 3 files changed, 123 insertions(+), 1 deletion(-) create mode 100644 funasr/runtime/python/utils/test_rtf.py create mode 100644 funasr/runtime/python/utils/test_rtf.sh diff --git a/funasr/bin/vad_inference_launch.py b/funasr/bin/vad_inference_launch.py index 42c5c1e12..18eba33fb 100644 --- a/funasr/bin/vad_inference_launch.py +++ b/funasr/bin/vad_inference_launch.py @@ -110,7 +110,8 @@ def inference_launch(mode, **kwargs): if mode == "offline": from funasr.bin.vad_inference import inference_modelscope return inference_modelscope(**kwargs) - elif mode == "online": + # elif mode == "online": + if "param_dict" in kwargs and kwargs["param_dict"]["online"]: from funasr.bin.vad_inference_online import inference_modelscope return inference_modelscope(**kwargs) else: diff --git a/funasr/runtime/python/utils/test_rtf.py b/funasr/runtime/python/utils/test_rtf.py new file mode 100644 index 000000000..3394e8a04 --- /dev/null +++ b/funasr/runtime/python/utils/test_rtf.py @@ -0,0 +1,47 @@ + +import time +import sys +import librosa +backend=sys.argv[1] +model_dir=sys.argv[2] +wav_file=sys.argv[3] + +from torch_paraformer import Paraformer +if backend == "onnxruntime": + from rapid_paraformer import Paraformer + +model = Paraformer(model_dir, batch_size=1, device_id="-1") + +wav_file_f = open(wav_file, 'r') +wav_files = wav_file_f.readlines() + +# warm-up +total = 0.0 +num = 100 +wav_path = wav_files[0].split("\t")[1].strip() if "\t" in wav_files[0] else wav_files[0].split(" ")[1].strip() +for i in range(num): + beg_time = time.time() + result = model(wav_path) + end_time = time.time() + duration = end_time-beg_time + total += duration + print(result) + print("num: {}, time, {}, avg: {}, rtf: {}".format(len(wav_path), duration, total/(i+1), (total/(i+1))/5.53)) + +# infer time +beg_time = time.time() +for i, wav_path_i in enumerate(wav_files): + wav_path = wav_path_i.split("\t")[1].strip() if "\t" in wav_path_i else wav_path_i.split(" ")[1].strip() + result = model(wav_path) +end_time = time.time() +duration = (end_time-beg_time)*1000 +print("total_time_comput_ms: {}".format(int(duration))) + +duration_time = 0.0 +for i, wav_path_i in enumerate(wav_files): + wav_path = wav_path_i.split("\t")[1].strip() if "\t" in wav_path_i else wav_path_i.split(" ")[1].strip() + waveform, _ = librosa.load(wav_path, sr=16000) + duration_time += len(waveform)/16.0 +print("total_time_wav_ms: {}".format(int(duration_time))) + +print("total_rtf: {:.5}".format(duration/duration_time)) \ No newline at end of file diff --git a/funasr/runtime/python/utils/test_rtf.sh b/funasr/runtime/python/utils/test_rtf.sh new file mode 100644 index 000000000..b1562b101 --- /dev/null +++ b/funasr/runtime/python/utils/test_rtf.sh @@ -0,0 +1,74 @@ + +nj=64 + +#:< ${local_scp_dir}/log.$JOB.txt + }& + +done +wait + + +rm -rf ${local_scp_dir}/total_time_comput.txt +rm -rf ${local_scp_dir}/total_time_wav.txt +rm -rf ${local_scp_dir}/total_rtf.txt +for JOB in $(seq ${nj}); do + { + cat ${local_scp_dir}/log.$JOB.txt | grep "total_time_comput" | awk -F ' ' '{print $2}' >> ${local_scp_dir}/total_time_comput.txt + cat ${local_scp_dir}/log.$JOB.txt | grep "total_time_wav" | awk -F ' ' '{print $2}' >> ${local_scp_dir}/total_time_wav.txt + cat ${local_scp_dir}/log.$JOB.txt | grep "total_rtf" | awk -F ' ' '{print $2}' >> ${local_scp_dir}/total_rtf.txt + } + +done + +total_time_comput=`cat ${local_scp_dir}/total_time_comput.txt | awk 'BEGIN {max = 0} {if ($1+0>max+0) max=$1 fi} END {print max}'` +total_time_wav=`cat ${local_scp_dir}/total_time_wav.txt | awk '{sum +=$1};END {print sum}'` +rtf=`awk 'BEGIN{printf "%.5f\n",'$total_time_comput'/'$total_time_wav'}'` +speed=`awk 'BEGIN{printf "%.2f\n",1/'$rtf'}'` + +echo "total_time_comput_ms: $total_time_comput" +echo "total_time_wav: $total_time_wav" +echo "total_rtf: $rtf, speech: $speed" \ No newline at end of file From 63d444e8eb57b772f77de766ed2257d1f6e3d687 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B8=B8=E9=9B=81?= Date: Tue, 14 Mar 2023 19:46:46 +0800 Subject: [PATCH 04/13] rtf --- funasr/runtime/python/utils/test_rtf.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/funasr/runtime/python/utils/test_rtf.sh b/funasr/runtime/python/utils/test_rtf.sh index b1562b101..32166c1b0 100644 --- a/funasr/runtime/python/utils/test_rtf.sh +++ b/funasr/runtime/python/utils/test_rtf.sh @@ -25,7 +25,7 @@ model_dir="/nfs/zhifu.gzf/export/damo/amp_int8/onnx_dynamic" tag=${backend}_fp32 ! -scp=/nfs/haoneng.lhn/funasr_data/aishell-1/data/test/wav.scp +#scp=/nfs/haoneng.lhn/funasr_data/aishell-1/data/test/wav.scp scp="/nfs/zhifu.gzf/data_debug/test/wav_1500.scp" local_scp_dir=/nfs/zhifu.gzf/data_debug/test/${tag}/split$nj From 8a620a5a36df782e1f9e8cc56064d5dc6a1330b5 Mon Sep 17 00:00:00 2001 From: "wanchen.swc" Date: Wed, 15 Mar 2023 15:31:31 +0800 Subject: [PATCH 05/13] [Quantization] automatic mixed precision quantization --- funasr/export/README.md | 26 ++++++++++------ funasr/export/export_model.py | 57 ++++++++++++++++++++++------------- 2 files changed, 53 insertions(+), 30 deletions(-) diff --git a/funasr/export/README.md b/funasr/export/README.md index a1ed892d5..33ab22ea9 100644 --- a/funasr/export/README.md +++ b/funasr/export/README.md @@ -11,35 +11,43 @@ The installation is the same as [funasr](../../README.md) `Tips`: torch>=1.11.0 ```shell - python -m funasr.export.export_model [model_name] [export_dir] [onnx] [quant] + python -m funasr.export.export_model \ + --model-name [model_name] \ + --export-dir [export_dir] \ + --type [onnx, torch] \ + --quantize \ + --fallback-num [fallback_num] ``` - `model_name`: the model is to export. It could be the models from modelscope, or local finetuned model(named: model.pb). + `model-name`: the model is to export. It could be the models from modelscope, or local finetuned model(named: model.pb). - `export_dir`: the dir where the onnx is export. + `export-dir`: the dir where the onnx is export. - `onnx`: `true`, export onnx format model; `false`, export torchscripts format model. + `type`: `onnx` or `torch`, export onnx format model or torchscript format model. + + `quantize`: `true`, export quantized model at the same time; `false`, export fp32 model only. + + `fallback-num`: specify the number of fallback layers to perform automatic mixed precision quantization. - `quant`: `true`, export quantized model at the same time; `false`, export fp32 model only. ## For example ### Export onnx format model Export model from modelscope ```shell -python -m funasr.export.export_model 'damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch' "./export" true false +python -m funasr.export.export_model --model-name damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch --export-dir ./export --type onnx ``` Export model from local path, the model'name must be `model.pb`. ```shell -python -m funasr.export.export_model '/mnt/workspace/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch' "./export" true false +python -m funasr.export.export_model --model-name /mnt/workspace/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch --export-dir ./export --type onnx ``` ### Export torchscripts format model Export model from modelscope ```shell -python -m funasr.export.export_model 'damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch' "./export" false false +python -m funasr.export.export_model --model-name damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch --export-dir ./export --type torch ``` Export model from local path, the model'name must be `model.pb`. ```shell -python -m funasr.export.export_model '/mnt/workspace/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch' "./export" false false +python -m funasr.export.export_model --model-name /mnt/workspace/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch --export-dir ./export --type torch ``` diff --git a/funasr/export/export_model.py b/funasr/export/export_model.py index 7370c3cff..beb1efe00 100644 --- a/funasr/export/export_model.py +++ b/funasr/export/export_model.py @@ -16,7 +16,11 @@ import random class ASRModelExportParaformer: def __init__( - self, cache_dir: Union[Path, str] = None, onnx: bool = True, quant: bool = True + self, + cache_dir: Union[Path, str] = None, + onnx: bool = True, + quant: bool = True, + fallback_num: int = 0, ): assert check_argument_types() self.set_all_random_seed(0) @@ -31,6 +35,7 @@ class ASRModelExportParaformer: print("output dir: {}".format(self.cache_dir)) self.onnx = onnx self.quant = quant + self.fallback_num = fallback_num def _export( @@ -60,8 +65,12 @@ class ASRModelExportParaformer: def _torch_quantize(self, model): + def _run_calibration_data(m): + # using dummy inputs for a example + dummy_input = model.get_dummy_inputs() + m(*dummy_input) + from torch_quant.module import ModuleFilter - from torch_quant.observer import HistogramObserver from torch_quant.quantizer import Backend, Quantizer from funasr.export.models.modules.decoder_layer import DecoderLayerSANM from funasr.export.models.modules.encoder_layer import EncoderLayerSANM @@ -70,17 +79,21 @@ class ASRModelExportParaformer: quantizer = Quantizer( module_filter=module_filter, backend=Backend.FBGEMM, - act_ob_ctr=HistogramObserver, ) model.eval() calib_model = quantizer.calib(model) - # run calibration data - # using dummy inputs for a example - dummy_input = model.get_dummy_inputs() - _ = calib_model(*dummy_input) + _run_calibration_data(calib_model) + if self.fallback_num > 0: + # perform automatic mixed precision quantization + amp_model = quantizer.amp(model) + _run_calibration_data(amp_model) + quantizer.fallback(amp_model, num=self.fallback_num) + print('Fallback layers:') + print('\n'.join(quantizer.module_filter.exclude_names)) quant_model = quantizer.quantize(model) return quant_model + def _export_torchscripts(self, model, verbose, path, enc_size=None): if enc_size: dummy_input = model.get_dummy_inputs(enc_size) @@ -170,17 +183,19 @@ class ASRModelExportParaformer: if __name__ == '__main__': - import sys - - model_path = sys.argv[1] - output_dir = sys.argv[2] - onnx = sys.argv[3] - quant = sys.argv[4] - onnx = onnx.lower() - onnx = onnx == 'true' - quant = quant == 'true' - # model_path = 'damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch' - # output_dir = "../export" - export_model = ASRModelExportParaformer(cache_dir=output_dir, onnx=onnx, quant=quant) - export_model.export(model_path) - # export_model.export('/root/cache/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch') + import argparse + parser = argparse.ArgumentParser() + parser.add_argument('--model-name', type=str, required=True) + parser.add_argument('--export-dir', type=str, required=True) + parser.add_argument('--type', type=str, default='onnx', help='["onnx", "torch"]') + parser.add_argument('--quantize', action='store_true', help='export quantized model') + parser.add_argument('--fallback-num', type=int, default=0, help='amp fallback number') + args = parser.parse_args() + + export_model = ASRModelExportParaformer( + cache_dir=args.export_dir, + onnx=args.type == 'onnx', + quant=args.quantize, + fallback_num=args.fallback_num, + ) + export_model.export(args.model_name) From b7b45efc4ea164d90d443c116484adfb8b185648 Mon Sep 17 00:00:00 2001 From: "shixian.shi" Date: Wed, 15 Mar 2023 20:05:37 +0800 Subject: [PATCH 06/13] update paraformer_tiny export --- funasr/export/models/encoder/conformer_encoder.py | 1 - funasr/export/models/modules/decoder_layer.py | 1 + funasr/export/models/modules/encoder_layer.py | 4 ++-- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/funasr/export/models/encoder/conformer_encoder.py b/funasr/export/models/encoder/conformer_encoder.py index 9f2257462..0a3565304 100644 --- a/funasr/export/models/encoder/conformer_encoder.py +++ b/funasr/export/models/encoder/conformer_encoder.py @@ -61,7 +61,6 @@ class ConformerEncoder(nn.Module): speech: torch.Tensor, speech_lengths: torch.Tensor, ): - speech = speech * self._output_size ** 0.5 mask = self.make_pad_mask(speech_lengths) mask = self.prepare_mask(mask) if self.embed is None: diff --git a/funasr/export/models/modules/decoder_layer.py b/funasr/export/models/modules/decoder_layer.py index f5394523d..9a464a46c 100644 --- a/funasr/export/models/modules/decoder_layer.py +++ b/funasr/export/models/modules/decoder_layer.py @@ -54,6 +54,7 @@ class DecoderLayer(nn.Module): def forward(self, tgt, tgt_mask, memory, memory_mask, cache=None): residual = tgt + tgt = self.norm1(tgt) tgt_q = tgt tgt_q_mask = tgt_mask x = residual + self.self_attn(tgt_q, tgt, tgt, tgt_q_mask) diff --git a/funasr/export/models/modules/encoder_layer.py b/funasr/export/models/modules/encoder_layer.py index 1da05f382..7d0139793 100644 --- a/funasr/export/models/modules/encoder_layer.py +++ b/funasr/export/models/modules/encoder_layer.py @@ -61,7 +61,7 @@ class EncoderLayerConformer(nn.Module): if self.feed_forward_macaron is not None: residual = x x = self.norm_ff_macaron(x) - x = residual + self.feed_forward_macaron(x) + x = residual + self.feed_forward_macaron(x) * 0.5 residual = x x = self.norm_mha(x) @@ -81,7 +81,7 @@ class EncoderLayerConformer(nn.Module): residual = x x = self.norm_ff(x) - x = residual + self.feed_forward(x) + x = residual + self.feed_forward(x) * 0.5 x = self.norm_final(x) From 7bf2eec71e0c65f15628a105d11406a8a14ae178 Mon Sep 17 00:00:00 2001 From: "shixian.shi" Date: Wed, 15 Mar 2023 20:17:01 +0800 Subject: [PATCH 07/13] update paraformer_onnx --- .../rapid_paraformer/paraformer_onnx.py | 55 +++++++++++++------ 1 file changed, 39 insertions(+), 16 deletions(-) diff --git a/funasr/runtime/python/onnxruntime/rapid_paraformer/paraformer_onnx.py b/funasr/runtime/python/onnxruntime/rapid_paraformer/paraformer_onnx.py index 9b8a67bb0..091db0d25 100644 --- a/funasr/runtime/python/onnxruntime/rapid_paraformer/paraformer_onnx.py +++ b/funasr/runtime/python/onnxruntime/rapid_paraformer/paraformer_onnx.py @@ -23,6 +23,8 @@ class Paraformer(): def __init__(self, model_dir: Union[str, Path] = None, batch_size: int = 1, device_id: Union[str, int] = "-1", + plot_timestamp_to: str = "", + pred_bias: int = 1, ): if not Path(model_dir).exists(): @@ -41,14 +43,15 @@ class Paraformer(): ) self.ort_infer = OrtInferSession(model_file, device_id) self.batch_size = batch_size - self.plot = True + self.plot_timestamp_to = plot_timestamp_to + self.pred_bias = pred_bias def __call__(self, wav_content: Union[str, np.ndarray, List[str]], **kwargs) -> List: waveform_list = self.load_data(wav_content, self.frontend.opts.frame_opts.samp_freq) waveform_nums = len(waveform_list) asr_res = [] for beg_idx in range(0, waveform_nums, self.batch_size): - res = {} + end_idx = min(waveform_nums, beg_idx + self.batch_size) feats, feats_len = self.extract_feat(waveform_list[beg_idx:end_idx]) try: @@ -64,19 +67,41 @@ class Paraformer(): logging.warning("input wav is silence or noise") preds = [''] else: - preds, raw_token = self.decode(am_scores, valid_token_lens)[0] - res['preds'] = preds - if us_cif_peak is not None: - timestamp, timestamp_total = time_stamp_lfr6_onnx(us_cif_peak, copy.copy(raw_token)) - res['timestamp'] = timestamp - if self.plot: - self.plot_wave_timestamp(waveform_list[0], timestamp_total) - asr_res.append(res) + preds = self.decode(am_scores, valid_token_lens) + if us_cif_peak is None: + for pred in preds: + asr_res.append({'preds': pred}) + else: + for pred, us_cif_peak_ in zip(preds, us_cif_peak): + text, tokens = pred + timestamp, timestamp_total = time_stamp_lfr6_onnx(us_cif_peak_, copy.copy(tokens)) + if len(self.plot_timestamp_to): + self.plot_wave_timestamp(waveform_list[0], timestamp_total, self.plot_timestamp_to) + asr_res.append({'preds': text, 'timestamp': timestamp}) return asr_res - def plot_wave_timestamp(self, wav, text_timestamp): + def plot_wave_timestamp(self, wav, text_timestamp, dest): # TODO: Plot the wav and timestamp results with matplotlib - import pdb; pdb.set_trace() + import matplotlib + matplotlib.use('Agg') + matplotlib.rc("font", family='Alibaba PuHuiTi') # set it to a font that your system supports + import matplotlib.pyplot as plt + fig, ax1 = plt.subplots(figsize=(11, 3.5), dpi=320) + ax2 = ax1.twinx() + ax2.set_ylim([0, 2.0]) + # plot waveform + ax1.set_ylim([-0.3, 0.3]) + time = np.arange(wav.shape[0]) / 16000 + ax1.plot(time, wav/wav.max()*0.3, color='gray', alpha=0.4) + # plot lines and text + for (char, start, end) in text_timestamp: + ax1.vlines(start, -0.3, 0.3, ls='--') + ax1.vlines(end, -0.3, 0.3, ls='--') + x_adj = 0.045 if char != '' else 0.12 + ax1.text((start + end) * 0.5 - x_adj, 0, char) + # plt.legend() + plotname = "{}/timestamp.png".format(dest) + plt.savefig(plotname, bbox_inches='tight') def load_data(self, wav_content: Union[str, np.ndarray, List[str]], fs: int = None) -> List: @@ -150,9 +175,7 @@ class Paraformer(): # Change integer-ids to tokens token = self.converter.ids2tokens(token_int) - # token = token[:valid_token_num-1] + token = token[:valid_token_num-self.pred_bias] texts = sentence_postprocess(token) - text = texts[0] - # text = self.tokenizer.tokens2text(token) - return text, token + return texts From 01b4ff3bde05b7c5ea071af8867a331e9ae4bf53 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B8=B8=E9=9B=81?= Date: Wed, 15 Mar 2023 20:36:58 +0800 Subject: [PATCH 08/13] calib set --- funasr/export/export_model.py | 52 +++++++++++++++++++++++++++++++-- funasr/export/utils/wav_load.py | 2 ++ 2 files changed, 52 insertions(+), 2 deletions(-) create mode 100644 funasr/export/utils/wav_load.py diff --git a/funasr/export/export_model.py b/funasr/export/export_model.py index beb1efe00..7e65a8f76 100644 --- a/funasr/export/export_model.py +++ b/funasr/export/export_model.py @@ -21,6 +21,8 @@ class ASRModelExportParaformer: onnx: bool = True, quant: bool = True, fallback_num: int = 0, + audio_in: str = None, + calib_num: int = 200, ): assert check_argument_types() self.set_all_random_seed(0) @@ -36,6 +38,9 @@ class ASRModelExportParaformer: self.onnx = onnx self.quant = quant self.fallback_num = fallback_num + self.frontend = None + self.audio_in = audio_in + self.calib_num = calib_num def _export( @@ -67,8 +72,14 @@ class ASRModelExportParaformer: def _torch_quantize(self, model): def _run_calibration_data(m): # using dummy inputs for a example - dummy_input = model.get_dummy_inputs() - m(*dummy_input) + if self.audio_in is not None: + feats, feats_len = self.load_feats(self.audio_in) + for feat, len in zip(feats, feats_len): + m(feat, len) + else: + dummy_input = model.get_dummy_inputs() + m(*dummy_input) + from torch_quant.module import ModuleFilter from torch_quant.quantizer import Backend, Quantizer @@ -114,6 +125,39 @@ class ASRModelExportParaformer: random.seed(seed) np.random.seed(seed) torch.random.manual_seed(seed) + + def parse_audio_in(self, audio_in): + + wav_list, name_list = [], [] + if audio_in.endswith(".scp"): + f = open(audio_in, 'r') + lines = f.readlines()[:self.calib_num] + for line in lines: + name, path = line.strip().split() + name_list.append(name) + wav_list.append(path) + else: + wav_list = [audio_in,] + name_list = ["test",] + return wav_list, name_list + + def load_feats(self, audio_in: str = None): + import torchaudio + + wav_list, name_list = self.parse_audio_in(audio_in) + feats = [] + feats_len = [] + for line in wav_list: + name, path = line.strip().split() + waveform, sampling_rate = torchaudio.load(path) + if sampling_rate != self.frontend.fs: + waveform = torchaudio.transforms.Resample(orig_freq=sampling_rate, + new_freq=self.frontend.fs)(waveform) + fbank, fbank_len = self.frontend(waveform, [waveform.size(1)]) + feats.append(fbank) + feats_len.append(fbank_len) + return feats, feats_len + def export(self, tag_name: str = 'damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch', mode: str = 'paraformer', @@ -190,6 +234,8 @@ if __name__ == '__main__': parser.add_argument('--type', type=str, default='onnx', help='["onnx", "torch"]') parser.add_argument('--quantize', action='store_true', help='export quantized model') parser.add_argument('--fallback-num', type=int, default=0, help='amp fallback number') + parser.add_argument('--audio_in', type=str, default=None, help='["wav", "wav.scp"]') + parser.add_argument('--calib_num', type=int, default=200, help='calib max num') args = parser.parse_args() export_model = ASRModelExportParaformer( @@ -197,5 +243,7 @@ if __name__ == '__main__': onnx=args.type == 'onnx', quant=args.quantize, fallback_num=args.fallback_num, + audio_in=args.audio_in, + calib_num=args.calib_num, ) export_model.export(args.model_name) diff --git a/funasr/export/utils/wav_load.py b/funasr/export/utils/wav_load.py new file mode 100644 index 000000000..b48e5a074 --- /dev/null +++ b/funasr/export/utils/wav_load.py @@ -0,0 +1,2 @@ +import os + From 18b6fb3b502ee1bf4c6b595a8e96cf2216393f80 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B8=B8=E9=9B=81?= Date: Wed, 15 Mar 2023 20:38:20 +0800 Subject: [PATCH 09/13] calib set --- funasr/export/utils/wav_load.py | 2 -- 1 file changed, 2 deletions(-) delete mode 100644 funasr/export/utils/wav_load.py diff --git a/funasr/export/utils/wav_load.py b/funasr/export/utils/wav_load.py deleted file mode 100644 index b48e5a074..000000000 --- a/funasr/export/utils/wav_load.py +++ /dev/null @@ -1,2 +0,0 @@ -import os - From 9f5c42c476463e801d3de4bfeb2aefdd04f5691d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B8=B8=E9=9B=81?= Date: Wed, 15 Mar 2023 20:57:58 +0800 Subject: [PATCH 10/13] calib set --- funasr/export/export_model.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/funasr/export/export_model.py b/funasr/export/export_model.py index 7e65a8f76..b827f16b3 100644 --- a/funasr/export/export_model.py +++ b/funasr/export/export_model.py @@ -148,7 +148,7 @@ class ASRModelExportParaformer: feats = [] feats_len = [] for line in wav_list: - name, path = line.strip().split() + path = line.strip() waveform, sampling_rate = torchaudio.load(path) if sampling_rate != self.frontend.fs: waveform = torchaudio.transforms.Resample(orig_freq=sampling_rate, @@ -184,6 +184,7 @@ class ASRModelExportParaformer: model, asr_train_args = ASRTask.build_model_from_file( asr_train_config, asr_model_file, cmvn_file, 'cpu' ) + self.frontend = model.frontend self._export(model, tag_name) From 4ffd5655da0884095e68cc7207bf9bf5d409d2bb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B8=B8=E9=9B=81?= Date: Thu, 16 Mar 2023 16:40:15 +0800 Subject: [PATCH 11/13] calib --- funasr/export/export_model.py | 3 ++- funasr/runtime/python/utils/test_rtf.sh | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/funasr/export/export_model.py b/funasr/export/export_model.py index b827f16b3..c5bcac1e0 100644 --- a/funasr/export/export_model.py +++ b/funasr/export/export_model.py @@ -74,7 +74,8 @@ class ASRModelExportParaformer: # using dummy inputs for a example if self.audio_in is not None: feats, feats_len = self.load_feats(self.audio_in) - for feat, len in zip(feats, feats_len): + for i, (feat, len) in enumerate(zip(feats, feats_len)): + print("debug, iter: {}".format(i)) m(feat, len) else: dummy_input = model.get_dummy_inputs() diff --git a/funasr/runtime/python/utils/test_rtf.sh b/funasr/runtime/python/utils/test_rtf.sh index 32166c1b0..fe13da7d8 100644 --- a/funasr/runtime/python/utils/test_rtf.sh +++ b/funasr/runtime/python/utils/test_rtf.sh @@ -39,7 +39,7 @@ for JOB in $(seq ${nj}); do split_scps="$split_scps $local_scp_dir/wav.$JOB.scp" done -perl egs/aishell/transformer/utils/split_scp.pl $scp ${split_scps} +perl ../../../egs/aishell/transformer/utils/split_scp.pl $scp ${split_scps} for JOB in $(seq ${nj}); do From 6ca0d1f54c8b698e3315edab8aa9ba7227d7c9e7 Mon Sep 17 00:00:00 2001 From: "wanchen.swc" Date: Thu, 16 Mar 2023 18:53:57 +0800 Subject: [PATCH 12/13] [Quantization] run calib without grad --- funasr/export/export_model.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/funasr/export/export_model.py b/funasr/export/export_model.py index c5bcac1e0..e57a39750 100644 --- a/funasr/export/export_model.py +++ b/funasr/export/export_model.py @@ -76,7 +76,8 @@ class ASRModelExportParaformer: feats, feats_len = self.load_feats(self.audio_in) for i, (feat, len) in enumerate(zip(feats, feats_len)): print("debug, iter: {}".format(i)) - m(feat, len) + with torch.no_grad(): + m(feat, len) else: dummy_input = model.get_dummy_inputs() m(*dummy_input) From 175860147c08c2a6940032039b7eb83abfcc3ca0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B8=B8=E9=9B=81?= Date: Thu, 16 Mar 2023 19:28:44 +0800 Subject: [PATCH 13/13] cab --- funasr/export/export_model.py | 1 - 1 file changed, 1 deletion(-) diff --git a/funasr/export/export_model.py b/funasr/export/export_model.py index e57a39750..9a1ef9604 100644 --- a/funasr/export/export_model.py +++ b/funasr/export/export_model.py @@ -75,7 +75,6 @@ class ASRModelExportParaformer: if self.audio_in is not None: feats, feats_len = self.load_feats(self.audio_in) for i, (feat, len) in enumerate(zip(feats, feats_len)): - print("debug, iter: {}".format(i)) with torch.no_grad(): m(feat, len) else: