Merge branch 'dev_gzf_deepspeed' of github.com:alibaba-damo-academy/FunASR into dev_gzf_deepspeed

merge
This commit is contained in:
游雁 2024-06-26 17:03:59 +08:00
commit 33591ef555

View File

@ -1247,7 +1247,8 @@ class LLMASR4(nn.Module):
return output
def inference(
def inference_prepare(
self,
data_in,
data_lengths=None,
@ -1326,6 +1327,22 @@ class LLMASR4(nn.Module):
] = speech_token
speech_idx += 1
return inputs_embeds, contents, batch, source_ids, meta_data
def inference(
self,
data_in,
data_lengths=None,
key: list = None,
tokenizer=None,
frontend=None,
**kwargs,
):
inputs_embeds, contents, batch, source_ids, meta_data = self.inference_prepare(
data_in, data_lengths, key, tokenizer, frontend, **kwargs
)
llm_dtype = kwargs.get("llm_dtype", "fp32")
if llm_dtype == "fp32":