mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
update
This commit is contained in:
parent
c607f7e8c4
commit
6c358b9a3c
@ -1376,11 +1376,13 @@ class LLMASR4(nn.Module):
|
||||
label = contents["assistant"][-1]
|
||||
self.llm = self.llm.to(dtype_map[llm_dtype])
|
||||
inputs_embeds = inputs_embeds.to(dtype_map[llm_dtype])
|
||||
|
||||
llm_kwargs = kwargs.get("llm_kwargs", {})
|
||||
if not kwargs.get("tearchforing", False):
|
||||
|
||||
generated_ids = self.llm.generate(
|
||||
inputs_embeds=inputs_embeds, max_new_tokens=kwargs.get("max_length", 512)
|
||||
inputs_embeds=inputs_embeds,
|
||||
max_new_tokens=kwargs.get("max_length", 512),
|
||||
**llm_kwargs,
|
||||
)
|
||||
# generated_ids = [
|
||||
# output_ids[len(input_id) :]
|
||||
@ -1398,7 +1400,10 @@ class LLMASR4(nn.Module):
|
||||
attention_mask = batch.get("attention_mask", None)
|
||||
# attention_mask = attention_mask.to(dtype_map[llm_dtype])
|
||||
model_outputs = self.llm(
|
||||
inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels_ids
|
||||
inputs_embeds=inputs_embeds,
|
||||
attention_mask=attention_mask,
|
||||
labels=labels_ids,
|
||||
**llm_kwargs,
|
||||
)
|
||||
|
||||
preds = torch.argmax(model_outputs.logits, -1)[:, source_ids.shape[1] :]
|
||||
|
||||
Loading…
Reference in New Issue
Block a user