fsmn_encoder.py 9.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316
  1. import numpy as np
  2. import torch
  3. import torch.nn as nn
  4. import torch.nn.functional as F
  5. from typing import Tuple
  6. class LinearTransform(nn.Module):
  7. def __init__(self, input_dim, output_dim, quantize=0):
  8. super(LinearTransform, self).__init__()
  9. self.input_dim = input_dim
  10. self.output_dim = output_dim
  11. self.linear = nn.Linear(input_dim, output_dim, bias=False)
  12. self.quantize = quantize
  13. self.quant = torch.quantization.QuantStub()
  14. self.dequant = torch.quantization.DeQuantStub()
  15. def forward(self, input):
  16. if self.quantize:
  17. output = self.quant(input)
  18. else:
  19. output = input
  20. output = self.linear(output)
  21. if self.quantize:
  22. output = self.dequant(output)
  23. return output
  24. class AffineTransform(nn.Module):
  25. def __init__(self, input_dim, output_dim, quantize=0):
  26. super(AffineTransform, self).__init__()
  27. self.input_dim = input_dim
  28. self.output_dim = output_dim
  29. self.quantize = quantize
  30. self.linear = nn.Linear(input_dim, output_dim)
  31. self.quant = torch.quantization.QuantStub()
  32. self.dequant = torch.quantization.DeQuantStub()
  33. def forward(self, input):
  34. if self.quantize:
  35. output = self.quant(input)
  36. else:
  37. output = input
  38. output = self.linear(output)
  39. if self.quantize:
  40. output = self.dequant(output)
  41. return output
  42. class FSMNBlock(nn.Module):
  43. def __init__(
  44. self,
  45. input_dim: int,
  46. output_dim: int,
  47. lorder=None,
  48. rorder=None,
  49. lstride=1,
  50. rstride=1,
  51. quantize=0
  52. ):
  53. super(FSMNBlock, self).__init__()
  54. self.dim = input_dim
  55. if lorder is None:
  56. return
  57. self.lorder = lorder
  58. self.rorder = rorder
  59. self.lstride = lstride
  60. self.rstride = rstride
  61. self.conv_left = nn.Conv2d(
  62. self.dim, self.dim, [lorder, 1], dilation=[lstride, 1], groups=self.dim, bias=False)
  63. if self.rorder > 0:
  64. self.conv_right = nn.Conv2d(
  65. self.dim, self.dim, [rorder, 1], dilation=[rstride, 1], groups=self.dim, bias=False)
  66. else:
  67. self.conv_right = None
  68. self.quantize = quantize
  69. self.quant = torch.quantization.QuantStub()
  70. self.dequant = torch.quantization.DeQuantStub()
  71. def forward(self, input):
  72. x = torch.unsqueeze(input, 1)
  73. x_per = x.permute(0, 3, 2, 1)
  74. y_left = F.pad(x_per, [0, 0, (self.lorder - 1) * self.lstride, 0])
  75. if self.quantize:
  76. y_left = self.quant(y_left)
  77. y_left = self.conv_left(y_left)
  78. if self.quantize:
  79. y_left = self.dequant(y_left)
  80. out = x_per + y_left
  81. if self.conv_right is not None:
  82. y_right = F.pad(x_per, [0, 0, 0, self.rorder * self.rstride])
  83. y_right = y_right[:, :, self.rstride:, :]
  84. if self.quantize:
  85. y_right = self.quant(y_right)
  86. y_right = self.conv_right(y_right)
  87. if self.quantize:
  88. y_right = self.dequant(y_right)
  89. out += y_right
  90. out_per = out.permute(0, 3, 2, 1)
  91. output = out_per.squeeze(1)
  92. return output
  93. class RectifiedLinear(nn.Module):
  94. def __init__(self, input_dim, output_dim):
  95. super(RectifiedLinear, self).__init__()
  96. self.dim = input_dim
  97. self.relu = nn.ReLU()
  98. self.dropout = nn.Dropout(0.1)
  99. def forward(self, input):
  100. out = self.relu(input)
  101. # out = self.dropout(out)
  102. return out
  103. def _build_repeats(
  104. fsmn_layers: int,
  105. linear_dim: int,
  106. proj_dim: int,
  107. lorder: int,
  108. rorder: int,
  109. lstride=1,
  110. rstride=1,
  111. ):
  112. repeats = [
  113. nn.Sequential(
  114. LinearTransform(linear_dim, proj_dim),
  115. FSMNBlock(proj_dim, proj_dim, lorder, rorder, 1, 1),
  116. AffineTransform(proj_dim, linear_dim),
  117. RectifiedLinear(linear_dim, linear_dim))
  118. for i in range(fsmn_layers)
  119. ]
  120. return nn.Sequential(*repeats)
  121. '''
  122. FSMN net for keyword spotting
  123. input_dim: input dimension
  124. linear_dim: fsmn input dimensionll
  125. proj_dim: fsmn projection dimension
  126. lorder: fsmn left order
  127. rorder: fsmn right order
  128. num_syn: output dimension
  129. fsmn_layers: no. of sequential fsmn layers
  130. '''
  131. class FSMN(nn.Module):
  132. def __init__(
  133. self,
  134. input_dim: int,
  135. input_affine_dim: int,
  136. fsmn_layers: int,
  137. linear_dim: int,
  138. proj_dim: int,
  139. lorder: int,
  140. rorder: int,
  141. lstride: int,
  142. rstride: int,
  143. output_affine_dim: int,
  144. output_dim: int,
  145. ):
  146. super(FSMN, self).__init__()
  147. self.input_dim = input_dim
  148. self.input_affine_dim = input_affine_dim
  149. self.fsmn_layers = fsmn_layers
  150. self.linear_dim = linear_dim
  151. self.proj_dim = proj_dim
  152. self.lorder = lorder
  153. self.rorder = rorder
  154. self.lstride = lstride
  155. self.rstride = rstride
  156. self.output_affine_dim = output_affine_dim
  157. self.output_dim = output_dim
  158. self.in_linear1 = AffineTransform(input_dim, input_affine_dim)
  159. self.in_linear2 = AffineTransform(input_affine_dim, linear_dim)
  160. self.relu = RectifiedLinear(linear_dim, linear_dim)
  161. self.fsmn = _build_repeats(fsmn_layers,
  162. linear_dim,
  163. proj_dim,
  164. lorder, rorder,
  165. lstride, rstride)
  166. self.out_linear1 = AffineTransform(linear_dim, output_affine_dim)
  167. self.out_linear2 = AffineTransform(output_affine_dim, output_dim)
  168. self.softmax = nn.Softmax(dim=-1)
  169. def fuse_modules(self):
  170. pass
  171. def forward(
  172. self,
  173. input: torch.Tensor,
  174. in_cache: torch.Tensor = torch.zeros(0, 0, 0, dtype=torch.float)
  175. ) -> torch.Tensor:
  176. """
  177. Args:
  178. input (torch.Tensor): Input tensor (B, T, D)
  179. in_cache(torhc.Tensor): (B, D, C), C is the accumulated cache size
  180. """
  181. x1 = self.in_linear1(input)
  182. x2 = self.in_linear2(x1)
  183. x3 = self.relu(x2)
  184. x4 = self.fsmn(x3)
  185. x5 = self.out_linear1(x4)
  186. x6 = self.out_linear2(x5)
  187. x7 = self.softmax(x6)
  188. return x7
  189. # return x6, in_cache
  190. '''
  191. one deep fsmn layer
  192. dimproj: projection dimension, input and output dimension of memory blocks
  193. dimlinear: dimension of mapping layer
  194. lorder: left order
  195. rorder: right order
  196. lstride: left stride
  197. rstride: right stride
  198. '''
  199. class DFSMN(nn.Module):
  200. def __init__(self, dimproj=64, dimlinear=128, lorder=20, rorder=1, lstride=1, rstride=1):
  201. super(DFSMN, self).__init__()
  202. self.lorder = lorder
  203. self.rorder = rorder
  204. self.lstride = lstride
  205. self.rstride = rstride
  206. self.expand = AffineTransform(dimproj, dimlinear)
  207. self.shrink = LinearTransform(dimlinear, dimproj)
  208. self.conv_left = nn.Conv2d(
  209. dimproj, dimproj, [lorder, 1], dilation=[lstride, 1], groups=dimproj, bias=False)
  210. if rorder > 0:
  211. self.conv_right = nn.Conv2d(
  212. dimproj, dimproj, [rorder, 1], dilation=[rstride, 1], groups=dimproj, bias=False)
  213. else:
  214. self.conv_right = None
  215. def forward(self, input):
  216. f1 = F.relu(self.expand(input))
  217. p1 = self.shrink(f1)
  218. x = torch.unsqueeze(p1, 1)
  219. x_per = x.permute(0, 3, 2, 1)
  220. y_left = F.pad(x_per, [0, 0, (self.lorder - 1) * self.lstride, 0])
  221. if self.conv_right is not None:
  222. y_right = F.pad(x_per, [0, 0, 0, (self.rorder) * self.rstride])
  223. y_right = y_right[:, :, self.rstride:, :]
  224. out = x_per + self.conv_left(y_left) + self.conv_right(y_right)
  225. else:
  226. out = x_per + self.conv_left(y_left)
  227. out1 = out.permute(0, 3, 2, 1)
  228. output = input + out1.squeeze(1)
  229. return output
  230. '''
  231. build stacked dfsmn layers
  232. '''
  233. def buildDFSMNRepeats(linear_dim=128, proj_dim=64, lorder=20, rorder=1, fsmn_layers=6):
  234. repeats = [
  235. nn.Sequential(
  236. DFSMN(proj_dim, linear_dim, lorder, rorder, 1, 1))
  237. for i in range(fsmn_layers)
  238. ]
  239. return nn.Sequential(*repeats)
  240. if __name__ == '__main__':
  241. fsmn = FSMN(400, 140, 4, 250, 128, 10, 2, 1, 1, 140, 2599)
  242. print(fsmn)
  243. num_params = sum(p.numel() for p in fsmn.parameters())
  244. print('the number of model params: {}'.format(num_params))
  245. x = torch.zeros(128, 200, 400) # batch-size * time * dim
  246. y, _ = fsmn(x) # batch-size * time * dim
  247. print('input shape: {}'.format(x.shape))
  248. print('output shape: {}'.format(y.shape))
  249. print(fsmn.to_kaldi_net())