游雁 2 gadi atpakaļ
vecāks
revīzija
9fb473bc89
1 mainītis faili ar 2 papildinājumiem un 2 dzēšanām
  1. 2 2
      funasr/models/llm_asr/model.py

+ 2 - 2
funasr/models/llm_asr/model.py

@@ -216,8 +216,8 @@ class LLMASRNAR(nn.Module):
         self, speech: torch.Tensor, speech_lengths: torch.Tensor, **kwargs,
     ) -> Tuple[torch.Tensor, torch.Tensor]:
     
-        audio_mask = kwargs.get("audio_mask")
-        audio_token_lengths = audio_mask.sum(-1) if audio_mask else None
+        audio_mask = kwargs.get("audio_mask", None)
+        audio_token_lengths = audio_mask.sum(-1) if audio_mask is not None else None
 
         batch = {"speech": speech, "speech_lengths": speech_lengths}
         enc, enc_lens = self.audio_encoder.encode(**batch)