|
|
@@ -216,8 +216,8 @@ class LLMASRNAR(nn.Module):
|
|
|
self, speech: torch.Tensor, speech_lengths: torch.Tensor, **kwargs,
|
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
|
|
|
|
- audio_mask = kwargs.get("audio_mask")
|
|
|
- audio_token_lengths = audio_mask.sum(-1) if audio_mask else None
|
|
|
+ audio_mask = kwargs.get("audio_mask", None)
|
|
|
+ audio_token_lengths = audio_mask.sum(-1) if audio_mask is not None else None
|
|
|
|
|
|
batch = {"speech": speech, "speech_lengths": speech_lengths}
|
|
|
enc, enc_lens = self.audio_encoder.encode(**batch)
|