fsmn_encoder.py 9.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301
  1. from typing import Tuple, Dict
  2. import copy
  3. import numpy as np
  4. import torch
  5. import torch.nn as nn
  6. import torch.nn.functional as F
  7. class LinearTransform(nn.Module):
  8. def __init__(self, input_dim, output_dim):
  9. super(LinearTransform, self).__init__()
  10. self.input_dim = input_dim
  11. self.output_dim = output_dim
  12. self.linear = nn.Linear(input_dim, output_dim, bias=False)
  13. def forward(self, input):
  14. output = self.linear(input)
  15. return output
  16. class AffineTransform(nn.Module):
  17. def __init__(self, input_dim, output_dim):
  18. super(AffineTransform, self).__init__()
  19. self.input_dim = input_dim
  20. self.output_dim = output_dim
  21. self.linear = nn.Linear(input_dim, output_dim)
  22. def forward(self, input):
  23. output = self.linear(input)
  24. return output
  25. class RectifiedLinear(nn.Module):
  26. def __init__(self, input_dim, output_dim):
  27. super(RectifiedLinear, self).__init__()
  28. self.dim = input_dim
  29. self.relu = nn.ReLU()
  30. self.dropout = nn.Dropout(0.1)
  31. def forward(self, input):
  32. out = self.relu(input)
  33. return out
  34. class FSMNBlock(nn.Module):
  35. def __init__(
  36. self,
  37. input_dim: int,
  38. output_dim: int,
  39. lorder=None,
  40. rorder=None,
  41. lstride=1,
  42. rstride=1,
  43. ):
  44. super(FSMNBlock, self).__init__()
  45. self.dim = input_dim
  46. if lorder is None:
  47. return
  48. self.lorder = lorder
  49. self.rorder = rorder
  50. self.lstride = lstride
  51. self.rstride = rstride
  52. self.conv_left = nn.Conv2d(
  53. self.dim, self.dim, [lorder, 1], dilation=[lstride, 1], groups=self.dim, bias=False)
  54. if self.rorder > 0:
  55. self.conv_right = nn.Conv2d(
  56. self.dim, self.dim, [rorder, 1], dilation=[rstride, 1], groups=self.dim, bias=False)
  57. else:
  58. self.conv_right = None
  59. def forward(self, input: torch.Tensor, cache: torch.Tensor):
  60. x = torch.unsqueeze(input, 1)
  61. x_per = x.permute(0, 3, 2, 1) # B D T C
  62. cache = cache.to(x_per.device)
  63. y_left = torch.cat((cache, x_per), dim=2)
  64. cache = y_left[:, :, -(self.lorder - 1) * self.lstride:, :]
  65. y_left = self.conv_left(y_left)
  66. out = x_per + y_left
  67. if self.conv_right is not None:
  68. # maybe need to check
  69. y_right = F.pad(x_per, [0, 0, 0, self.rorder * self.rstride])
  70. y_right = y_right[:, :, self.rstride:, :]
  71. y_right = self.conv_right(y_right)
  72. out += y_right
  73. out_per = out.permute(0, 3, 2, 1)
  74. output = out_per.squeeze(1)
  75. return output, cache
  76. class BasicBlock(nn.Sequential):
  77. def __init__(self,
  78. linear_dim: int,
  79. proj_dim: int,
  80. lorder: int,
  81. rorder: int,
  82. lstride: int,
  83. rstride: int,
  84. stack_layer: int
  85. ):
  86. super(BasicBlock, self).__init__()
  87. self.lorder = lorder
  88. self.rorder = rorder
  89. self.lstride = lstride
  90. self.rstride = rstride
  91. self.stack_layer = stack_layer
  92. self.linear = LinearTransform(linear_dim, proj_dim)
  93. self.fsmn_block = FSMNBlock(proj_dim, proj_dim, lorder, rorder, lstride, rstride)
  94. self.affine = AffineTransform(proj_dim, linear_dim)
  95. self.relu = RectifiedLinear(linear_dim, linear_dim)
  96. def forward(self, input: torch.Tensor, in_cache: Dict[str, torch.Tensor]):
  97. x1 = self.linear(input) # B T D
  98. cache_layer_name = 'cache_layer_{}'.format(self.stack_layer)
  99. if cache_layer_name not in in_cache:
  100. in_cache[cache_layer_name] = torch.zeros(x1.shape[0], x1.shape[-1], (self.lorder - 1) * self.lstride, 1)
  101. x2, in_cache[cache_layer_name] = self.fsmn_block(x1, in_cache[cache_layer_name])
  102. x3 = self.affine(x2)
  103. x4 = self.relu(x3)
  104. return x4
  105. class FsmnStack(nn.Sequential):
  106. def __init__(self, *args):
  107. super(FsmnStack, self).__init__(*args)
  108. def forward(self, input: torch.Tensor, in_cache: Dict[str, torch.Tensor]):
  109. x = input
  110. for module in self._modules.values():
  111. x = module(x, in_cache)
  112. return x
  113. '''
  114. FSMN net for keyword spotting
  115. input_dim: input dimension
  116. linear_dim: fsmn input dimensionll
  117. proj_dim: fsmn projection dimension
  118. lorder: fsmn left order
  119. rorder: fsmn right order
  120. num_syn: output dimension
  121. fsmn_layers: no. of sequential fsmn layers
  122. '''
  123. class FSMN(nn.Module):
  124. def __init__(
  125. self,
  126. input_dim: int,
  127. input_affine_dim: int,
  128. fsmn_layers: int,
  129. linear_dim: int,
  130. proj_dim: int,
  131. lorder: int,
  132. rorder: int,
  133. lstride: int,
  134. rstride: int,
  135. output_affine_dim: int,
  136. output_dim: int
  137. ):
  138. super(FSMN, self).__init__()
  139. self.input_dim = input_dim
  140. self.input_affine_dim = input_affine_dim
  141. self.fsmn_layers = fsmn_layers
  142. self.linear_dim = linear_dim
  143. self.proj_dim = proj_dim
  144. self.output_affine_dim = output_affine_dim
  145. self.output_dim = output_dim
  146. self.in_linear1 = AffineTransform(input_dim, input_affine_dim)
  147. self.in_linear2 = AffineTransform(input_affine_dim, linear_dim)
  148. self.relu = RectifiedLinear(linear_dim, linear_dim)
  149. self.fsmn = FsmnStack(*[BasicBlock(linear_dim, proj_dim, lorder, rorder, lstride, rstride, i) for i in
  150. range(fsmn_layers)])
  151. self.out_linear1 = AffineTransform(linear_dim, output_affine_dim)
  152. self.out_linear2 = AffineTransform(output_affine_dim, output_dim)
  153. self.softmax = nn.Softmax(dim=-1)
  154. def fuse_modules(self):
  155. pass
  156. def forward(
  157. self,
  158. input: torch.Tensor,
  159. in_cache: Dict[str, torch.Tensor]
  160. ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
  161. """
  162. Args:
  163. input (torch.Tensor): Input tensor (B, T, D)
  164. in_cache: when in_cache is not None, the forward is in streaming. The type of in_cache is a dict, egs,
  165. {'cache_layer_1': torch.Tensor(B, T1, D)}, T1 is equal to self.lorder. It is {} for the 1st frame
  166. """
  167. x1 = self.in_linear1(input)
  168. x2 = self.in_linear2(x1)
  169. x3 = self.relu(x2)
  170. x4 = self.fsmn(x3, in_cache) # self.in_cache will update automatically in self.fsmn
  171. x5 = self.out_linear1(x4)
  172. x6 = self.out_linear2(x5)
  173. x7 = self.softmax(x6)
  174. return x7
  175. '''
  176. one deep fsmn layer
  177. dimproj: projection dimension, input and output dimension of memory blocks
  178. dimlinear: dimension of mapping layer
  179. lorder: left order
  180. rorder: right order
  181. lstride: left stride
  182. rstride: right stride
  183. '''
  184. class DFSMN(nn.Module):
  185. def __init__(self, dimproj=64, dimlinear=128, lorder=20, rorder=1, lstride=1, rstride=1):
  186. super(DFSMN, self).__init__()
  187. self.lorder = lorder
  188. self.rorder = rorder
  189. self.lstride = lstride
  190. self.rstride = rstride
  191. self.expand = AffineTransform(dimproj, dimlinear)
  192. self.shrink = LinearTransform(dimlinear, dimproj)
  193. self.conv_left = nn.Conv2d(
  194. dimproj, dimproj, [lorder, 1], dilation=[lstride, 1], groups=dimproj, bias=False)
  195. if rorder > 0:
  196. self.conv_right = nn.Conv2d(
  197. dimproj, dimproj, [rorder, 1], dilation=[rstride, 1], groups=dimproj, bias=False)
  198. else:
  199. self.conv_right = None
  200. def forward(self, input):
  201. f1 = F.relu(self.expand(input))
  202. p1 = self.shrink(f1)
  203. x = torch.unsqueeze(p1, 1)
  204. x_per = x.permute(0, 3, 2, 1)
  205. y_left = F.pad(x_per, [0, 0, (self.lorder - 1) * self.lstride, 0])
  206. if self.conv_right is not None:
  207. y_right = F.pad(x_per, [0, 0, 0, (self.rorder) * self.rstride])
  208. y_right = y_right[:, :, self.rstride:, :]
  209. out = x_per + self.conv_left(y_left) + self.conv_right(y_right)
  210. else:
  211. out = x_per + self.conv_left(y_left)
  212. out1 = out.permute(0, 3, 2, 1)
  213. output = input + out1.squeeze(1)
  214. return output
  215. '''
  216. build stacked dfsmn layers
  217. '''
  218. def buildDFSMNRepeats(linear_dim=128, proj_dim=64, lorder=20, rorder=1, fsmn_layers=6):
  219. repeats = [
  220. nn.Sequential(
  221. DFSMN(proj_dim, linear_dim, lorder, rorder, 1, 1))
  222. for i in range(fsmn_layers)
  223. ]
  224. return nn.Sequential(*repeats)
  225. if __name__ == '__main__':
  226. fsmn = FSMN(400, 140, 4, 250, 128, 10, 2, 1, 1, 140, 2599)
  227. print(fsmn)
  228. num_params = sum(p.numel() for p in fsmn.parameters())
  229. print('the number of model params: {}'.format(num_params))
  230. x = torch.zeros(128, 200, 400) # batch-size * time * dim
  231. y, _ = fsmn(x) # batch-size * time * dim
  232. print('input shape: {}'.format(x.shape))
  233. print('output shape: {}'.format(y.shape))
  234. print(fsmn.to_kaldi_net())