repeat.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. # Copyright 2019 Shigeki Karita
  4. # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
  5. """Repeat the same layer definition."""
  6. from typing import Dict, List, Optional
  7. import torch
  8. class MultiSequential(torch.nn.Sequential):
  9. """Multi-input multi-output torch.nn.Sequential."""
  10. def forward(self, *args):
  11. """Repeat."""
  12. for m in self:
  13. args = m(*args)
  14. return args
  15. def repeat(N, fn):
  16. """Repeat module N times.
  17. Args:
  18. N (int): Number of repeat time.
  19. fn (Callable): Function to generate module.
  20. Returns:
  21. MultiSequential: Repeated model instance.
  22. """
  23. return MultiSequential(*[fn(n) for n in range(N)])
  24. class MultiBlocks(torch.nn.Module):
  25. """MultiBlocks definition.
  26. Args:
  27. block_list: Individual blocks of the encoder architecture.
  28. output_size: Architecture output size.
  29. norm_class: Normalization module class.
  30. norm_args: Normalization module arguments.
  31. """
  32. def __init__(
  33. self,
  34. block_list: List[torch.nn.Module],
  35. output_size: int,
  36. norm_class: torch.nn.Module = torch.nn.LayerNorm,
  37. ) -> None:
  38. """Construct a MultiBlocks object."""
  39. super().__init__()
  40. self.blocks = torch.nn.ModuleList(block_list)
  41. self.norm_blocks = norm_class(output_size)
  42. self.num_blocks = len(block_list)
  43. def reset_streaming_cache(self, left_context: int, device: torch.device) -> None:
  44. """Initialize/Reset encoder streaming cache.
  45. Args:
  46. left_context: Number of left frames during chunk-by-chunk inference.
  47. device: Device to use for cache tensor.
  48. """
  49. for idx in range(self.num_blocks):
  50. self.blocks[idx].reset_streaming_cache(left_context, device)
  51. def forward(
  52. self,
  53. x: torch.Tensor,
  54. pos_enc: torch.Tensor,
  55. mask: torch.Tensor,
  56. chunk_mask: Optional[torch.Tensor] = None,
  57. ) -> torch.Tensor:
  58. """Forward each block of the encoder architecture.
  59. Args:
  60. x: MultiBlocks input sequences. (B, T, D_block_1)
  61. pos_enc: Positional embedding sequences.
  62. mask: Source mask. (B, T)
  63. chunk_mask: Chunk mask. (T_2, T_2)
  64. Returns:
  65. x: Output sequences. (B, T, D_block_N)
  66. """
  67. for block_index, block in enumerate(self.blocks):
  68. x, mask, pos_enc = block(x, pos_enc, mask, chunk_mask=chunk_mask)
  69. x = self.norm_blocks(x)
  70. return x
  71. def chunk_forward(
  72. self,
  73. x: torch.Tensor,
  74. pos_enc: torch.Tensor,
  75. mask: torch.Tensor,
  76. chunk_size: int = 0,
  77. left_context: int = 0,
  78. right_context: int = 0,
  79. ) -> torch.Tensor:
  80. """Forward each block of the encoder architecture.
  81. Args:
  82. x: MultiBlocks input sequences. (B, T, D_block_1)
  83. pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_att)
  84. mask: Source mask. (B, T_2)
  85. left_context: Number of frames in left context.
  86. right_context: Number of frames in right context.
  87. Returns:
  88. x: MultiBlocks output sequences. (B, T, D_block_N)
  89. """
  90. for block_idx, block in enumerate(self.blocks):
  91. x, pos_enc = block.chunk_forward(
  92. x,
  93. pos_enc,
  94. mask,
  95. chunk_size=chunk_size,
  96. left_context=left_context,
  97. right_context=right_context,
  98. )
  99. x = self.norm_blocks(x)
  100. return x