From 6c3d68927daa4852b663617ba727618975c66e91 Mon Sep 17 00:00:00 2001 From: "haoneng.lhn" Date: Fri, 24 Mar 2023 11:30:51 +0800 Subject: [PATCH] update paraformer streaming recipe --- .../infer.py | 58 ++---- .../bin/asr_inference_paraformer_streaming.py | 196 +++++++----------- 2 files changed, 99 insertions(+), 155 deletions(-) diff --git a/egs_modelscope/asr/paraformer/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8404-online/infer.py b/egs_modelscope/asr/paraformer/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8404-online/infer.py index c1c541ba8..2eb9cc8bf 100644 --- a/egs_modelscope/asr/paraformer/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8404-online/infer.py +++ b/egs_modelscope/asr/paraformer/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8404-online/infer.py @@ -1,57 +1,37 @@ +import os +import logging import torch -import torchaudio +import soundfile + from modelscope.pipelines import pipeline from modelscope.utils.constant import Tasks - from modelscope.utils.logger import get_logger -import logging + logger = get_logger(log_level=logging.CRITICAL) logger.setLevel(logging.CRITICAL) +os.environ["MODELSCOPE_CACHE"] = "./" inference_pipeline = pipeline( task=Tasks.auto_speech_recognition, model='damo/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8404-online', model_revision='v1.0.2') -waveform, sample_rate = torchaudio.load("waihu.wav") -speech_length = waveform.shape[1] -speech = waveform[0] +model_dir = os.path.join(os.environ["MODELSCOPE_CACHE"], "damo/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8404-online") +speech, sample_rate = soundfile.read(os.path.join(model_dir, "example/asr_example.wav")) +speech_length = speech.shape[0] -cache_en = {"start_idx": 0, "pad_left": 0, "stride": 10, "pad_right": 5, "cif_hidden": None, "cif_alphas": None} -cache_de = {"decode_fsmn": None} -cache = {"encoder": cache_en, "decoder": cache_de} -param_dict = {} -param_dict["cache"] = cache - -first_chunk = True -speech_buffer = speech -speech_cache = [] +sample_offset = 0 +step = 4800 #300ms +param_dict = {"cache": dict(), "is_final": False} final_result = "" -while len(speech_buffer) >= 960: - if first_chunk: - if len(speech_buffer) >= 14400: - rec_result = inference_pipeline(audio_in=speech_buffer[0:14400], param_dict=param_dict) - speech_buffer = speech_buffer[4800:] - else: - cache_en["stride"] = len(speech_buffer) // 960 - cache_en["pad_right"] = 0 - rec_result = inference_pipeline(audio_in=speech_buffer, param_dict=param_dict) - speech_buffer = [] - cache_en["start_idx"] = -5 - first_chunk = False - else: - cache_en["start_idx"] += 10 - if len(speech_buffer) >= 4800: - cache_en["pad_left"] = 5 - rec_result = inference_pipeline(audio_in=speech_buffer[:19200], param_dict=param_dict) - speech_buffer = speech_buffer[9600:] - else: - cache_en["stride"] = len(speech_buffer) // 960 - cache_en["pad_right"] = 0 - rec_result = inference_pipeline(audio_in=speech_buffer, param_dict=param_dict) - speech_buffer = [] - if len(rec_result) !=0 and rec_result['text'] != "sil": +for sample_offset in range(0, speech_length, min(step, speech_length - sample_offset)): + if sample_offset + step >= speech_length - 1: + step = speech_length - sample_offset + param_dict["is_final"] = True + rec_result = inference_pipeline(audio_in=speech[sample_offset: sample_offset + step], + param_dict=param_dict) + if len(rec_result) != 0 and rec_result['text'] != "sil" and rec_result['text'] != "waiting_for_more_voice": final_result += rec_result['text'] print(rec_result) print(final_result) diff --git a/funasr/bin/asr_inference_paraformer_streaming.py b/funasr/bin/asr_inference_paraformer_streaming.py index 9b572a0af..907f190b3 100644 --- a/funasr/bin/asr_inference_paraformer_streaming.py +++ b/funasr/bin/asr_inference_paraformer_streaming.py @@ -544,11 +544,6 @@ def inference_modelscope( ) export_mode = False - if param_dict is not None: - hotword_list_or_file = param_dict.get('hotword') - export_mode = param_dict.get("export_mode", False) - else: - hotword_list_or_file = None if ngpu >= 1 and torch.cuda.is_available(): device = "cuda" @@ -578,7 +573,6 @@ def inference_modelscope( ngram_weight=ngram_weight, penalty=penalty, nbest=nbest, - hotword_list_or_file=hotword_list_or_file, ) if export_mode: speech2text = Speech2TextExport(**speech2text_kwargs) @@ -594,123 +588,92 @@ def inference_modelscope( **kwargs, ): - hotword_list_or_file = None - if param_dict is not None: - hotword_list_or_file = param_dict.get('hotword') - if 'hotword' in kwargs: - hotword_list_or_file = kwargs['hotword'] - if hotword_list_or_file is not None or 'hotword' in kwargs: - speech2text.hotword_list = speech2text.generate_hotwords_list(hotword_list_or_file) - # 3. Build data-iterator if data_path_and_name_and_type is None and raw_inputs is not None: - if isinstance(raw_inputs, torch.Tensor): - raw_inputs = raw_inputs.numpy() - data_path_and_name_and_type = [raw_inputs, "speech", "waveform"] - loader = ASRTask.build_streaming_iterator( - data_path_and_name_and_type, - dtype=dtype, - fs=fs, - batch_size=batch_size, - key_file=key_file, - num_workers=num_workers, - preprocess_fn=ASRTask.build_preprocess_fn(speech2text.asr_train_args, False), - collate_fn=ASRTask.build_collate_fn(speech2text.asr_train_args, False), - allow_variable_data_keys=allow_variable_data_keys, - inference=True, - ) + if isinstance(raw_inputs, np.ndarray): + raw_inputs = torch.tensor(raw_inputs) - if param_dict is not None: - use_timestamp = param_dict.get('use_timestamp', True) - else: - use_timestamp = True - - forward_time_total = 0.0 - length_total = 0.0 - finish_count = 0 - file_count = 1 - cache = None + is_final = False + if param_dict is not None and "cache" in param_dict: + cache = param_dict["cache"] + if param_dict is not None and "is_final" in param_dict: + is_final = param_dict["is_final"] # 7 .Start for-loop # FIXME(kamo): The output format should be discussed about asr_result_list = [] - output_path = output_dir_v2 if output_dir_v2 is not None else output_dir - if output_path is not None: - writer = DatadirWriter(output_path) + results = [] + asr_result = "" + wait = True + if len(cache) == 0: + cache["encoder"] = {"start_idx": 0, "pad_left": 0, "stride": 10, "pad_right": 5, "cif_hidden": None, "cif_alphas": None} + cache_de = {"decode_fsmn": None} + cache["decoder"] = cache_de + cache["first_chunk"] = True + cache["speech"] = [] + cache["chunk_index"] = 0 + cache["speech_chunk"] = [] + + if raw_inputs is not None: + if len(cache["speech"]) == 0: + cache["speech"] = raw_inputs + else: + cache["speech"] = torch.cat([cache["speech"], raw_inputs], dim=0) + if len(cache["speech_chunk"]) == 0: + cache["speech_chunk"] = raw_inputs + else: + cache["speech_chunk"] = torch.cat([cache["speech_chunk"], raw_inputs], dim=0) + while len(cache["speech_chunk"]) >= 960: + if cache["first_chunk"]: + if len(cache["speech_chunk"]) >= 14400: + speech = torch.unsqueeze(cache["speech_chunk"][0:14400], axis=0) + speech_length = torch.tensor([14400]) + results = speech2text(cache, speech, speech_length) + cache["speech_chunk"]= cache["speech_chunk"][4800:] + cache["first_chunk"] = False + cache["encoder"]["start_idx"] = -5 + wait = False + else: + if is_final: + cache["encoder"]["stride"] = len(cache["speech_chunk"]) // 960 + cache["encoder"]["pad_right"] = 0 + speech = torch.unsqueeze(cache["speech_chunk"], axis=0) + speech_length = torch.tensor([len(cache["speech_chunk"])]) + results = speech2text(cache, speech, speech_length) + cache["speech_chunk"] = [] + wait = False + else: + break + else: + if len(cache["speech_chunk"]) >= 19200: + cache["encoder"]["start_idx"] += 10 + cache["encoder"]["pad_left"] = 5 + speech = torch.unsqueeze(cache["speech_chunk"][:19200], axis=0) + speech_length = torch.tensor([19200]) + results = speech2text(cache, speech, speech_length) + cache["speech_chunk"] = cache["speech_chunk"][9600:] + wait = False + else: + if is_final: + cache["encoder"]["stride"] = len(cache["speech_chunk"]) // 960 + cache["encoder"]["pad_right"] = 0 + speech = torch.unsqueeze(cache["speech_chunk"], axis=0) + speech_length = torch.tensor([len(cache["speech_chunk"])]) + results = speech2text(cache, speech, speech_length) + cache["speech_chunk"] = [] + wait = False + else: + break + + if len(results) >= 1: + asr_result += results[0][0] + if asr_result == "": + asr_result = "sil" + if wait: + asr_result = "waiting_for_more_voice" + item = {'key': "utt", 'value': asr_result} + asr_result_list.append(item) else: - writer = None - if param_dict is not None and "cache" in param_dict: - cache = param_dict["cache"] - for keys, batch in loader: - assert isinstance(batch, dict), type(batch) - assert all(isinstance(s, str) for s in keys), keys - _bs = len(next(iter(batch.values()))) - assert len(keys) == _bs, f"{len(keys)} != {_bs}" - # batch = {k: v for k, v in batch.items() if not k.endswith("_lengths")} - logging.info("decoding, utt_id: {}".format(keys)) - # N-best list of (text, token, token_int, hyp_object) - - time_beg = time.time() - results = speech2text(cache=cache, **batch) - if len(results) < 1: - hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[]) - results = [[" ", ["sil"], [2], hyp, 10, 6]] * nbest - time_end = time.time() - forward_time = time_end - time_beg - lfr_factor = results[0][-1] - length = results[0][-2] - forward_time_total += forward_time - length_total += length - rtf_cur = "decoding, feature length: {}, forward_time: {:.4f}, rtf: {:.4f}".format(length, forward_time, - 100 * forward_time / ( - length * lfr_factor)) - logging.info(rtf_cur) - - for batch_id in range(_bs): - result = [results[batch_id][:-2]] - - key = keys[batch_id] - for n, result in zip(range(1, nbest + 1), result): - text, token, token_int, hyp = result[0], result[1], result[2], result[3] - time_stamp = None if len(result) < 5 else result[4] - # Create a directory: outdir/{n}best_recog - if writer is not None: - ibest_writer = writer[f"{n}best_recog"] - - # Write the result to each file - ibest_writer["token"][key] = " ".join(token) - # ibest_writer["token_int"][key] = " ".join(map(str, token_int)) - ibest_writer["score"][key] = str(hyp.score) - ibest_writer["rtf"][key] = rtf_cur - - if text is not None: - if use_timestamp and time_stamp is not None: - postprocessed_result = postprocess_utils.sentence_postprocess(token, time_stamp) - else: - postprocessed_result = postprocess_utils.sentence_postprocess(token) - time_stamp_postprocessed = "" - if len(postprocessed_result) == 3: - text_postprocessed, time_stamp_postprocessed, word_lists = postprocessed_result[0], \ - postprocessed_result[1], \ - postprocessed_result[2] - else: - text_postprocessed, word_lists = postprocessed_result[0], postprocessed_result[1] - item = {'key': key, 'value': text_postprocessed} - if time_stamp_postprocessed != "": - item['time_stamp'] = time_stamp_postprocessed - asr_result_list.append(item) - finish_count += 1 - # asr_utils.print_progress(finish_count / file_count) - if writer is not None: - ibest_writer["text"][key] = text_postprocessed - - logging.info("decoding, utt: {}, predictions: {}".format(key, text)) - rtf_avg = "decoding, feature length total: {}, forward_time total: {:.4f}, rtf avg: {:.4f}".format(length_total, - forward_time_total, - 100 * forward_time_total / ( - length_total * lfr_factor)) - logging.info(rtf_avg) - if writer is not None: - ibest_writer["rtf"]["rtf_avf"] = rtf_avg + return [] return asr_result_list return _forward @@ -905,3 +868,4 @@ if __name__ == "__main__": # rec_result = inference_16k_pipline(audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav') # print(rec_result) +