mossformer_encoder.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. try:
  5. from rotary_embedding_torch import RotaryEmbedding
  6. except:
  7. print("Please install rotary_embedding_torch by: \n pip install -U rotary_embedding_torch")
  8. from funasr.modules.layer_norm import GlobalLayerNorm, CumulativeLayerNorm, ScaleNorm
  9. from funasr.modules.embedding import ScaledSinuEmbedding
  10. from funasr.modules.mossformer import FLASH_ShareA_FFConvM
  11. def select_norm(norm, dim, shape):
  12. """Just a wrapper to select the normalization type.
  13. """
  14. if norm == "gln":
  15. return GlobalLayerNorm(dim, shape, elementwise_affine=True)
  16. if norm == "cln":
  17. return CumulativeLayerNorm(dim, elementwise_affine=True)
  18. if norm == "ln":
  19. return nn.GroupNorm(1, dim, eps=1e-8)
  20. else:
  21. return nn.BatchNorm1d(dim)
  22. class MossformerBlock(nn.Module):
  23. def __init__(
  24. self,
  25. *,
  26. dim,
  27. depth,
  28. group_size = 256,
  29. query_key_dim = 128,
  30. expansion_factor = 4.,
  31. causal = False,
  32. attn_dropout = 0.1,
  33. norm_type = 'scalenorm',
  34. shift_tokens = True
  35. ):
  36. super().__init__()
  37. assert norm_type in ('scalenorm', 'layernorm'), 'norm_type must be one of scalenorm or layernorm'
  38. if norm_type == 'scalenorm':
  39. norm_klass = ScaleNorm
  40. elif norm_type == 'layernorm':
  41. norm_klass = nn.LayerNorm
  42. self.group_size = group_size
  43. rotary_pos_emb = RotaryEmbedding(dim = min(32, query_key_dim))
  44. # max rotary embedding dimensions of 32, partial Rotary embeddings, from Wang et al - GPT-J
  45. self.layers = nn.ModuleList([FLASH_ShareA_FFConvM(dim = dim, group_size = group_size, query_key_dim = query_key_dim, expansion_factor = expansion_factor, causal = causal, dropout = attn_dropout, rotary_pos_emb = rotary_pos_emb, norm_klass = norm_klass, shift_tokens = shift_tokens) for _ in range(depth)])
  46. def forward(
  47. self,
  48. x,
  49. *,
  50. mask = None
  51. ):
  52. ii = 0
  53. for flash in self.layers:
  54. x = flash(x, mask = mask)
  55. ii = ii + 1
  56. return x
  57. class MossFormer_MaskNet(nn.Module):
  58. """The MossFormer module for computing output masks.
  59. Arguments
  60. ---------
  61. in_channels : int
  62. Number of channels at the output of the encoder.
  63. out_channels : int
  64. Number of channels that would be inputted to the intra and inter blocks.
  65. num_blocks : int
  66. Number of layers of Dual Computation Block.
  67. norm : str
  68. Normalization type.
  69. num_spks : int
  70. Number of sources (speakers).
  71. skip_around_intra : bool
  72. Skip connection around intra.
  73. use_global_pos_enc : bool
  74. Global positional encodings.
  75. max_length : int
  76. Maximum sequence length.
  77. Example
  78. ---------
  79. >>> mossformer_block = MossFormerM(1, 64, 8)
  80. >>> mossformer_masknet = MossFormer_MaskNet(64, 64, intra_block, num_spks=2)
  81. >>> x = torch.randn(10, 64, 2000)
  82. >>> x = mossformer_masknet(x)
  83. >>> x.shape
  84. torch.Size([2, 10, 64, 2000])
  85. """
  86. def __init__(
  87. self,
  88. in_channels,
  89. out_channels,
  90. num_blocks=24,
  91. norm="ln",
  92. num_spks=2,
  93. skip_around_intra=True,
  94. use_global_pos_enc=True,
  95. max_length=20000,
  96. ):
  97. super(MossFormer_MaskNet, self).__init__()
  98. self.num_spks = num_spks
  99. self.num_blocks = num_blocks
  100. self.norm = select_norm(norm, in_channels, 3)
  101. self.conv1d_encoder = nn.Conv1d(in_channels, out_channels, 1, bias=False)
  102. self.use_global_pos_enc = use_global_pos_enc
  103. if self.use_global_pos_enc:
  104. self.pos_enc = ScaledSinuEmbedding(out_channels)
  105. self.mdl = Computation_Block(
  106. num_blocks,
  107. out_channels,
  108. norm,
  109. skip_around_intra=skip_around_intra,
  110. )
  111. self.conv1d_out = nn.Conv1d(
  112. out_channels, out_channels * num_spks, kernel_size=1
  113. )
  114. self.conv1_decoder = nn.Conv1d(out_channels, in_channels, 1, bias=False)
  115. self.prelu = nn.PReLU()
  116. self.activation = nn.ReLU()
  117. # gated output layer
  118. self.output = nn.Sequential(
  119. nn.Conv1d(out_channels, out_channels, 1), nn.Tanh()
  120. )
  121. self.output_gate = nn.Sequential(
  122. nn.Conv1d(out_channels, out_channels, 1), nn.Sigmoid()
  123. )
  124. def forward(self, x):
  125. """Returns the output tensor.
  126. Arguments
  127. ---------
  128. x : torch.Tensor
  129. Input tensor of dimension [B, N, S].
  130. Returns
  131. -------
  132. out : torch.Tensor
  133. Output tensor of dimension [spks, B, N, S]
  134. where, spks = Number of speakers
  135. B = Batchsize,
  136. N = number of filters
  137. S = the number of time frames
  138. """
  139. # before each line we indicate the shape after executing the line
  140. # [B, N, L]
  141. x = self.norm(x)
  142. # [B, N, L]
  143. x = self.conv1d_encoder(x)
  144. if self.use_global_pos_enc:
  145. #x = self.pos_enc(x.transpose(1, -1)).transpose(1, -1) + x * (
  146. # x.size(1) ** 0.5)
  147. base = x
  148. x = x.transpose(1, -1)
  149. emb = self.pos_enc(x)
  150. emb = emb.transpose(0, -1)
  151. #print('base: {}, emb: {}'.format(base.shape, emb.shape))
  152. x = base + emb
  153. # [B, N, S]
  154. #for i in range(self.num_modules):
  155. # x = self.dual_mdl[i](x)
  156. x = self.mdl(x)
  157. x = self.prelu(x)
  158. # [B, N*spks, S]
  159. x = self.conv1d_out(x)
  160. B, _, S = x.shape
  161. # [B*spks, N, S]
  162. x = x.view(B * self.num_spks, -1, S)
  163. # [B*spks, N, S]
  164. x = self.output(x) * self.output_gate(x)
  165. # [B*spks, N, S]
  166. x = self.conv1_decoder(x)
  167. # [B, spks, N, S]
  168. _, N, L = x.shape
  169. x = x.view(B, self.num_spks, N, L)
  170. x = self.activation(x)
  171. # [spks, B, N, S]
  172. x = x.transpose(0, 1)
  173. return x
  174. class MossFormerEncoder(nn.Module):
  175. """Convolutional Encoder Layer.
  176. Arguments
  177. ---------
  178. kernel_size : int
  179. Length of filters.
  180. in_channels : int
  181. Number of input channels.
  182. out_channels : int
  183. Number of output channels.
  184. Example
  185. -------
  186. >>> x = torch.randn(2, 1000)
  187. >>> encoder = Encoder(kernel_size=4, out_channels=64)
  188. >>> h = encoder(x)
  189. >>> h.shape
  190. torch.Size([2, 64, 499])
  191. """
  192. def __init__(self, kernel_size=2, out_channels=64, in_channels=1):
  193. super(MossFormerEncoder, self).__init__()
  194. self.conv1d = nn.Conv1d(
  195. in_channels=in_channels,
  196. out_channels=out_channels,
  197. kernel_size=kernel_size,
  198. stride=kernel_size // 2,
  199. groups=1,
  200. bias=False,
  201. )
  202. self.in_channels = in_channels
  203. def forward(self, x):
  204. """Return the encoded output.
  205. Arguments
  206. ---------
  207. x : torch.Tensor
  208. Input tensor with dimensionality [B, L].
  209. Return
  210. ------
  211. x : torch.Tensor
  212. Encoded tensor with dimensionality [B, N, T_out].
  213. where B = Batchsize
  214. L = Number of timepoints
  215. N = Number of filters
  216. T_out = Number of timepoints at the output of the encoder
  217. """
  218. # B x L -> B x 1 x L
  219. if self.in_channels == 1:
  220. x = torch.unsqueeze(x, dim=1)
  221. # B x 1 x L -> B x N x T_out
  222. x = self.conv1d(x)
  223. x = F.relu(x)
  224. return x
  225. class MossFormerM(nn.Module):
  226. """This class implements the transformer encoder.
  227. Arguments
  228. ---------
  229. num_blocks : int
  230. Number of mossformer blocks to include.
  231. d_model : int
  232. The dimension of the input embedding.
  233. attn_dropout : float
  234. Dropout for the self-attention (Optional).
  235. group_size: int
  236. the chunk size
  237. query_key_dim: int
  238. the attention vector dimension
  239. expansion_factor: int
  240. the expansion factor for the linear projection in conv module
  241. causal: bool
  242. true for causal / false for non causal
  243. Example
  244. -------
  245. >>> import torch
  246. >>> x = torch.rand((8, 60, 512))
  247. >>> net = TransformerEncoder_MossFormerM(num_blocks=8, d_model=512)
  248. >>> output, _ = net(x)
  249. >>> output.shape
  250. torch.Size([8, 60, 512])
  251. """
  252. def __init__(
  253. self,
  254. num_blocks,
  255. d_model=None,
  256. causal=False,
  257. group_size = 256,
  258. query_key_dim = 128,
  259. expansion_factor = 4.,
  260. attn_dropout = 0.1
  261. ):
  262. super().__init__()
  263. self.mossformerM = MossformerBlock(
  264. dim=d_model,
  265. depth=num_blocks,
  266. group_size=group_size,
  267. query_key_dim=query_key_dim,
  268. expansion_factor=expansion_factor,
  269. causal=causal,
  270. attn_dropout=attn_dropout
  271. )
  272. self.norm = nn.LayerNorm(d_model, eps=1e-6)
  273. def forward(
  274. self,
  275. src,
  276. ):
  277. """
  278. Arguments
  279. ----------
  280. src : torch.Tensor
  281. Tensor shape [B, L, N],
  282. where, B = Batchsize,
  283. L = time points
  284. N = number of filters
  285. The sequence to the encoder layer (required).
  286. src_mask : tensor
  287. The mask for the src sequence (optional).
  288. src_key_padding_mask : tensor
  289. The mask for the src keys per batch (optional).
  290. """
  291. output = self.mossformerM(src)
  292. output = self.norm(output)
  293. return output
  294. class Computation_Block(nn.Module):
  295. """Computation block for dual-path processing.
  296. Arguments
  297. ---------
  298. out_channels : int
  299. Dimensionality of inter/intra model.
  300. norm : str
  301. Normalization type.
  302. skip_around_intra : bool
  303. Skip connection around the intra layer.
  304. Example
  305. ---------
  306. >>> comp_block = Computation_Block(64)
  307. >>> x = torch.randn(10, 64, 100)
  308. >>> x = comp_block(x)
  309. >>> x.shape
  310. torch.Size([10, 64, 100])
  311. """
  312. def __init__(
  313. self,
  314. num_blocks,
  315. out_channels,
  316. norm="ln",
  317. skip_around_intra=True,
  318. ):
  319. super(Computation_Block, self).__init__()
  320. ##MossFormer2M: MossFormer with recurrence
  321. #self.intra_mdl = MossFormer2M(num_blocks=num_blocks, d_model=out_channels)
  322. ##MossFormerM: the orignal MossFormer
  323. self.intra_mdl = MossFormerM(num_blocks=num_blocks, d_model=out_channels)
  324. self.skip_around_intra = skip_around_intra
  325. # Norm
  326. self.norm = norm
  327. if norm is not None:
  328. self.intra_norm = select_norm(norm, out_channels, 3)
  329. def forward(self, x):
  330. """Returns the output tensor.
  331. Arguments
  332. ---------
  333. x : torch.Tensor
  334. Input tensor of dimension [B, N, S].
  335. Return
  336. ---------
  337. out: torch.Tensor
  338. Output tensor of dimension [B, N, S].
  339. where, B = Batchsize,
  340. N = number of filters
  341. S = sequence time index
  342. """
  343. B, N, S = x.shape
  344. # intra RNN
  345. # [B, S, N]
  346. intra = x.permute(0, 2, 1).contiguous() #.view(B, S, N)
  347. intra = self.intra_mdl(intra)
  348. # [B, N, S]
  349. intra = intra.permute(0, 2, 1).contiguous()
  350. if self.norm is not None:
  351. intra = self.intra_norm(intra)
  352. # [B, N, S]
  353. if self.skip_around_intra:
  354. intra = intra + x
  355. out = intra
  356. return out