fairseq_modules.py 9.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. from typing import Optional, Tuple, List
  5. import numpy as np
  6. def LayerNorm(normalized_shape, eps=1e-5, elementwise_affine=True, export=False):
  7. return torch.nn.LayerNorm(normalized_shape, eps, elementwise_affine)
  8. class SamePad(nn.Module):
  9. def __init__(self, kernel_size, causal=False):
  10. super().__init__()
  11. if causal:
  12. self.remove = kernel_size - 1
  13. else:
  14. self.remove = 1 if kernel_size % 2 == 0 else 0
  15. def forward(self, x):
  16. if self.remove > 0:
  17. x = x[:, :, : -self.remove]
  18. return x
  19. class TransposeLast(nn.Module):
  20. def __init__(self, deconstruct_idx=None):
  21. super().__init__()
  22. self.deconstruct_idx = deconstruct_idx
  23. def forward(self, x):
  24. if self.deconstruct_idx is not None:
  25. x = x[self.deconstruct_idx]
  26. return x.transpose(-2, -1)
  27. class Fp32LayerNorm(nn.LayerNorm):
  28. def __init__(self, *args, **kwargs):
  29. super().__init__(*args, **kwargs)
  30. def forward(self, input):
  31. output = F.layer_norm(
  32. input.float(),
  33. self.normalized_shape,
  34. self.weight.float() if self.weight is not None else None,
  35. self.bias.float() if self.bias is not None else None,
  36. self.eps,
  37. )
  38. return output.type_as(input)
  39. class Fp32GroupNorm(nn.GroupNorm):
  40. def __init__(self, *args, **kwargs):
  41. super().__init__(*args, **kwargs)
  42. def forward(self, input):
  43. output = F.group_norm(
  44. input.float(),
  45. self.num_groups,
  46. self.weight.float() if self.weight is not None else None,
  47. self.bias.float() if self.bias is not None else None,
  48. self.eps,
  49. )
  50. return output.type_as(input)
  51. class ConvFeatureExtractionModel(nn.Module):
  52. def __init__(
  53. self,
  54. conv_layers: List[Tuple[int, int, int]],
  55. dropout: float = 0.0,
  56. mode: str = "default",
  57. conv_bias: bool = False,
  58. ):
  59. super().__init__()
  60. assert mode in {"default", "layer_norm"}
  61. def block(
  62. n_in,
  63. n_out,
  64. k,
  65. stride,
  66. is_layer_norm=False,
  67. is_group_norm=False,
  68. conv_bias=False,
  69. ):
  70. def make_conv():
  71. conv = nn.Conv1d(n_in, n_out, k, stride=stride, bias=conv_bias)
  72. nn.init.kaiming_normal_(conv.weight)
  73. return conv
  74. assert (
  75. is_layer_norm and is_group_norm
  76. ) == False, "layer norm and group norm are exclusive"
  77. if is_layer_norm:
  78. return nn.Sequential(
  79. make_conv(),
  80. nn.Dropout(p=dropout),
  81. nn.Sequential(
  82. TransposeLast(),
  83. Fp32LayerNorm(dim, elementwise_affine=True),
  84. TransposeLast(),
  85. ),
  86. nn.GELU(),
  87. )
  88. elif is_group_norm:
  89. return nn.Sequential(
  90. make_conv(),
  91. nn.Dropout(p=dropout),
  92. Fp32GroupNorm(dim, dim, affine=True),
  93. nn.GELU(),
  94. )
  95. else:
  96. return nn.Sequential(make_conv(), nn.Dropout(p=dropout), nn.GELU())
  97. in_d = 1
  98. self.conv_layers = nn.ModuleList()
  99. for i, cl in enumerate(conv_layers):
  100. assert len(cl) == 3, "invalid conv definition: " + str(cl)
  101. (dim, k, stride) = cl
  102. self.conv_layers.append(
  103. block(
  104. in_d,
  105. dim,
  106. k,
  107. stride,
  108. is_layer_norm=mode == "layer_norm",
  109. is_group_norm=mode == "default" and i == 0,
  110. conv_bias=conv_bias,
  111. )
  112. )
  113. in_d = dim
  114. def forward(self, x):
  115. # BxT -> BxCxT
  116. x = x.unsqueeze(1)
  117. for conv in self.conv_layers:
  118. x = conv(x)
  119. return x
  120. def compute_mask_indices(
  121. shape: Tuple[int, int],
  122. padding_mask: Optional[torch.Tensor],
  123. mask_prob: float,
  124. mask_length: int,
  125. mask_type: str = "static",
  126. mask_other: float = 0.0,
  127. min_masks: int = 0,
  128. no_overlap: bool = False,
  129. min_space: int = 0,
  130. require_same_masks: bool = True,
  131. mask_dropout: float = 0.0,
  132. ) -> np.ndarray:
  133. """
  134. Computes random mask spans for a given shape
  135. Args:
  136. shape: the the shape for which to compute masks.
  137. should be of size 2 where first element is batch size and 2nd is timesteps
  138. padding_mask: optional padding mask of the same size as shape, which will prevent masking padded elements
  139. mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by
  140. number of timesteps divided by length of mask span to mask approximately this percentage of all elements.
  141. however due to overlaps, the actual number will be smaller (unless no_overlap is True)
  142. mask_type: how to compute mask lengths
  143. static = fixed size
  144. uniform = sample from uniform distribution [mask_other, mask_length*2]
  145. normal = sample from normal distribution with mean mask_length and stdev mask_other. mask is min 1 element
  146. poisson = sample from possion distribution with lambda = mask length
  147. min_masks: minimum number of masked spans
  148. no_overlap: if false, will switch to an alternative recursive algorithm that prevents spans from overlapping
  149. min_space: only used if no_overlap is True, this is how many elements to keep unmasked between spans
  150. require_same_masks: if true, will randomly drop out masks until same amount of masks remains in each sample
  151. mask_dropout: randomly dropout this percentage of masks in each example
  152. """
  153. bsz, all_sz = shape
  154. mask = np.full((bsz, all_sz), False)
  155. all_num_mask = int(
  156. # add a random number for probabilistic rounding
  157. mask_prob * all_sz / float(mask_length)
  158. + np.random.rand()
  159. )
  160. all_num_mask = max(min_masks, all_num_mask)
  161. mask_idcs = []
  162. for i in range(bsz):
  163. if padding_mask is not None:
  164. sz = all_sz - padding_mask[i].long().sum().item()
  165. num_mask = int(
  166. # add a random number for probabilistic rounding
  167. mask_prob * sz / float(mask_length)
  168. + np.random.rand()
  169. )
  170. num_mask = max(min_masks, num_mask)
  171. else:
  172. sz = all_sz
  173. num_mask = all_num_mask
  174. if mask_type == "static":
  175. lengths = np.full(num_mask, mask_length)
  176. elif mask_type == "uniform":
  177. lengths = np.random.randint(mask_other, mask_length * 2 + 1, size=num_mask)
  178. elif mask_type == "normal":
  179. lengths = np.random.normal(mask_length, mask_other, size=num_mask)
  180. lengths = [max(1, int(round(x))) for x in lengths]
  181. elif mask_type == "poisson":
  182. lengths = np.random.poisson(mask_length, size=num_mask)
  183. lengths = [int(round(x)) for x in lengths]
  184. else:
  185. raise Exception("unknown mask selection " + mask_type)
  186. if sum(lengths) == 0:
  187. lengths[0] = min(mask_length, sz - 1)
  188. if no_overlap:
  189. mask_idc = []
  190. def arrange(s, e, length, keep_length):
  191. span_start = np.random.randint(s, e - length)
  192. mask_idc.extend(span_start + i for i in range(length))
  193. new_parts = []
  194. if span_start - s - min_space >= keep_length:
  195. new_parts.append((s, span_start - min_space + 1))
  196. if e - span_start - length - min_space > keep_length:
  197. new_parts.append((span_start + length + min_space, e))
  198. return new_parts
  199. parts = [(0, sz)]
  200. min_length = min(lengths)
  201. for length in sorted(lengths, reverse=True):
  202. lens = np.fromiter(
  203. (e - s if e - s >= length + min_space else 0 for s, e in parts),
  204. np.int,
  205. )
  206. l_sum = np.sum(lens)
  207. if l_sum == 0:
  208. break
  209. probs = lens / np.sum(lens)
  210. c = np.random.choice(len(parts), p=probs)
  211. s, e = parts.pop(c)
  212. parts.extend(arrange(s, e, length, min_length))
  213. mask_idc = np.asarray(mask_idc)
  214. else:
  215. min_len = min(lengths)
  216. if sz - min_len <= num_mask:
  217. min_len = sz - num_mask - 1
  218. mask_idc = np.random.choice(sz - min_len, num_mask, replace=False)
  219. mask_idc = np.asarray(
  220. [
  221. mask_idc[j] + offset
  222. for j in range(len(mask_idc))
  223. for offset in range(lengths[j])
  224. ]
  225. )
  226. mask_idcs.append(np.unique(mask_idc[mask_idc < sz]))
  227. min_len = min([len(m) for m in mask_idcs])
  228. for i, mask_idc in enumerate(mask_idcs):
  229. if len(mask_idc) > min_len and require_same_masks:
  230. mask_idc = np.random.choice(mask_idc, min_len, replace=False)
  231. if mask_dropout > 0:
  232. num_holes = np.rint(len(mask_idc) * mask_dropout).astype(int)
  233. mask_idc = np.random.choice(
  234. mask_idc, len(mask_idc) - num_holes, replace=False
  235. )
  236. mask[i, mask_idc] = True
  237. return mask
  238. class GradMultiply(torch.autograd.Function):
  239. @staticmethod
  240. def forward(ctx, x, scale):
  241. ctx.scale = scale
  242. res = x.new(x)
  243. return res
  244. @staticmethod
  245. def backward(ctx, grad):
  246. return grad * ctx.scale, None
  247. def is_xla_tensor(tensor):
  248. return torch.is_tensor(tensor) and tensor.device.type == "xla"
  249. def index_put(tensor, indices, value):
  250. if is_xla_tensor(tensor):
  251. for _ in range(indices.dim(), tensor.dim()):
  252. indices = indices.unsqueeze(-1)
  253. if indices.size(-1) < tensor.size(-1):
  254. indices = indices.expand_as(tensor)
  255. tensor = torch.mul(tensor, ~indices) + torch.mul(value, indices)
  256. else:
  257. tensor[indices] = value
  258. return tensor