| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336 |
- import math
- import torch
- from typing import Sequence
- from typing import Union
- def mask_along_axis(
- spec: torch.Tensor,
- spec_lengths: torch.Tensor,
- mask_width_range: Sequence[int] = (0, 30),
- dim: int = 1,
- num_mask: int = 2,
- replace_with_zero: bool = True,
- ):
- """Apply mask along the specified direction.
- Args:
- spec: (Batch, Length, Freq)
- spec_lengths: (Length): Not using lengths in this implementation
- mask_width_range: Select the width randomly between this range
- """
- org_size = spec.size()
- if spec.dim() == 4:
- # spec: (Batch, Channel, Length, Freq) -> (Batch * Channel, Length, Freq)
- spec = spec.view(-1, spec.size(2), spec.size(3))
- B = spec.shape[0]
- # D = Length or Freq
- D = spec.shape[dim]
- # mask_length: (B, num_mask, 1)
- mask_length = torch.randint(
- mask_width_range[0],
- mask_width_range[1],
- (B, num_mask),
- device=spec.device,
- ).unsqueeze(2)
- # mask_pos: (B, num_mask, 1)
- mask_pos = torch.randint(
- 0, max(1, D - mask_length.max()), (B, num_mask), device=spec.device
- ).unsqueeze(2)
- # aran: (1, 1, D)
- aran = torch.arange(D, device=spec.device)[None, None, :]
- # mask: (Batch, num_mask, D)
- mask = (mask_pos <= aran) * (aran < (mask_pos + mask_length))
- # Multiply masks: (Batch, num_mask, D) -> (Batch, D)
- mask = mask.any(dim=1)
- if dim == 1:
- # mask: (Batch, Length, 1)
- mask = mask.unsqueeze(2)
- elif dim == 2:
- # mask: (Batch, 1, Freq)
- mask = mask.unsqueeze(1)
- if replace_with_zero:
- value = 0.0
- else:
- value = spec.mean()
- if spec.requires_grad:
- spec = spec.masked_fill(mask, value)
- else:
- spec = spec.masked_fill_(mask, value)
- spec = spec.view(*org_size)
- return spec, spec_lengths
- def mask_along_axis_lfr(
- spec: torch.Tensor,
- spec_lengths: torch.Tensor,
- mask_width_range: Sequence[int] = (0, 30),
- dim: int = 1,
- num_mask: int = 2,
- replace_with_zero: bool = True,
- lfr_rate: int = 1,
- ):
- """Apply mask along the specified direction.
- Args:
- spec: (Batch, Length, Freq)
- spec_lengths: (Length): Not using lengths in this implementation
- mask_width_range: Select the width randomly between this range
- lfr_rate:low frame rate
- """
- org_size = spec.size()
- if spec.dim() == 4:
- # spec: (Batch, Channel, Length, Freq) -> (Batch * Channel, Length, Freq)
- spec = spec.view(-1, spec.size(2), spec.size(3))
- B = spec.shape[0]
- # D = Length or Freq
- D = spec.shape[dim] // lfr_rate
- # mask_length: (B, num_mask, 1)
- mask_length = torch.randint(
- mask_width_range[0],
- mask_width_range[1],
- (B, num_mask),
- device=spec.device,
- ).unsqueeze(2)
- if lfr_rate > 1:
- mask_length = mask_length.repeat(1, lfr_rate, 1)
- # mask_pos: (B, num_mask, 1)
- mask_pos = torch.randint(
- 0, max(1, D - mask_length.max()), (B, num_mask), device=spec.device
- ).unsqueeze(2)
- if lfr_rate > 1:
- mask_pos_raw = mask_pos.clone()
- mask_pos = torch.zeros((B, 0, 1), device=spec.device, dtype=torch.int32)
- for i in range(lfr_rate):
- mask_pos_i = mask_pos_raw + D * i
- mask_pos = torch.cat((mask_pos, mask_pos_i), dim=1)
- # aran: (1, 1, D)
- D = spec.shape[dim]
- aran = torch.arange(D, device=spec.device)[None, None, :]
- # mask: (Batch, num_mask, D)
- mask = (mask_pos <= aran) * (aran < (mask_pos + mask_length))
- # Multiply masks: (Batch, num_mask, D) -> (Batch, D)
- mask = mask.any(dim=1)
- if dim == 1:
- # mask: (Batch, Length, 1)
- mask = mask.unsqueeze(2)
- elif dim == 2:
- # mask: (Batch, 1, Freq)
- mask = mask.unsqueeze(1)
- if replace_with_zero:
- value = 0.0
- else:
- value = spec.mean()
- if spec.requires_grad:
- spec = spec.masked_fill(mask, value)
- else:
- spec = spec.masked_fill_(mask, value)
- spec = spec.view(*org_size)
- return spec, spec_lengths
- class MaskAlongAxis(torch.nn.Module):
- def __init__(
- self,
- mask_width_range: Union[int, Sequence[int]] = (0, 30),
- num_mask: int = 2,
- dim: Union[int, str] = "time",
- replace_with_zero: bool = True,
- ):
- if isinstance(mask_width_range, int):
- mask_width_range = (0, mask_width_range)
- if len(mask_width_range) != 2:
- raise TypeError(
- f"mask_width_range must be a tuple of int and int values: "
- f"{mask_width_range}",
- )
- assert mask_width_range[1] > mask_width_range[0]
- if isinstance(dim, str):
- if dim == "time":
- dim = 1
- elif dim == "freq":
- dim = 2
- else:
- raise ValueError("dim must be int, 'time' or 'freq'")
- if dim == 1:
- self.mask_axis = "time"
- elif dim == 2:
- self.mask_axis = "freq"
- else:
- self.mask_axis = "unknown"
- super().__init__()
- self.mask_width_range = mask_width_range
- self.num_mask = num_mask
- self.dim = dim
- self.replace_with_zero = replace_with_zero
- def extra_repr(self):
- return (
- f"mask_width_range={self.mask_width_range}, "
- f"num_mask={self.num_mask}, axis={self.mask_axis}"
- )
- def forward(self, spec: torch.Tensor, spec_lengths: torch.Tensor = None):
- """Forward function.
- Args:
- spec: (Batch, Length, Freq)
- """
- return mask_along_axis(
- spec,
- spec_lengths,
- mask_width_range=self.mask_width_range,
- dim=self.dim,
- num_mask=self.num_mask,
- replace_with_zero=self.replace_with_zero,
- )
- class MaskAlongAxisVariableMaxWidth(torch.nn.Module):
- """Mask input spec along a specified axis with variable maximum width.
- Formula:
- max_width = max_width_ratio * seq_len
- """
- def __init__(
- self,
- mask_width_ratio_range: Union[float, Sequence[float]] = (0.0, 0.05),
- num_mask: int = 2,
- dim: Union[int, str] = "time",
- replace_with_zero: bool = True,
- ):
- if isinstance(mask_width_ratio_range, float):
- mask_width_ratio_range = (0.0, mask_width_ratio_range)
- if len(mask_width_ratio_range) != 2:
- raise TypeError(
- f"mask_width_ratio_range must be a tuple of float and float values: "
- f"{mask_width_ratio_range}",
- )
- assert mask_width_ratio_range[1] > mask_width_ratio_range[0]
- if isinstance(dim, str):
- if dim == "time":
- dim = 1
- elif dim == "freq":
- dim = 2
- else:
- raise ValueError("dim must be int, 'time' or 'freq'")
- if dim == 1:
- self.mask_axis = "time"
- elif dim == 2:
- self.mask_axis = "freq"
- else:
- self.mask_axis = "unknown"
- super().__init__()
- self.mask_width_ratio_range = mask_width_ratio_range
- self.num_mask = num_mask
- self.dim = dim
- self.replace_with_zero = replace_with_zero
- def extra_repr(self):
- return (
- f"mask_width_ratio_range={self.mask_width_ratio_range}, "
- f"num_mask={self.num_mask}, axis={self.mask_axis}"
- )
- def forward(self, spec: torch.Tensor, spec_lengths: torch.Tensor = None):
- """Forward function.
- Args:
- spec: (Batch, Length, Freq)
- """
- max_seq_len = spec.shape[self.dim]
- min_mask_width = math.floor(max_seq_len * self.mask_width_ratio_range[0])
- min_mask_width = max([0, min_mask_width])
- max_mask_width = math.floor(max_seq_len * self.mask_width_ratio_range[1])
- max_mask_width = min([max_seq_len, max_mask_width])
- if max_mask_width > min_mask_width:
- return mask_along_axis(
- spec,
- spec_lengths,
- mask_width_range=(min_mask_width, max_mask_width),
- dim=self.dim,
- num_mask=self.num_mask,
- replace_with_zero=self.replace_with_zero,
- )
- return spec, spec_lengths
- class MaskAlongAxisLFR(torch.nn.Module):
- def __init__(
- self,
- mask_width_range: Union[int, Sequence[int]] = (0, 30),
- num_mask: int = 2,
- dim: Union[int, str] = "time",
- replace_with_zero: bool = True,
- lfr_rate: int = 1,
- ):
- if isinstance(mask_width_range, int):
- mask_width_range = (0, mask_width_range)
- if len(mask_width_range) != 2:
- raise TypeError(
- f"mask_width_range must be a tuple of int and int values: "
- f"{mask_width_range}",
- )
- assert mask_width_range[1] > mask_width_range[0]
- if isinstance(dim, str):
- if dim == "time":
- dim = 1
- lfr_rate = 1
- elif dim == "freq":
- dim = 2
- else:
- raise ValueError("dim must be int, 'time' or 'freq'")
- if dim == 1:
- self.mask_axis = "time"
- lfr_rate = 1
- elif dim == 2:
- self.mask_axis = "freq"
- else:
- self.mask_axis = "unknown"
- super().__init__()
- self.mask_width_range = mask_width_range
- self.num_mask = num_mask
- self.dim = dim
- self.replace_with_zero = replace_with_zero
- self.lfr_rate = lfr_rate
- def extra_repr(self):
- return (
- f"mask_width_range={self.mask_width_range}, "
- f"num_mask={self.num_mask}, axis={self.mask_axis}"
- )
- def forward(self, spec: torch.Tensor, spec_lengths: torch.Tensor = None):
- """Forward function.
- Args:
- spec: (Batch, Length, Freq)
- """
- return mask_along_axis_lfr(
- spec,
- spec_lengths,
- mask_width_range=self.mask_width_range,
- dim=self.dim,
- num_mask=self.num_mask,
- replace_with_zero=self.replace_with_zero,
- lfr_rate=self.lfr_rate,
- )
|