|
|
@@ -255,7 +255,7 @@ class Paraformer(FunASRModel):
|
|
|
|
|
|
# force_gatherable: to-device and to-tensor if scalar for DataParallel
|
|
|
if self.length_normalized_loss:
|
|
|
- batch_size = (text_lengths + self.predictor_bias).sum().type_as(batch_size)
|
|
|
+ batch_size = int((text_lengths + self.predictor_bias).sum())
|
|
|
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
|
|
|
return loss, stats, weight
|
|
|
|
|
|
@@ -867,7 +867,7 @@ class ParaformerOnline(Paraformer):
|
|
|
|
|
|
# force_gatherable: to-device and to-tensor if scalar for DataParallel
|
|
|
if self.length_normalized_loss:
|
|
|
- batch_size = (text_lengths + self.predictor_bias).sum().type_as(batch_size)
|
|
|
+ batch_size = int((text_lengths + self.predictor_bias).sum())
|
|
|
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
|
|
|
return loss, stats, weight
|
|
|
|
|
|
@@ -1494,7 +1494,7 @@ class ParaformerBert(Paraformer):
|
|
|
|
|
|
# force_gatherable: to-device and to-tensor if scalar for DataParallel
|
|
|
if self.length_normalized_loss:
|
|
|
- batch_size = (text_lengths + self.predictor_bias).sum().type_as(batch_size)
|
|
|
+ batch_size = int((text_lengths + self.predictor_bias).sum())
|
|
|
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
|
|
|
return loss, stats, weight
|
|
|
|
|
|
@@ -1765,7 +1765,7 @@ class BiCifParaformer(Paraformer):
|
|
|
|
|
|
# force_gatherable: to-device and to-tensor if scalar for DataParallel
|
|
|
if self.length_normalized_loss:
|
|
|
- batch_size = (text_lengths + self.predictor_bias).sum().type_as(batch_size)
|
|
|
+ batch_size = int((text_lengths + self.predictor_bias).sum())
|
|
|
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
|
|
|
return loss, stats, weight
|
|
|
|
|
|
@@ -1967,7 +1967,7 @@ class ContextualParaformer(Paraformer):
|
|
|
|
|
|
# force_gatherable: to-device and to-tensor if scalar for DataParallel
|
|
|
if self.length_normalized_loss:
|
|
|
- batch_size = (text_lengths + self.predictor_bias).sum().type_as(batch_size)
|
|
|
+ batch_size = int((text_lengths + self.predictor_bias).sum())
|
|
|
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
|
|
|
return loss, stats, weight
|
|
|
|
|
|
@@ -2262,4 +2262,4 @@ class ContextualParaformer(Paraformer):
|
|
|
"torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf,
|
|
|
var_dict_tf[name_tf].shape))
|
|
|
|
|
|
- return var_dict_torch_update
|
|
|
+ return var_dict_torch_update
|