mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
batch
This commit is contained in:
parent
2779602177
commit
f57b68121a
@ -310,6 +310,7 @@ class SenseVoiceRWKV(nn.Module):
|
||||
speech_lengths = speech_lengths[:, 0]
|
||||
|
||||
batch_size, frames, _ = speech.shape
|
||||
_, text_tokens = text.shape
|
||||
|
||||
if self.activation_checkpoint:
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
@ -331,6 +332,10 @@ class SenseVoiceRWKV(nn.Module):
|
||||
stats["batch_size_x_frames"] = frames * batch_size
|
||||
stats["batch_size_real_frames"] = speech_lengths.sum().item()
|
||||
stats["padding_frames"] = stats["batch_size_x_frames"] - stats["batch_size_real_frames"]
|
||||
stats["batch_size_x_tokens"] = text_tokens * batch_size
|
||||
stats["batch_size_real_tokens"] = text_lengths.sum().item()
|
||||
stats["padding_tokens"] = stats["batch_size_x_tokens"] - stats["batch_size_real_tokens"]
|
||||
stats["batch_size_x_frames_plus_tokens"] = (text_tokens + frames) * batch_size
|
||||
|
||||
# force_gatherable: to-device and to-tensor if scalar for DataParallel
|
||||
if self.length_normalized_loss:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user