Преглед на файлове

fix loss normalization for ddp training

haoneng.lhn преди 2 години
родител
ревизия
acb9a0fec8
променени са 3 файла, в които са добавени 8 реда и са изтрити 8 реда
  1. 1 1
      funasr/models/e2e_asr.py
  2. 1 1
      funasr/models/e2e_asr_contextual_paraformer.py
  3. 6 6
      funasr/models/e2e_asr_paraformer.py

+ 1 - 1
funasr/models/e2e_asr.py

@@ -222,7 +222,7 @@ class ASRModel(FunASRModel):
 
         # force_gatherable: to-device and to-tensor if scalar for DataParallel
         if self.length_normalized_loss:
-            batch_size = (text_lengths + 1).sum().type_as(batch_size)
+            batch_size = int((text_lengths + 1).sum())
         loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
         return loss, stats, weight
 

+ 1 - 1
funasr/models/e2e_asr_contextual_paraformer.py

@@ -233,7 +233,7 @@ class NeatContextualParaformer(Paraformer):
         stats["loss"] = torch.clone(loss.detach())
         # force_gatherable: to-device and to-tensor if scalar for DataParallel
         if self.length_normalized_loss:
-            batch_size = (text_lengths + self.predictor_bias).sum().type_as(batch_size)
+            batch_size = int((text_lengths + self.predictor_bias).sum())
         loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
         return loss, stats, weight
     

+ 6 - 6
funasr/models/e2e_asr_paraformer.py

@@ -255,7 +255,7 @@ class Paraformer(FunASRModel):
 
         # force_gatherable: to-device and to-tensor if scalar for DataParallel
         if self.length_normalized_loss:
-            batch_size = (text_lengths + self.predictor_bias).sum().type_as(batch_size)
+            batch_size = int((text_lengths + self.predictor_bias).sum())
         loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
         return loss, stats, weight
 
@@ -867,7 +867,7 @@ class ParaformerOnline(Paraformer):
 
         # force_gatherable: to-device and to-tensor if scalar for DataParallel
         if self.length_normalized_loss:
-            batch_size = (text_lengths + self.predictor_bias).sum().type_as(batch_size)
+            batch_size = int((text_lengths + self.predictor_bias).sum())
         loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
         return loss, stats, weight
 
@@ -1494,7 +1494,7 @@ class ParaformerBert(Paraformer):
 
         # force_gatherable: to-device and to-tensor if scalar for DataParallel
         if self.length_normalized_loss:
-            batch_size = (text_lengths + self.predictor_bias).sum().type_as(batch_size)
+            batch_size = int((text_lengths + self.predictor_bias).sum())
         loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
         return loss, stats, weight
 
@@ -1765,7 +1765,7 @@ class BiCifParaformer(Paraformer):
 
         # force_gatherable: to-device and to-tensor if scalar for DataParallel
         if self.length_normalized_loss:
-            batch_size = (text_lengths + self.predictor_bias).sum().type_as(batch_size)
+            batch_size = int((text_lengths + self.predictor_bias).sum())
         loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
         return loss, stats, weight
 
@@ -1967,7 +1967,7 @@ class ContextualParaformer(Paraformer):
 
         # force_gatherable: to-device and to-tensor if scalar for DataParallel
         if self.length_normalized_loss:
-            batch_size = (text_lengths + self.predictor_bias).sum().type_as(batch_size)
+            batch_size = int((text_lengths + self.predictor_bias).sum())
         loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
         return loss, stats, weight
 
@@ -2262,4 +2262,4 @@ class ContextualParaformer(Paraformer):
                     "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf,
                                                                                   var_dict_tf[name_tf].shape))
 
-        return var_dict_torch_update
+        return var_dict_torch_update