Bladeren bron

update func cif_wo_hidden

shixian.shi 2 jaren geleden
bovenliggende
commit
c73d1a8e81
3 gewijzigde bestanden met toevoegingen van 3 en 3 verwijderingen
  1. 1 1
      funasr/export/models/predictor/cif.py
  2. 1 1
      funasr/models/predictor/cif.py
  3. 1 1
      funasr/utils/timestamp_tools.py

+ 1 - 1
funasr/export/models/predictor/cif.py

@@ -288,7 +288,7 @@ def cif_wo_hidden(alphas, threshold: float):
 
 
         fire_place = integrate >= threshold
         fire_place = integrate >= threshold
         integrate = torch.where(fire_place,
         integrate = torch.where(fire_place,
-                                integrate - torch.ones([batch_size], device=alphas.device),
+                                integrate - torch.ones([batch_size], device=alphas.device)*threshold,
                                 integrate)
                                 integrate)
 
 
     fires = torch.stack(list_fires, 1)
     fires = torch.stack(list_fires, 1)

+ 1 - 1
funasr/models/predictor/cif.py

@@ -499,7 +499,7 @@ def cif_wo_hidden(alphas, threshold):
 
 
         fire_place = integrate >= threshold
         fire_place = integrate >= threshold
         integrate = torch.where(fire_place,
         integrate = torch.where(fire_place,
-                                integrate - torch.ones([batch_size], device=alphas.device),
+                                integrate - torch.ones([batch_size], device=alphas.device)*threshold,
                                 integrate)
                                 integrate)
 
 
     fires = torch.stack(list_fires, 1)
     fires = torch.stack(list_fires, 1)

+ 1 - 1
funasr/utils/timestamp_tools.py

@@ -19,7 +19,7 @@ def cif_wo_hidden(alphas, threshold):
         list_fires.append(integrate)
         list_fires.append(integrate)
         fire_place = integrate >= threshold
         fire_place = integrate >= threshold
         integrate = torch.where(fire_place,
         integrate = torch.where(fire_place,
-                                integrate - torch.ones([batch_size], device=alphas.device),
+                                integrate - torch.ones([batch_size], device=alphas.device)*threshold,
                                 integrate)
                                 integrate)
     fires = torch.stack(list_fires, 1)
     fires = torch.stack(list_fires, 1)
     return fires
     return fires