encoder.py 9.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303
  1. from typing import Tuple, Dict
  2. import copy
  3. import numpy as np
  4. import torch
  5. import torch.nn as nn
  6. import torch.nn.functional as F
  7. from funasr.register import tables
  8. class LinearTransform(nn.Module):
  9. def __init__(self, input_dim, output_dim):
  10. super(LinearTransform, self).__init__()
  11. self.input_dim = input_dim
  12. self.output_dim = output_dim
  13. self.linear = nn.Linear(input_dim, output_dim, bias=False)
  14. def forward(self, input):
  15. output = self.linear(input)
  16. return output
  17. class AffineTransform(nn.Module):
  18. def __init__(self, input_dim, output_dim):
  19. super(AffineTransform, self).__init__()
  20. self.input_dim = input_dim
  21. self.output_dim = output_dim
  22. self.linear = nn.Linear(input_dim, output_dim)
  23. def forward(self, input):
  24. output = self.linear(input)
  25. return output
  26. class RectifiedLinear(nn.Module):
  27. def __init__(self, input_dim, output_dim):
  28. super(RectifiedLinear, self).__init__()
  29. self.dim = input_dim
  30. self.relu = nn.ReLU()
  31. self.dropout = nn.Dropout(0.1)
  32. def forward(self, input):
  33. out = self.relu(input)
  34. return out
  35. class FSMNBlock(nn.Module):
  36. def __init__(
  37. self,
  38. input_dim: int,
  39. output_dim: int,
  40. lorder=None,
  41. rorder=None,
  42. lstride=1,
  43. rstride=1,
  44. ):
  45. super(FSMNBlock, self).__init__()
  46. self.dim = input_dim
  47. if lorder is None:
  48. return
  49. self.lorder = lorder
  50. self.rorder = rorder
  51. self.lstride = lstride
  52. self.rstride = rstride
  53. self.conv_left = nn.Conv2d(
  54. self.dim, self.dim, [lorder, 1], dilation=[lstride, 1], groups=self.dim, bias=False)
  55. if self.rorder > 0:
  56. self.conv_right = nn.Conv2d(
  57. self.dim, self.dim, [rorder, 1], dilation=[rstride, 1], groups=self.dim, bias=False)
  58. else:
  59. self.conv_right = None
  60. def forward(self, input: torch.Tensor, cache: torch.Tensor):
  61. x = torch.unsqueeze(input, 1)
  62. x_per = x.permute(0, 3, 2, 1) # B D T C
  63. cache = cache.to(x_per.device)
  64. y_left = torch.cat((cache, x_per), dim=2)
  65. cache = y_left[:, :, -(self.lorder - 1) * self.lstride:, :]
  66. y_left = self.conv_left(y_left)
  67. out = x_per + y_left
  68. if self.conv_right is not None:
  69. # maybe need to check
  70. y_right = F.pad(x_per, [0, 0, 0, self.rorder * self.rstride])
  71. y_right = y_right[:, :, self.rstride:, :]
  72. y_right = self.conv_right(y_right)
  73. out += y_right
  74. out_per = out.permute(0, 3, 2, 1)
  75. output = out_per.squeeze(1)
  76. return output, cache
  77. class BasicBlock(nn.Module):
  78. def __init__(self,
  79. linear_dim: int,
  80. proj_dim: int,
  81. lorder: int,
  82. rorder: int,
  83. lstride: int,
  84. rstride: int,
  85. stack_layer: int
  86. ):
  87. super(BasicBlock, self).__init__()
  88. self.lorder = lorder
  89. self.rorder = rorder
  90. self.lstride = lstride
  91. self.rstride = rstride
  92. self.stack_layer = stack_layer
  93. self.linear = LinearTransform(linear_dim, proj_dim)
  94. self.fsmn_block = FSMNBlock(proj_dim, proj_dim, lorder, rorder, lstride, rstride)
  95. self.affine = AffineTransform(proj_dim, linear_dim)
  96. self.relu = RectifiedLinear(linear_dim, linear_dim)
  97. def forward(self, input: torch.Tensor, cache: Dict[str, torch.Tensor]):
  98. x1 = self.linear(input) # B T D
  99. cache_layer_name = 'cache_layer_{}'.format(self.stack_layer)
  100. if cache_layer_name not in cache:
  101. cache[cache_layer_name] = torch.zeros(x1.shape[0], x1.shape[-1], (self.lorder - 1) * self.lstride, 1)
  102. x2, cache[cache_layer_name] = self.fsmn_block(x1, cache[cache_layer_name])
  103. x3 = self.affine(x2)
  104. x4 = self.relu(x3)
  105. return x4
  106. class FsmnStack(nn.Sequential):
  107. def __init__(self, *args):
  108. super(FsmnStack, self).__init__(*args)
  109. def forward(self, input: torch.Tensor, cache: Dict[str, torch.Tensor]):
  110. x = input
  111. for module in self._modules.values():
  112. x = module(x, cache)
  113. return x
  114. '''
  115. FSMN net for keyword spotting
  116. input_dim: input dimension
  117. linear_dim: fsmn input dimensionll
  118. proj_dim: fsmn projection dimension
  119. lorder: fsmn left order
  120. rorder: fsmn right order
  121. num_syn: output dimension
  122. fsmn_layers: no. of sequential fsmn layers
  123. '''
  124. @tables.register("encoder_classes", "FSMN")
  125. class FSMN(nn.Module):
  126. def __init__(
  127. self,
  128. input_dim: int,
  129. input_affine_dim: int,
  130. fsmn_layers: int,
  131. linear_dim: int,
  132. proj_dim: int,
  133. lorder: int,
  134. rorder: int,
  135. lstride: int,
  136. rstride: int,
  137. output_affine_dim: int,
  138. output_dim: int
  139. ):
  140. super(FSMN, self).__init__()
  141. self.input_dim = input_dim
  142. self.input_affine_dim = input_affine_dim
  143. self.fsmn_layers = fsmn_layers
  144. self.linear_dim = linear_dim
  145. self.proj_dim = proj_dim
  146. self.output_affine_dim = output_affine_dim
  147. self.output_dim = output_dim
  148. self.in_linear1 = AffineTransform(input_dim, input_affine_dim)
  149. self.in_linear2 = AffineTransform(input_affine_dim, linear_dim)
  150. self.relu = RectifiedLinear(linear_dim, linear_dim)
  151. self.fsmn = FsmnStack(*[BasicBlock(linear_dim, proj_dim, lorder, rorder, lstride, rstride, i) for i in
  152. range(fsmn_layers)])
  153. self.out_linear1 = AffineTransform(linear_dim, output_affine_dim)
  154. self.out_linear2 = AffineTransform(output_affine_dim, output_dim)
  155. self.softmax = nn.Softmax(dim=-1)
  156. def fuse_modules(self):
  157. pass
  158. def forward(
  159. self,
  160. input: torch.Tensor,
  161. cache: Dict[str, torch.Tensor]
  162. ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
  163. """
  164. Args:
  165. input (torch.Tensor): Input tensor (B, T, D)
  166. cache: when cache is not None, the forward is in streaming. The type of cache is a dict, egs,
  167. {'cache_layer_1': torch.Tensor(B, T1, D)}, T1 is equal to self.lorder. It is {} for the 1st frame
  168. """
  169. x1 = self.in_linear1(input)
  170. x2 = self.in_linear2(x1)
  171. x3 = self.relu(x2)
  172. x4 = self.fsmn(x3, cache) # self.cache will update automatically in self.fsmn
  173. x5 = self.out_linear1(x4)
  174. x6 = self.out_linear2(x5)
  175. x7 = self.softmax(x6)
  176. return x7
  177. '''
  178. one deep fsmn layer
  179. dimproj: projection dimension, input and output dimension of memory blocks
  180. dimlinear: dimension of mapping layer
  181. lorder: left order
  182. rorder: right order
  183. lstride: left stride
  184. rstride: right stride
  185. '''
  186. @tables.register("encoder_classes", "DFSMN")
  187. class DFSMN(nn.Module):
  188. def __init__(self, dimproj=64, dimlinear=128, lorder=20, rorder=1, lstride=1, rstride=1):
  189. super(DFSMN, self).__init__()
  190. self.lorder = lorder
  191. self.rorder = rorder
  192. self.lstride = lstride
  193. self.rstride = rstride
  194. self.expand = AffineTransform(dimproj, dimlinear)
  195. self.shrink = LinearTransform(dimlinear, dimproj)
  196. self.conv_left = nn.Conv2d(
  197. dimproj, dimproj, [lorder, 1], dilation=[lstride, 1], groups=dimproj, bias=False)
  198. if rorder > 0:
  199. self.conv_right = nn.Conv2d(
  200. dimproj, dimproj, [rorder, 1], dilation=[rstride, 1], groups=dimproj, bias=False)
  201. else:
  202. self.conv_right = None
  203. def forward(self, input):
  204. f1 = F.relu(self.expand(input))
  205. p1 = self.shrink(f1)
  206. x = torch.unsqueeze(p1, 1)
  207. x_per = x.permute(0, 3, 2, 1)
  208. y_left = F.pad(x_per, [0, 0, (self.lorder - 1) * self.lstride, 0])
  209. if self.conv_right is not None:
  210. y_right = F.pad(x_per, [0, 0, 0, (self.rorder) * self.rstride])
  211. y_right = y_right[:, :, self.rstride:, :]
  212. out = x_per + self.conv_left(y_left) + self.conv_right(y_right)
  213. else:
  214. out = x_per + self.conv_left(y_left)
  215. out1 = out.permute(0, 3, 2, 1)
  216. output = input + out1.squeeze(1)
  217. return output
  218. '''
  219. build stacked dfsmn layers
  220. '''
  221. def buildDFSMNRepeats(linear_dim=128, proj_dim=64, lorder=20, rorder=1, fsmn_layers=6):
  222. repeats = [
  223. nn.Sequential(
  224. DFSMN(proj_dim, linear_dim, lorder, rorder, 1, 1))
  225. for i in range(fsmn_layers)
  226. ]
  227. return nn.Sequential(*repeats)
  228. if __name__ == '__main__':
  229. fsmn = FSMN(400, 140, 4, 250, 128, 10, 2, 1, 1, 140, 2599)
  230. print(fsmn)
  231. num_params = sum(p.numel() for p in fsmn.parameters())
  232. print('the number of model params: {}'.format(num_params))
  233. x = torch.zeros(128, 200, 400) # batch-size * time * dim
  234. y, _ = fsmn(x) # batch-size * time * dim
  235. print('input shape: {}'.format(x.shape))
  236. print('output shape: {}'.format(y.shape))
  237. print(fsmn.to_kaldi_net())