e2e_asr_paraformer.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366
  1. import logging
  2. import torch
  3. import torch.nn as nn
  4. from funasr.export.utils.torch_function import MakePadMask
  5. from funasr.export.utils.torch_function import sequence_mask
  6. from funasr.models.encoder.sanm_encoder import SANMEncoder, SANMEncoderChunkOpt
  7. from funasr.models.encoder.conformer_encoder import ConformerEncoder
  8. from funasr.export.models.encoder.sanm_encoder import SANMEncoder as SANMEncoder_export
  9. from funasr.export.models.encoder.conformer_encoder import ConformerEncoder as ConformerEncoder_export
  10. from funasr.models.predictor.cif import CifPredictorV2, CifPredictorV3
  11. from funasr.export.models.predictor.cif import CifPredictorV2 as CifPredictorV2_export
  12. from funasr.export.models.predictor.cif import CifPredictorV3 as CifPredictorV3_export
  13. from funasr.models.decoder.sanm_decoder import ParaformerSANMDecoder
  14. from funasr.models.decoder.transformer_decoder import ParaformerDecoderSAN
  15. from funasr.export.models.decoder.sanm_decoder import ParaformerSANMDecoder as ParaformerSANMDecoder_export
  16. from funasr.export.models.decoder.transformer_decoder import ParaformerDecoderSAN as ParaformerDecoderSAN_export
  17. from funasr.export.models.decoder.sanm_decoder import ParaformerSANMDecoderOnline as ParaformerSANMDecoderOnline_export
  18. class Paraformer(nn.Module):
  19. """
  20. Author: Speech Lab of DAMO Academy, Alibaba Group
  21. Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
  22. https://arxiv.org/abs/2206.08317
  23. """
  24. def __init__(
  25. self,
  26. model,
  27. max_seq_len=512,
  28. feats_dim=560,
  29. model_name='model',
  30. **kwargs,
  31. ):
  32. super().__init__()
  33. onnx = False
  34. if "onnx" in kwargs:
  35. onnx = kwargs["onnx"]
  36. if isinstance(model.encoder, SANMEncoder):
  37. self.encoder = SANMEncoder_export(model.encoder, onnx=onnx)
  38. elif isinstance(model.encoder, ConformerEncoder):
  39. self.encoder = ConformerEncoder_export(model.encoder, onnx=onnx)
  40. if isinstance(model.predictor, CifPredictorV2):
  41. self.predictor = CifPredictorV2_export(model.predictor)
  42. if isinstance(model.decoder, ParaformerSANMDecoder):
  43. self.decoder = ParaformerSANMDecoder_export(model.decoder, onnx=onnx)
  44. elif isinstance(model.decoder, ParaformerDecoderSAN):
  45. self.decoder = ParaformerDecoderSAN_export(model.decoder, onnx=onnx)
  46. self.feats_dim = feats_dim
  47. self.model_name = model_name
  48. if onnx:
  49. self.make_pad_mask = MakePadMask(max_seq_len, flip=False)
  50. else:
  51. self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
  52. def forward(
  53. self,
  54. speech: torch.Tensor,
  55. speech_lengths: torch.Tensor,
  56. ):
  57. # a. To device
  58. batch = {"speech": speech, "speech_lengths": speech_lengths}
  59. # batch = to_device(batch, device=self.device)
  60. enc, enc_len = self.encoder(**batch)
  61. mask = self.make_pad_mask(enc_len)[:, None, :]
  62. pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = self.predictor(enc, mask)
  63. pre_token_length = pre_token_length.floor().type(torch.int32)
  64. decoder_out, _ = self.decoder(enc, enc_len, pre_acoustic_embeds, pre_token_length)
  65. decoder_out = torch.log_softmax(decoder_out, dim=-1)
  66. # sample_ids = decoder_out.argmax(dim=-1)
  67. return decoder_out, pre_token_length
  68. def get_dummy_inputs(self):
  69. speech = torch.randn(2, 30, self.feats_dim)
  70. speech_lengths = torch.tensor([6, 30], dtype=torch.int32)
  71. return (speech, speech_lengths)
  72. def get_dummy_inputs_txt(self, txt_file: str = "/mnt/workspace/data_fbank/0207/12345.wav.fea.txt"):
  73. import numpy as np
  74. fbank = np.loadtxt(txt_file)
  75. fbank_lengths = np.array([fbank.shape[0], ], dtype=np.int32)
  76. speech = torch.from_numpy(fbank[None, :, :].astype(np.float32))
  77. speech_lengths = torch.from_numpy(fbank_lengths.astype(np.int32))
  78. return (speech, speech_lengths)
  79. def get_input_names(self):
  80. return ['speech', 'speech_lengths']
  81. def get_output_names(self):
  82. return ['logits', 'token_num']
  83. def get_dynamic_axes(self):
  84. return {
  85. 'speech': {
  86. 0: 'batch_size',
  87. 1: 'feats_length'
  88. },
  89. 'speech_lengths': {
  90. 0: 'batch_size',
  91. },
  92. 'logits': {
  93. 0: 'batch_size',
  94. 1: 'logits_length'
  95. },
  96. }
  97. class BiCifParaformer(nn.Module):
  98. """
  99. Author: Speech Lab of DAMO Academy, Alibaba Group
  100. Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
  101. https://arxiv.org/abs/2206.08317
  102. """
  103. def __init__(
  104. self,
  105. model,
  106. max_seq_len=512,
  107. feats_dim=560,
  108. model_name='model',
  109. **kwargs,
  110. ):
  111. super().__init__()
  112. onnx = False
  113. if "onnx" in kwargs:
  114. onnx = kwargs["onnx"]
  115. if isinstance(model.encoder, SANMEncoder):
  116. self.encoder = SANMEncoder_export(model.encoder, onnx=onnx)
  117. elif isinstance(model.encoder, ConformerEncoder):
  118. self.encoder = ConformerEncoder_export(model.encoder, onnx=onnx)
  119. else:
  120. logging.warning("Unsupported encoder type to export.")
  121. if isinstance(model.predictor, CifPredictorV3):
  122. self.predictor = CifPredictorV3_export(model.predictor)
  123. else:
  124. logging.warning("Wrong predictor type to export.")
  125. if isinstance(model.decoder, ParaformerSANMDecoder):
  126. self.decoder = ParaformerSANMDecoder_export(model.decoder, onnx=onnx)
  127. elif isinstance(model.decoder, ParaformerDecoderSAN):
  128. self.decoder = ParaformerDecoderSAN_export(model.decoder, onnx=onnx)
  129. else:
  130. logging.warning("Unsupported decoder type to export.")
  131. self.feats_dim = feats_dim
  132. self.model_name = model_name
  133. if onnx:
  134. self.make_pad_mask = MakePadMask(max_seq_len, flip=False)
  135. else:
  136. self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
  137. def forward(
  138. self,
  139. speech: torch.Tensor,
  140. speech_lengths: torch.Tensor,
  141. ):
  142. # a. To device
  143. batch = {"speech": speech, "speech_lengths": speech_lengths}
  144. # batch = to_device(batch, device=self.device)
  145. enc, enc_len = self.encoder(**batch)
  146. mask = self.make_pad_mask(enc_len)[:, None, :]
  147. pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = self.predictor(enc, mask)
  148. pre_token_length = pre_token_length.round().type(torch.int32)
  149. decoder_out, _ = self.decoder(enc, enc_len, pre_acoustic_embeds, pre_token_length)
  150. decoder_out = torch.log_softmax(decoder_out, dim=-1)
  151. # get predicted timestamps
  152. us_alphas, us_cif_peak = self.predictor.get_upsample_timestmap(enc, mask, pre_token_length)
  153. return decoder_out, pre_token_length, us_alphas, us_cif_peak
  154. def get_dummy_inputs(self):
  155. speech = torch.randn(2, 30, self.feats_dim)
  156. speech_lengths = torch.tensor([6, 30], dtype=torch.int32)
  157. return (speech, speech_lengths)
  158. def get_dummy_inputs_txt(self, txt_file: str = "/mnt/workspace/data_fbank/0207/12345.wav.fea.txt"):
  159. import numpy as np
  160. fbank = np.loadtxt(txt_file)
  161. fbank_lengths = np.array([fbank.shape[0], ], dtype=np.int32)
  162. speech = torch.from_numpy(fbank[None, :, :].astype(np.float32))
  163. speech_lengths = torch.from_numpy(fbank_lengths.astype(np.int32))
  164. return (speech, speech_lengths)
  165. def get_input_names(self):
  166. return ['speech', 'speech_lengths']
  167. def get_output_names(self):
  168. return ['logits', 'token_num', 'us_alphas', 'us_cif_peak']
  169. def get_dynamic_axes(self):
  170. return {
  171. 'speech': {
  172. 0: 'batch_size',
  173. 1: 'feats_length'
  174. },
  175. 'speech_lengths': {
  176. 0: 'batch_size',
  177. },
  178. 'logits': {
  179. 0: 'batch_size',
  180. 1: 'logits_length'
  181. },
  182. 'us_alphas': {
  183. 0: 'batch_size',
  184. 1: 'alphas_length'
  185. },
  186. 'us_cif_peak': {
  187. 0: 'batch_size',
  188. 1: 'alphas_length'
  189. },
  190. }
  191. class ParaformerOnline_encoder_predictor(nn.Module):
  192. """
  193. Author: Speech Lab, Alibaba Group, China
  194. Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
  195. https://arxiv.org/abs/2206.08317
  196. """
  197. def __init__(
  198. self,
  199. model,
  200. max_seq_len=512,
  201. feats_dim=560,
  202. model_name='model',
  203. **kwargs,
  204. ):
  205. super().__init__()
  206. onnx = False
  207. if "onnx" in kwargs:
  208. onnx = kwargs["onnx"]
  209. if isinstance(model.encoder, SANMEncoder) or isinstance(model.encoder, SANMEncoderChunkOpt):
  210. self.encoder = SANMEncoder_export(model.encoder, onnx=onnx)
  211. elif isinstance(model.encoder, ConformerEncoder):
  212. self.encoder = ConformerEncoder_export(model.encoder, onnx=onnx)
  213. if isinstance(model.predictor, CifPredictorV2):
  214. self.predictor = CifPredictorV2_export(model.predictor)
  215. self.feats_dim = feats_dim
  216. self.model_name = model_name
  217. if onnx:
  218. self.make_pad_mask = MakePadMask(max_seq_len, flip=False)
  219. else:
  220. self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
  221. def forward(
  222. self,
  223. speech: torch.Tensor,
  224. speech_lengths: torch.Tensor,
  225. ):
  226. # a. To device
  227. batch = {"speech": speech, "speech_lengths": speech_lengths, "online": True}
  228. # batch = to_device(batch, device=self.device)
  229. enc, enc_len = self.encoder(**batch)
  230. mask = self.make_pad_mask(enc_len)[:, None, :]
  231. alphas, _ = self.predictor.forward_cnn(enc, mask)
  232. return enc, enc_len, alphas
  233. def get_dummy_inputs(self):
  234. speech = torch.randn(2, 30, self.feats_dim)
  235. speech_lengths = torch.tensor([6, 30], dtype=torch.int32)
  236. return (speech, speech_lengths)
  237. def get_input_names(self):
  238. return ['speech', 'speech_lengths']
  239. def get_output_names(self):
  240. return ['enc', 'enc_len', 'alphas']
  241. def get_dynamic_axes(self):
  242. return {
  243. 'speech': {
  244. 0: 'batch_size',
  245. 1: 'feats_length'
  246. },
  247. 'speech_lengths': {
  248. 0: 'batch_size',
  249. },
  250. 'enc': {
  251. 0: 'batch_size',
  252. 1: 'feats_length'
  253. },
  254. 'enc_len': {
  255. 0: 'batch_size',
  256. },
  257. 'alphas': {
  258. 0: 'batch_size',
  259. 1: 'feats_length'
  260. },
  261. }
  262. class ParaformerOnline_decoder(nn.Module):
  263. """
  264. Author: Speech Lab, Alibaba Group, China
  265. Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
  266. https://arxiv.org/abs/2206.08317
  267. """
  268. def __init__(
  269. self,
  270. model,
  271. max_seq_len=512,
  272. feats_dim=560,
  273. model_name='model',
  274. **kwargs,
  275. ):
  276. super().__init__()
  277. onnx = False
  278. if "onnx" in kwargs:
  279. onnx = kwargs["onnx"]
  280. if isinstance(model.decoder, ParaformerDecoderSAN):
  281. self.decoder = ParaformerDecoderSAN_export(model.decoder, onnx=onnx)
  282. elif isinstance(model.decoder, ParaformerSANMDecoder):
  283. self.decoder = ParaformerSANMDecoderOnline_export(model.decoder, onnx=onnx)
  284. self.feats_dim = feats_dim
  285. self.model_name = model_name
  286. self.enc_size = model.encoder._output_size
  287. if onnx:
  288. self.make_pad_mask = MakePadMask(max_seq_len, flip=False)
  289. else:
  290. self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
  291. def forward(
  292. self,
  293. enc: torch.Tensor,
  294. enc_len: torch.Tensor,
  295. acoustic_embeds: torch.Tensor,
  296. acoustic_embeds_len: torch.Tensor,
  297. *args,
  298. ):
  299. decoder_out, out_caches = self.decoder(enc, enc_len, acoustic_embeds, acoustic_embeds_len, *args)
  300. sample_ids = decoder_out.argmax(dim=-1)
  301. return decoder_out, sample_ids, out_caches
  302. def get_dummy_inputs(self, ):
  303. dummy_inputs = self.decoder.get_dummy_inputs(enc_size=self.enc_size)
  304. return dummy_inputs
  305. def get_input_names(self):
  306. return self.decoder.get_input_names()
  307. def get_output_names(self):
  308. return self.decoder.get_output_names()
  309. def get_dynamic_axes(self):
  310. return self.decoder.get_dynamic_axes()