wav2vec2.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407
  1. # Copyright (c) Facebook, Inc. and its affiliates.
  2. #
  3. # This source code is licensed under the MIT license found in the
  4. # LICENSE file in the root directory of this source tree.
  5. import logging
  6. import math
  7. from typing import List, Tuple
  8. import numpy as np
  9. import torch
  10. import torch.nn as nn
  11. import torch.nn.functional as F
  12. from funasr.models.data2vec import utils
  13. from funasr.models.data2vec.multihead_attention import MultiheadAttention
  14. class ConvFeatureExtractionModel(nn.Module):
  15. def __init__(
  16. self,
  17. conv_layers: List[Tuple[int, int, int]],
  18. dropout: float = 0.0,
  19. mode: str = "default",
  20. conv_bias: bool = False,
  21. in_d: int = 1
  22. ):
  23. super().__init__()
  24. assert mode in {"default", "layer_norm"}
  25. def block(
  26. n_in,
  27. n_out,
  28. k,
  29. stride,
  30. is_layer_norm=False,
  31. is_group_norm=False,
  32. conv_bias=False,
  33. ):
  34. def make_conv():
  35. conv = nn.Conv1d(n_in, n_out, k, stride=stride, bias=conv_bias)
  36. nn.init.kaiming_normal_(conv.weight)
  37. return conv
  38. assert (
  39. is_layer_norm and is_group_norm
  40. ) == False, "layer norm and group norm are exclusive"
  41. if is_layer_norm:
  42. return nn.Sequential(
  43. make_conv(),
  44. nn.Dropout(p=dropout),
  45. nn.Sequential(
  46. utils.TransposeLast(),
  47. utils.Fp32LayerNorm(dim, elementwise_affine=True),
  48. utils.TransposeLast(),
  49. ),
  50. nn.GELU(),
  51. )
  52. elif is_group_norm:
  53. return nn.Sequential(
  54. make_conv(),
  55. nn.Dropout(p=dropout),
  56. utils.Fp32GroupNorm(dim, dim, affine=True),
  57. nn.GELU(),
  58. )
  59. else:
  60. return nn.Sequential(make_conv(), nn.Dropout(p=dropout), nn.GELU())
  61. self.conv_layers = nn.ModuleList()
  62. for i, cl in enumerate(conv_layers):
  63. assert len(cl) == 3, "invalid conv definition: " + str(cl)
  64. (dim, k, stride) = cl
  65. self.conv_layers.append(
  66. block(
  67. in_d,
  68. dim,
  69. k,
  70. stride,
  71. is_layer_norm=mode == "layer_norm",
  72. is_group_norm=mode == "default" and i == 0,
  73. conv_bias=conv_bias,
  74. )
  75. )
  76. in_d = dim
  77. def forward(self, x):
  78. if len(x.shape) == 2:
  79. x = x.unsqueeze(1)
  80. else:
  81. x = x.transpose(1, 2)
  82. for conv in self.conv_layers:
  83. x = conv(x)
  84. return x
  85. def make_conv_pos(e, k, g):
  86. pos_conv = nn.Conv1d(
  87. e,
  88. e,
  89. kernel_size=k,
  90. padding=k // 2,
  91. groups=g,
  92. )
  93. dropout = 0
  94. std = math.sqrt((4 * (1.0 - dropout)) / (k * e))
  95. nn.init.normal_(pos_conv.weight, mean=0, std=std)
  96. nn.init.constant_(pos_conv.bias, 0)
  97. pos_conv = nn.utils.weight_norm(pos_conv, name="weight", dim=2)
  98. pos_conv = nn.Sequential(pos_conv, utils.SamePad(k), nn.GELU())
  99. return pos_conv
  100. class TransformerEncoder(nn.Module):
  101. def build_encoder_layer(self):
  102. if self.layer_type == "transformer":
  103. layer = TransformerSentenceEncoderLayer(
  104. embedding_dim=self.embedding_dim,
  105. ffn_embedding_dim=self.encoder_ffn_embed_dim,
  106. num_attention_heads=self.encoder_attention_heads,
  107. dropout=self.dropout,
  108. attention_dropout=self.attention_dropout,
  109. activation_dropout=self.activation_dropout,
  110. activation_fn=self.activation_fn,
  111. layer_norm_first=self.layer_norm_first,
  112. )
  113. else:
  114. logging.error("Only transformer is supported for data2vec now")
  115. return layer
  116. def __init__(
  117. self,
  118. # position
  119. dropout,
  120. encoder_embed_dim,
  121. required_seq_len_multiple,
  122. pos_conv_depth,
  123. conv_pos,
  124. conv_pos_groups,
  125. # transformer layers
  126. layer_type,
  127. encoder_layers,
  128. encoder_ffn_embed_dim,
  129. encoder_attention_heads,
  130. attention_dropout,
  131. activation_dropout,
  132. activation_fn,
  133. layer_norm_first,
  134. encoder_layerdrop,
  135. max_positions,
  136. ):
  137. super().__init__()
  138. # position
  139. self.dropout = dropout
  140. self.embedding_dim = encoder_embed_dim
  141. self.required_seq_len_multiple = required_seq_len_multiple
  142. if pos_conv_depth > 1:
  143. num_layers = pos_conv_depth
  144. k = max(3, conv_pos // num_layers)
  145. def make_conv_block(e, k, g, l):
  146. return nn.Sequential(
  147. *[
  148. nn.Sequential(
  149. nn.Conv1d(
  150. e,
  151. e,
  152. kernel_size=k,
  153. padding=k // 2,
  154. groups=g,
  155. ),
  156. utils.SamePad(k),
  157. utils.TransposeLast(),
  158. torch.nn.LayerNorm(e, elementwise_affine=False),
  159. utils.TransposeLast(),
  160. nn.GELU(),
  161. )
  162. for _ in range(l)
  163. ]
  164. )
  165. self.pos_conv = make_conv_block(
  166. self.embedding_dim, k, conv_pos_groups, num_layers
  167. )
  168. else:
  169. self.pos_conv = make_conv_pos(
  170. self.embedding_dim,
  171. conv_pos,
  172. conv_pos_groups,
  173. )
  174. # transformer layers
  175. self.layer_type = layer_type
  176. self.encoder_ffn_embed_dim = encoder_ffn_embed_dim
  177. self.encoder_attention_heads = encoder_attention_heads
  178. self.attention_dropout = attention_dropout
  179. self.activation_dropout = activation_dropout
  180. self.activation_fn = activation_fn
  181. self.layer_norm_first = layer_norm_first
  182. self.layerdrop = encoder_layerdrop
  183. self.max_positions = max_positions
  184. self.layers = nn.ModuleList(
  185. [self.build_encoder_layer() for _ in range(encoder_layers)]
  186. )
  187. self.layer_norm = torch.nn.LayerNorm(self.embedding_dim)
  188. self.apply(utils.init_bert_params)
  189. def forward(self, x, padding_mask=None, layer=None):
  190. x, layer_results = self.extract_features(x, padding_mask, layer)
  191. if self.layer_norm_first and layer is None:
  192. x = self.layer_norm(x)
  193. return x, layer_results
  194. def extract_features(
  195. self,
  196. x,
  197. padding_mask=None,
  198. tgt_layer=None,
  199. min_layer=0,
  200. ):
  201. if padding_mask is not None:
  202. x[padding_mask] = 0
  203. x_conv = self.pos_conv(x.transpose(1, 2))
  204. x_conv = x_conv.transpose(1, 2)
  205. x = x + x_conv
  206. if not self.layer_norm_first:
  207. x = self.layer_norm(x)
  208. # pad to the sequence length dimension
  209. x, pad_length = utils.pad_to_multiple(
  210. x, self.required_seq_len_multiple, dim=-2, value=0
  211. )
  212. if pad_length > 0 and padding_mask is None:
  213. padding_mask = x.new_zeros((x.size(0), x.size(1)), dtype=torch.bool)
  214. padding_mask[:, -pad_length:] = True
  215. else:
  216. padding_mask, _ = utils.pad_to_multiple(
  217. padding_mask, self.required_seq_len_multiple, dim=-1, value=True
  218. )
  219. x = F.dropout(x, p=self.dropout, training=self.training)
  220. # B x T x C -> T x B x C
  221. x = x.transpose(0, 1)
  222. layer_results = []
  223. r = None
  224. for i, layer in enumerate(self.layers):
  225. dropout_probability = np.random.random() if self.layerdrop > 0 else 1
  226. if not self.training or (dropout_probability > self.layerdrop):
  227. x, (z, lr) = layer(x, self_attn_padding_mask=padding_mask)
  228. if i >= min_layer:
  229. layer_results.append((x, z, lr))
  230. if i == tgt_layer:
  231. r = x
  232. break
  233. if r is not None:
  234. x = r
  235. # T x B x C -> B x T x C
  236. x = x.transpose(0, 1)
  237. # undo paddding
  238. if pad_length > 0:
  239. x = x[:, :-pad_length]
  240. def undo_pad(a, b, c):
  241. return (
  242. a[:-pad_length],
  243. b[:-pad_length] if b is not None else b,
  244. c[:-pad_length],
  245. )
  246. layer_results = [undo_pad(*u) for u in layer_results]
  247. return x, layer_results
  248. def max_positions(self):
  249. """Maximum output length supported by the encoder."""
  250. return self.max_positions
  251. def upgrade_state_dict_named(self, state_dict, name):
  252. """Upgrade a (possibly old) state dict for new versions of fairseq."""
  253. return state_dict
  254. class TransformerSentenceEncoderLayer(nn.Module):
  255. """
  256. Implements a Transformer Encoder Layer used in BERT/XLM style pre-trained
  257. models.
  258. """
  259. def __init__(
  260. self,
  261. embedding_dim: int = 768,
  262. ffn_embedding_dim: int = 3072,
  263. num_attention_heads: int = 8,
  264. dropout: float = 0.1,
  265. attention_dropout: float = 0.1,
  266. activation_dropout: float = 0.1,
  267. activation_fn: str = "relu",
  268. layer_norm_first: bool = False,
  269. ) -> None:
  270. super().__init__()
  271. # Initialize parameters
  272. self.embedding_dim = embedding_dim
  273. self.dropout = dropout
  274. self.activation_dropout = activation_dropout
  275. # Initialize blocks
  276. self.activation_fn = utils.get_activation_fn(activation_fn)
  277. self.self_attn = MultiheadAttention(
  278. self.embedding_dim,
  279. num_attention_heads,
  280. dropout=attention_dropout,
  281. self_attention=True,
  282. )
  283. self.dropout1 = nn.Dropout(dropout)
  284. self.dropout2 = nn.Dropout(self.activation_dropout)
  285. self.dropout3 = nn.Dropout(dropout)
  286. self.layer_norm_first = layer_norm_first
  287. # layer norm associated with the self attention layer
  288. self.self_attn_layer_norm = torch.nn.LayerNorm(self.embedding_dim)
  289. self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim)
  290. self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim)
  291. # layer norm associated with the position wise feed-forward NN
  292. self.final_layer_norm = torch.nn.LayerNorm(self.embedding_dim)
  293. def forward(
  294. self,
  295. x: torch.Tensor, # (T, B, C)
  296. self_attn_mask: torch.Tensor = None,
  297. self_attn_padding_mask: torch.Tensor = None,
  298. ):
  299. """
  300. LayerNorm is applied either before or after the self-attention/ffn
  301. modules similar to the original Transformer imlementation.
  302. """
  303. residual = x
  304. if self.layer_norm_first:
  305. x = self.self_attn_layer_norm(x)
  306. x, attn = self.self_attn(
  307. query=x,
  308. key=x,
  309. value=x,
  310. key_padding_mask=self_attn_padding_mask,
  311. attn_mask=self_attn_mask,
  312. need_weights=False,
  313. )
  314. x = self.dropout1(x)
  315. x = residual + x
  316. residual = x
  317. x = self.final_layer_norm(x)
  318. x = self.activation_fn(self.fc1(x))
  319. x = self.dropout2(x)
  320. x = self.fc2(x)
  321. layer_result = x
  322. x = self.dropout3(x)
  323. x = residual + x
  324. else:
  325. x, attn = self.self_attn(
  326. query=x,
  327. key=x,
  328. value=x,
  329. key_padding_mask=self_attn_padding_mask,
  330. need_weights=False,
  331. )
  332. x = self.dropout1(x)
  333. x = residual + x
  334. x = self.self_attn_layer_norm(x)
  335. residual = x
  336. x = self.activation_fn(self.fc1(x))
  337. x = self.dropout2(x)
  338. x = self.fc2(x)
  339. layer_result = x
  340. x = self.dropout3(x)
  341. x = residual + x
  342. x = self.final_layer_norm(x)
  343. return x, (attn, layer_result)