add inference prepare func (#1848)

This commit is contained in:
PerfeZ 2024-06-26 15:35:31 +08:00 committed by GitHub
parent d9bdd0eb67
commit e3eb52f8bf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1241,7 +1241,8 @@ class LLMASR4(nn.Module):
return output
def inference(
def inference_prepare(
self,
data_in,
data_lengths=None,
@ -1319,6 +1320,22 @@ class LLMASR4(nn.Module):
] = speech_token
speech_idx += 1
return inputs_embeds, contents, batch, source_ids, meta_data
def inference(
self,
data_in,
data_lengths=None,
key: list = None,
tokenizer=None,
frontend=None,
**kwargs,
):
inputs_embeds, contents, batch, source_ids, meta_data = self.inference_prepare(
data_in, data_lengths, key, tokenizer, frontend, **kwargs
)
llm_dtype = kwargs.get("llm_dtype", "fp32")
if llm_dtype == "fp32":