This commit is contained in:
游雁 2024-06-11 13:56:24 +08:00
parent 63e60cc43d
commit a8653d897d

View File

@ -717,9 +717,9 @@ class LLMASR2(nn.Module):
dtype_map = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}
with torch.cuda.amp.autocast(dtype=dtype_map[llm_dtype]):
label = contents["assistant"][0]
self.llm = self.llm.to(dtype_map[llm_dtype])
inputs_embeds = inputs_embeds.to(dtype_map[llm_dtype])
attention_mask = attention_mask.to(dtype_map[llm_dtype])
# self.llm = self.llm.to(dtype_map[llm_dtype])
# inputs_embeds = inputs_embeds.to(dtype_map[llm_dtype])
if not kwargs.get("tearchforing", False):
generated_ids = self.llm.generate(
@ -739,6 +739,7 @@ class LLMASR2(nn.Module):
labels_ids = batch["labels_ids"]
labels_ids[labels_ids == -1] = -100
attention_mask = batch.get("attention_mask", None)
# attention_mask = attention_mask.to(dtype_map[llm_dtype])
model_outputs = self.llm(
inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels_ids
)