From df00f5fc0b9f61068df74349f6e001640931efc9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B8=B8=E9=9B=81?= Date: Sat, 8 Jun 2024 16:54:14 +0800 Subject: [PATCH] fix bug --- .../llm_asr/demo_speech2text.py | 30 ++++++++----------- funasr/models/llm_asr/model.py | 2 +- 2 files changed, 13 insertions(+), 19 deletions(-) diff --git a/examples/industrial_data_pretraining/llm_asr/demo_speech2text.py b/examples/industrial_data_pretraining/llm_asr/demo_speech2text.py index eb7e72f74..ed02373dd 100644 --- a/examples/industrial_data_pretraining/llm_asr/demo_speech2text.py +++ b/examples/industrial_data_pretraining/llm_asr/demo_speech2text.py @@ -6,29 +6,23 @@ from funasr import AutoModel model = AutoModel( - model="iic/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch", - vad_model="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch", - vad_kwargs={"max_single_segment_time": 60000}, - punc_model="iic/punc_ct-transformer_zh-cn-common-vocab272727-pytorch", - # spk_model="iic/speech_campplus_sv_zh-cn_16k-common", + model="/nfs/beinian.lzr/workspace/GPT-4o/Exp/exp6/4m-8gpu/exp6_speech2text_0607_linear_ddp", ) -res = model.generate( - input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav", - cache={}, +jsonl = ( + "/nfs/beinian.lzr/workspace/GPT-4o/Data/Speech2Text/TestData/aishell1_test_speech2text.jsonl" ) -print(res) +with open(jsonl, "r") as f: + lines = f.readlines() +for i, line in enumerate(lines): + data_dict = json.loads(line.strip()) + data = data_dict["messages"] -""" can not use currently -from funasr import AutoFrontend + res = model.generate( + input=data, + cache={}, + ) -frontend = AutoFrontend(model="iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch") - -fbanks = frontend(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav", batch_size=2) - -for batch_idx, fbank_dict in enumerate(fbanks): - res = model.generate(**fbank_dict) print(res) -""" diff --git a/funasr/models/llm_asr/model.py b/funasr/models/llm_asr/model.py index 78d9340a8..ff70c3ca6 100644 --- a/funasr/models/llm_asr/model.py +++ b/funasr/models/llm_asr/model.py @@ -532,7 +532,7 @@ class LLMASR2(nn.Module): loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) return loss, stats, weight - def data_template(self, data_in): + def data_template(self, data): system, user, assistant = [], [], [] for i, item in enumerate(data): role = item["role"]