fsmn_encoder.py 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310
  1. from typing import Tuple, Dict
  2. import copy
  3. import numpy as np
  4. import torch
  5. import torch.nn as nn
  6. import torch.nn.functional as F
  7. class LinearTransform(nn.Module):
  8. def __init__(self, input_dim, output_dim):
  9. super(LinearTransform, self).__init__()
  10. self.input_dim = input_dim
  11. self.output_dim = output_dim
  12. self.linear = nn.Linear(input_dim, output_dim, bias=False)
  13. def forward(self, input):
  14. output = self.linear(input)
  15. return output
  16. class AffineTransform(nn.Module):
  17. def __init__(self, input_dim, output_dim):
  18. super(AffineTransform, self).__init__()
  19. self.input_dim = input_dim
  20. self.output_dim = output_dim
  21. self.linear = nn.Linear(input_dim, output_dim)
  22. def forward(self, input):
  23. output = self.linear(input)
  24. return output
  25. class RectifiedLinear(nn.Module):
  26. def __init__(self, input_dim, output_dim):
  27. super(RectifiedLinear, self).__init__()
  28. self.dim = input_dim
  29. self.relu = nn.ReLU()
  30. self.dropout = nn.Dropout(0.1)
  31. def forward(self, input):
  32. out = self.relu(input)
  33. return out
  34. class FSMNBlock(nn.Module):
  35. def __init__(
  36. self,
  37. input_dim: int,
  38. output_dim: int,
  39. lorder=None,
  40. rorder=None,
  41. lstride=1,
  42. rstride=1,
  43. ):
  44. super(FSMNBlock, self).__init__()
  45. self.dim = input_dim
  46. if lorder is None:
  47. return
  48. self.lorder = lorder
  49. self.rorder = rorder
  50. self.lstride = lstride
  51. self.rstride = rstride
  52. self.conv_left = nn.Conv2d(
  53. self.dim, self.dim, [lorder, 1], dilation=[lstride, 1], groups=self.dim, bias=False)
  54. if self.rorder > 0:
  55. self.conv_right = nn.Conv2d(
  56. self.dim, self.dim, [rorder, 1], dilation=[rstride, 1], groups=self.dim, bias=False)
  57. else:
  58. self.conv_right = None
  59. def forward(self, input: torch.Tensor, in_cache=None):
  60. x = torch.unsqueeze(input, 1)
  61. x_per = x.permute(0, 3, 2, 1) # B D T C
  62. if in_cache is None: # offline
  63. y_left = F.pad(x_per, [0, 0, (self.lorder - 1) * self.lstride, 0])
  64. else:
  65. y_left = torch.cat((in_cache, x_per), dim=2)
  66. in_cache = y_left[:, :, -(self.lorder - 1) * self.lstride:, :]
  67. y_left = self.conv_left(y_left)
  68. out = x_per + y_left
  69. if self.conv_right is not None:
  70. # maybe need to check
  71. y_right = F.pad(x_per, [0, 0, 0, self.rorder * self.rstride])
  72. y_right = y_right[:, :, self.rstride:, :]
  73. y_right = self.conv_right(y_right)
  74. out += y_right
  75. out_per = out.permute(0, 3, 2, 1)
  76. output = out_per.squeeze(1)
  77. return output, in_cache
  78. class BasicBlock(nn.Sequential):
  79. def __init__(self,
  80. linear_dim: int,
  81. proj_dim: int,
  82. lorder: int,
  83. rorder: int,
  84. lstride: int,
  85. rstride: int,
  86. stack_layer: int
  87. ):
  88. super(BasicBlock, self).__init__()
  89. self.lorder = lorder
  90. self.rorder = rorder
  91. self.lstride = lstride
  92. self.rstride = rstride
  93. self.stack_layer = stack_layer
  94. self.linear = LinearTransform(linear_dim, proj_dim)
  95. self.fsmn_block = FSMNBlock(proj_dim, proj_dim, lorder, rorder, lstride, rstride)
  96. self.affine = AffineTransform(proj_dim, linear_dim)
  97. self.relu = RectifiedLinear(linear_dim, linear_dim)
  98. def forward(self, input: torch.Tensor, in_cache=None):
  99. x1 = self.linear(input) # B T D
  100. if in_cache is not None: # Dict[str, tensor.Tensor]
  101. cache_layer_name = 'cache_layer_{}'.format(self.stack_layer)
  102. if cache_layer_name not in in_cache:
  103. in_cache[cache_layer_name] = torch.zeros(x1.shape[0], x1.shape[-1], (self.lorder - 1) * self.lstride, 1)
  104. x2, in_cache[cache_layer_name] = self.fsmn_block(x1, in_cache[cache_layer_name])
  105. else:
  106. x2, _ = self.fsmn_block(x1)
  107. x3 = self.affine(x2)
  108. x4 = self.relu(x3)
  109. return x4, in_cache
  110. class FsmnStack(nn.Sequential):
  111. def __init__(self, *args):
  112. super(FsmnStack, self).__init__(*args)
  113. def forward(self, input: torch.Tensor, in_cache=None):
  114. x = input
  115. for module in self._modules.values():
  116. x, in_cache = module(x, in_cache)
  117. return x
  118. '''
  119. FSMN net for keyword spotting
  120. input_dim: input dimension
  121. linear_dim: fsmn input dimensionll
  122. proj_dim: fsmn projection dimension
  123. lorder: fsmn left order
  124. rorder: fsmn right order
  125. num_syn: output dimension
  126. fsmn_layers: no. of sequential fsmn layers
  127. '''
  128. class FSMN(nn.Module):
  129. def __init__(
  130. self,
  131. input_dim: int,
  132. input_affine_dim: int,
  133. fsmn_layers: int,
  134. linear_dim: int,
  135. proj_dim: int,
  136. lorder: int,
  137. rorder: int,
  138. lstride: int,
  139. rstride: int,
  140. output_affine_dim: int,
  141. output_dim: int,
  142. streaming=False
  143. ):
  144. super(FSMN, self).__init__()
  145. self.input_dim = input_dim
  146. self.input_affine_dim = input_affine_dim
  147. self.fsmn_layers = fsmn_layers
  148. self.linear_dim = linear_dim
  149. self.proj_dim = proj_dim
  150. self.output_affine_dim = output_affine_dim
  151. self.output_dim = output_dim
  152. self.in_cache_original = dict() if streaming else None
  153. self.in_cache = copy.deepcopy(self.in_cache_original)
  154. self.in_linear1 = AffineTransform(input_dim, input_affine_dim)
  155. self.in_linear2 = AffineTransform(input_affine_dim, linear_dim)
  156. self.relu = RectifiedLinear(linear_dim, linear_dim)
  157. self.fsmn = FsmnStack(*[BasicBlock(linear_dim, proj_dim, lorder, rorder, lstride, rstride, i) for i in
  158. range(fsmn_layers)])
  159. self.out_linear1 = AffineTransform(linear_dim, output_affine_dim)
  160. self.out_linear2 = AffineTransform(output_affine_dim, output_dim)
  161. self.softmax = nn.Softmax(dim=-1)
  162. def fuse_modules(self):
  163. pass
  164. def cache_reset(self):
  165. self.in_cache = copy.deepcopy(self.in_cache_original)
  166. def forward(
  167. self,
  168. input: torch.Tensor,
  169. ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
  170. """
  171. Args:
  172. input (torch.Tensor): Input tensor (B, T, D)
  173. in_cache: when in_cache is not None, the forward is in streaming. The type of in_cache is a dict, egs,
  174. {'cache_layer_1': torch.Tensor(B, T1, D)}, T1 is equal to self.lorder. It is {} for the 1st frame
  175. """
  176. x1 = self.in_linear1(input)
  177. x2 = self.in_linear2(x1)
  178. x3 = self.relu(x2)
  179. x4 = self.fsmn(x3, self.in_cache) # if in_cache is not None, self.fsmn is streaming's format, it will update automatically in self.fsmn
  180. x5 = self.out_linear1(x4)
  181. x6 = self.out_linear2(x5)
  182. x7 = self.softmax(x6)
  183. return x7
  184. '''
  185. one deep fsmn layer
  186. dimproj: projection dimension, input and output dimension of memory blocks
  187. dimlinear: dimension of mapping layer
  188. lorder: left order
  189. rorder: right order
  190. lstride: left stride
  191. rstride: right stride
  192. '''
  193. class DFSMN(nn.Module):
  194. def __init__(self, dimproj=64, dimlinear=128, lorder=20, rorder=1, lstride=1, rstride=1):
  195. super(DFSMN, self).__init__()
  196. self.lorder = lorder
  197. self.rorder = rorder
  198. self.lstride = lstride
  199. self.rstride = rstride
  200. self.expand = AffineTransform(dimproj, dimlinear)
  201. self.shrink = LinearTransform(dimlinear, dimproj)
  202. self.conv_left = nn.Conv2d(
  203. dimproj, dimproj, [lorder, 1], dilation=[lstride, 1], groups=dimproj, bias=False)
  204. if rorder > 0:
  205. self.conv_right = nn.Conv2d(
  206. dimproj, dimproj, [rorder, 1], dilation=[rstride, 1], groups=dimproj, bias=False)
  207. else:
  208. self.conv_right = None
  209. def forward(self, input):
  210. f1 = F.relu(self.expand(input))
  211. p1 = self.shrink(f1)
  212. x = torch.unsqueeze(p1, 1)
  213. x_per = x.permute(0, 3, 2, 1)
  214. y_left = F.pad(x_per, [0, 0, (self.lorder - 1) * self.lstride, 0])
  215. if self.conv_right is not None:
  216. y_right = F.pad(x_per, [0, 0, 0, (self.rorder) * self.rstride])
  217. y_right = y_right[:, :, self.rstride:, :]
  218. out = x_per + self.conv_left(y_left) + self.conv_right(y_right)
  219. else:
  220. out = x_per + self.conv_left(y_left)
  221. out1 = out.permute(0, 3, 2, 1)
  222. output = input + out1.squeeze(1)
  223. return output
  224. '''
  225. build stacked dfsmn layers
  226. '''
  227. def buildDFSMNRepeats(linear_dim=128, proj_dim=64, lorder=20, rorder=1, fsmn_layers=6):
  228. repeats = [
  229. nn.Sequential(
  230. DFSMN(proj_dim, linear_dim, lorder, rorder, 1, 1))
  231. for i in range(fsmn_layers)
  232. ]
  233. return nn.Sequential(*repeats)
  234. if __name__ == '__main__':
  235. fsmn = FSMN(400, 140, 4, 250, 128, 10, 2, 1, 1, 140, 2599)
  236. print(fsmn)
  237. num_params = sum(p.numel() for p in fsmn.parameters())
  238. print('the number of model params: {}'.format(num_params))
  239. x = torch.zeros(128, 200, 400) # batch-size * time * dim
  240. y, _ = fsmn(x) # batch-size * time * dim
  241. print('input shape: {}'.format(x.shape))
  242. print('output shape: {}'.format(y.shape))
  243. print(fsmn.to_kaldi_net())