mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
update repo
This commit is contained in:
parent
3d70934e7f
commit
7d6177b43f
@ -24,7 +24,7 @@ import torch
|
|||||||
from packaging.version import parse as V
|
from packaging.version import parse as V
|
||||||
from typeguard import check_argument_types
|
from typeguard import check_argument_types
|
||||||
from typeguard import check_return_type
|
from typeguard import check_return_type
|
||||||
|
from funasr.build_utils.build_model_from_file import build_model_from_file
|
||||||
from funasr.models.e2e_asr_contextual_paraformer import NeatContextualParaformer
|
from funasr.models.e2e_asr_contextual_paraformer import NeatContextualParaformer
|
||||||
from funasr.models.e2e_asr_paraformer import BiCifParaformer, ContextualParaformer
|
from funasr.models.e2e_asr_paraformer import BiCifParaformer, ContextualParaformer
|
||||||
from funasr.models.frontend.wav_frontend import WavFrontend, WavFrontendOnline
|
from funasr.models.frontend.wav_frontend import WavFrontend, WavFrontendOnline
|
||||||
@ -35,9 +35,7 @@ from funasr.modules.beam_search.beam_search_transducer import BeamSearchTransduc
|
|||||||
from funasr.modules.beam_search.beam_search_transducer import Hypothesis as HypothesisTransducer
|
from funasr.modules.beam_search.beam_search_transducer import Hypothesis as HypothesisTransducer
|
||||||
from funasr.modules.scorers.ctc import CTCPrefixScorer
|
from funasr.modules.scorers.ctc import CTCPrefixScorer
|
||||||
from funasr.modules.scorers.length_bonus import LengthBonus
|
from funasr.modules.scorers.length_bonus import LengthBonus
|
||||||
from funasr.tasks.asr import ASRTask
|
from funasr.build_utils.build_asr_model import frontend_choices
|
||||||
from funasr.tasks.asr import frontend_choices
|
|
||||||
from funasr.tasks.lm import LMTask
|
|
||||||
from funasr.text.build_tokenizer import build_tokenizer
|
from funasr.text.build_tokenizer import build_tokenizer
|
||||||
from funasr.text.token_id_converter import TokenIDConverter
|
from funasr.text.token_id_converter import TokenIDConverter
|
||||||
from funasr.torch_utils.device_funcs import to_device
|
from funasr.torch_utils.device_funcs import to_device
|
||||||
@ -84,15 +82,14 @@ class Speech2Text:
|
|||||||
|
|
||||||
# 1. Build ASR model
|
# 1. Build ASR model
|
||||||
scorers = {}
|
scorers = {}
|
||||||
asr_model, asr_train_args = ASRTask.build_model_from_file(
|
asr_model, asr_train_args = build_model_from_file(
|
||||||
asr_train_config, asr_model_file, cmvn_file, device
|
asr_train_config, asr_model_file, cmvn_file, device, mode="asr"
|
||||||
)
|
)
|
||||||
frontend = None
|
frontend = None
|
||||||
if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None:
|
if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None:
|
||||||
if asr_train_args.frontend == 'wav_frontend':
|
if asr_train_args.frontend == 'wav_frontend':
|
||||||
frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf)
|
frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf)
|
||||||
else:
|
else:
|
||||||
from funasr.tasks.asr import frontend_choices
|
|
||||||
frontend_class = frontend_choices.get_class(asr_train_args.frontend)
|
frontend_class = frontend_choices.get_class(asr_train_args.frontend)
|
||||||
frontend = frontend_class(**asr_train_args.frontend_conf).eval()
|
frontend = frontend_class(**asr_train_args.frontend_conf).eval()
|
||||||
|
|
||||||
@ -112,7 +109,7 @@ class Speech2Text:
|
|||||||
|
|
||||||
# 2. Build Language model
|
# 2. Build Language model
|
||||||
if lm_train_config is not None:
|
if lm_train_config is not None:
|
||||||
lm, lm_train_args = LMTask.build_model_from_file(
|
lm, lm_train_args = build_model_from_file(
|
||||||
lm_train_config, lm_file, None, device
|
lm_train_config, lm_file, None, device
|
||||||
)
|
)
|
||||||
scorers["lm"] = lm.lm
|
scorers["lm"] = lm.lm
|
||||||
@ -295,9 +292,8 @@ class Speech2TextParaformer:
|
|||||||
|
|
||||||
# 1. Build ASR model
|
# 1. Build ASR model
|
||||||
scorers = {}
|
scorers = {}
|
||||||
from funasr.tasks.asr import ASRTaskParaformer as ASRTask
|
asr_model, asr_train_args = build_model_from_file(
|
||||||
asr_model, asr_train_args = ASRTask.build_model_from_file(
|
asr_train_config, asr_model_file, cmvn_file, device, mode="paraformer"
|
||||||
asr_train_config, asr_model_file, cmvn_file, device
|
|
||||||
)
|
)
|
||||||
frontend = None
|
frontend = None
|
||||||
if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None:
|
if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None:
|
||||||
@ -319,7 +315,7 @@ class Speech2TextParaformer:
|
|||||||
|
|
||||||
# 2. Build Language model
|
# 2. Build Language model
|
||||||
if lm_train_config is not None:
|
if lm_train_config is not None:
|
||||||
lm, lm_train_args = LMTask.build_model_from_file(
|
lm, lm_train_args = build_model_from_file(
|
||||||
lm_train_config, lm_file, device
|
lm_train_config, lm_file, device
|
||||||
)
|
)
|
||||||
scorers["lm"] = lm.lm
|
scorers["lm"] = lm.lm
|
||||||
@ -616,9 +612,8 @@ class Speech2TextParaformerOnline:
|
|||||||
|
|
||||||
# 1. Build ASR model
|
# 1. Build ASR model
|
||||||
scorers = {}
|
scorers = {}
|
||||||
from funasr.tasks.asr import ASRTaskParaformer as ASRTask
|
asr_model, asr_train_args = build_model_from_file(
|
||||||
asr_model, asr_train_args = ASRTask.build_model_from_file(
|
asr_train_config, asr_model_file, cmvn_file, device, mode="paraformer"
|
||||||
asr_train_config, asr_model_file, cmvn_file, device
|
|
||||||
)
|
)
|
||||||
frontend = None
|
frontend = None
|
||||||
if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None:
|
if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None:
|
||||||
@ -640,7 +635,7 @@ class Speech2TextParaformerOnline:
|
|||||||
|
|
||||||
# 2. Build Language model
|
# 2. Build Language model
|
||||||
if lm_train_config is not None:
|
if lm_train_config is not None:
|
||||||
lm, lm_train_args = LMTask.build_model_from_file(
|
lm, lm_train_args = build_model_from_file(
|
||||||
lm_train_config, lm_file, device
|
lm_train_config, lm_file, device
|
||||||
)
|
)
|
||||||
scorers["lm"] = lm.lm
|
scorers["lm"] = lm.lm
|
||||||
@ -873,9 +868,8 @@ class Speech2TextUniASR:
|
|||||||
|
|
||||||
# 1. Build ASR model
|
# 1. Build ASR model
|
||||||
scorers = {}
|
scorers = {}
|
||||||
from funasr.tasks.asr import ASRTaskUniASR as ASRTask
|
asr_model, asr_train_args = build_model_from_file(
|
||||||
asr_model, asr_train_args = ASRTask.build_model_from_file(
|
asr_train_config, asr_model_file, cmvn_file, device, mode="uniasr"
|
||||||
asr_train_config, asr_model_file, cmvn_file, device
|
|
||||||
)
|
)
|
||||||
frontend = None
|
frontend = None
|
||||||
if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None:
|
if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None:
|
||||||
@ -901,8 +895,8 @@ class Speech2TextUniASR:
|
|||||||
|
|
||||||
# 2. Build Language model
|
# 2. Build Language model
|
||||||
if lm_train_config is not None:
|
if lm_train_config is not None:
|
||||||
lm, lm_train_args = LMTask.build_model_from_file(
|
lm, lm_train_args = build_model_from_file(
|
||||||
lm_train_config, lm_file, device
|
lm_train_config, lm_file, device, "lm"
|
||||||
)
|
)
|
||||||
scorers["lm"] = lm.lm
|
scorers["lm"] = lm.lm
|
||||||
|
|
||||||
@ -1104,9 +1098,8 @@ class Speech2TextMFCCA:
|
|||||||
assert check_argument_types()
|
assert check_argument_types()
|
||||||
|
|
||||||
# 1. Build ASR model
|
# 1. Build ASR model
|
||||||
from funasr.tasks.asr import ASRTaskMFCCA as ASRTask
|
|
||||||
scorers = {}
|
scorers = {}
|
||||||
asr_model, asr_train_args = ASRTask.build_model_from_file(
|
asr_model, asr_train_args = build_model_from_file(
|
||||||
asr_train_config, asr_model_file, cmvn_file, device
|
asr_train_config, asr_model_file, cmvn_file, device
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -1126,7 +1119,7 @@ class Speech2TextMFCCA:
|
|||||||
|
|
||||||
# 2. Build Language model
|
# 2. Build Language model
|
||||||
if lm_train_config is not None:
|
if lm_train_config is not None:
|
||||||
lm, lm_train_args = LMTask.build_model_from_file(
|
lm, lm_train_args = build_model_from_file(
|
||||||
lm_train_config, lm_file, device
|
lm_train_config, lm_file, device
|
||||||
)
|
)
|
||||||
lm.to(device)
|
lm.to(device)
|
||||||
@ -1315,8 +1308,7 @@ class Speech2TextTransducer:
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
assert check_argument_types()
|
assert check_argument_types()
|
||||||
from funasr.tasks.asr import ASRTransducerTask
|
asr_model, asr_train_args = build_model_from_file(
|
||||||
asr_model, asr_train_args = ASRTransducerTask.build_model_from_file(
|
|
||||||
asr_train_config, asr_model_file, cmvn_file, device
|
asr_train_config, asr_model_file, cmvn_file, device
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -1350,7 +1342,7 @@ class Speech2TextTransducer:
|
|||||||
asr_model.to(dtype=getattr(torch, dtype)).eval()
|
asr_model.to(dtype=getattr(torch, dtype)).eval()
|
||||||
|
|
||||||
if lm_train_config is not None:
|
if lm_train_config is not None:
|
||||||
lm, lm_train_args = LMTask.build_model_from_file(
|
lm, lm_train_args = build_model_from_file(
|
||||||
lm_train_config, lm_file, device
|
lm_train_config, lm_file, device
|
||||||
)
|
)
|
||||||
lm_scorer = lm.lm
|
lm_scorer = lm.lm
|
||||||
@ -1638,9 +1630,8 @@ class Speech2TextSAASR:
|
|||||||
assert check_argument_types()
|
assert check_argument_types()
|
||||||
|
|
||||||
# 1. Build ASR model
|
# 1. Build ASR model
|
||||||
from funasr.tasks.sa_asr import ASRTask
|
|
||||||
scorers = {}
|
scorers = {}
|
||||||
asr_model, asr_train_args = ASRTask.build_model_from_file(
|
asr_model, asr_train_args = build_model_from_file(
|
||||||
asr_train_config, asr_model_file, cmvn_file, device
|
asr_train_config, asr_model_file, cmvn_file, device
|
||||||
)
|
)
|
||||||
frontend = None
|
frontend = None
|
||||||
@ -1667,7 +1658,7 @@ class Speech2TextSAASR:
|
|||||||
|
|
||||||
# 2. Build Language model
|
# 2. Build Language model
|
||||||
if lm_train_config is not None:
|
if lm_train_config is not None:
|
||||||
lm, lm_train_args = LMTask.build_model_from_file(
|
lm, lm_train_args = build_model_from_file(
|
||||||
lm_train_config, lm_file, None, device
|
lm_train_config, lm_file, None, device
|
||||||
)
|
)
|
||||||
scorers["lm"] = lm.lm
|
scorers["lm"] = lm.lm
|
||||||
|
|||||||
128
funasr/build_utils/build_model_from_file.py
Normal file
128
funasr/build_utils/build_model_from_file.py
Normal file
@ -0,0 +1,128 @@
|
|||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import yaml
|
||||||
|
from typeguard import check_argument_types
|
||||||
|
|
||||||
|
from funasr.build_utils.build_model import build_model
|
||||||
|
from funasr.models.base_model import FunASRModel
|
||||||
|
|
||||||
|
|
||||||
|
def build_model_from_file(
|
||||||
|
config_file: Union[Path, str] = None,
|
||||||
|
model_file: Union[Path, str] = None,
|
||||||
|
cmvn_file: Union[Path, str] = None,
|
||||||
|
device: str = "cpu",
|
||||||
|
mode: str = "paraformer",
|
||||||
|
):
|
||||||
|
"""Build model from the files.
|
||||||
|
|
||||||
|
This method is used for inference or fine-tuning.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config_file: The yaml file saved when training.
|
||||||
|
model_file: The model file saved when training.
|
||||||
|
device: Device type, "cpu", "cuda", or "cuda:N".
|
||||||
|
|
||||||
|
"""
|
||||||
|
assert check_argument_types()
|
||||||
|
if config_file is None:
|
||||||
|
assert model_file is not None, (
|
||||||
|
"The argument 'model_file' must be provided "
|
||||||
|
"if the argument 'config_file' is not specified."
|
||||||
|
)
|
||||||
|
config_file = Path(model_file).parent / "config.yaml"
|
||||||
|
else:
|
||||||
|
config_file = Path(config_file)
|
||||||
|
|
||||||
|
with config_file.open("r", encoding="utf-8") as f:
|
||||||
|
args = yaml.safe_load(f)
|
||||||
|
if cmvn_file is not None:
|
||||||
|
args["cmvn_file"] = cmvn_file
|
||||||
|
args = argparse.Namespace(**args)
|
||||||
|
model = build_model(args)
|
||||||
|
if not isinstance(model, FunASRModel):
|
||||||
|
raise RuntimeError(
|
||||||
|
f"model must inherit {FunASRModel.__name__}, but got {type(model)}"
|
||||||
|
)
|
||||||
|
model.to(device)
|
||||||
|
model_dict = dict()
|
||||||
|
model_name_pth = None
|
||||||
|
if model_file is not None:
|
||||||
|
logging.info("model_file is {}".format(model_file))
|
||||||
|
if device == "cuda":
|
||||||
|
device = f"cuda:{torch.cuda.current_device()}"
|
||||||
|
model_dir = os.path.dirname(model_file)
|
||||||
|
model_name = os.path.basename(model_file)
|
||||||
|
if "model.ckpt-" in model_name or ".bin" in model_name:
|
||||||
|
model_name_pth = os.path.join(model_dir, model_name.replace('.bin',
|
||||||
|
'.pb')) if ".bin" in model_name else os.path.join(
|
||||||
|
model_dir, "{}.pb".format(model_name))
|
||||||
|
if os.path.exists(model_name_pth):
|
||||||
|
logging.info("model_file is load from pth: {}".format(model_name_pth))
|
||||||
|
model_dict = torch.load(model_name_pth, map_location=device)
|
||||||
|
else:
|
||||||
|
model_dict = convert_tf2torch(model, model_file, mode)
|
||||||
|
model.load_state_dict(model_dict)
|
||||||
|
else:
|
||||||
|
model_dict = torch.load(model_file, map_location=device)
|
||||||
|
model.load_state_dict(model_dict)
|
||||||
|
if model_name_pth is not None and not os.path.exists(model_name_pth):
|
||||||
|
torch.save(model_dict, model_name_pth)
|
||||||
|
logging.info("model_file is saved to pth: {}".format(model_name_pth))
|
||||||
|
|
||||||
|
return model, args
|
||||||
|
|
||||||
|
|
||||||
|
def convert_tf2torch(
|
||||||
|
model,
|
||||||
|
ckpt,
|
||||||
|
mode,
|
||||||
|
):
|
||||||
|
assert mode == "paraformer" or mode == "uniasr"
|
||||||
|
logging.info("start convert tf model to torch model")
|
||||||
|
from funasr.modules.streaming_utils.load_fr_tf import load_tf_dict
|
||||||
|
var_dict_tf = load_tf_dict(ckpt)
|
||||||
|
var_dict_torch = model.state_dict()
|
||||||
|
var_dict_torch_update = dict()
|
||||||
|
if mode == "uniasr":
|
||||||
|
# encoder
|
||||||
|
var_dict_torch_update_local = model.encoder.convert_tf2torch(var_dict_tf, var_dict_torch)
|
||||||
|
var_dict_torch_update.update(var_dict_torch_update_local)
|
||||||
|
# predictor
|
||||||
|
var_dict_torch_update_local = model.predictor.convert_tf2torch(var_dict_tf, var_dict_torch)
|
||||||
|
var_dict_torch_update.update(var_dict_torch_update_local)
|
||||||
|
# decoder
|
||||||
|
var_dict_torch_update_local = model.decoder.convert_tf2torch(var_dict_tf, var_dict_torch)
|
||||||
|
var_dict_torch_update.update(var_dict_torch_update_local)
|
||||||
|
# encoder2
|
||||||
|
var_dict_torch_update_local = model.encoder2.convert_tf2torch(var_dict_tf, var_dict_torch)
|
||||||
|
var_dict_torch_update.update(var_dict_torch_update_local)
|
||||||
|
# predictor2
|
||||||
|
var_dict_torch_update_local = model.predictor2.convert_tf2torch(var_dict_tf, var_dict_torch)
|
||||||
|
var_dict_torch_update.update(var_dict_torch_update_local)
|
||||||
|
# decoder2
|
||||||
|
var_dict_torch_update_local = model.decoder2.convert_tf2torch(var_dict_tf, var_dict_torch)
|
||||||
|
var_dict_torch_update.update(var_dict_torch_update_local)
|
||||||
|
# stride_conv
|
||||||
|
var_dict_torch_update_local = model.stride_conv.convert_tf2torch(var_dict_tf, var_dict_torch)
|
||||||
|
var_dict_torch_update.update(var_dict_torch_update_local)
|
||||||
|
else:
|
||||||
|
# encoder
|
||||||
|
var_dict_torch_update_local = model.encoder.convert_tf2torch(var_dict_tf, var_dict_torch)
|
||||||
|
var_dict_torch_update.update(var_dict_torch_update_local)
|
||||||
|
# predictor
|
||||||
|
var_dict_torch_update_local = model.predictor.convert_tf2torch(var_dict_tf, var_dict_torch)
|
||||||
|
var_dict_torch_update.update(var_dict_torch_update_local)
|
||||||
|
# decoder
|
||||||
|
var_dict_torch_update_local = model.decoder.convert_tf2torch(var_dict_tf, var_dict_torch)
|
||||||
|
var_dict_torch_update.update(var_dict_torch_update_local)
|
||||||
|
# bias_encoder
|
||||||
|
var_dict_torch_update_local = model.clas_convert_tf2torch(var_dict_tf, var_dict_torch)
|
||||||
|
var_dict_torch_update.update(var_dict_torch_update_local)
|
||||||
|
|
||||||
|
return var_dict_torch_update
|
||||||
Loading…
Reference in New Issue
Block a user