|
|
@@ -488,15 +488,20 @@ class Speech2TextParaformer:
|
|
|
|
|
|
nbest_hyps = nbest_hyps[: self.nbest]
|
|
|
else:
|
|
|
- yseq = am_scores.argmax(dim=-1)
|
|
|
- score = am_scores.max(dim=-1)[0]
|
|
|
- score = torch.sum(score, dim=-1)
|
|
|
- # pad with mask tokens to ensure compatibility with sos/eos tokens
|
|
|
- yseq = torch.tensor(
|
|
|
- [self.asr_model.sos] + yseq.tolist() + [self.asr_model.eos], device=yseq.device
|
|
|
- )
|
|
|
+ if pre_token_length[i] == 0:
|
|
|
+ yseq = torch.tensor(
|
|
|
+ [self.asr_model.sos] + [self.asr_model.eos], device=yseq.device
|
|
|
+ )
|
|
|
+ score = torch.tensor(0.0, device=yseq.device)
|
|
|
+ else:
|
|
|
+ yseq = am_scores.argmax(dim=-1)
|
|
|
+ score = am_scores.max(dim=-1)[0]
|
|
|
+ score = torch.sum(score, dim=-1)
|
|
|
+ # pad with mask tokens to ensure compatibility with sos/eos tokens
|
|
|
+ yseq = torch.tensor(
|
|
|
+ [self.asr_model.sos] + yseq.tolist() + [self.asr_model.eos], device=yseq.device
|
|
|
+ )
|
|
|
nbest_hyps = [Hypothesis(yseq=yseq, score=score)]
|
|
|
-
|
|
|
for hyp in nbest_hyps:
|
|
|
assert isinstance(hyp, (Hypothesis)), type(hyp)
|
|
|
|