mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
fix loss normalization for ddp training
This commit is contained in:
parent
4152cf4615
commit
acb9a0fec8
@ -222,7 +222,7 @@ class ASRModel(FunASRModel):
|
||||
|
||||
# force_gatherable: to-device and to-tensor if scalar for DataParallel
|
||||
if self.length_normalized_loss:
|
||||
batch_size = (text_lengths + 1).sum().type_as(batch_size)
|
||||
batch_size = int((text_lengths + 1).sum())
|
||||
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
|
||||
return loss, stats, weight
|
||||
|
||||
|
||||
@ -233,7 +233,7 @@ class NeatContextualParaformer(Paraformer):
|
||||
stats["loss"] = torch.clone(loss.detach())
|
||||
# force_gatherable: to-device and to-tensor if scalar for DataParallel
|
||||
if self.length_normalized_loss:
|
||||
batch_size = (text_lengths + self.predictor_bias).sum().type_as(batch_size)
|
||||
batch_size = int((text_lengths + self.predictor_bias).sum())
|
||||
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
|
||||
return loss, stats, weight
|
||||
|
||||
|
||||
@ -255,7 +255,7 @@ class Paraformer(FunASRModel):
|
||||
|
||||
# force_gatherable: to-device and to-tensor if scalar for DataParallel
|
||||
if self.length_normalized_loss:
|
||||
batch_size = (text_lengths + self.predictor_bias).sum().type_as(batch_size)
|
||||
batch_size = int((text_lengths + self.predictor_bias).sum())
|
||||
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
|
||||
return loss, stats, weight
|
||||
|
||||
@ -867,7 +867,7 @@ class ParaformerOnline(Paraformer):
|
||||
|
||||
# force_gatherable: to-device and to-tensor if scalar for DataParallel
|
||||
if self.length_normalized_loss:
|
||||
batch_size = (text_lengths + self.predictor_bias).sum().type_as(batch_size)
|
||||
batch_size = int((text_lengths + self.predictor_bias).sum())
|
||||
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
|
||||
return loss, stats, weight
|
||||
|
||||
@ -1494,7 +1494,7 @@ class ParaformerBert(Paraformer):
|
||||
|
||||
# force_gatherable: to-device and to-tensor if scalar for DataParallel
|
||||
if self.length_normalized_loss:
|
||||
batch_size = (text_lengths + self.predictor_bias).sum().type_as(batch_size)
|
||||
batch_size = int((text_lengths + self.predictor_bias).sum())
|
||||
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
|
||||
return loss, stats, weight
|
||||
|
||||
@ -1765,7 +1765,7 @@ class BiCifParaformer(Paraformer):
|
||||
|
||||
# force_gatherable: to-device and to-tensor if scalar for DataParallel
|
||||
if self.length_normalized_loss:
|
||||
batch_size = (text_lengths + self.predictor_bias).sum().type_as(batch_size)
|
||||
batch_size = int((text_lengths + self.predictor_bias).sum())
|
||||
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
|
||||
return loss, stats, weight
|
||||
|
||||
@ -1967,7 +1967,7 @@ class ContextualParaformer(Paraformer):
|
||||
|
||||
# force_gatherable: to-device and to-tensor if scalar for DataParallel
|
||||
if self.length_normalized_loss:
|
||||
batch_size = (text_lengths + self.predictor_bias).sum().type_as(batch_size)
|
||||
batch_size = int((text_lengths + self.predictor_bias).sum())
|
||||
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
|
||||
return loss, stats, weight
|
||||
|
||||
@ -2262,4 +2262,4 @@ class ContextualParaformer(Paraformer):
|
||||
"torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf,
|
||||
var_dict_tf[name_tf].shape))
|
||||
|
||||
return var_dict_torch_update
|
||||
return var_dict_torch_update
|
||||
|
||||
Loading…
Reference in New Issue
Block a user