|
|
@@ -68,7 +68,7 @@ class NeatContextualParaformer(Paraformer):
|
|
|
target_buffer_length: int = -1,
|
|
|
inner_dim: int = 256,
|
|
|
bias_encoder_type: str = 'lstm',
|
|
|
- use_decoder_embedding: bool = True,
|
|
|
+ use_decoder_embedding: bool = False,
|
|
|
crit_attn_weight: float = 0.0,
|
|
|
crit_attn_smooth: float = 0.0,
|
|
|
bias_encoder_dropout_rate: float = 0.0,
|
|
|
@@ -340,7 +340,7 @@ class NeatContextualParaformer(Paraformer):
|
|
|
input_mask_expand_dim, 0)
|
|
|
return sematic_embeds * tgt_mask, decoder_out * tgt_mask
|
|
|
|
|
|
- def cal_decoder_with_predictor_with_hwlist_advanced(self, encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, hw_list=None):
|
|
|
+ def cal_decoder_with_predictor(self, encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, hw_list=None):
|
|
|
if hw_list is None:
|
|
|
hw_list = [torch.Tensor([1]).long().to(encoder_out.device)] # empty hotword list
|
|
|
hw_list_pad = pad_list(hw_list, 0)
|
|
|
@@ -350,7 +350,6 @@ class NeatContextualParaformer(Paraformer):
|
|
|
hw_embed = self.bias_embed(hw_list_pad)
|
|
|
hw_embed, (h_n, _) = self.bias_encoder(hw_embed)
|
|
|
else:
|
|
|
- # hw_list = hw_list[1:] + [hw_list[0]] # reorder
|
|
|
hw_lengths = [len(i) for i in hw_list]
|
|
|
hw_list_pad = pad_list([torch.Tensor(i).long() for i in hw_list], 0).to(encoder_out.device)
|
|
|
if self.use_decoder_embedding:
|
|
|
@@ -366,7 +365,6 @@ class NeatContextualParaformer(Paraformer):
|
|
|
if _h_n is not None:
|
|
|
h_n = _h_n
|
|
|
hw_embed = h_n.repeat(encoder_out.shape[0], 1, 1)
|
|
|
- # import pdb; pdb.set_trace()
|
|
|
|
|
|
decoder_outs = self.decoder(
|
|
|
encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, contextual_info=hw_embed
|