Merge pull request #1160 from alibaba-damo-academy/dev_lhn

fix loss normalization for ddp training
This commit is contained in:
hnluo 2023-12-08 16:22:44 +08:00 committed by GitHub
commit aab1fcd6b6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 8 additions and 8 deletions

View File

@ -222,7 +222,7 @@ class ASRModel(FunASRModel):
# force_gatherable: to-device and to-tensor if scalar for DataParallel
if self.length_normalized_loss:
batch_size = (text_lengths + 1).sum().type_as(batch_size)
batch_size = int((text_lengths + 1).sum())
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
return loss, stats, weight

View File

@ -233,7 +233,7 @@ class NeatContextualParaformer(Paraformer):
stats["loss"] = torch.clone(loss.detach())
# force_gatherable: to-device and to-tensor if scalar for DataParallel
if self.length_normalized_loss:
batch_size = (text_lengths + self.predictor_bias).sum().type_as(batch_size)
batch_size = int((text_lengths + self.predictor_bias).sum())
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
return loss, stats, weight

View File

@ -255,7 +255,7 @@ class Paraformer(FunASRModel):
# force_gatherable: to-device and to-tensor if scalar for DataParallel
if self.length_normalized_loss:
batch_size = (text_lengths + self.predictor_bias).sum().type_as(batch_size)
batch_size = int((text_lengths + self.predictor_bias).sum())
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
return loss, stats, weight
@ -867,7 +867,7 @@ class ParaformerOnline(Paraformer):
# force_gatherable: to-device and to-tensor if scalar for DataParallel
if self.length_normalized_loss:
batch_size = (text_lengths + self.predictor_bias).sum().type_as(batch_size)
batch_size = int((text_lengths + self.predictor_bias).sum())
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
return loss, stats, weight
@ -1494,7 +1494,7 @@ class ParaformerBert(Paraformer):
# force_gatherable: to-device and to-tensor if scalar for DataParallel
if self.length_normalized_loss:
batch_size = (text_lengths + self.predictor_bias).sum().type_as(batch_size)
batch_size = int((text_lengths + self.predictor_bias).sum())
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
return loss, stats, weight
@ -1765,7 +1765,7 @@ class BiCifParaformer(Paraformer):
# force_gatherable: to-device and to-tensor if scalar for DataParallel
if self.length_normalized_loss:
batch_size = (text_lengths + self.predictor_bias).sum().type_as(batch_size)
batch_size = int((text_lengths + self.predictor_bias).sum())
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
return loss, stats, weight
@ -1967,7 +1967,7 @@ class ContextualParaformer(Paraformer):
# force_gatherable: to-device and to-tensor if scalar for DataParallel
if self.length_normalized_loss:
batch_size = (text_lengths + self.predictor_bias).sum().type_as(batch_size)
batch_size = int((text_lengths + self.predictor_bias).sum())
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
return loss, stats, weight
@ -2262,4 +2262,4 @@ class ContextualParaformer(Paraformer):
"torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf,
var_dict_tf[name_tf].shape))
return var_dict_torch_update
return var_dict_torch_update