语帆 2 tahun lalu
induk
melakukan
e0fca115cb
1 mengubah file dengan 2 tambahan dan 2 penghapusan
  1. 2 2
      funasr/models/lcbnet/model.py

+ 2 - 2
funasr/models/lcbnet/model.py

@@ -443,8 +443,8 @@ class LCBNet(nn.Module):
             encoder_out = encoder_out[0]
         
         ocr_list_new = [[x + 1 if x != 0 else x for x in sublist] for sublist in ocr_sample_list]
-        ocr = torch.tensor(ocr_list_new)
-        ocr_lengths = ocr.new_full([1], dtype=torch.long, fill_value=ocr.size(1))
+        ocr = torch.tensor(ocr_list_new).to(device=kwargs["device"])
+        ocr_lengths = ocr.new_full([1], dtype=torch.long, fill_value=ocr.size(1)).to(device=kwargs["device"])
         ocr, ocr_lens, _ = self.text_encoder(ocr, ocr_lengths)
         fusion_out, _, _, _ = self.fusion_encoder(encoder_out,None, ocr, None)
         encoder_out = encoder_out + fusion_out