This commit is contained in:
游雁 2024-07-26 11:53:21 +08:00
parent c607f7e8c4
commit 6c358b9a3c

View File

@ -1376,11 +1376,13 @@ class LLMASR4(nn.Module):
label = contents["assistant"][-1]
self.llm = self.llm.to(dtype_map[llm_dtype])
inputs_embeds = inputs_embeds.to(dtype_map[llm_dtype])
llm_kwargs = kwargs.get("llm_kwargs", {})
if not kwargs.get("tearchforing", False):
generated_ids = self.llm.generate(
inputs_embeds=inputs_embeds, max_new_tokens=kwargs.get("max_length", 512)
inputs_embeds=inputs_embeds,
max_new_tokens=kwargs.get("max_length", 512),
**llm_kwargs,
)
# generated_ids = [
# output_ids[len(input_id) :]
@ -1398,7 +1400,10 @@ class LLMASR4(nn.Module):
attention_mask = batch.get("attention_mask", None)
# attention_mask = attention_mask.to(dtype_map[llm_dtype])
model_outputs = self.llm(
inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels_ids
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
labels=labels_ids,
**llm_kwargs,
)
preds = torch.argmax(model_outputs.logits, -1)[:, source_ids.shape[1] :]