speech2speech

This commit is contained in:
游雁 2024-09-14 14:23:36 +08:00
parent 89c1dd5f08
commit cd67bf6c73
3 changed files with 28 additions and 16 deletions

View File

@ -437,6 +437,7 @@ class OpenAIDatasetMultiTurn(torch.utils.data.Dataset):
"attention_mask": attention_mask,
"labels_ids": labels,
}
output["item"] = item
if len(fbank) > 0:
output["speech"] = fbank
output["speech_lengths"] = fbank_lens

View File

@ -77,7 +77,8 @@ class OpenAIIndexDSJsonl(torch.utils.data.Dataset): # torch.utils.data.Dataset
assistant.append([content, {"wav_path": wav_path}])
else:
assistant.append(content)
if len(system) == 0:
system = ["You are a helpful assistant."]
system = system * len(user)
contents_i = {
@ -113,6 +114,7 @@ class OpenAIIndexDSJsonl(torch.utils.data.Dataset): # torch.utils.data.Dataset
return 0
@tables.register("index_ds_classes", "OpenAIIndexDSJsonlForFullDuplexVAD")
class OpenAIIndexDSJsonlForFullDuplexVAD(torch.utils.data.Dataset): # torch.utils.data.Dataset
@ -155,18 +157,24 @@ class OpenAIIndexDSJsonlForFullDuplexVAD(torch.utils.data.Dataset): # torch.uti
data_dict = json.loads(line.strip())
data = data_dict["messages"]
for message in data:
if message['role'] == 'user':
message['content'] = message['content'].replace("/home/qinglin.zql/project/dataset/gpt-4o/vad", "/cpfs_speech/qinglin.zql/project/datasets/gpt-4o/vad")
message['content'] = message['content'].replace("/cpfs_speech/qinglin.zql/project/datasets/gpt-4o/vad/alimeeting/wav", "/cpfs_speech/qinglin.zql/project/datasets/gpt-4o/vad/alimeeting/alimeeting_vad/wav")
if message["role"] == "user":
message["content"] = message["content"].replace(
"/home/qinglin.zql/project/dataset/gpt-4o/vad",
"/cpfs_speech/qinglin.zql/project/datasets/gpt-4o/vad",
)
message["content"] = message["content"].replace(
"/cpfs_speech/qinglin.zql/project/datasets/gpt-4o/vad/alimeeting/wav",
"/cpfs_speech/qinglin.zql/project/datasets/gpt-4o/vad/alimeeting/alimeeting_vad/wav",
)
speech_length = data_dict.get("speech_length", -1) // 8
text_length = data_dict.get("text_length", 0)
task = data_dict['task']
last_total_time = data[-1]['end_time'] - data[-1]['start_time']
if task == 'turn-taking':
true_time_span = data[-1]['turn-taking-gap_time-added']
task = data_dict["task"]
last_total_time = data[-1]["end_time"] - data[-1]["start_time"]
if task == "turn-taking":
true_time_span = data[-1]["turn-taking-gap_time-added"]
elif task == "barge-in":
true_time_span = last_total_time - data[-1]['barge-in-0']
true_time_span = last_total_time - data[-1]["barge-in-0"]
if speech_length > self.max_source_length:
logging.info(
f"speech_length: {speech_length} > {self.max_source_length}, drop it"
@ -199,7 +207,7 @@ class OpenAIIndexDSJsonlForFullDuplexVAD(torch.utils.data.Dataset): # torch.uti
"source_len": speech_length + text_length,
"task": task,
"true_time_span": true_time_span,
"last_total_time": last_total_time
"last_total_time": last_total_time,
}
contents.append(contents_i)

View File

@ -651,13 +651,16 @@ class Trainer:
loss_dict["lr"] = scheduler.get_last_lr()[0]
loss_dict["batch_num_epoch"] = len(dataloader_train)
self.train_loss_avg = (
self.train_loss_avg * batch_idx + loss_dict["loss"].detach().cpu().item()
) / (batch_idx + 1)
loss_log = loss_dict["loss"].detach().cpu().item()
acc_log = loss_dict["stats"]["acc"].detach().cpu().item()
if torch.isnan(loss_dict["loss"]):
logging.warning(f"loss is {loss_log}, set is to 0.0.\nitem:\n{batch['item']}")
loss_log = 0.0
acc_log = 0.0
self.train_loss_avg = (self.train_loss_avg * batch_idx + loss_log) / (batch_idx + 1)
if "acc" in loss_dict["stats"]:
self.train_acc_avg = (
self.train_acc_avg * batch_idx + loss_dict["stats"]["acc"].detach().cpu().item()
) / (batch_idx + 1)
self.train_acc_avg = (self.train_acc_avg * batch_idx + acc_log) / (batch_idx + 1)
self.log(loss_dict, tag="train")