add batch support for token extraction

This commit is contained in:
志浩 2024-09-24 17:33:41 +08:00
parent bc0608d380
commit 4f96a06d13
2 changed files with 5 additions and 2 deletions

View File

@ -103,7 +103,7 @@ class WhisperFrontend(nn.Module):
feat = self.pad_or_trim(input[i], self.pad_samples)
else:
feat = input[i]
feat, feat_len = self.log_mel_spectrogram(feat[None, :], input_lengths[0])
feat, feat_len = self.log_mel_spectrogram(feat[None, :], input_lengths[i])
feats.append(feat[0])
feats_lens.append(feat_len)
feats_lens = torch.as_tensor(feats_lens)

View File

@ -2023,8 +2023,11 @@ class SenseVoiceL(nn.Module):
)
time2 = time.perf_counter()
meta_data["load_data"] = f"{time2 - time1:0.3f}"
if data_lengths is None:
data_lengths = [x.shape[0] for x in audio_sample_list]
speech, speech_lengths = extract_fbank(
audio_sample_list, data_type=kwargs.get("data_type", "sound"), frontend=frontend
audio_sample_list, data_type=kwargs.get("data_type", "sound"), frontend=frontend,
data_len=data_lengths
)
time3 = time.perf_counter()
meta_data["extract_feat"] = f"{time3 - time2:0.3f}"