label_aggregation.py 2.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182
  1. import torch
  2. from typeguard import check_argument_types
  3. from typing import Optional
  4. from typing import Tuple
  5. from funasr.modules.nets_utils import make_pad_mask
  6. class LabelAggregate(torch.nn.Module):
  7. def __init__(
  8. self,
  9. win_length: int = 512,
  10. hop_length: int = 128,
  11. center: bool = True,
  12. ):
  13. assert check_argument_types()
  14. super().__init__()
  15. self.win_length = win_length
  16. self.hop_length = hop_length
  17. self.center = center
  18. def extra_repr(self):
  19. return (
  20. f"win_length={self.win_length}, "
  21. f"hop_length={self.hop_length}, "
  22. f"center={self.center}, "
  23. )
  24. def forward(
  25. self, input: torch.Tensor, ilens: torch.Tensor = None
  26. ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
  27. """LabelAggregate forward function.
  28. Args:
  29. input: (Batch, Nsamples, Label_dim)
  30. ilens: (Batch)
  31. Returns:
  32. output: (Batch, Frames, Label_dim)
  33. """
  34. bs = input.size(0)
  35. max_length = input.size(1)
  36. label_dim = input.size(2)
  37. # NOTE(jiatong):
  38. # The default behaviour of label aggregation is compatible with
  39. # torch.stft about framing and padding.
  40. # Step1: center padding
  41. if self.center:
  42. pad = self.win_length // 2
  43. max_length = max_length + 2 * pad
  44. input = torch.nn.functional.pad(input, (0, 0, pad, pad), "constant", 0)
  45. input[:, :pad, :] = input[:, pad : (2 * pad), :]
  46. input[:, (max_length - pad) : max_length, :] = input[
  47. :, (max_length - 2 * pad) : (max_length - pad), :
  48. ]
  49. nframe = (max_length - self.win_length) // self.hop_length + 1
  50. # Step2: framing
  51. output = input.as_strided(
  52. (bs, nframe, self.win_length, label_dim),
  53. (max_length * label_dim, self.hop_length * label_dim, label_dim, 1),
  54. )
  55. # Step3: aggregate label
  56. output = torch.gt(output.sum(dim=2, keepdim=False), self.win_length // 2)
  57. output = output.float()
  58. # Step4: process lengths
  59. if ilens is not None:
  60. if self.center:
  61. pad = self.win_length // 2
  62. ilens = ilens + 2 * pad
  63. olens = (ilens - self.win_length) // self.hop_length + 1
  64. output.masked_fill_(make_pad_mask(olens, output, 1), 0.0)
  65. else:
  66. olens = None
  67. return output, olens