mossformer_encoder.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. from rotary_embedding_torch import RotaryEmbedding
  5. from funasr.modules.layer_norm import GlobalLayerNorm, CumulativeLayerNorm, ScaleNorm
  6. from funasr.modules.embedding import ScaledSinuEmbedding
  7. from funasr.modules.mossformer import FLASH_ShareA_FFConvM
  8. def select_norm(norm, dim, shape):
  9. """Just a wrapper to select the normalization type.
  10. """
  11. if norm == "gln":
  12. return GlobalLayerNorm(dim, shape, elementwise_affine=True)
  13. if norm == "cln":
  14. return CumulativeLayerNorm(dim, elementwise_affine=True)
  15. if norm == "ln":
  16. return nn.GroupNorm(1, dim, eps=1e-8)
  17. else:
  18. return nn.BatchNorm1d(dim)
  19. class MossformerBlock(nn.Module):
  20. def __init__(
  21. self,
  22. *,
  23. dim,
  24. depth,
  25. group_size = 256,
  26. query_key_dim = 128,
  27. expansion_factor = 4.,
  28. causal = False,
  29. attn_dropout = 0.1,
  30. norm_type = 'scalenorm',
  31. shift_tokens = True
  32. ):
  33. super().__init__()
  34. assert norm_type in ('scalenorm', 'layernorm'), 'norm_type must be one of scalenorm or layernorm'
  35. if norm_type == 'scalenorm':
  36. norm_klass = ScaleNorm
  37. elif norm_type == 'layernorm':
  38. norm_klass = nn.LayerNorm
  39. self.group_size = group_size
  40. rotary_pos_emb = RotaryEmbedding(dim = min(32, query_key_dim))
  41. # max rotary embedding dimensions of 32, partial Rotary embeddings, from Wang et al - GPT-J
  42. self.layers = nn.ModuleList([FLASH_ShareA_FFConvM(dim = dim, group_size = group_size, query_key_dim = query_key_dim, expansion_factor = expansion_factor, causal = causal, dropout = attn_dropout, rotary_pos_emb = rotary_pos_emb, norm_klass = norm_klass, shift_tokens = shift_tokens) for _ in range(depth)])
  43. def forward(
  44. self,
  45. x,
  46. *,
  47. mask = None
  48. ):
  49. ii = 0
  50. for flash in self.layers:
  51. x = flash(x, mask = mask)
  52. ii = ii + 1
  53. return x
  54. class MossFormer_MaskNet(nn.Module):
  55. """The MossFormer module for computing output masks.
  56. Arguments
  57. ---------
  58. in_channels : int
  59. Number of channels at the output of the encoder.
  60. out_channels : int
  61. Number of channels that would be inputted to the intra and inter blocks.
  62. num_blocks : int
  63. Number of layers of Dual Computation Block.
  64. norm : str
  65. Normalization type.
  66. num_spks : int
  67. Number of sources (speakers).
  68. skip_around_intra : bool
  69. Skip connection around intra.
  70. use_global_pos_enc : bool
  71. Global positional encodings.
  72. max_length : int
  73. Maximum sequence length.
  74. Example
  75. ---------
  76. >>> mossformer_block = MossFormerM(1, 64, 8)
  77. >>> mossformer_masknet = MossFormer_MaskNet(64, 64, intra_block, num_spks=2)
  78. >>> x = torch.randn(10, 64, 2000)
  79. >>> x = mossformer_masknet(x)
  80. >>> x.shape
  81. torch.Size([2, 10, 64, 2000])
  82. """
  83. def __init__(
  84. self,
  85. in_channels,
  86. out_channels,
  87. num_blocks=24,
  88. norm="ln",
  89. num_spks=2,
  90. skip_around_intra=True,
  91. use_global_pos_enc=True,
  92. max_length=20000,
  93. ):
  94. super(MossFormer_MaskNet, self).__init__()
  95. self.num_spks = num_spks
  96. self.num_blocks = num_blocks
  97. self.norm = select_norm(norm, in_channels, 3)
  98. self.conv1d_encoder = nn.Conv1d(in_channels, out_channels, 1, bias=False)
  99. self.use_global_pos_enc = use_global_pos_enc
  100. if self.use_global_pos_enc:
  101. self.pos_enc = ScaledSinuEmbedding(out_channels)
  102. self.mdl = Computation_Block(
  103. num_blocks,
  104. out_channels,
  105. norm,
  106. skip_around_intra=skip_around_intra,
  107. )
  108. self.conv1d_out = nn.Conv1d(
  109. out_channels, out_channels * num_spks, kernel_size=1
  110. )
  111. self.conv1_decoder = nn.Conv1d(out_channels, in_channels, 1, bias=False)
  112. self.prelu = nn.PReLU()
  113. self.activation = nn.ReLU()
  114. # gated output layer
  115. self.output = nn.Sequential(
  116. nn.Conv1d(out_channels, out_channels, 1), nn.Tanh()
  117. )
  118. self.output_gate = nn.Sequential(
  119. nn.Conv1d(out_channels, out_channels, 1), nn.Sigmoid()
  120. )
  121. def forward(self, x):
  122. """Returns the output tensor.
  123. Arguments
  124. ---------
  125. x : torch.Tensor
  126. Input tensor of dimension [B, N, S].
  127. Returns
  128. -------
  129. out : torch.Tensor
  130. Output tensor of dimension [spks, B, N, S]
  131. where, spks = Number of speakers
  132. B = Batchsize,
  133. N = number of filters
  134. S = the number of time frames
  135. """
  136. # before each line we indicate the shape after executing the line
  137. # [B, N, L]
  138. x = self.norm(x)
  139. # [B, N, L]
  140. x = self.conv1d_encoder(x)
  141. if self.use_global_pos_enc:
  142. #x = self.pos_enc(x.transpose(1, -1)).transpose(1, -1) + x * (
  143. # x.size(1) ** 0.5)
  144. base = x
  145. x = x.transpose(1, -1)
  146. emb = self.pos_enc(x)
  147. emb = emb.transpose(0, -1)
  148. #print('base: {}, emb: {}'.format(base.shape, emb.shape))
  149. x = base + emb
  150. # [B, N, S]
  151. #for i in range(self.num_modules):
  152. # x = self.dual_mdl[i](x)
  153. x = self.mdl(x)
  154. x = self.prelu(x)
  155. # [B, N*spks, S]
  156. x = self.conv1d_out(x)
  157. B, _, S = x.shape
  158. # [B*spks, N, S]
  159. x = x.view(B * self.num_spks, -1, S)
  160. # [B*spks, N, S]
  161. x = self.output(x) * self.output_gate(x)
  162. # [B*spks, N, S]
  163. x = self.conv1_decoder(x)
  164. # [B, spks, N, S]
  165. _, N, L = x.shape
  166. x = x.view(B, self.num_spks, N, L)
  167. x = self.activation(x)
  168. # [spks, B, N, S]
  169. x = x.transpose(0, 1)
  170. return x
  171. class MossFormerEncoder(nn.Module):
  172. """Convolutional Encoder Layer.
  173. Arguments
  174. ---------
  175. kernel_size : int
  176. Length of filters.
  177. in_channels : int
  178. Number of input channels.
  179. out_channels : int
  180. Number of output channels.
  181. Example
  182. -------
  183. >>> x = torch.randn(2, 1000)
  184. >>> encoder = Encoder(kernel_size=4, out_channels=64)
  185. >>> h = encoder(x)
  186. >>> h.shape
  187. torch.Size([2, 64, 499])
  188. """
  189. def __init__(self, kernel_size=2, out_channels=64, in_channels=1):
  190. super(MossFormerEncoder, self).__init__()
  191. self.conv1d = nn.Conv1d(
  192. in_channels=in_channels,
  193. out_channels=out_channels,
  194. kernel_size=kernel_size,
  195. stride=kernel_size // 2,
  196. groups=1,
  197. bias=False,
  198. )
  199. self.in_channels = in_channels
  200. def forward(self, x):
  201. """Return the encoded output.
  202. Arguments
  203. ---------
  204. x : torch.Tensor
  205. Input tensor with dimensionality [B, L].
  206. Return
  207. ------
  208. x : torch.Tensor
  209. Encoded tensor with dimensionality [B, N, T_out].
  210. where B = Batchsize
  211. L = Number of timepoints
  212. N = Number of filters
  213. T_out = Number of timepoints at the output of the encoder
  214. """
  215. # B x L -> B x 1 x L
  216. if self.in_channels == 1:
  217. x = torch.unsqueeze(x, dim=1)
  218. # B x 1 x L -> B x N x T_out
  219. x = self.conv1d(x)
  220. x = F.relu(x)
  221. return x
  222. class MossFormerM(nn.Module):
  223. """This class implements the transformer encoder.
  224. Arguments
  225. ---------
  226. num_blocks : int
  227. Number of mossformer blocks to include.
  228. d_model : int
  229. The dimension of the input embedding.
  230. attn_dropout : float
  231. Dropout for the self-attention (Optional).
  232. group_size: int
  233. the chunk size
  234. query_key_dim: int
  235. the attention vector dimension
  236. expansion_factor: int
  237. the expansion factor for the linear projection in conv module
  238. causal: bool
  239. true for causal / false for non causal
  240. Example
  241. -------
  242. >>> import torch
  243. >>> x = torch.rand((8, 60, 512))
  244. >>> net = TransformerEncoder_MossFormerM(num_blocks=8, d_model=512)
  245. >>> output, _ = net(x)
  246. >>> output.shape
  247. torch.Size([8, 60, 512])
  248. """
  249. def __init__(
  250. self,
  251. num_blocks,
  252. d_model=None,
  253. causal=False,
  254. group_size = 256,
  255. query_key_dim = 128,
  256. expansion_factor = 4.,
  257. attn_dropout = 0.1
  258. ):
  259. super().__init__()
  260. self.mossformerM = MossformerBlock(
  261. dim=d_model,
  262. depth=num_blocks,
  263. group_size=group_size,
  264. query_key_dim=query_key_dim,
  265. expansion_factor=expansion_factor,
  266. causal=causal,
  267. attn_dropout=attn_dropout
  268. )
  269. self.norm = nn.LayerNorm(d_model, eps=1e-6)
  270. def forward(
  271. self,
  272. src,
  273. ):
  274. """
  275. Arguments
  276. ----------
  277. src : torch.Tensor
  278. Tensor shape [B, L, N],
  279. where, B = Batchsize,
  280. L = time points
  281. N = number of filters
  282. The sequence to the encoder layer (required).
  283. src_mask : tensor
  284. The mask for the src sequence (optional).
  285. src_key_padding_mask : tensor
  286. The mask for the src keys per batch (optional).
  287. """
  288. output = self.mossformerM(src)
  289. output = self.norm(output)
  290. return output
  291. class Computation_Block(nn.Module):
  292. """Computation block for dual-path processing.
  293. Arguments
  294. ---------
  295. out_channels : int
  296. Dimensionality of inter/intra model.
  297. norm : str
  298. Normalization type.
  299. skip_around_intra : bool
  300. Skip connection around the intra layer.
  301. Example
  302. ---------
  303. >>> comp_block = Computation_Block(64)
  304. >>> x = torch.randn(10, 64, 100)
  305. >>> x = comp_block(x)
  306. >>> x.shape
  307. torch.Size([10, 64, 100])
  308. """
  309. def __init__(
  310. self,
  311. num_blocks,
  312. out_channels,
  313. norm="ln",
  314. skip_around_intra=True,
  315. ):
  316. super(Computation_Block, self).__init__()
  317. ##MossFormer2M: MossFormer with recurrence
  318. #self.intra_mdl = MossFormer2M(num_blocks=num_blocks, d_model=out_channels)
  319. ##MossFormerM: the orignal MossFormer
  320. self.intra_mdl = MossFormerM(num_blocks=num_blocks, d_model=out_channels)
  321. self.skip_around_intra = skip_around_intra
  322. # Norm
  323. self.norm = norm
  324. if norm is not None:
  325. self.intra_norm = select_norm(norm, out_channels, 3)
  326. def forward(self, x):
  327. """Returns the output tensor.
  328. Arguments
  329. ---------
  330. x : torch.Tensor
  331. Input tensor of dimension [B, N, S].
  332. Return
  333. ---------
  334. out: torch.Tensor
  335. Output tensor of dimension [B, N, S].
  336. where, B = Batchsize,
  337. N = number of filters
  338. S = sequence time index
  339. """
  340. B, N, S = x.shape
  341. # intra RNN
  342. # [B, S, N]
  343. intra = x.permute(0, 2, 1).contiguous() #.view(B, S, N)
  344. intra = self.intra_mdl(intra)
  345. # [B, N, S]
  346. intra = intra.permute(0, 2, 1).contiguous()
  347. if self.norm is not None:
  348. intra = self.intra_norm(intra)
  349. # [B, N, S]
  350. if self.skip_around_intra:
  351. intra = intra + x
  352. out = intra
  353. return out