diff --git a/funasr/bin/asr_infer.py b/funasr/bin/asr_infer.py index a537a73fe..0ce8dd849 100644 --- a/funasr/bin/asr_infer.py +++ b/funasr/bin/asr_infer.py @@ -285,6 +285,7 @@ class Speech2TextParaformer: nbest: int = 1, frontend_conf: dict = None, hotword_list_or_file: str = None, + clas_scale: float = 1.0, decoding_ind: int = 0, **kwargs, ): @@ -382,6 +383,7 @@ class Speech2TextParaformer: # 6. [Optional] Build hotword list from str, local file or url self.hotword_list = None self.hotword_list = self.generate_hotwords_list(hotword_list_or_file) + self.clas_scale = clas_scale is_use_lm = lm_weight != 0.0 and lm_file is not None if (ctc_weight == 0.0 or asr_model.ctc == None) and not is_use_lm: @@ -446,16 +448,20 @@ class Speech2TextParaformer: pre_token_length = pre_token_length.round().long() if torch.max(pre_token_length) < 1: return [] - if not isinstance(self.asr_model, ContextualParaformer) and not isinstance(self.asr_model, - NeatContextualParaformer): + if not isinstance(self.asr_model, ContextualParaformer) and \ + not isinstance(self.asr_model, NeatContextualParaformer): if self.hotword_list: logging.warning("Hotword is given but asr model is not a ContextualParaformer.") decoder_outs = self.asr_model.cal_decoder_with_predictor(enc, enc_len, pre_acoustic_embeds, pre_token_length) decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1] else: - decoder_outs = self.asr_model.cal_decoder_with_predictor(enc, enc_len, pre_acoustic_embeds, - pre_token_length, hw_list=self.hotword_list) + decoder_outs = self.asr_model.cal_decoder_with_predictor(enc, + enc_len, + pre_acoustic_embeds, + pre_token_length, + hw_list=self.hotword_list, + clas_scale=self.clas_scale) decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1] if isinstance(self.asr_model, BiCifParaformer): diff --git a/funasr/bin/asr_inference_launch.py b/funasr/bin/asr_inference_launch.py index 5d1b80488..026874e28 100644 --- a/funasr/bin/asr_inference_launch.py +++ b/funasr/bin/asr_inference_launch.py @@ -260,6 +260,7 @@ def inference_paraformer( export_mode = param_dict.get("export_mode", False) else: hotword_list_or_file = None + clas_scale = param_dict.get('clas_scale', 1.0) if kwargs.get("device", None) == "cpu": ngpu = 0 @@ -292,6 +293,7 @@ def inference_paraformer( penalty=penalty, nbest=nbest, hotword_list_or_file=hotword_list_or_file, + clas_sacle=clas_scale, ) speech2text = Speech2TextParaformer(**speech2text_kwargs) diff --git a/funasr/models/decoder/contextual_decoder.py b/funasr/models/decoder/contextual_decoder.py index 78105ab31..18d486136 100644 --- a/funasr/models/decoder/contextual_decoder.py +++ b/funasr/models/decoder/contextual_decoder.py @@ -246,6 +246,7 @@ class ContextualParaformerDecoder(ParaformerSANMDecoder): ys_in_pad: torch.Tensor, ys_in_lens: torch.Tensor, contextual_info: torch.Tensor, + clas_scale: float = 1.0, return_hidden: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor]: """Forward decoder. @@ -285,7 +286,7 @@ class ContextualParaformerDecoder(ParaformerSANMDecoder): cx, tgt_mask, _, _, _ = self.bias_decoder(x_self_attn, tgt_mask, contextual_info, memory_mask=contextual_mask) if self.bias_output is not None: - x = torch.cat([x_src_attn, cx], dim=2) + x = torch.cat([x_src_attn, cx*clas_scale], dim=2) x = self.bias_output(x.transpose(1, 2)).transpose(1, 2) # 2D -> D x = x_self_attn + self.dropout(x) diff --git a/funasr/models/e2e_asr_contextual_paraformer.py b/funasr/models/e2e_asr_contextual_paraformer.py index cfb500815..b8ae951fc 100644 --- a/funasr/models/e2e_asr_contextual_paraformer.py +++ b/funasr/models/e2e_asr_contextual_paraformer.py @@ -343,7 +343,7 @@ class NeatContextualParaformer(Paraformer): input_mask_expand_dim, 0) return sematic_embeds * tgt_mask, decoder_out * tgt_mask - def cal_decoder_with_predictor(self, encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, hw_list=None): + def cal_decoder_with_predictor(self, encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, hw_list=None, clas_scale=1.0): if hw_list is None: hw_list = [torch.Tensor([1]).long().to(encoder_out.device)] # empty hotword list hw_list_pad = pad_list(hw_list, 0) @@ -365,7 +365,7 @@ class NeatContextualParaformer(Paraformer): hw_embed = h_n.repeat(encoder_out.shape[0], 1, 1) decoder_outs = self.decoder( - encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, contextual_info=hw_embed + encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, contextual_info=hw_embed, clas_scale=clas_scale ) decoder_out = decoder_outs[0] decoder_out = torch.log_softmax(decoder_out, dim=-1)