From 267e2d09e643d6d0fd45a0d874bbaa0927152def Mon Sep 17 00:00:00 2001 From: lzr265946 Date: Fri, 17 Feb 2023 11:12:03 +0800 Subject: [PATCH 1/2] support paraformer-large-contextual with vad and punc model --- funasr/bin/asr_inference_paraformer_vad.py | 6 ++ .../bin/asr_inference_paraformer_vad_punc.py | 85 +++++++++++++++++-- 2 files changed, 84 insertions(+), 7 deletions(-) diff --git a/funasr/bin/asr_inference_paraformer_vad.py b/funasr/bin/asr_inference_paraformer_vad.py index c01c6ba5e..78fc5f32f 100644 --- a/funasr/bin/asr_inference_paraformer_vad.py +++ b/funasr/bin/asr_inference_paraformer_vad.py @@ -167,6 +167,11 @@ def inference_modelscope( level=log_level, format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", ) + + if param_dict is not None: + hotword_list_or_file = param_dict.get('hotword') + else: + hotword_list_or_file = None if ngpu >= 1 and torch.cuda.is_available(): device = "cuda" @@ -206,6 +211,7 @@ def inference_modelscope( ngram_weight=ngram_weight, penalty=penalty, nbest=nbest, + hotword_list_or_file=hotword_list_or_file, ) speech2text = Speech2Text(**speech2text_kwargs) text2punc = None diff --git a/funasr/bin/asr_inference_paraformer_vad_punc.py b/funasr/bin/asr_inference_paraformer_vad_punc.py index f194830b2..408b5b968 100644 --- a/funasr/bin/asr_inference_paraformer_vad_punc.py +++ b/funasr/bin/asr_inference_paraformer_vad_punc.py @@ -5,6 +5,10 @@ import argparse import logging import sys import time +import os +import codecs +import tempfile +import requests from pathlib import Path from typing import Optional from typing import Sequence @@ -41,7 +45,7 @@ from funasr.models.frontend.wav_frontend import WavFrontend from funasr.tasks.vad import VADTask from funasr.utils.timestamp_tools import time_stamp_lfr6_pl from funasr.bin.punctuation_infer import Text2Punc -from funasr.models.e2e_asr_paraformer import BiCifParaformer +from funasr.models.e2e_asr_paraformer import BiCifParaformer, ContextualParaformer header_colors = '\033[95m' end_colors = '\033[0m' @@ -79,6 +83,7 @@ class Speech2Text: penalty: float = 0.0, nbest: int = 1, frontend_conf: dict = None, + hotword_list_or_file: str = None, **kwargs, ): assert check_argument_types() @@ -169,6 +174,58 @@ class Speech2Text: self.asr_train_args = asr_train_args self.converter = converter self.tokenizer = tokenizer + + # 6. [Optional] Build hotword list from str, local file or url + # for None + if hotword_list_or_file is None: + self.hotword_list = None + # for text str input + elif not os.path.exists(hotword_list_or_file) and not hotword_list_or_file.startswith('http'): + logging.info("Attempting to parse hotwords as str...") + self.hotword_list = [] + hotword_str_list = [] + for hw in hotword_list_or_file.strip().split(): + hotword_str_list.append(hw) + self.hotword_list.append(self.converter.tokens2ids([i for i in hw])) + self.hotword_list.append([self.asr_model.sos]) + hotword_str_list.append('') + logging.info("Hotword list: {}.".format(hotword_str_list)) + # for local txt inputs + elif os.path.exists(hotword_list_or_file): + logging.info("Attempting to parse hotwords from local txt...") + self.hotword_list = [] + hotword_str_list = [] + with codecs.open(hotword_list_or_file, 'r') as fin: + for line in fin.readlines(): + hw = line.strip() + hotword_str_list.append(hw) + self.hotword_list.append(self.converter.tokens2ids([i for i in hw])) + self.hotword_list.append([self.asr_model.sos]) + hotword_str_list.append('') + logging.info("Initialized hotword list from file: {}, hotword list: {}." + .format(hotword_list_or_file, hotword_str_list)) + # for url, download and generate txt + else: + logging.info("Attempting to parse hotwords from url...") + work_dir = tempfile.TemporaryDirectory().name + if not os.path.exists(work_dir): + os.makedirs(work_dir) + text_file_path = os.path.join(work_dir, os.path.basename(hotword_list_or_file)) + local_file = requests.get(hotword_list_or_file) + open(text_file_path, "wb").write(local_file.content) + hotword_list_or_file = text_file_path + self.hotword_list = [] + hotword_str_list = [] + with codecs.open(hotword_list_or_file, 'r') as fin: + for line in fin.readlines(): + hw = line.strip() + hotword_str_list.append(hw) + self.hotword_list.append(self.converter.tokens2ids([i for i in hw])) + self.hotword_list.append([self.asr_model.sos]) + hotword_str_list.append('') + logging.info("Initialized hotword list from file: {}, hotword list: {}." + .format(hotword_list_or_file, hotword_str_list)) + is_use_lm = lm_weight != 0.0 and lm_file is not None if (ctc_weight == 0.0 or asr_model.ctc == None) and not is_use_lm: beam_search = None @@ -233,8 +290,15 @@ class Speech2Text: pre_token_length = pre_token_length.round().long() if torch.max(pre_token_length) < 1: return [] - decoder_outs = self.asr_model.cal_decoder_with_predictor(enc, enc_len, pre_acoustic_embeds, pre_token_length) - decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1] + + if not isinstance(self.asr_model, ContextualParaformer): + if self.hotword_list: + logging.warning("Hotword is given but asr model is not a ContextualParaformer.") + decoder_outs = self.asr_model.cal_decoder_with_predictor(enc, enc_len, pre_acoustic_embeds, pre_token_length) + decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1] + else: + decoder_outs = self.asr_model.cal_decoder_with_predictor(enc, enc_len, pre_acoustic_embeds, pre_token_length, hw_list=self.hotword_list) + decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1] if isinstance(self.asr_model, BiCifParaformer): _, _, us_alphas, us_cif_peak = self.asr_model.calc_predictor_timestamp(enc, enc_len, @@ -282,10 +346,11 @@ class Speech2Text: else: text = None - - timestamp = time_stamp_lfr6_pl(us_alphas[i], us_cif_peak[i], copy.copy(token), begin_time, end_time) - results.append((text, token, token_int, timestamp, enc_len_batch_total, lfr_factor)) - + if isinstance(self.asr_model, BiCifParaformer): + timestamp = time_stamp_lfr6_pl(us_alphas[i], us_cif_peak[i], copy.copy(token), begin_time, end_time) + results.append((text, token, token_int, timestamp, enc_len_batch_total, lfr_factor)) + else: + results.append((text, token, token_int, enc_len_batch_total, lfr_factor)) # assert check_return_type(results) return results @@ -512,6 +577,11 @@ def inference_modelscope( format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", ) + if param_dict is not None: + hotword_list_or_file = param_dict.get('hotword') + else: + hotword_list_or_file = None + if ngpu >= 1 and torch.cuda.is_available(): device = "cuda" else: @@ -550,6 +620,7 @@ def inference_modelscope( ngram_weight=ngram_weight, penalty=penalty, nbest=nbest, + hotword_list_or_file=hotword_list_or_file, ) speech2text = Speech2Text(**speech2text_kwargs) text2punc = None From 861f2f26056aea26da1a9eedc2f0288e07254ec2 Mon Sep 17 00:00:00 2001 From: lzr265946 Date: Fri, 17 Feb 2023 11:12:26 +0800 Subject: [PATCH 2/2] fix demo in egs_modelscope vad --- egs_modelscope/vad/speech_fsmn_vad_zh-cn-8k-common/infer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs_modelscope/vad/speech_fsmn_vad_zh-cn-8k-common/infer.py b/egs_modelscope/vad/speech_fsmn_vad_zh-cn-8k-common/infer.py index 6061413e5..71af48656 100644 --- a/egs_modelscope/vad/speech_fsmn_vad_zh-cn-8k-common/infer.py +++ b/egs_modelscope/vad/speech_fsmn_vad_zh-cn-8k-common/infer.py @@ -7,7 +7,7 @@ if __name__ == '__main__': inference_pipline = pipeline( task=Tasks.voice_activity_detection, model="damo/speech_fsmn_vad_zh-cn-8k-common", - model_revision='v1.1.1', + model_revision=None, output_dir='./output_dir', batch_size=1, )