This commit is contained in:
游雁 2024-04-29 15:15:24 +08:00
parent 2779602177
commit f57b68121a

View File

@ -310,6 +310,7 @@ class SenseVoiceRWKV(nn.Module):
speech_lengths = speech_lengths[:, 0]
batch_size, frames, _ = speech.shape
_, text_tokens = text.shape
if self.activation_checkpoint:
from torch.utils.checkpoint import checkpoint
@ -331,6 +332,10 @@ class SenseVoiceRWKV(nn.Module):
stats["batch_size_x_frames"] = frames * batch_size
stats["batch_size_real_frames"] = speech_lengths.sum().item()
stats["padding_frames"] = stats["batch_size_x_frames"] - stats["batch_size_real_frames"]
stats["batch_size_x_tokens"] = text_tokens * batch_size
stats["batch_size_real_tokens"] = text_lengths.sum().item()
stats["padding_tokens"] = stats["batch_size_x_tokens"] - stats["batch_size_real_tokens"]
stats["batch_size_x_frames_plus_tokens"] = (text_tokens + frames) * batch_size
# force_gatherable: to-device and to-tensor if scalar for DataParallel
if self.length_normalized_loss: