mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
batch
This commit is contained in:
parent
2779602177
commit
f57b68121a
@ -310,6 +310,7 @@ class SenseVoiceRWKV(nn.Module):
|
|||||||
speech_lengths = speech_lengths[:, 0]
|
speech_lengths = speech_lengths[:, 0]
|
||||||
|
|
||||||
batch_size, frames, _ = speech.shape
|
batch_size, frames, _ = speech.shape
|
||||||
|
_, text_tokens = text.shape
|
||||||
|
|
||||||
if self.activation_checkpoint:
|
if self.activation_checkpoint:
|
||||||
from torch.utils.checkpoint import checkpoint
|
from torch.utils.checkpoint import checkpoint
|
||||||
@ -331,6 +332,10 @@ class SenseVoiceRWKV(nn.Module):
|
|||||||
stats["batch_size_x_frames"] = frames * batch_size
|
stats["batch_size_x_frames"] = frames * batch_size
|
||||||
stats["batch_size_real_frames"] = speech_lengths.sum().item()
|
stats["batch_size_real_frames"] = speech_lengths.sum().item()
|
||||||
stats["padding_frames"] = stats["batch_size_x_frames"] - stats["batch_size_real_frames"]
|
stats["padding_frames"] = stats["batch_size_x_frames"] - stats["batch_size_real_frames"]
|
||||||
|
stats["batch_size_x_tokens"] = text_tokens * batch_size
|
||||||
|
stats["batch_size_real_tokens"] = text_lengths.sum().item()
|
||||||
|
stats["padding_tokens"] = stats["batch_size_x_tokens"] - stats["batch_size_real_tokens"]
|
||||||
|
stats["batch_size_x_frames_plus_tokens"] = (text_tokens + frames) * batch_size
|
||||||
|
|
||||||
# force_gatherable: to-device and to-tensor if scalar for DataParallel
|
# force_gatherable: to-device and to-tensor if scalar for DataParallel
|
||||||
if self.length_normalized_loss:
|
if self.length_normalized_loss:
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user