fsmn_encoder.py 9.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300
  1. from typing import Tuple, Dict
  2. import copy
  3. import numpy as np
  4. import torch
  5. import torch.nn as nn
  6. import torch.nn.functional as F
  7. class LinearTransform(nn.Module):
  8. def __init__(self, input_dim, output_dim):
  9. super(LinearTransform, self).__init__()
  10. self.input_dim = input_dim
  11. self.output_dim = output_dim
  12. self.linear = nn.Linear(input_dim, output_dim, bias=False)
  13. def forward(self, input):
  14. output = self.linear(input)
  15. return output
  16. class AffineTransform(nn.Module):
  17. def __init__(self, input_dim, output_dim):
  18. super(AffineTransform, self).__init__()
  19. self.input_dim = input_dim
  20. self.output_dim = output_dim
  21. self.linear = nn.Linear(input_dim, output_dim)
  22. def forward(self, input):
  23. output = self.linear(input)
  24. return output
  25. class RectifiedLinear(nn.Module):
  26. def __init__(self, input_dim, output_dim):
  27. super(RectifiedLinear, self).__init__()
  28. self.dim = input_dim
  29. self.relu = nn.ReLU()
  30. self.dropout = nn.Dropout(0.1)
  31. def forward(self, input):
  32. out = self.relu(input)
  33. return out
  34. class FSMNBlock(nn.Module):
  35. def __init__(
  36. self,
  37. input_dim: int,
  38. output_dim: int,
  39. lorder=None,
  40. rorder=None,
  41. lstride=1,
  42. rstride=1,
  43. ):
  44. super(FSMNBlock, self).__init__()
  45. self.dim = input_dim
  46. if lorder is None:
  47. return
  48. self.lorder = lorder
  49. self.rorder = rorder
  50. self.lstride = lstride
  51. self.rstride = rstride
  52. self.conv_left = nn.Conv2d(
  53. self.dim, self.dim, [lorder, 1], dilation=[lstride, 1], groups=self.dim, bias=False)
  54. if self.rorder > 0:
  55. self.conv_right = nn.Conv2d(
  56. self.dim, self.dim, [rorder, 1], dilation=[rstride, 1], groups=self.dim, bias=False)
  57. else:
  58. self.conv_right = None
  59. def forward(self, input: torch.Tensor, cache: torch.Tensor):
  60. x = torch.unsqueeze(input, 1)
  61. x_per = x.permute(0, 3, 2, 1) # B D T C
  62. y_left = torch.cat((cache, x_per), dim=2)
  63. cache = y_left[:, :, -(self.lorder - 1) * self.lstride:, :]
  64. y_left = self.conv_left(y_left)
  65. out = x_per + y_left
  66. if self.conv_right is not None:
  67. # maybe need to check
  68. y_right = F.pad(x_per, [0, 0, 0, self.rorder * self.rstride])
  69. y_right = y_right[:, :, self.rstride:, :]
  70. y_right = self.conv_right(y_right)
  71. out += y_right
  72. out_per = out.permute(0, 3, 2, 1)
  73. output = out_per.squeeze(1)
  74. return output, cache
  75. class BasicBlock(nn.Sequential):
  76. def __init__(self,
  77. linear_dim: int,
  78. proj_dim: int,
  79. lorder: int,
  80. rorder: int,
  81. lstride: int,
  82. rstride: int,
  83. stack_layer: int
  84. ):
  85. super(BasicBlock, self).__init__()
  86. self.lorder = lorder
  87. self.rorder = rorder
  88. self.lstride = lstride
  89. self.rstride = rstride
  90. self.stack_layer = stack_layer
  91. self.linear = LinearTransform(linear_dim, proj_dim)
  92. self.fsmn_block = FSMNBlock(proj_dim, proj_dim, lorder, rorder, lstride, rstride)
  93. self.affine = AffineTransform(proj_dim, linear_dim)
  94. self.relu = RectifiedLinear(linear_dim, linear_dim)
  95. def forward(self, input: torch.Tensor, in_cache: Dict[str, torch.Tensor]):
  96. x1 = self.linear(input) # B T D
  97. cache_layer_name = 'cache_layer_{}'.format(self.stack_layer)
  98. if cache_layer_name not in in_cache:
  99. in_cache[cache_layer_name] = torch.zeros(x1.shape[0], x1.shape[-1], (self.lorder - 1) * self.lstride, 1)
  100. x2, in_cache[cache_layer_name] = self.fsmn_block(x1, in_cache[cache_layer_name])
  101. x3 = self.affine(x2)
  102. x4 = self.relu(x3)
  103. return x4
  104. class FsmnStack(nn.Sequential):
  105. def __init__(self, *args):
  106. super(FsmnStack, self).__init__(*args)
  107. def forward(self, input: torch.Tensor, in_cache: Dict[str, torch.Tensor]):
  108. x = input
  109. for module in self._modules.values():
  110. x = module(x, in_cache)
  111. return x
  112. '''
  113. FSMN net for keyword spotting
  114. input_dim: input dimension
  115. linear_dim: fsmn input dimensionll
  116. proj_dim: fsmn projection dimension
  117. lorder: fsmn left order
  118. rorder: fsmn right order
  119. num_syn: output dimension
  120. fsmn_layers: no. of sequential fsmn layers
  121. '''
  122. class FSMN(nn.Module):
  123. def __init__(
  124. self,
  125. input_dim: int,
  126. input_affine_dim: int,
  127. fsmn_layers: int,
  128. linear_dim: int,
  129. proj_dim: int,
  130. lorder: int,
  131. rorder: int,
  132. lstride: int,
  133. rstride: int,
  134. output_affine_dim: int,
  135. output_dim: int
  136. ):
  137. super(FSMN, self).__init__()
  138. self.input_dim = input_dim
  139. self.input_affine_dim = input_affine_dim
  140. self.fsmn_layers = fsmn_layers
  141. self.linear_dim = linear_dim
  142. self.proj_dim = proj_dim
  143. self.output_affine_dim = output_affine_dim
  144. self.output_dim = output_dim
  145. self.in_linear1 = AffineTransform(input_dim, input_affine_dim)
  146. self.in_linear2 = AffineTransform(input_affine_dim, linear_dim)
  147. self.relu = RectifiedLinear(linear_dim, linear_dim)
  148. self.fsmn = FsmnStack(*[BasicBlock(linear_dim, proj_dim, lorder, rorder, lstride, rstride, i) for i in
  149. range(fsmn_layers)])
  150. self.out_linear1 = AffineTransform(linear_dim, output_affine_dim)
  151. self.out_linear2 = AffineTransform(output_affine_dim, output_dim)
  152. self.softmax = nn.Softmax(dim=-1)
  153. def fuse_modules(self):
  154. pass
  155. def forward(
  156. self,
  157. input: torch.Tensor,
  158. in_cache: Dict[str, torch.Tensor]
  159. ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
  160. """
  161. Args:
  162. input (torch.Tensor): Input tensor (B, T, D)
  163. in_cache: when in_cache is not None, the forward is in streaming. The type of in_cache is a dict, egs,
  164. {'cache_layer_1': torch.Tensor(B, T1, D)}, T1 is equal to self.lorder. It is {} for the 1st frame
  165. """
  166. x1 = self.in_linear1(input)
  167. x2 = self.in_linear2(x1)
  168. x3 = self.relu(x2)
  169. x4 = self.fsmn(x3, in_cache) # self.in_cache will update automatically in self.fsmn
  170. x5 = self.out_linear1(x4)
  171. x6 = self.out_linear2(x5)
  172. x7 = self.softmax(x6)
  173. return x7
  174. '''
  175. one deep fsmn layer
  176. dimproj: projection dimension, input and output dimension of memory blocks
  177. dimlinear: dimension of mapping layer
  178. lorder: left order
  179. rorder: right order
  180. lstride: left stride
  181. rstride: right stride
  182. '''
  183. class DFSMN(nn.Module):
  184. def __init__(self, dimproj=64, dimlinear=128, lorder=20, rorder=1, lstride=1, rstride=1):
  185. super(DFSMN, self).__init__()
  186. self.lorder = lorder
  187. self.rorder = rorder
  188. self.lstride = lstride
  189. self.rstride = rstride
  190. self.expand = AffineTransform(dimproj, dimlinear)
  191. self.shrink = LinearTransform(dimlinear, dimproj)
  192. self.conv_left = nn.Conv2d(
  193. dimproj, dimproj, [lorder, 1], dilation=[lstride, 1], groups=dimproj, bias=False)
  194. if rorder > 0:
  195. self.conv_right = nn.Conv2d(
  196. dimproj, dimproj, [rorder, 1], dilation=[rstride, 1], groups=dimproj, bias=False)
  197. else:
  198. self.conv_right = None
  199. def forward(self, input):
  200. f1 = F.relu(self.expand(input))
  201. p1 = self.shrink(f1)
  202. x = torch.unsqueeze(p1, 1)
  203. x_per = x.permute(0, 3, 2, 1)
  204. y_left = F.pad(x_per, [0, 0, (self.lorder - 1) * self.lstride, 0])
  205. if self.conv_right is not None:
  206. y_right = F.pad(x_per, [0, 0, 0, (self.rorder) * self.rstride])
  207. y_right = y_right[:, :, self.rstride:, :]
  208. out = x_per + self.conv_left(y_left) + self.conv_right(y_right)
  209. else:
  210. out = x_per + self.conv_left(y_left)
  211. out1 = out.permute(0, 3, 2, 1)
  212. output = input + out1.squeeze(1)
  213. return output
  214. '''
  215. build stacked dfsmn layers
  216. '''
  217. def buildDFSMNRepeats(linear_dim=128, proj_dim=64, lorder=20, rorder=1, fsmn_layers=6):
  218. repeats = [
  219. nn.Sequential(
  220. DFSMN(proj_dim, linear_dim, lorder, rorder, 1, 1))
  221. for i in range(fsmn_layers)
  222. ]
  223. return nn.Sequential(*repeats)
  224. if __name__ == '__main__':
  225. fsmn = FSMN(400, 140, 4, 250, 128, 10, 2, 1, 1, 140, 2599)
  226. print(fsmn)
  227. num_params = sum(p.numel() for p in fsmn.parameters())
  228. print('the number of model params: {}'.format(num_params))
  229. x = torch.zeros(128, 200, 400) # batch-size * time * dim
  230. y, _ = fsmn(x) # batch-size * time * dim
  231. print('input shape: {}'.format(x.shape))
  232. print('output shape: {}'.format(y.shape))
  233. print(fsmn.to_kaldi_net())