label_aggregation.py 2.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980
  1. import torch
  2. from typing import Optional
  3. from typing import Tuple
  4. from funasr.modules.nets_utils import make_pad_mask
  5. class LabelAggregate(torch.nn.Module):
  6. def __init__(
  7. self,
  8. win_length: int = 512,
  9. hop_length: int = 128,
  10. center: bool = True,
  11. ):
  12. super().__init__()
  13. self.win_length = win_length
  14. self.hop_length = hop_length
  15. self.center = center
  16. def extra_repr(self):
  17. return (
  18. f"win_length={self.win_length}, "
  19. f"hop_length={self.hop_length}, "
  20. f"center={self.center}, "
  21. )
  22. def forward(
  23. self, input: torch.Tensor, ilens: torch.Tensor = None
  24. ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
  25. """LabelAggregate forward function.
  26. Args:
  27. input: (Batch, Nsamples, Label_dim)
  28. ilens: (Batch)
  29. Returns:
  30. output: (Batch, Frames, Label_dim)
  31. """
  32. bs = input.size(0)
  33. max_length = input.size(1)
  34. label_dim = input.size(2)
  35. # NOTE(jiatong):
  36. # The default behaviour of label aggregation is compatible with
  37. # torch.stft about framing and padding.
  38. # Step1: center padding
  39. if self.center:
  40. pad = self.win_length // 2
  41. max_length = max_length + 2 * pad
  42. input = torch.nn.functional.pad(input, (0, 0, pad, pad), "constant", 0)
  43. input[:, :pad, :] = input[:, pad : (2 * pad), :]
  44. input[:, (max_length - pad) : max_length, :] = input[
  45. :, (max_length - 2 * pad) : (max_length - pad), :
  46. ]
  47. nframe = (max_length - self.win_length) // self.hop_length + 1
  48. # Step2: framing
  49. output = input.as_strided(
  50. (bs, nframe, self.win_length, label_dim),
  51. (max_length * label_dim, self.hop_length * label_dim, label_dim, 1),
  52. )
  53. # Step3: aggregate label
  54. output = torch.gt(output.sum(dim=2, keepdim=False), self.win_length // 2)
  55. output = output.float()
  56. # Step4: process lengths
  57. if ilens is not None:
  58. if self.center:
  59. pad = self.win_length // 2
  60. ilens = ilens + 2 * pad
  61. olens = (ilens - self.win_length) // self.hop_length + 1
  62. output.masked_fill_(make_pad_mask(olens, output, 1), 0.0)
  63. else:
  64. olens = None
  65. return output.to(input.dtype), olens