This commit is contained in:
游雁 2024-06-13 15:38:17 +08:00
parent e42c693f0b
commit 5de8bfdcd8

View File

@ -21,6 +21,8 @@ from funasr.register import tables
from funasr.train_utils.device_funcs import to_device
import traceback
dtype_map = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}
@tables.register("model_classes", "LLMASR")
class LLMASR(nn.Module):
@ -449,6 +451,7 @@ class LLMASR2(nn.Module):
model.eval()
self.llm = model
llm_dim = model.get_input_embeddings().weight.shape[-1]
self.llm_dtype = llm_conf.get("llm_dtype", "fp32")
# adaptor
adaptor_class = tables.adaptor_classes.get(audio_adaptor)
@ -527,12 +530,15 @@ class LLMASR2(nn.Module):
batch_idx, :min_len, :
]
labels_ids[labels_ids == -1] = -100
attention_mask[attention_mask < 0] = 0
model_outputs = self.llm(
inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels_ids
)
loss = model_outputs.loss
with torch.cuda.amp.autocast(
enabled=True if self.llm_dtype != "fp32" else False, dtype=dtype_map[self.llm_dtype]
):
labels_ids[labels_ids == -1] = -100
attention_mask[attention_mask < 0] = 0
model_outputs = self.llm(
inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels_ids
)
loss = model_outputs.loss
stats = {}
with torch.no_grad():
@ -737,7 +743,6 @@ class LLMASR2(nn.Module):
llm_dtype = "fp16" if kwargs.get("fp16", False) else llm_dtype
llm_dtype = "bf16" if kwargs.get("bf16", False) else llm_dtype
dtype_map = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}
with torch.cuda.amp.autocast(
enabled=True if llm_dtype != "fp32" else False, dtype=dtype_map[llm_dtype]
):