This commit is contained in:
游雁 2024-07-19 10:40:13 +08:00
parent 6aacee8f9e
commit 99634e859f

View File

@ -367,28 +367,28 @@ class OpenAIDatasetMultiTurn(torch.utils.data.Dataset):
if sub_str.startswith("!"):
try:
data_src = load_audio_text_image_video(sub_str[1:], fs=self.fs)
speech, speech_lengths = extract_fbank(
data_src,
data_type=self.data_type,
frontend=self.frontend,
is_final=True,
) # speech: [b, T, d]
if speech_lengths > self.max_source_length:
logging.info(
f"speech_lengths > max_source_length: {speech_lengths}>{self.max_source_length}, {item}"
)
badcase_flag = True
if self.permute:
speech = speech.permute(0, 2, 1)
# if speech_lengths > self.batch_size:
# continue
except Exception as e:
logging.error(
f"Loading wav failed! {str(e)}, {traceback.format_exc()}"
)
badcase_flag = True
continue
speech, speech_lengths = extract_fbank(
data_src,
data_type=self.data_type,
frontend=self.frontend,
is_final=True,
) # speech: [b, T, d]
if speech_lengths > self.max_source_length:
logging.info(
f"speech_lengths > max_source_length: {speech_lengths}>{self.max_source_length}, {item}"
)
badcase_flag = True
if self.permute:
speech = speech.permute(0, 2, 1)
# if speech_lengths > self.batch_size:
# continue
olens = 1 + (speech_lengths[0].item() - 3 + 2 * 1) // 2
olens = 1 + (olens - 3 + 2 * 1) // 2
fake_token_len_i = (olens - 1) // 2 + 1