Browse Source

update device

shixian.shi 2 years ago
parent
commit
c3442d9566
2 changed files with 3 additions and 1 deletions
  1. 2 1
      funasr/models/bicif_paraformer/model.py
  2. 1 0
      funasr/models/campplus/model.py

+ 2 - 1
funasr/models/bicif_paraformer/model.py

@@ -252,7 +252,8 @@ class BiCifParaformer(Paraformer):
             meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
             meta_data["batch_data_time"] = speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000
         
-        speech.to(device=kwargs["device"]), speech_lengths.to(device=kwargs["device"])
+        speech = speech.to(device=kwargs["device"])
+        speech_lengths = speech_lengths.to(device=kwargs["device"])
         
         # Encoder
         encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)

+ 1 - 0
funasr/models/campplus/model.py

@@ -110,6 +110,7 @@ class CAMPPlus(nn.Module):
         time2 = time.perf_counter()
         meta_data["load_data"] = f"{time2 - time1:0.3f}"
         speech, speech_lengths, speech_times = extract_feature(audio_sample_list)
+        speech = speech.to(device=kwargs["device"])
         time3 = time.perf_counter()
         meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
         meta_data["batch_data_time"] = np.array(speech_times).sum().item() / 16000.0