|
|
@@ -63,8 +63,9 @@ class Paraformer(nn.Module):
|
|
|
|
|
|
decoder_out, _ = self.decoder(enc, enc_len, pre_acoustic_embeds, pre_token_length)
|
|
|
decoder_out = torch.log_softmax(decoder_out, dim=-1)
|
|
|
+ sample_ids = decoder_out.argmax(dim=-1)
|
|
|
|
|
|
- return decoder_out, pre_token_length
|
|
|
+ return decoder_out, sample_ids
|
|
|
|
|
|
# def get_output_size(self):
|
|
|
# return self.model.encoders[0].size
|
|
|
@@ -74,6 +75,14 @@ class Paraformer(nn.Module):
|
|
|
speech_lengths = torch.tensor([6, 30], dtype=torch.int32)
|
|
|
return (speech, speech_lengths)
|
|
|
|
|
|
+ def get_dummy_inputs_txt(self, txt_file: str = "/mnt/workspace/data_fbank/0207/12345.wav.fea.txt"):
|
|
|
+ import numpy as np
|
|
|
+ fbank = np.loadtxt(txt_file)
|
|
|
+ fbank_lengths = np.array([fbank.shape[0], ], dtype=np.int32)
|
|
|
+ speech = torch.from_numpy(fbank[None, :, :].astype(np.float32))
|
|
|
+ speech_lengths = torch.from_numpy(fbank_lengths.astype(np.int32))
|
|
|
+ return (speech, speech_lengths)
|
|
|
+
|
|
|
def get_input_names(self):
|
|
|
return ['speech', 'speech_lengths']
|
|
|
|