mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
fix bug
This commit is contained in:
parent
e5be285347
commit
3d5e19792c
@ -16,12 +16,14 @@ jsonl = (
|
||||
with open(jsonl, "r") as f:
|
||||
lines = f.readlines()
|
||||
|
||||
tearchforing = True
|
||||
for i, line in enumerate(lines):
|
||||
data_dict = json.loads(line.strip())
|
||||
data = data_dict["messages"]
|
||||
|
||||
res = model.generate(
|
||||
input=data,
|
||||
input=[data],
|
||||
tearchforing=tearchforing,
|
||||
cache={},
|
||||
)
|
||||
|
||||
|
||||
@ -568,6 +568,7 @@ class LLMASR2(nn.Module):
|
||||
[],
|
||||
[],
|
||||
[],
|
||||
[],
|
||||
)
|
||||
|
||||
for i, (system_prompt, user_prompt, target_out) in enumerate(zip(system, user, assistant)):
|
||||
@ -624,7 +625,7 @@ class LLMASR2(nn.Module):
|
||||
input_ids = torch.tensor(input_ids, dtype=torch.int64) # [: self.max_token_length]
|
||||
attention_mask = torch.tensor([1] * len(input_ids), dtype=torch.int32)
|
||||
labels = torch.tensor(labels, dtype=torch.int64) # [: self.max_token_length]
|
||||
source_ids = torch.tensor(source_ids, dtype=torch.int64)
|
||||
source_ids = torch.tensor(source_ids_i, dtype=torch.int64)
|
||||
target_ids = torch.tensor(target_ids, dtype=torch.int64)
|
||||
|
||||
fbank = speech[0, :, :]
|
||||
@ -662,7 +663,7 @@ class LLMASR2(nn.Module):
|
||||
if kwargs.get("batch_size", 1) > 1:
|
||||
raise NotImplementedError("batch decoding is not implemented")
|
||||
|
||||
contents = self.data_template(data_in)
|
||||
contents = self.data_template(data_in[0])
|
||||
output = self.data_load_speech(contents, tokenizer, frontend, **kwargs)
|
||||
batch = to_device(output, kwargs["device"])
|
||||
|
||||
@ -676,7 +677,7 @@ class LLMASR2(nn.Module):
|
||||
|
||||
input_ids = batch["input_ids"]
|
||||
source_ids = batch["source_ids"]
|
||||
if kwargs.get("tearchforing", False):
|
||||
if not kwargs.get("tearchforing", False):
|
||||
input_ids = source_ids
|
||||
input_ids[input_ids < 0] = 0
|
||||
inputs_embeds = self.llm.model.get_input_embeddings()(input_ids)
|
||||
@ -704,6 +705,23 @@ class LLMASR2(nn.Module):
|
||||
generated_ids, skip_special_tokens=kwargs.get("skip_special_tokens", True)
|
||||
)[0]
|
||||
label = contents["assistant"][0]
|
||||
loss = None
|
||||
else:
|
||||
|
||||
labels_ids = batch["labels_ids"]
|
||||
labels_ids[labels_ids == -1] = -100
|
||||
attention_mask = batch.get("attention_mask", None)
|
||||
model_outputs = self.llm(
|
||||
inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels_ids
|
||||
)
|
||||
|
||||
preds = torch.argmax(model_outputs.logits, -1)[:, source_ids.shape[1]]
|
||||
response = tokenizer.batch_decode(
|
||||
preds,
|
||||
add_special_tokens=False,
|
||||
skip_special_tokens=kwargs.get("skip_special_tokens", True),
|
||||
)[0]
|
||||
loss = model_outputs.loss
|
||||
|
||||
ibest_writer = None
|
||||
if kwargs.get("output_dir") is not None:
|
||||
@ -713,10 +731,12 @@ class LLMASR2(nn.Module):
|
||||
|
||||
results = []
|
||||
result_i = {"key": key[0], "text": response, "label": label}
|
||||
if loss is not None:
|
||||
result_i["loss"] = loss
|
||||
results.append(result_i)
|
||||
|
||||
if ibest_writer is not None:
|
||||
ibest_writer["text"][key[0]] = text
|
||||
ibest_writer["text"][key[0]] = response
|
||||
ibest_writer["label"][key[0]] = label
|
||||
|
||||
return results, meta_data
|
||||
|
||||
Loading…
Reference in New Issue
Block a user