|
|
@@ -1510,8 +1510,13 @@ class Speech2TextTransducer:
|
|
|
if isinstance(speech, np.ndarray):
|
|
|
speech = torch.tensor(speech)
|
|
|
|
|
|
- feats = speech.unsqueeze(0).to(getattr(torch, self.dtype))
|
|
|
- feats_lengths = feats.new_full([1], dtype=torch.long, fill_value=feats.size(1))
|
|
|
+ if self.frontend is not None:
|
|
|
+ speech = torch.unsqueeze(speech, axis=0)
|
|
|
+ speech_lengths = speech.new_full([1], dtype=torch.long, fill_value=speech.size(1))
|
|
|
+ feats, feats_lengths = self.frontend(speech, speech_lengths)
|
|
|
+ else:
|
|
|
+ feats = speech.unsqueeze(0).to(getattr(torch, self.dtype))
|
|
|
+ feats_lengths = feats.new_full([1], dtype=torch.long, fill_value=feats.size(1))
|
|
|
|
|
|
if self.asr_model.normalize is not None:
|
|
|
feats, feats_lengths = self.asr_model.normalize(feats, feats_lengths)
|
|
|
@@ -1536,14 +1541,19 @@ class Speech2TextTransducer:
|
|
|
|
|
|
if isinstance(speech, np.ndarray):
|
|
|
speech = torch.tensor(speech)
|
|
|
-
|
|
|
- feats = speech.unsqueeze(0).to(getattr(torch, self.dtype))
|
|
|
- feats_lengths = feats.new_full([1], dtype=torch.long, fill_value=feats.size(1))
|
|
|
+
|
|
|
+ if self.frontend is not None:
|
|
|
+ speech = torch.unsqueeze(speech, axis=0)
|
|
|
+ speech_lengths = speech.new_full([1], dtype=torch.long, fill_value=speech.size(1))
|
|
|
+ feats, feats_lengths = self.frontend(speech, speech_lengths)
|
|
|
+ else:
|
|
|
+ feats = speech.unsqueeze(0).to(getattr(torch, self.dtype))
|
|
|
+ feats_lengths = feats.new_full([1], dtype=torch.long, fill_value=feats.size(1))
|
|
|
|
|
|
feats = to_device(feats, device=self.device)
|
|
|
feats_lengths = to_device(feats_lengths, device=self.device)
|
|
|
|
|
|
- enc_out, _ = self.asr_model.encoder(feats, feats_lengths)
|
|
|
+ enc_out, _, _ = self.asr_model.encoder(feats, feats_lengths)
|
|
|
|
|
|
nbest_hyps = self.beam_search(enc_out[0])
|
|
|
|