Explorar o código

update neat contextual paraformer

shixian.shi %!s(int64=2) %!d(string=hai) anos
pai
achega
e1d535e697
Modificáronse 1 ficheiros con 1 adicións e 26 borrados
  1. 1 26
      funasr/models/e2e_asr_contextual_paraformer.py

+ 1 - 26
funasr/models/e2e_asr_contextual_paraformer.py

@@ -291,7 +291,7 @@ class NeatContextualParaformer(Paraformer):
             loss_ideal = None
         '''
         loss_ideal = None
-        
+
         if decoder_out_1st is None:
             decoder_out_1st = decoder_out
         # 2. Compute attention loss
@@ -362,11 +362,6 @@ class NeatContextualParaformer(Paraformer):
             hw_embed = torch.nn.utils.rnn.pack_padded_sequence(hw_embed, hw_lengths, batch_first=True,
                                                             enforce_sorted=False)
             _, (h_n, _) = self.bias_encoder(hw_embed)
-            # hw_embed, _ = torch.nn.utils.rnn.pad_packed_sequence(hw_embed, batch_first=True)
-            if h_n.shape[1] > 2000: # large hotword list
-                _h_n = self.pick_hwlist_group(h_n.squeeze(0), encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens)
-                if _h_n is not None:
-                    h_n = _h_n
             hw_embed = h_n.repeat(encoder_out.shape[0], 1, 1)
         
         decoder_outs = self.decoder(
@@ -375,23 +370,3 @@ class NeatContextualParaformer(Paraformer):
         decoder_out = decoder_outs[0]
         decoder_out = torch.log_softmax(decoder_out, dim=-1)
         return decoder_out, ys_pad_lens
-
-    def pick_hwlist_group(self, hw_embed, encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens):
-        max_attn_score = 0.0
-        # max_attn_index = 0
-        argmax_g = None
-        non_blank = hw_embed[-1]
-        hw_embed_groups = hw_embed[:-1].split(2000)
-        for i, g in enumerate(hw_embed_groups):
-            g = torch.cat([g, non_blank.unsqueeze(0)], dim=0)
-            _ = self.decoder(
-                encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, contextual_info=g.unsqueeze(0)
-            )
-            attn = self.decoder.bias_decoder.src_attn.attn[0]
-            _max_attn_score = attn.max(0)[0][:,:-1].max()
-            if _max_attn_score > max_attn_score:
-                max_attn_score = _max_attn_score
-                # max_attn_index = i
-                argmax_g = g
-        # import pdb; pdb.set_trace()
-        return argmax_g