mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
[Optimization] support bladedisc fp16 optimization (#1790)
This commit is contained in:
parent
22e51ec95f
commit
9a9b474e7d
@ -13,7 +13,8 @@ model = AutoModel(
|
||||
model="iic/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404",
|
||||
)
|
||||
|
||||
res = model.export(type="torchscript", quantize=False)
|
||||
res = model.export(type="torchscripts", quantize=False)
|
||||
# res = model.export(type="bladedisc", input=f"{model.model_path}/example/asr_example.wav")
|
||||
print(res)
|
||||
|
||||
|
||||
|
||||
@ -1,5 +1,11 @@
|
||||
import os
|
||||
import torch
|
||||
import functools
|
||||
|
||||
try:
|
||||
import torch_blade
|
||||
except Exception as e:
|
||||
print(f"failed to load torch_blade: {e}")
|
||||
|
||||
|
||||
def export(model, data_in=None, quantize: bool = False, opset_version: int = 14, type='onnx', **kwargs):
|
||||
@ -27,6 +33,15 @@ def export(model, data_in=None, quantize: bool = False, opset_version: int = 14,
|
||||
path=export_dir,
|
||||
device=device
|
||||
)
|
||||
elif type == "bladedisc":
|
||||
assert (
|
||||
torch.cuda.is_available()
|
||||
), "Currently bladedisc optimization for FunASR only supports GPU"
|
||||
# bladedisc only optimizes encoder/decoder modules
|
||||
if hasattr(m, "encoder") and hasattr(m, "decoder"):
|
||||
_bladedisc_opt_for_encdec(m, path=export_dir, enable_fp16=True)
|
||||
else:
|
||||
_torchscripts(m, path=export_dir, device="cuda")
|
||||
print("output dir: {}".format(export_dir))
|
||||
|
||||
return export_dir
|
||||
@ -92,3 +107,92 @@ def _torchscripts(model, path, device='cuda'):
|
||||
|
||||
model_script = torch.jit.trace(model, dummy_input)
|
||||
model_script.save(os.path.join(path, f'{model.export_name}.torchscripts'))
|
||||
|
||||
|
||||
def _bladedisc_opt(model, model_inputs, enable_fp16=True):
|
||||
model = model.eval()
|
||||
torch_config = torch_blade.config.Config()
|
||||
torch_config.enable_fp16 = enable_fp16
|
||||
with torch.no_grad(), torch_config:
|
||||
opt_model = torch_blade.optimize(
|
||||
model,
|
||||
allow_tracing=True,
|
||||
model_inputs=model_inputs,
|
||||
)
|
||||
return opt_model
|
||||
|
||||
|
||||
def _rescale_input_hook(m, x, scale):
|
||||
if len(x) > 1:
|
||||
return (x[0] / scale, *x[1:])
|
||||
else:
|
||||
return (x[0] / scale,)
|
||||
|
||||
|
||||
def _rescale_output_hook(m, x, y, scale):
|
||||
if isinstance(y, tuple):
|
||||
return (y[0] / scale, *y[1:])
|
||||
else:
|
||||
return y / scale
|
||||
|
||||
|
||||
def _rescale_encoder_model(model, input_data):
|
||||
# Calculate absmax
|
||||
absmax = torch.tensor(0).cuda()
|
||||
|
||||
def stat_input_hook(m, x, y):
|
||||
val = x[0] if isinstance(x, tuple) else x
|
||||
absmax.copy_(torch.max(absmax, val.detach().abs().max()))
|
||||
|
||||
encoders = model.encoder.model.encoders
|
||||
hooks = [m.register_forward_hook(stat_input_hook) for m in encoders]
|
||||
model = model.cuda()
|
||||
model(*input_data)
|
||||
for h in hooks:
|
||||
h.remove()
|
||||
|
||||
# Rescale encoder modules
|
||||
fp16_scale = int(2 * absmax // 65536)
|
||||
print(f"rescale encoder modules with factor={fp16_scale}")
|
||||
model.encoder.model.encoders0.register_forward_pre_hook(
|
||||
functools.partial(_rescale_input_hook, scale=fp16_scale),
|
||||
)
|
||||
for name, m in model.encoder.model.named_modules():
|
||||
if name.endswith("self_attn"):
|
||||
m.register_forward_hook(
|
||||
functools.partial(_rescale_output_hook, scale=fp16_scale)
|
||||
)
|
||||
if name.endswith("feed_forward.w_2"):
|
||||
state_dict = {k: v / fp16_scale for k, v in m.state_dict().items()}
|
||||
m.load_state_dict(state_dict)
|
||||
|
||||
|
||||
def _bladedisc_opt_for_encdec(model, path, enable_fp16):
|
||||
# Get input data
|
||||
# TODO: better to use real data
|
||||
input_data = model.export_dummy_inputs()
|
||||
if isinstance(input_data, torch.Tensor):
|
||||
input_data = input_data.cuda()
|
||||
else:
|
||||
input_data = tuple([i.cuda() for i in input_data])
|
||||
|
||||
# Get input data for decoder module
|
||||
decoder_inputs = list()
|
||||
|
||||
def get_input_hook(m, x):
|
||||
decoder_inputs.extend(list(x))
|
||||
|
||||
hook = model.decoder.register_forward_pre_hook(get_input_hook)
|
||||
model = model.cuda()
|
||||
model(*input_data)
|
||||
hook.remove()
|
||||
|
||||
# Prevent FP16 overflow
|
||||
if enable_fp16:
|
||||
_rescale_encoder_model(model, input_data)
|
||||
|
||||
# Export and optimize encoder/decoder modules
|
||||
model.encoder = _bladedisc_opt(model.encoder, input_data[:2])
|
||||
model.decoder = _bladedisc_opt(model.decoder, tuple(decoder_inputs))
|
||||
model_script = torch.jit.trace(model, input_data)
|
||||
model_script.save(os.path.join(path, f"{model.export_name}.torchscripts"))
|
||||
|
||||
Loading…
Reference in New Issue
Block a user