sensevoice

This commit is contained in:
游雁 2024-06-19 22:45:29 +08:00
parent 6659755acf
commit 7e9bdc7037

View File

@ -335,6 +335,12 @@ class OpenAIDatasetMultiTurn(torch.utils.data.Dataset):
):
if i >= self.multiturn_num_max:
break
if len(input_ids) > self.max_token_length:
logging.info(
f"input_ids > max_token_length: {len(input_ids)}>{self.max_token_length}, {item}"
)
break
if i == 0:
source_input = f"<|im_start|>system\n{system_prompt}<|im_end|>\n<|im_start|>user\n{user_prompt}<|im_end|>\n<|im_start|>assistant\n"
else:
@ -373,6 +379,11 @@ class OpenAIDatasetMultiTurn(torch.utils.data.Dataset):
frontend=self.frontend,
is_final=True,
) # speech: [b, T, d]
if speech_lengths > self.max_source_length:
logging.info(
f"speech_lengths > max_source_length: {speech_lengths}>{self.max_source_length}, {item}"
)
badcase_flag = True
if self.permute:
speech = speech.permute(0, 2, 1)
# if speech_lengths > self.batch_size:
@ -400,18 +411,9 @@ class OpenAIDatasetMultiTurn(torch.utils.data.Dataset):
fbank_mask += fbank_mask_i
fbank_lens.append(speech_lengths)
if len(input_ids) > self.max_token_length:
logging.info(
f"input_ids > max_token_length: {len(input_ids)}>{self.max_token_length}, {item}"
)
badcase_flag = True
if speech_lengths > self.max_source_length:
logging.info(
f"speech_lengths > max_source_length: {speech_lengths}>{self.max_source_length}, {item}"
)
badcase_flag = True
if badcase_flag:
continue
input_ids = torch.tensor(input_ids, dtype=torch.int64) # [: self.max_token_length]
attention_mask = torch.tensor([1] * len(input_ids), dtype=torch.int32)
labels = torch.tensor(labels, dtype=torch.int64) # [: self.max_token_length]