From 8a620a5a36df782e1f9e8cc56064d5dc6a1330b5 Mon Sep 17 00:00:00 2001 From: "wanchen.swc" Date: Wed, 15 Mar 2023 15:31:31 +0800 Subject: [PATCH] [Quantization] automatic mixed precision quantization --- funasr/export/README.md | 26 ++++++++++------ funasr/export/export_model.py | 57 ++++++++++++++++++++++------------- 2 files changed, 53 insertions(+), 30 deletions(-) diff --git a/funasr/export/README.md b/funasr/export/README.md index a1ed892d5..33ab22ea9 100644 --- a/funasr/export/README.md +++ b/funasr/export/README.md @@ -11,35 +11,43 @@ The installation is the same as [funasr](../../README.md) `Tips`: torch>=1.11.0 ```shell - python -m funasr.export.export_model [model_name] [export_dir] [onnx] [quant] + python -m funasr.export.export_model \ + --model-name [model_name] \ + --export-dir [export_dir] \ + --type [onnx, torch] \ + --quantize \ + --fallback-num [fallback_num] ``` - `model_name`: the model is to export. It could be the models from modelscope, or local finetuned model(named: model.pb). + `model-name`: the model is to export. It could be the models from modelscope, or local finetuned model(named: model.pb). - `export_dir`: the dir where the onnx is export. + `export-dir`: the dir where the onnx is export. - `onnx`: `true`, export onnx format model; `false`, export torchscripts format model. + `type`: `onnx` or `torch`, export onnx format model or torchscript format model. + + `quantize`: `true`, export quantized model at the same time; `false`, export fp32 model only. + + `fallback-num`: specify the number of fallback layers to perform automatic mixed precision quantization. - `quant`: `true`, export quantized model at the same time; `false`, export fp32 model only. ## For example ### Export onnx format model Export model from modelscope ```shell -python -m funasr.export.export_model 'damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch' "./export" true false +python -m funasr.export.export_model --model-name damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch --export-dir ./export --type onnx ``` Export model from local path, the model'name must be `model.pb`. ```shell -python -m funasr.export.export_model '/mnt/workspace/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch' "./export" true false +python -m funasr.export.export_model --model-name /mnt/workspace/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch --export-dir ./export --type onnx ``` ### Export torchscripts format model Export model from modelscope ```shell -python -m funasr.export.export_model 'damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch' "./export" false false +python -m funasr.export.export_model --model-name damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch --export-dir ./export --type torch ``` Export model from local path, the model'name must be `model.pb`. ```shell -python -m funasr.export.export_model '/mnt/workspace/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch' "./export" false false +python -m funasr.export.export_model --model-name /mnt/workspace/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch --export-dir ./export --type torch ``` diff --git a/funasr/export/export_model.py b/funasr/export/export_model.py index 7370c3cff..beb1efe00 100644 --- a/funasr/export/export_model.py +++ b/funasr/export/export_model.py @@ -16,7 +16,11 @@ import random class ASRModelExportParaformer: def __init__( - self, cache_dir: Union[Path, str] = None, onnx: bool = True, quant: bool = True + self, + cache_dir: Union[Path, str] = None, + onnx: bool = True, + quant: bool = True, + fallback_num: int = 0, ): assert check_argument_types() self.set_all_random_seed(0) @@ -31,6 +35,7 @@ class ASRModelExportParaformer: print("output dir: {}".format(self.cache_dir)) self.onnx = onnx self.quant = quant + self.fallback_num = fallback_num def _export( @@ -60,8 +65,12 @@ class ASRModelExportParaformer: def _torch_quantize(self, model): + def _run_calibration_data(m): + # using dummy inputs for a example + dummy_input = model.get_dummy_inputs() + m(*dummy_input) + from torch_quant.module import ModuleFilter - from torch_quant.observer import HistogramObserver from torch_quant.quantizer import Backend, Quantizer from funasr.export.models.modules.decoder_layer import DecoderLayerSANM from funasr.export.models.modules.encoder_layer import EncoderLayerSANM @@ -70,17 +79,21 @@ class ASRModelExportParaformer: quantizer = Quantizer( module_filter=module_filter, backend=Backend.FBGEMM, - act_ob_ctr=HistogramObserver, ) model.eval() calib_model = quantizer.calib(model) - # run calibration data - # using dummy inputs for a example - dummy_input = model.get_dummy_inputs() - _ = calib_model(*dummy_input) + _run_calibration_data(calib_model) + if self.fallback_num > 0: + # perform automatic mixed precision quantization + amp_model = quantizer.amp(model) + _run_calibration_data(amp_model) + quantizer.fallback(amp_model, num=self.fallback_num) + print('Fallback layers:') + print('\n'.join(quantizer.module_filter.exclude_names)) quant_model = quantizer.quantize(model) return quant_model + def _export_torchscripts(self, model, verbose, path, enc_size=None): if enc_size: dummy_input = model.get_dummy_inputs(enc_size) @@ -170,17 +183,19 @@ class ASRModelExportParaformer: if __name__ == '__main__': - import sys - - model_path = sys.argv[1] - output_dir = sys.argv[2] - onnx = sys.argv[3] - quant = sys.argv[4] - onnx = onnx.lower() - onnx = onnx == 'true' - quant = quant == 'true' - # model_path = 'damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch' - # output_dir = "../export" - export_model = ASRModelExportParaformer(cache_dir=output_dir, onnx=onnx, quant=quant) - export_model.export(model_path) - # export_model.export('/root/cache/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch') + import argparse + parser = argparse.ArgumentParser() + parser.add_argument('--model-name', type=str, required=True) + parser.add_argument('--export-dir', type=str, required=True) + parser.add_argument('--type', type=str, default='onnx', help='["onnx", "torch"]') + parser.add_argument('--quantize', action='store_true', help='export quantized model') + parser.add_argument('--fallback-num', type=int, default=0, help='amp fallback number') + args = parser.parse_args() + + export_model = ASRModelExportParaformer( + cache_dir=args.export_dir, + onnx=args.type == 'onnx', + quant=args.quantize, + fallback_num=args.fallback_num, + ) + export_model.export(args.model_name)