label_aggregation.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113
  1. import torch
  2. from typing import Optional
  3. from typing import Tuple
  4. from torch.nn import functional as F
  5. from funasr.models.transformer.utils.nets_utils import make_pad_mask
  6. class LabelAggregate(torch.nn.Module):
  7. def __init__(
  8. self,
  9. win_length: int = 512,
  10. hop_length: int = 128,
  11. center: bool = True,
  12. ):
  13. super().__init__()
  14. self.win_length = win_length
  15. self.hop_length = hop_length
  16. self.center = center
  17. def extra_repr(self):
  18. return (
  19. f"win_length={self.win_length}, "
  20. f"hop_length={self.hop_length}, "
  21. f"center={self.center}, "
  22. )
  23. def forward(
  24. self, input: torch.Tensor, ilens: torch.Tensor = None
  25. ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
  26. """LabelAggregate forward function.
  27. Args:
  28. input: (Batch, Nsamples, Label_dim)
  29. ilens: (Batch)
  30. Returns:
  31. output: (Batch, Frames, Label_dim)
  32. """
  33. bs = input.size(0)
  34. max_length = input.size(1)
  35. label_dim = input.size(2)
  36. # NOTE(jiatong):
  37. # The default behaviour of label aggregation is compatible with
  38. # torch.stft about framing and padding.
  39. # Step1: center padding
  40. if self.center:
  41. pad = self.win_length // 2
  42. max_length = max_length + 2 * pad
  43. input = torch.nn.functional.pad(input, (0, 0, pad, pad), "constant", 0)
  44. input[:, :pad, :] = input[:, pad : (2 * pad), :]
  45. input[:, (max_length - pad) : max_length, :] = input[
  46. :, (max_length - 2 * pad) : (max_length - pad), :
  47. ]
  48. nframe = (max_length - self.win_length) // self.hop_length + 1
  49. # Step2: framing
  50. output = input.as_strided(
  51. (bs, nframe, self.win_length, label_dim),
  52. (max_length * label_dim, self.hop_length * label_dim, label_dim, 1),
  53. )
  54. # Step3: aggregate label
  55. output = torch.gt(output.sum(dim=2, keepdim=False), self.win_length // 2)
  56. output = output.float()
  57. # Step4: process lengths
  58. if ilens is not None:
  59. if self.center:
  60. pad = self.win_length // 2
  61. ilens = ilens + 2 * pad
  62. olens = (ilens - self.win_length) // self.hop_length + 1
  63. output.masked_fill_(make_pad_mask(olens, output, 1), 0.0)
  64. else:
  65. olens = None
  66. return output.to(input.dtype), olens
  67. class LabelAggregateMaxPooling(torch.nn.Module):
  68. def __init__(
  69. self,
  70. hop_length: int = 8,
  71. ):
  72. super().__init__()
  73. self.hop_length = hop_length
  74. def extra_repr(self):
  75. return (
  76. f"hop_length={self.hop_length}, "
  77. )
  78. def forward(
  79. self, input: torch.Tensor, ilens: torch.Tensor = None
  80. ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
  81. """LabelAggregate forward function.
  82. Args:
  83. input: (Batch, Nsamples, Label_dim)
  84. ilens: (Batch)
  85. Returns:
  86. output: (Batch, Frames, Label_dim)
  87. """
  88. output = F.max_pool1d(input.transpose(1, 2), self.hop_length, self.hop_length).transpose(1, 2)
  89. olens = ilens // self.hop_length
  90. return output.to(input.dtype), olens