Sfoglia il codice sorgente

fix bug, ys_pad_masked in sampler of paraformer

游雁 3 anni fa
parent
commit
cfb2fda87c
1 ha cambiato i file con 3 aggiunte e 3 eliminazioni
  1. 3 3
      funasr/models/e2e_asr_paraformer.py

+ 3 - 3
funasr/models/e2e_asr_paraformer.py

@@ -499,11 +499,11 @@ class Paraformer(AbsESPnetModel):
     def sampler(self, encoder_out, encoder_out_lens, ys_pad, ys_pad_lens, pre_acoustic_embeds):
 
         tgt_mask = (~make_pad_mask(ys_pad_lens, maxlen=ys_pad_lens.max())[:, :, None]).to(ys_pad.device)
-        ys_pad = ys_pad * tgt_mask[:, :, 0]
+        ys_pad_masked = ys_pad * tgt_mask[:, :, 0]
         if self.share_embedding:
-            ys_pad_embed = self.decoder.output_layer.weight[ys_pad]
+            ys_pad_embed = self.decoder.output_layer.weight[ys_pad_masked]
         else:
-            ys_pad_embed = self.decoder.embed(ys_pad)
+            ys_pad_embed = self.decoder.embed(ys_pad_masked)
         with torch.no_grad():
             decoder_outs = self.decoder(
                 encoder_out, encoder_out_lens, pre_acoustic_embeds, ys_pad_lens