mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
fp16
This commit is contained in:
parent
ea85c483ad
commit
23008c7cac
@ -706,37 +706,43 @@ class LLMASR2(nn.Module):
|
||||
batch_idx, :min_len, :
|
||||
]
|
||||
|
||||
label = contents["assistant"][0]
|
||||
if not kwargs.get("tearchforing", False):
|
||||
llm_dtype = kwargs.get("llm_dtype", "fp32")
|
||||
dtype_map = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}
|
||||
with torch.cuda.amp.autocast(dtype=dtype_map[llm_dtype]):
|
||||
label = contents["assistant"][0]
|
||||
self.llm = self.llm.to(dtype_map[llm_dtype])
|
||||
inputs_embeds = inputs_embeds.to(dtype_map[llm_dtype])
|
||||
attention_mask = attention_mask.to(dtype_map[llm_dtype])
|
||||
if not kwargs.get("tearchforing", False):
|
||||
|
||||
generated_ids = self.llm.generate(
|
||||
inputs_embeds=inputs_embeds, max_new_tokens=kwargs.get("max_length", 512)
|
||||
)
|
||||
# generated_ids = [
|
||||
# output_ids[len(input_id) :]
|
||||
# for input_id, output_ids in zip(input_ids, generated_ids)
|
||||
# ]
|
||||
response = tokenizer.batch_decode(
|
||||
generated_ids, skip_special_tokens=kwargs.get("skip_special_tokens", True)
|
||||
)[0]
|
||||
generated_ids = self.llm.generate(
|
||||
inputs_embeds=inputs_embeds, max_new_tokens=kwargs.get("max_length", 512)
|
||||
)
|
||||
# generated_ids = [
|
||||
# output_ids[len(input_id) :]
|
||||
# for input_id, output_ids in zip(input_ids, generated_ids)
|
||||
# ]
|
||||
response = tokenizer.batch_decode(
|
||||
generated_ids, skip_special_tokens=kwargs.get("skip_special_tokens", True)
|
||||
)[0]
|
||||
|
||||
loss = None
|
||||
else:
|
||||
loss = None
|
||||
else:
|
||||
|
||||
labels_ids = batch["labels_ids"]
|
||||
labels_ids[labels_ids == -1] = -100
|
||||
attention_mask = batch.get("attention_mask", None)
|
||||
model_outputs = self.llm(
|
||||
inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels_ids
|
||||
)
|
||||
labels_ids = batch["labels_ids"]
|
||||
labels_ids[labels_ids == -1] = -100
|
||||
attention_mask = batch.get("attention_mask", None)
|
||||
model_outputs = self.llm(
|
||||
inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels_ids
|
||||
)
|
||||
|
||||
preds = torch.argmax(model_outputs.logits, -1)[:, source_ids.shape[1] :]
|
||||
response = tokenizer.batch_decode(
|
||||
preds,
|
||||
add_special_tokens=False,
|
||||
skip_special_tokens=kwargs.get("skip_special_tokens", True),
|
||||
)[0]
|
||||
loss = model_outputs.loss.item()
|
||||
preds = torch.argmax(model_outputs.logits, -1)[:, source_ids.shape[1] :]
|
||||
response = tokenizer.batch_decode(
|
||||
preds,
|
||||
add_special_tokens=False,
|
||||
skip_special_tokens=kwargs.get("skip_special_tokens", True),
|
||||
)[0]
|
||||
loss = model_outputs.loss.item()
|
||||
|
||||
ibest_writer = None
|
||||
if kwargs.get("output_dir") is not None:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user