mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
Merge pull request #673 from alibaba-damo-academy/dev_clas
contextual paraformer related update: infer and finetune
This commit is contained in:
commit
c20c871e9f
@ -3,6 +3,10 @@ from modelscope.utils.constant import Tasks
|
||||
|
||||
param_dict = dict()
|
||||
param_dict['hotword'] = "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/hotword.txt"
|
||||
param_dict['clas_scale'] = 1.00 # 1.50 # set it larger if you want high recall (sacrifice general accuracy)
|
||||
# 13% relative recall raise over internal hotword test set (45%->51%)
|
||||
# CER might raise when utterance contains no hotword
|
||||
|
||||
inference_pipeline = pipeline(
|
||||
task=Tasks.auto_speech_recognition,
|
||||
model="damo/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404",
|
||||
|
||||
@ -280,6 +280,7 @@ class Speech2TextParaformer:
|
||||
nbest: int = 1,
|
||||
frontend_conf: dict = None,
|
||||
hotword_list_or_file: str = None,
|
||||
clas_scale: float = 1.0,
|
||||
decoding_ind: int = 0,
|
||||
**kwargs,
|
||||
):
|
||||
@ -376,6 +377,7 @@ class Speech2TextParaformer:
|
||||
# 6. [Optional] Build hotword list from str, local file or url
|
||||
self.hotword_list = None
|
||||
self.hotword_list = self.generate_hotwords_list(hotword_list_or_file)
|
||||
self.clas_scale = clas_scale
|
||||
|
||||
is_use_lm = lm_weight != 0.0 and lm_file is not None
|
||||
if (ctc_weight == 0.0 or asr_model.ctc == None) and not is_use_lm:
|
||||
@ -439,16 +441,20 @@ class Speech2TextParaformer:
|
||||
pre_token_length = pre_token_length.round().long()
|
||||
if torch.max(pre_token_length) < 1:
|
||||
return []
|
||||
if not isinstance(self.asr_model, ContextualParaformer) and not isinstance(self.asr_model,
|
||||
NeatContextualParaformer):
|
||||
if not isinstance(self.asr_model, ContextualParaformer) and \
|
||||
not isinstance(self.asr_model, NeatContextualParaformer):
|
||||
if self.hotword_list:
|
||||
logging.warning("Hotword is given but asr model is not a ContextualParaformer.")
|
||||
decoder_outs = self.asr_model.cal_decoder_with_predictor(enc, enc_len, pre_acoustic_embeds,
|
||||
pre_token_length)
|
||||
decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
|
||||
else:
|
||||
decoder_outs = self.asr_model.cal_decoder_with_predictor(enc, enc_len, pre_acoustic_embeds,
|
||||
pre_token_length, hw_list=self.hotword_list)
|
||||
decoder_outs = self.asr_model.cal_decoder_with_predictor(enc,
|
||||
enc_len,
|
||||
pre_acoustic_embeds,
|
||||
pre_token_length,
|
||||
hw_list=self.hotword_list,
|
||||
clas_scale=self.clas_scale)
|
||||
decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
|
||||
|
||||
if isinstance(self.asr_model, BiCifParaformer):
|
||||
|
||||
@ -257,6 +257,7 @@ def inference_paraformer(
|
||||
export_mode = param_dict.get("export_mode", False)
|
||||
else:
|
||||
hotword_list_or_file = None
|
||||
clas_scale = param_dict.get('clas_scale', 1.0)
|
||||
|
||||
if kwargs.get("device", None) == "cpu":
|
||||
ngpu = 0
|
||||
@ -289,6 +290,7 @@ def inference_paraformer(
|
||||
penalty=penalty,
|
||||
nbest=nbest,
|
||||
hotword_list_or_file=hotword_list_or_file,
|
||||
clas_scale=clas_scale,
|
||||
)
|
||||
|
||||
speech2text = Speech2TextParaformer(**speech2text_kwargs)
|
||||
|
||||
@ -85,7 +85,9 @@ def build_trainer(modelscope_dict,
|
||||
finetune_configs = yaml.safe_load(f)
|
||||
# set data_types
|
||||
if dataset_type == "large":
|
||||
finetune_configs["dataset_conf"]["data_types"] = "sound,text"
|
||||
# finetune_configs["dataset_conf"]["data_types"] = "sound,text"
|
||||
if 'data_types' not in finetune_configs['dataset_conf']:
|
||||
finetune_configs["dataset_conf"]["data_types"] = "sound,text"
|
||||
finetune_configs = update_dct(configs, finetune_configs)
|
||||
for key, value in finetune_configs.items():
|
||||
if hasattr(args, key):
|
||||
|
||||
@ -202,14 +202,7 @@ def Dataset(data_list_file,
|
||||
data_types = conf.get("data_types", "kaldi_ark,text")
|
||||
|
||||
pre_hwfile = conf.get("pre_hwlist", None)
|
||||
pre_prob = conf.get("pre_prob", 0) # unused yet
|
||||
|
||||
hw_config = {"sample_rate": conf.get("sample_rate", 0.6),
|
||||
"double_rate": conf.get("double_rate", 0.1),
|
||||
"hotword_min_length": conf.get("hotword_min_length", 2),
|
||||
"hotword_max_length": conf.get("hotword_max_length", 8),
|
||||
"pre_prob": conf.get("pre_prob", 0.0)}
|
||||
|
||||
# pre_prob = conf.get("pre_prob", 0) # unused yet
|
||||
if pre_hwfile is not None:
|
||||
pre_hwlist = []
|
||||
with open(pre_hwfile, 'r') as fin:
|
||||
@ -218,6 +211,15 @@ def Dataset(data_list_file,
|
||||
else:
|
||||
pre_hwlist = None
|
||||
|
||||
hw_config = {"sample_rate": conf.get("sample_rate", 0.6),
|
||||
"double_rate": conf.get("double_rate", 0.1),
|
||||
"hotword_min_length": conf.get("hotword_min_length", 2),
|
||||
"hotword_max_length": conf.get("hotword_max_length", 8),
|
||||
"pre_prob": conf.get("pre_prob", 0.0),
|
||||
"pre_hwlist": pre_hwlist}
|
||||
|
||||
|
||||
|
||||
dataset = AudioDataset(scp_lists,
|
||||
data_names,
|
||||
data_types,
|
||||
|
||||
@ -6,7 +6,8 @@ def sample_hotword(length,
|
||||
sample_rate,
|
||||
double_rate,
|
||||
pre_prob,
|
||||
pre_index=None):
|
||||
pre_index=None,
|
||||
pre_hwlist=None):
|
||||
if length < hotword_min_length:
|
||||
return [-1]
|
||||
if random.random() < sample_rate:
|
||||
|
||||
@ -54,7 +54,17 @@ def tokenize(data,
|
||||
|
||||
length = len(text)
|
||||
if 'hw_tag' in data:
|
||||
hotword_indxs = sample_hotword(length, **hw_config)
|
||||
if hw_config['pre_hwlist'] is not None and hw_config['pre_prob'] > 0:
|
||||
# enable preset hotword detect in sampling
|
||||
pre_index = None
|
||||
for hw in hw_config['pre_hwlist']:
|
||||
hw = " ".join(seg_tokenize(hw, seg_dict))
|
||||
_find = " ".join(text).find(hw)
|
||||
if _find != -1:
|
||||
# _find = text[:_find].count(" ") # bpe sometimes
|
||||
pre_index = [_find, _find + max(hw.count(" "), 1)]
|
||||
break
|
||||
hotword_indxs = sample_hotword(length, **hw_config, pre_index=pre_index)
|
||||
data['hotword_indxs'] = hotword_indxs
|
||||
del data['hw_tag']
|
||||
for i in range(length):
|
||||
|
||||
@ -244,6 +244,7 @@ class ContextualParaformerDecoder(ParaformerSANMDecoder):
|
||||
ys_in_pad: torch.Tensor,
|
||||
ys_in_lens: torch.Tensor,
|
||||
contextual_info: torch.Tensor,
|
||||
clas_scale: float = 1.0,
|
||||
return_hidden: bool = False,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Forward decoder.
|
||||
@ -283,7 +284,7 @@ class ContextualParaformerDecoder(ParaformerSANMDecoder):
|
||||
cx, tgt_mask, _, _, _ = self.bias_decoder(x_self_attn, tgt_mask, contextual_info, memory_mask=contextual_mask)
|
||||
|
||||
if self.bias_output is not None:
|
||||
x = torch.cat([x_src_attn, cx], dim=2)
|
||||
x = torch.cat([x_src_attn, cx*clas_scale], dim=2)
|
||||
x = self.bias_output(x.transpose(1, 2)).transpose(1, 2) # 2D -> D
|
||||
x = x_self_attn + self.dropout(x)
|
||||
|
||||
|
||||
@ -341,7 +341,7 @@ class NeatContextualParaformer(Paraformer):
|
||||
input_mask_expand_dim, 0)
|
||||
return sematic_embeds * tgt_mask, decoder_out * tgt_mask
|
||||
|
||||
def cal_decoder_with_predictor(self, encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, hw_list=None):
|
||||
def cal_decoder_with_predictor(self, encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, hw_list=None, clas_scale=1.0):
|
||||
if hw_list is None:
|
||||
hw_list = [torch.Tensor([1]).long().to(encoder_out.device)] # empty hotword list
|
||||
hw_list_pad = pad_list(hw_list, 0)
|
||||
@ -363,7 +363,7 @@ class NeatContextualParaformer(Paraformer):
|
||||
hw_embed = h_n.repeat(encoder_out.shape[0], 1, 1)
|
||||
|
||||
decoder_outs = self.decoder(
|
||||
encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, contextual_info=hw_embed
|
||||
encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, contextual_info=hw_embed, clas_scale=clas_scale
|
||||
)
|
||||
decoder_out = decoder_outs[0]
|
||||
decoder_out = torch.log_softmax(decoder_out, dim=-1)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user