mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
fsmn_kws_mt finetune and inference adapt to right modelscope hub (#2113)
Co-authored-by: pengteng.spt <pengteng.spt@alibaba-inc.com>
This commit is contained in:
parent
5b53b8fb7b
commit
a8f0aad81d
@ -49,8 +49,6 @@ def convert_to_kaldi(
|
||||
copyfile(network_file, os.path.join(model_dir, 'origin.torch.pt'))
|
||||
|
||||
model = FsmnKWSMTConvert(
|
||||
vocab_size=configs['encoder_conf']['output_dim'],
|
||||
vocab_size2=configs['encoder_conf']['output_dim2'],
|
||||
encoder='FSMNMTConvert',
|
||||
encoder_conf=configs['encoder_conf'],
|
||||
ctc_conf=configs['ctc_conf'],
|
||||
@ -82,8 +80,6 @@ def convert_to_pytorch(
|
||||
model_name="convert.torch.pt"
|
||||
):
|
||||
model = FsmnKWSMTConvert(
|
||||
vocab_size=configs['encoder_conf']['output_dim'],
|
||||
vocab_size2=configs['encoder_conf']['output_dim2'],
|
||||
encoder='FSMNMTConvert',
|
||||
encoder_conf=configs['encoder_conf'],
|
||||
ctc_conf=configs['ctc_conf'],
|
||||
|
||||
@ -5,16 +5,16 @@ workspace=`pwd`
|
||||
local_path_root=${workspace}/modelscope_models
|
||||
mkdir -p ${local_path_root}
|
||||
|
||||
local_path=${local_path_root}/speech_charctc_kws_phone-xiaoyun
|
||||
local_path=${local_path_root}/speech_charctc_kws_phone-xiaoyun_mt
|
||||
if [ ! -d "$local_path" ]; then
|
||||
git clone https://www.modelscope.cn/iic/speech_charctc_kws_phone-xiaoyun.git ${local_path}
|
||||
git clone https://www.modelscope.cn/iic/speech_charctc_kws_phone-xiaoyun_mt.git ${local_path}
|
||||
fi
|
||||
|
||||
export PATH=${local_path}/runtime:$PATH
|
||||
export LD_LIBRARY_PATH=${local_path}/runtime:$LD_LIBRARY_PATH
|
||||
|
||||
# finetune config file
|
||||
config=./conf/fsmn_4e_l10r2_280_200_fdim40_t2602_t4.yaml
|
||||
config=./conf/fsmn_4e_l10r2_250_128_fdim80_t2599_t4.yaml
|
||||
|
||||
# finetune output checkpoint
|
||||
torch_nnet=exp/finetune_outputs/model.pt.avg10
|
||||
|
||||
@ -6,7 +6,7 @@
|
||||
from funasr import AutoModel
|
||||
|
||||
model = AutoModel(
|
||||
model="iic/speech_charctc_kws_phone-xiaoyun",
|
||||
model="iic/speech_charctc_kws_phone-xiaoyun_mt",
|
||||
keywords="小云小云",
|
||||
output_dir="./outputs/debug",
|
||||
device='cpu'
|
||||
|
||||
@ -27,19 +27,19 @@ test_sets="test"
|
||||
# model_name from model_hub, or model_dir in local path
|
||||
|
||||
## option 1, download model automatically, unsupported currently
|
||||
model_name_or_model_dir="iic/speech_charctc_kws_phone-xiaoyun"
|
||||
model_name_or_model_dir="iic/speech_charctc_kws_phone-xiaoyun_mt"
|
||||
|
||||
## option 2, download model by git
|
||||
local_path_root=${workspace}/modelscope_models
|
||||
model_name_or_model_dir=${local_path_root}/${model_name_or_model_dir}
|
||||
if [ ! -d $model_name_or_model_dir ]; then
|
||||
mkdir -p ${model_name_or_model_dir}
|
||||
git clone https://www.modelscope.cn/iic/speech_charctc_kws_phone-xiaoyun.git ${model_name_or_model_dir}
|
||||
git clone https://www.modelscope.cn/iic/speech_charctc_kws_phone-xiaoyun_mt.git ${model_name_or_model_dir}
|
||||
fi
|
||||
|
||||
config=fsmn_4e_l10r2_250_128_fdim80_t2599_t4.yaml
|
||||
token_list=${model_name_or_model_dir}/funasr/tokens_2599.txt
|
||||
token_list2=${model_name_or_model_dir}/funasr/tokens_xiaoyun_char.txt
|
||||
token_list2=${model_name_or_model_dir}/funasr/tokens_xiaoyun.txt
|
||||
lexicon_list=${model_name_or_model_dir}/funasr/lexicon.txt
|
||||
cmvn_file=${model_name_or_model_dir}/funasr/am.mvn.dim80_l2r2
|
||||
init_param="${model_name_or_model_dir}/funasr/basetrain_fsmn_4e_l10r2_250_128_fdim80_t2599.pt"
|
||||
@ -141,10 +141,8 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
|
||||
--config-path="${output_dir}" \
|
||||
--config-name="config.yaml" \
|
||||
++init_param="${output_dir}/${inference_checkpoint}" \
|
||||
++tokenizer_conf.token_list="${token_list}" \
|
||||
++tokenizer_conf.seg_dict="${lexicon_list}" \
|
||||
++tokenizer2_conf.token_list="${token_list2}" \
|
||||
++tokenizer2_conf.seg_dict="${lexicon_list}" \
|
||||
++token_lists='['''${token_list}''', '''${token_list2}''']' \
|
||||
++seg_dicts='['''${lexicon_list}''', '''${lexicon_list}''']' \
|
||||
++frontend_conf.cmvn_file="${cmvn_file}" \
|
||||
++keywords="\"$keywords_string"\" \
|
||||
++input="${_logdir}/keys.${JOB}.scp" \
|
||||
|
||||
@ -3,7 +3,7 @@
|
||||
|
||||
# method1, inference from model hub
|
||||
|
||||
model="iic/speech_charctc_kws_phone-xiaoyun"
|
||||
model="iic/speech_charctc_kws_phone-xiaoyun_mt"
|
||||
|
||||
# for more input type, please ref to readme.md
|
||||
input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/KWS/pos_testset/kws_xiaoyunxiaoyun.wav"
|
||||
|
||||
@ -13,14 +13,14 @@ workspace=`pwd`
|
||||
# download model
|
||||
local_path_root=${workspace}/modelscope_models
|
||||
mkdir -p ${local_path_root}
|
||||
local_path=${local_path_root}/speech_charctc_kws_phone-xiaoyun
|
||||
git clone https://www.modelscope.cn/iic/speech_charctc_kws_phone-xiaoyun.git ${local_path}
|
||||
local_path=${local_path_root}/speech_charctc_kws_phone-xiaoyun_mt
|
||||
git clone https://www.modelscope.cn/iic/speech_charctc_kws_phone-xiaoyun_mt.git ${local_path}
|
||||
|
||||
device="cuda:0" # "cuda:0" for gpu0, "cuda:1" for gpu1, "cpu"
|
||||
|
||||
config="inference_fsmn_4e_l10r2_280_200_fdim40_t2602_t4.yaml"
|
||||
tokens="${local_path}/funasr/tokens_2602.txt"
|
||||
tokens2="${local_path}/funasr/tokens_xiaoyun_char.txt"
|
||||
tokens2="${local_path}/funasr/tokens_xiaoyun.txt"
|
||||
seg_dict="${local_path}/funasr/lexicon.txt"
|
||||
init_param="${local_path}/funasr/finetune_fsmn_4e_l10r2_280_200_fdim40_t2602_t4_xiaoyun_xiaoyun.pt"
|
||||
cmvn_file="${local_path}/funasr/am.mvn.dim40_l4r4"
|
||||
@ -34,10 +34,8 @@ python -m funasr.bin.inference \
|
||||
--config-name "${config}" \
|
||||
++init_param="${init_param}" \
|
||||
++frontend_conf.cmvn_file="${cmvn_file}" \
|
||||
++tokenizer_conf.token_list="${tokens}" \
|
||||
++tokenizer_conf.seg_dict="${seg_dict}" \
|
||||
++tokenizer2_conf.token_list="${tokens2}" \
|
||||
++tokenizer2_conf.seg_dict="${seg_dict}" \
|
||||
++token_lists='['''${tokens}''', '''${tokens2}''']' \
|
||||
++seg_dicts='['''${seg_dict}''', '''${seg_dict}''']' \
|
||||
++input="${input}" \
|
||||
++output_dir="${output_dir}" \
|
||||
++device="${device}" \
|
||||
|
||||
@ -199,6 +199,7 @@ class AutoModel:
|
||||
tokenizers_build = []
|
||||
vocab_sizes = []
|
||||
token_lists = []
|
||||
|
||||
### === only for kws ===
|
||||
token_list_files = kwargs.get("token_lists", [])
|
||||
seg_dicts = kwargs.get("seg_dicts", [])
|
||||
@ -213,9 +214,9 @@ class AutoModel:
|
||||
|
||||
### === only for kws ===
|
||||
if len(token_list_files) > 1:
|
||||
tokenizer_conf.token_list = token_list_files[i]
|
||||
tokenizer_conf["token_list"] = token_list_files[i]
|
||||
if len(seg_dicts) > 1:
|
||||
tokenizer_conf.seg_dict = seg_dicts[i]
|
||||
tokenizer_conf["seg_dict"] = seg_dicts[i]
|
||||
### === only for kws ===
|
||||
|
||||
tokenizer = tokenizer_class(**tokenizer_conf)
|
||||
|
||||
@ -162,6 +162,7 @@ def add_file_root_path(model_or_path: str, file_path_metas: dict, cfg={}):
|
||||
if isinstance(file_path_metas, dict):
|
||||
if isinstance(cfg, list):
|
||||
cfg.append({})
|
||||
|
||||
for k, v in file_path_metas.items():
|
||||
if isinstance(v, str):
|
||||
p = os.path.join(model_or_path, v)
|
||||
@ -186,8 +187,8 @@ def add_file_root_path(model_or_path: str, file_path_metas: dict, cfg={}):
|
||||
if k not in cfg:
|
||||
cfg[k] = []
|
||||
if isinstance(vv, str):
|
||||
p = os.path.join(model_or_path, v)
|
||||
file_path_metas[i] = p
|
||||
p = os.path.join(model_or_path, vv)
|
||||
# file_path_metas[i] = p
|
||||
if os.path.exists(p):
|
||||
if isinstance(cfg[k], dict):
|
||||
cfg[k] = p
|
||||
|
||||
@ -41,8 +41,7 @@ class FsmnKWSMT(torch.nn.Module):
|
||||
encoder_conf: Optional[Dict] = None,
|
||||
ctc_conf: Optional[Dict] = None,
|
||||
input_size: int = 360,
|
||||
vocab_size: int = -1,
|
||||
vocab_size2: int = -1,
|
||||
vocab_size: list = [],
|
||||
ignore_id: int = -1,
|
||||
blank_id: int = 0,
|
||||
**kwargs,
|
||||
@ -63,14 +62,13 @@ class FsmnKWSMT(torch.nn.Module):
|
||||
encoder_output_size2 = encoder.output_size2()
|
||||
|
||||
ctc = CTC(
|
||||
odim=vocab_size, encoder_output_size=encoder_output_size, **ctc_conf
|
||||
odim=vocab_size[0], encoder_output_size=encoder_output_size, **ctc_conf
|
||||
)
|
||||
ctc2 = CTC(
|
||||
odim=vocab_size2, encoder_output_size=encoder_output_size2, **ctc_conf
|
||||
odim=vocab_size[1], encoder_output_size=encoder_output_size2, **ctc_conf
|
||||
)
|
||||
|
||||
self.blank_id = blank_id
|
||||
self.vocab_size = vocab_size
|
||||
self.ignore_id = ignore_id
|
||||
|
||||
# self.frontend = frontend
|
||||
@ -208,7 +206,6 @@ class FsmnKWSMT(torch.nn.Module):
|
||||
data_lengths=None,
|
||||
key: list=None,
|
||||
tokenizer=None,
|
||||
tokenizer2=None,
|
||||
frontend=None,
|
||||
**kwargs,
|
||||
):
|
||||
@ -217,14 +214,14 @@ class FsmnKWSMT(torch.nn.Module):
|
||||
self.kws_decoder = KwsCtcPrefixDecoder(
|
||||
ctc=self.ctc,
|
||||
keywords=keywords,
|
||||
token_list=tokenizer.token_list,
|
||||
seg_dict=tokenizer.seg_dict,
|
||||
token_list=tokenizer[0].token_list,
|
||||
seg_dict=tokenizer[0].seg_dict,
|
||||
)
|
||||
self.kws_decoder2 = KwsCtcPrefixDecoder(
|
||||
ctc=self.ctc2,
|
||||
keywords=keywords,
|
||||
token_list=tokenizer2.token_list,
|
||||
seg_dict=tokenizer2.seg_dict,
|
||||
token_list=tokenizer[1].token_list,
|
||||
seg_dict=tokenizer[1].seg_dict,
|
||||
)
|
||||
|
||||
meta_data = {}
|
||||
@ -314,12 +311,9 @@ class FsmnKWSMTConvert(torch.nn.Module):
|
||||
self,
|
||||
encoder: str = None,
|
||||
encoder_conf: Optional[Dict] = None,
|
||||
ctc: str = None,
|
||||
ctc_conf: Optional[Dict] = None,
|
||||
ctc_weight: float = 1.0,
|
||||
input_size: int = 360,
|
||||
vocab_size: int = -1,
|
||||
vocab_size2: int = -1,
|
||||
blank_id: int = 0,
|
||||
**kwargs,
|
||||
):
|
||||
@ -328,18 +322,8 @@ class FsmnKWSMTConvert(torch.nn.Module):
|
||||
encoder_class = tables.encoder_classes.get(encoder)
|
||||
encoder = encoder_class(**encoder_conf)
|
||||
encoder_output_size = encoder.output_size()
|
||||
|
||||
if ctc_conf is None:
|
||||
ctc_conf = {}
|
||||
ctc = CTC(
|
||||
odim=vocab_size, encoder_output_size=encoder_output_size, **ctc_conf
|
||||
)
|
||||
|
||||
self.blank_id = blank_id
|
||||
self.vocab_size = vocab_size
|
||||
self.ctc_weight = ctc_weight
|
||||
self.encoder = encoder
|
||||
self.ctc = ctc
|
||||
|
||||
self.error_calculator = None
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user