|
@@ -255,7 +255,8 @@ class Paraformer(FunASRModel):
|
|
|
|
|
|
|
|
# force_gatherable: to-device and to-tensor if scalar for DataParallel
|
|
# force_gatherable: to-device and to-tensor if scalar for DataParallel
|
|
|
if self.length_normalized_loss:
|
|
if self.length_normalized_loss:
|
|
|
- batch_size = (text_lengths + self.predictor_bias).sum().type_as(batch_size)
|
|
|
|
|
|
|
+ batch_size = int((text_lengths + self.predictor_bias).sum())
|
|
|
|
|
+
|
|
|
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
|
|
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
|
|
|
return loss, stats, weight
|
|
return loss, stats, weight
|
|
|
|
|
|
|
@@ -867,7 +868,8 @@ class ParaformerOnline(Paraformer):
|
|
|
|
|
|
|
|
# force_gatherable: to-device and to-tensor if scalar for DataParallel
|
|
# force_gatherable: to-device and to-tensor if scalar for DataParallel
|
|
|
if self.length_normalized_loss:
|
|
if self.length_normalized_loss:
|
|
|
- batch_size = (text_lengths + self.predictor_bias).sum().type_as(batch_size)
|
|
|
|
|
|
|
+ batch_size = int((text_lengths + self.predictor_bias).sum())
|
|
|
|
|
+
|
|
|
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
|
|
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
|
|
|
return loss, stats, weight
|
|
return loss, stats, weight
|
|
|
|
|
|
|
@@ -1494,7 +1496,8 @@ class ParaformerBert(Paraformer):
|
|
|
|
|
|
|
|
# force_gatherable: to-device and to-tensor if scalar for DataParallel
|
|
# force_gatherable: to-device and to-tensor if scalar for DataParallel
|
|
|
if self.length_normalized_loss:
|
|
if self.length_normalized_loss:
|
|
|
- batch_size = (text_lengths + self.predictor_bias).sum().type_as(batch_size)
|
|
|
|
|
|
|
+ batch_size = int((text_lengths + self.predictor_bias).sum())
|
|
|
|
|
+
|
|
|
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
|
|
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
|
|
|
return loss, stats, weight
|
|
return loss, stats, weight
|
|
|
|
|
|
|
@@ -1765,7 +1768,8 @@ class BiCifParaformer(Paraformer):
|
|
|
|
|
|
|
|
# force_gatherable: to-device and to-tensor if scalar for DataParallel
|
|
# force_gatherable: to-device and to-tensor if scalar for DataParallel
|
|
|
if self.length_normalized_loss:
|
|
if self.length_normalized_loss:
|
|
|
- batch_size = (text_lengths + self.predictor_bias).sum().type_as(batch_size)
|
|
|
|
|
|
|
+ batch_size = int((text_lengths + self.predictor_bias).sum())
|
|
|
|
|
+
|
|
|
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
|
|
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
|
|
|
return loss, stats, weight
|
|
return loss, stats, weight
|
|
|
|
|
|
|
@@ -1967,7 +1971,8 @@ class ContextualParaformer(Paraformer):
|
|
|
|
|
|
|
|
# force_gatherable: to-device and to-tensor if scalar for DataParallel
|
|
# force_gatherable: to-device and to-tensor if scalar for DataParallel
|
|
|
if self.length_normalized_loss:
|
|
if self.length_normalized_loss:
|
|
|
- batch_size = (text_lengths + self.predictor_bias).sum().type_as(batch_size)
|
|
|
|
|
|
|
+ batch_size = int((text_lengths + self.predictor_bias).sum())
|
|
|
|
|
+
|
|
|
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
|
|
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
|
|
|
return loss, stats, weight
|
|
return loss, stats, weight
|
|
|
|
|
|