From 124c39a02c76f5cfafb44eef50f85f1f0c17d667 Mon Sep 17 00:00:00 2001 From: Xflick Date: Thu, 8 Aug 2024 20:11:54 +0800 Subject: [PATCH] LLMASR4 streaming input --- funasr/models/llm_asr/model.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/funasr/models/llm_asr/model.py b/funasr/models/llm_asr/model.py index 91113d961..e4a5e77d9 100644 --- a/funasr/models/llm_asr/model.py +++ b/funasr/models/llm_asr/model.py @@ -1161,9 +1161,15 @@ class LLMASR4(nn.Module): if isinstance(user_prompt, (list, tuple)): user_prompt, audio = user_prompt if i == 0: - source_input = f"<|im_start|>system\n{system_prompt}<|im_end|>\n<|im_start|>user\n{user_prompt}<|im_end|>\n<|im_start|>assistant\n" + if kwargs.get("infer_with_assistant_input", False): + source_input = f"<|im_start|>system\n{system_prompt}<|im_end|>\n<|im_start|>user\n{user_prompt}" + else: + source_input = f"<|im_start|>system\n{system_prompt}<|im_end|>\n<|im_start|>user\n{user_prompt}<|im_end|>\n<|im_start|>assistant\n" else: - source_input = f"<|im_start|>user\n{user_prompt}<|im_end|>\n<|im_start|>assistant\n" + if kwargs.get("infer_with_assistant_input", False): + source_input = f"<|im_start|>user\n{user_prompt}" + else: + source_input = f"<|im_start|>user\n{user_prompt}<|im_end|>\n<|im_start|>assistant\n" splits = pattern.split(source_input) source_ids = []