fsmn_encoder.py 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296
  1. from typing import Tuple, Dict
  2. import copy
  3. import numpy as np
  4. import torch
  5. import torch.nn as nn
  6. import torch.nn.functional as F
  7. from funasr.models.encoder.fsmn_encoder import BasicBlock
  8. class LinearTransform(nn.Module):
  9. def __init__(self, input_dim, output_dim):
  10. super(LinearTransform, self).__init__()
  11. self.input_dim = input_dim
  12. self.output_dim = output_dim
  13. self.linear = nn.Linear(input_dim, output_dim, bias=False)
  14. def forward(self, input):
  15. output = self.linear(input)
  16. return output
  17. class AffineTransform(nn.Module):
  18. def __init__(self, input_dim, output_dim):
  19. super(AffineTransform, self).__init__()
  20. self.input_dim = input_dim
  21. self.output_dim = output_dim
  22. self.linear = nn.Linear(input_dim, output_dim)
  23. def forward(self, input):
  24. output = self.linear(input)
  25. return output
  26. class RectifiedLinear(nn.Module):
  27. def __init__(self, input_dim, output_dim):
  28. super(RectifiedLinear, self).__init__()
  29. self.dim = input_dim
  30. self.relu = nn.ReLU()
  31. self.dropout = nn.Dropout(0.1)
  32. def forward(self, input):
  33. out = self.relu(input)
  34. return out
  35. class FSMNBlock(nn.Module):
  36. def __init__(
  37. self,
  38. input_dim: int,
  39. output_dim: int,
  40. lorder=None,
  41. rorder=None,
  42. lstride=1,
  43. rstride=1,
  44. ):
  45. super(FSMNBlock, self).__init__()
  46. self.dim = input_dim
  47. if lorder is None:
  48. return
  49. self.lorder = lorder
  50. self.rorder = rorder
  51. self.lstride = lstride
  52. self.rstride = rstride
  53. self.conv_left = nn.Conv2d(
  54. self.dim, self.dim, [lorder, 1], dilation=[lstride, 1], groups=self.dim, bias=False)
  55. if self.rorder > 0:
  56. self.conv_right = nn.Conv2d(
  57. self.dim, self.dim, [rorder, 1], dilation=[rstride, 1], groups=self.dim, bias=False)
  58. else:
  59. self.conv_right = None
  60. def forward(self, input: torch.Tensor, cache: torch.Tensor):
  61. x = torch.unsqueeze(input, 1)
  62. x_per = x.permute(0, 3, 2, 1) # B D T C
  63. cache = cache.to(x_per.device)
  64. y_left = torch.cat((cache, x_per), dim=2)
  65. cache = y_left[:, :, -(self.lorder - 1) * self.lstride:, :]
  66. y_left = self.conv_left(y_left)
  67. out = x_per + y_left
  68. if self.conv_right is not None:
  69. # maybe need to check
  70. y_right = F.pad(x_per, [0, 0, 0, self.rorder * self.rstride])
  71. y_right = y_right[:, :, self.rstride:, :]
  72. y_right = self.conv_right(y_right)
  73. out += y_right
  74. out_per = out.permute(0, 3, 2, 1)
  75. output = out_per.squeeze(1)
  76. return output, cache
  77. class BasicBlock_export(nn.Module):
  78. def __init__(self,
  79. model,
  80. ):
  81. super(BasicBlock_export, self).__init__()
  82. self.linear = model.linear
  83. self.fsmn_block = model.fsmn_block
  84. self.affine = model.affine
  85. self.relu = model.relu
  86. def forward(self, input: torch.Tensor, in_cache: torch.Tensor):
  87. x = self.linear(input) # B T D
  88. # cache_layer_name = 'cache_layer_{}'.format(self.stack_layer)
  89. # if cache_layer_name not in in_cache:
  90. # in_cache[cache_layer_name] = torch.zeros(x1.shape[0], x1.shape[-1], (self.lorder - 1) * self.lstride, 1)
  91. x, out_cache = self.fsmn_block(x, in_cache)
  92. x = self.affine(x)
  93. x = self.relu(x)
  94. return x, out_cache
  95. # class FsmnStack(nn.Sequential):
  96. # def __init__(self, *args):
  97. # super(FsmnStack, self).__init__(*args)
  98. #
  99. # def forward(self, input: torch.Tensor, in_cache: Dict[str, torch.Tensor]):
  100. # x = input
  101. # for module in self._modules.values():
  102. # x = module(x, in_cache)
  103. # return x
  104. '''
  105. FSMN net for keyword spotting
  106. input_dim: input dimension
  107. linear_dim: fsmn input dimensionll
  108. proj_dim: fsmn projection dimension
  109. lorder: fsmn left order
  110. rorder: fsmn right order
  111. num_syn: output dimension
  112. fsmn_layers: no. of sequential fsmn layers
  113. '''
  114. class FSMN(nn.Module):
  115. def __init__(
  116. self, model,
  117. ):
  118. super(FSMN, self).__init__()
  119. # self.input_dim = input_dim
  120. # self.input_affine_dim = input_affine_dim
  121. # self.fsmn_layers = fsmn_layers
  122. # self.linear_dim = linear_dim
  123. # self.proj_dim = proj_dim
  124. # self.output_affine_dim = output_affine_dim
  125. # self.output_dim = output_dim
  126. #
  127. # self.in_linear1 = AffineTransform(input_dim, input_affine_dim)
  128. # self.in_linear2 = AffineTransform(input_affine_dim, linear_dim)
  129. # self.relu = RectifiedLinear(linear_dim, linear_dim)
  130. # self.fsmn = FsmnStack(*[BasicBlock(linear_dim, proj_dim, lorder, rorder, lstride, rstride, i) for i in
  131. # range(fsmn_layers)])
  132. # self.out_linear1 = AffineTransform(linear_dim, output_affine_dim)
  133. # self.out_linear2 = AffineTransform(output_affine_dim, output_dim)
  134. # self.softmax = nn.Softmax(dim=-1)
  135. self.in_linear1 = model.in_linear1
  136. self.in_linear2 = model.in_linear2
  137. self.relu = model.relu
  138. # self.fsmn = model.fsmn
  139. self.out_linear1 = model.out_linear1
  140. self.out_linear2 = model.out_linear2
  141. self.softmax = model.softmax
  142. self.fsmn = model.fsmn
  143. for i, d in enumerate(model.fsmn):
  144. if isinstance(d, BasicBlock):
  145. self.fsmn[i] = BasicBlock_export(d)
  146. def fuse_modules(self):
  147. pass
  148. def forward(
  149. self,
  150. input: torch.Tensor,
  151. *args,
  152. ):
  153. """
  154. Args:
  155. input (torch.Tensor): Input tensor (B, T, D)
  156. in_cache: when in_cache is not None, the forward is in streaming. The type of in_cache is a dict, egs,
  157. {'cache_layer_1': torch.Tensor(B, T1, D)}, T1 is equal to self.lorder. It is {} for the 1st frame
  158. """
  159. x = self.in_linear1(input)
  160. x = self.in_linear2(x)
  161. x = self.relu(x)
  162. # x4 = self.fsmn(x3, in_cache) # self.in_cache will update automatically in self.fsmn
  163. out_caches = list()
  164. for i, d in enumerate(self.fsmn):
  165. in_cache = args[i]
  166. x, out_cache = d(x, in_cache)
  167. out_caches.append(out_cache)
  168. x = self.out_linear1(x)
  169. x = self.out_linear2(x)
  170. x = self.softmax(x)
  171. return x, out_caches
  172. '''
  173. one deep fsmn layer
  174. dimproj: projection dimension, input and output dimension of memory blocks
  175. dimlinear: dimension of mapping layer
  176. lorder: left order
  177. rorder: right order
  178. lstride: left stride
  179. rstride: right stride
  180. '''
  181. class DFSMN(nn.Module):
  182. def __init__(self, dimproj=64, dimlinear=128, lorder=20, rorder=1, lstride=1, rstride=1):
  183. super(DFSMN, self).__init__()
  184. self.lorder = lorder
  185. self.rorder = rorder
  186. self.lstride = lstride
  187. self.rstride = rstride
  188. self.expand = AffineTransform(dimproj, dimlinear)
  189. self.shrink = LinearTransform(dimlinear, dimproj)
  190. self.conv_left = nn.Conv2d(
  191. dimproj, dimproj, [lorder, 1], dilation=[lstride, 1], groups=dimproj, bias=False)
  192. if rorder > 0:
  193. self.conv_right = nn.Conv2d(
  194. dimproj, dimproj, [rorder, 1], dilation=[rstride, 1], groups=dimproj, bias=False)
  195. else:
  196. self.conv_right = None
  197. def forward(self, input):
  198. f1 = F.relu(self.expand(input))
  199. p1 = self.shrink(f1)
  200. x = torch.unsqueeze(p1, 1)
  201. x_per = x.permute(0, 3, 2, 1)
  202. y_left = F.pad(x_per, [0, 0, (self.lorder - 1) * self.lstride, 0])
  203. if self.conv_right is not None:
  204. y_right = F.pad(x_per, [0, 0, 0, (self.rorder) * self.rstride])
  205. y_right = y_right[:, :, self.rstride:, :]
  206. out = x_per + self.conv_left(y_left) + self.conv_right(y_right)
  207. else:
  208. out = x_per + self.conv_left(y_left)
  209. out1 = out.permute(0, 3, 2, 1)
  210. output = input + out1.squeeze(1)
  211. return output
  212. '''
  213. build stacked dfsmn layers
  214. '''
  215. def buildDFSMNRepeats(linear_dim=128, proj_dim=64, lorder=20, rorder=1, fsmn_layers=6):
  216. repeats = [
  217. nn.Sequential(
  218. DFSMN(proj_dim, linear_dim, lorder, rorder, 1, 1))
  219. for i in range(fsmn_layers)
  220. ]
  221. return nn.Sequential(*repeats)
  222. if __name__ == '__main__':
  223. fsmn = FSMN(400, 140, 4, 250, 128, 10, 2, 1, 1, 140, 2599)
  224. print(fsmn)
  225. num_params = sum(p.numel() for p in fsmn.parameters())
  226. print('the number of model params: {}'.format(num_params))
  227. x = torch.zeros(128, 200, 400) # batch-size * time * dim
  228. y, _ = fsmn(x) # batch-size * time * dim
  229. print('input shape: {}'.format(x.shape))
  230. print('output shape: {}'.format(y.shape))
  231. print(fsmn.to_kaldi_net())