repeat.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. # Copyright 2019 Shigeki Karita
  4. # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
  5. """Repeat the same layer definition."""
  6. from typing import Dict, List, Optional
  7. from funasr.modules.layer_norm import LayerNorm
  8. import torch
  9. class MultiSequential(torch.nn.Sequential):
  10. """Multi-input multi-output torch.nn.Sequential."""
  11. def forward(self, *args):
  12. """Repeat."""
  13. for m in self:
  14. args = m(*args)
  15. return args
  16. def repeat(N, fn):
  17. """Repeat module N times.
  18. Args:
  19. N (int): Number of repeat time.
  20. fn (Callable): Function to generate module.
  21. Returns:
  22. MultiSequential: Repeated model instance.
  23. """
  24. return MultiSequential(*[fn(n) for n in range(N)])
  25. class MultiBlocks(torch.nn.Module):
  26. """MultiBlocks definition.
  27. Args:
  28. block_list: Individual blocks of the encoder architecture.
  29. output_size: Architecture output size.
  30. norm_class: Normalization module class.
  31. norm_args: Normalization module arguments.
  32. """
  33. def __init__(
  34. self,
  35. block_list: List[torch.nn.Module],
  36. output_size: int,
  37. norm_class: torch.nn.Module = LayerNorm,
  38. ) -> None:
  39. """Construct a MultiBlocks object."""
  40. super().__init__()
  41. self.blocks = torch.nn.ModuleList(block_list)
  42. self.norm_blocks = norm_class(output_size)
  43. self.num_blocks = len(block_list)
  44. def reset_streaming_cache(self, left_context: int, device: torch.device) -> None:
  45. """Initialize/Reset encoder streaming cache.
  46. Args:
  47. left_context: Number of left frames during chunk-by-chunk inference.
  48. device: Device to use for cache tensor.
  49. """
  50. for idx in range(self.num_blocks):
  51. self.blocks[idx].reset_streaming_cache(left_context, device)
  52. def forward(
  53. self,
  54. x: torch.Tensor,
  55. pos_enc: torch.Tensor,
  56. mask: torch.Tensor,
  57. chunk_mask: Optional[torch.Tensor] = None,
  58. ) -> torch.Tensor:
  59. """Forward each block of the encoder architecture.
  60. Args:
  61. x: MultiBlocks input sequences. (B, T, D_block_1)
  62. pos_enc: Positional embedding sequences.
  63. mask: Source mask. (B, T)
  64. chunk_mask: Chunk mask. (T_2, T_2)
  65. Returns:
  66. x: Output sequences. (B, T, D_block_N)
  67. """
  68. for block_index, block in enumerate(self.blocks):
  69. x, mask, pos_enc = block(x, pos_enc, mask, chunk_mask=chunk_mask)
  70. x = self.norm_blocks(x)
  71. return x
  72. def chunk_forward(
  73. self,
  74. x: torch.Tensor,
  75. pos_enc: torch.Tensor,
  76. mask: torch.Tensor,
  77. chunk_size: int = 0,
  78. left_context: int = 0,
  79. right_context: int = 0,
  80. ) -> torch.Tensor:
  81. """Forward each block of the encoder architecture.
  82. Args:
  83. x: MultiBlocks input sequences. (B, T, D_block_1)
  84. pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_att)
  85. mask: Source mask. (B, T_2)
  86. left_context: Number of frames in left context.
  87. right_context: Number of frames in right context.
  88. Returns:
  89. x: MultiBlocks output sequences. (B, T, D_block_N)
  90. """
  91. for block_idx, block in enumerate(self.blocks):
  92. x, pos_enc = block.chunk_forward(
  93. x,
  94. pos_enc,
  95. mask,
  96. chunk_size=chunk_size,
  97. left_context=left_context,
  98. right_context=right_context,
  99. )
  100. x = self.norm_blocks(x)
  101. return x