diff --git a/egs_modelscope/punctuation/punc_ct-transformer_zh-cn-common-vadrealtime-vocab272727/infer.py b/egs_modelscope/punctuation/punc_ct-transformer_zh-cn-common-vadrealtime-vocab272727/infer.py index baad18884..a6629cdd0 100644 --- a/egs_modelscope/punctuation/punc_ct-transformer_zh-cn-common-vadrealtime-vocab272727/infer.py +++ b/egs_modelscope/punctuation/punc_ct-transformer_zh-cn-common-vadrealtime-vocab272727/infer.py @@ -17,13 +17,10 @@ inference_pipeline = pipeline( ) vads = inputs.split("|") - -cache_out = [] rec_result_all="outputs:" +param_dict = {"cache": []} for vad in vads: - rec_result = inference_pipeline(text_in=vad, cache=cache_out) - #print(rec_result) - cache_out = rec_result['cache'] + rec_result = inference_pipeline(text_in=vad, param_dict=param_dict) rec_result_all += rec_result['text'] print(rec_result_all) diff --git a/funasr/bin/punctuation_infer_vadrealtime.py b/funasr/bin/punctuation_infer_vadrealtime.py index ea8c04120..ce1cee8b0 100644 --- a/funasr/bin/punctuation_infer_vadrealtime.py +++ b/funasr/bin/punctuation_infer_vadrealtime.py @@ -226,7 +226,7 @@ def inference_modelscope( ): results = [] split_size = 10 - + cache_in = param_dict["cache"] if raw_inputs != None: line = raw_inputs.strip() key = "demo" @@ -234,34 +234,12 @@ def inference_modelscope( item = {'key': key, 'value': ""} results.append(item) return results - result, _, cache = text2punc(line, cache) - item = {'key': key, 'value': result, 'cache': cache} + result, _, cache = text2punc(line, cache_in) + param_dict["cache"] = cache + item = {'key': key, 'value': result} results.append(item) return results - for inference_text, _, _ in data_path_and_name_and_type: - with open(inference_text, "r", encoding="utf-8") as fin: - for line in fin: - line = line.strip() - segs = line.split("\t") - if len(segs) != 2: - continue - key = segs[0] - if len(segs[1]) == 0: - continue - result, _ = text2punc(segs[1]) - item = {'key': key, 'value': result} - results.append(item) - output_path = output_dir_v2 if output_dir_v2 is not None else output_dir - if output_path != None: - output_file_name = "infer.out" - Path(output_path).mkdir(parents=True, exist_ok=True) - output_file_path = (Path(output_path) / output_file_name).absolute() - with open(output_file_path, "w", encoding="utf-8") as fout: - for item_i in results: - key_out = item_i["key"] - value_out = item_i["value"] - fout.write(f"{key_out}\t{value_out}\n") return results return _forward diff --git a/funasr/runtime/python/websocket/ASR_server.py b/funasr/runtime/python/websocket/ASR_server.py index ee1a7c668..17f73bb54 100644 --- a/funasr/runtime/python/websocket/ASR_server.py +++ b/funasr/runtime/python/websocket/ASR_server.py @@ -53,7 +53,7 @@ speek = Queue() inference_pipeline_vad = pipeline( task=Tasks.voice_activity_detection, model=args.vad_model, - model_revision="v1.2.0", + model_revision=None, output_dir=None, batch_size=1, mode='online', @@ -62,7 +62,7 @@ inference_pipeline_vad = pipeline( param_dict_vad = {'in_cache': dict(), "is_final": False} # asr -param_dict_asr = dict() +param_dict_asr = {} # param_dict["hotword"] = "小五 小五月" # 设置热词,用空格隔开 inference_pipeline_asr = pipeline( task=Tasks.auto_speech_recognition, @@ -71,10 +71,11 @@ inference_pipeline_asr = pipeline( ngpu=args.ngpu, ) -inference_pipline_punc = pipeline( +param_dict_punc = {'cache': list()} +inference_pipeline_punc = pipeline( task=Tasks.punctuation, model=args.punc_model, - model_revision="v1.0.1", + model_revision=None, ngpu=args.ngpu, ) @@ -116,13 +117,16 @@ def vad(data): # 推理 def asr(): # 推理 global inference_pipeline2 - global speek + global speek, param_dict_punc while True: while not speek.empty(): audio_in = speek.get() speek.task_done() - rec_result = inference_pipeline_asr(audio_in=audio_in) - print(rec_result) + if len(audio_in) > 0: + rec_result = inference_pipeline_asr(audio_in=audio_in) + if 'text' in rec_result: + rec_result = inference_pipeline_punc(text_in=rec_result['text'], param_dict=param_dict_punc) + print(rec_result["text"]) time.sleep(0.1) time.sleep(0.1)