|
@@ -278,9 +278,10 @@ class NeatContextualParaformer(Paraformer):
|
|
|
|
|
|
|
|
# 1. Forward decoder
|
|
# 1. Forward decoder
|
|
|
decoder_outs = self.decoder(
|
|
decoder_outs = self.decoder(
|
|
|
- encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, contextual_info=contextual_info, ret_attn=(ideal_attn is not None)
|
|
|
|
|
|
|
+ encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, contextual_info=contextual_info
|
|
|
)
|
|
)
|
|
|
decoder_out, _, attn = decoder_outs[0], decoder_outs[1], decoder_outs[2]
|
|
decoder_out, _, attn = decoder_outs[0], decoder_outs[1], decoder_outs[2]
|
|
|
|
|
+
|
|
|
if self.crit_attn_weight > 0 and attn.shape[-1] > 1:
|
|
if self.crit_attn_weight > 0 and attn.shape[-1] > 1:
|
|
|
ideal_attn = ideal_attn + self.crit_attn_smooth / (self.crit_attn_smooth + 1.0)
|
|
ideal_attn = ideal_attn + self.crit_attn_smooth / (self.crit_attn_smooth + 1.0)
|
|
|
attn_non_blank = attn[:,:,:,:-1]
|
|
attn_non_blank = attn[:,:,:,:-1]
|