mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
speech2speech
This commit is contained in:
parent
89c1dd5f08
commit
cd67bf6c73
@ -437,6 +437,7 @@ class OpenAIDatasetMultiTurn(torch.utils.data.Dataset):
|
||||
"attention_mask": attention_mask,
|
||||
"labels_ids": labels,
|
||||
}
|
||||
output["item"] = item
|
||||
if len(fbank) > 0:
|
||||
output["speech"] = fbank
|
||||
output["speech_lengths"] = fbank_lens
|
||||
|
||||
@ -77,7 +77,8 @@ class OpenAIIndexDSJsonl(torch.utils.data.Dataset): # torch.utils.data.Dataset
|
||||
assistant.append([content, {"wav_path": wav_path}])
|
||||
else:
|
||||
assistant.append(content)
|
||||
|
||||
if len(system) == 0:
|
||||
system = ["You are a helpful assistant."]
|
||||
system = system * len(user)
|
||||
|
||||
contents_i = {
|
||||
@ -113,6 +114,7 @@ class OpenAIIndexDSJsonl(torch.utils.data.Dataset): # torch.utils.data.Dataset
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
@tables.register("index_ds_classes", "OpenAIIndexDSJsonlForFullDuplexVAD")
|
||||
class OpenAIIndexDSJsonlForFullDuplexVAD(torch.utils.data.Dataset): # torch.utils.data.Dataset
|
||||
|
||||
@ -155,18 +157,24 @@ class OpenAIIndexDSJsonlForFullDuplexVAD(torch.utils.data.Dataset): # torch.uti
|
||||
data_dict = json.loads(line.strip())
|
||||
data = data_dict["messages"]
|
||||
for message in data:
|
||||
if message['role'] == 'user':
|
||||
message['content'] = message['content'].replace("/home/qinglin.zql/project/dataset/gpt-4o/vad", "/cpfs_speech/qinglin.zql/project/datasets/gpt-4o/vad")
|
||||
message['content'] = message['content'].replace("/cpfs_speech/qinglin.zql/project/datasets/gpt-4o/vad/alimeeting/wav", "/cpfs_speech/qinglin.zql/project/datasets/gpt-4o/vad/alimeeting/alimeeting_vad/wav")
|
||||
if message["role"] == "user":
|
||||
message["content"] = message["content"].replace(
|
||||
"/home/qinglin.zql/project/dataset/gpt-4o/vad",
|
||||
"/cpfs_speech/qinglin.zql/project/datasets/gpt-4o/vad",
|
||||
)
|
||||
message["content"] = message["content"].replace(
|
||||
"/cpfs_speech/qinglin.zql/project/datasets/gpt-4o/vad/alimeeting/wav",
|
||||
"/cpfs_speech/qinglin.zql/project/datasets/gpt-4o/vad/alimeeting/alimeeting_vad/wav",
|
||||
)
|
||||
|
||||
speech_length = data_dict.get("speech_length", -1) // 8
|
||||
text_length = data_dict.get("text_length", 0)
|
||||
task = data_dict['task']
|
||||
last_total_time = data[-1]['end_time'] - data[-1]['start_time']
|
||||
if task == 'turn-taking':
|
||||
true_time_span = data[-1]['turn-taking-gap_time-added']
|
||||
task = data_dict["task"]
|
||||
last_total_time = data[-1]["end_time"] - data[-1]["start_time"]
|
||||
if task == "turn-taking":
|
||||
true_time_span = data[-1]["turn-taking-gap_time-added"]
|
||||
elif task == "barge-in":
|
||||
true_time_span = last_total_time - data[-1]['barge-in-0']
|
||||
true_time_span = last_total_time - data[-1]["barge-in-0"]
|
||||
if speech_length > self.max_source_length:
|
||||
logging.info(
|
||||
f"speech_length: {speech_length} > {self.max_source_length}, drop it"
|
||||
@ -199,7 +207,7 @@ class OpenAIIndexDSJsonlForFullDuplexVAD(torch.utils.data.Dataset): # torch.uti
|
||||
"source_len": speech_length + text_length,
|
||||
"task": task,
|
||||
"true_time_span": true_time_span,
|
||||
"last_total_time": last_total_time
|
||||
"last_total_time": last_total_time,
|
||||
}
|
||||
|
||||
contents.append(contents_i)
|
||||
|
||||
@ -651,13 +651,16 @@ class Trainer:
|
||||
loss_dict["lr"] = scheduler.get_last_lr()[0]
|
||||
loss_dict["batch_num_epoch"] = len(dataloader_train)
|
||||
|
||||
self.train_loss_avg = (
|
||||
self.train_loss_avg * batch_idx + loss_dict["loss"].detach().cpu().item()
|
||||
) / (batch_idx + 1)
|
||||
loss_log = loss_dict["loss"].detach().cpu().item()
|
||||
acc_log = loss_dict["stats"]["acc"].detach().cpu().item()
|
||||
if torch.isnan(loss_dict["loss"]):
|
||||
logging.warning(f"loss is {loss_log}, set is to 0.0.\nitem:\n{batch['item']}")
|
||||
loss_log = 0.0
|
||||
acc_log = 0.0
|
||||
|
||||
self.train_loss_avg = (self.train_loss_avg * batch_idx + loss_log) / (batch_idx + 1)
|
||||
if "acc" in loss_dict["stats"]:
|
||||
self.train_acc_avg = (
|
||||
self.train_acc_avg * batch_idx + loss_dict["stats"]["acc"].detach().cpu().item()
|
||||
) / (batch_idx + 1)
|
||||
self.train_acc_avg = (self.train_acc_avg * batch_idx + acc_log) / (batch_idx + 1)
|
||||
|
||||
self.log(loss_dict, tag="train")
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user