repeat.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. # Copyright 2019 Shigeki Karita
  4. # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
  5. """Repeat the same layer definition."""
  6. from typing import Dict, List, Optional
  7. from funasr.modules.layer_norm import LayerNorm
  8. import torch
  9. class MultiSequential(torch.nn.Sequential):
  10. """Multi-input multi-output torch.nn.Sequential."""
  11. def __init__(self, *args, layer_drop_rate=0.0):
  12. """Initialize MultiSequential with layer_drop.
  13. Args:
  14. layer_drop_rate (float): Probability of dropping out each fn (layer).
  15. """
  16. super(MultiSequential, self).__init__(*args)
  17. self.layer_drop_rate = layer_drop_rate
  18. def forward(self, *args):
  19. """Repeat."""
  20. _probs = torch.empty(len(self)).uniform_()
  21. for idx, m in enumerate(self):
  22. if not self.training or (_probs[idx] >= self.layer_drop_rate):
  23. args = m(*args)
  24. return args
  25. def repeat(N, fn, layer_drop_rate=0.0):
  26. """Repeat module N times.
  27. Args:
  28. N (int): Number of repeat time.
  29. fn (Callable): Function to generate module.
  30. layer_drop_rate (float): Probability of dropping out each fn (layer).
  31. Returns:
  32. MultiSequential: Repeated model instance.
  33. """
  34. return MultiSequential(*[fn(n) for n in range(N)], layer_drop_rate=layer_drop_rate)
  35. class MultiBlocks(torch.nn.Module):
  36. """MultiBlocks definition.
  37. Args:
  38. block_list: Individual blocks of the encoder architecture.
  39. output_size: Architecture output size.
  40. norm_class: Normalization module class.
  41. norm_args: Normalization module arguments.
  42. """
  43. def __init__(
  44. self,
  45. block_list: List[torch.nn.Module],
  46. output_size: int,
  47. norm_class: torch.nn.Module = LayerNorm,
  48. ) -> None:
  49. """Construct a MultiBlocks object."""
  50. super().__init__()
  51. self.blocks = torch.nn.ModuleList(block_list)
  52. self.norm_blocks = norm_class(output_size)
  53. self.num_blocks = len(block_list)
  54. def reset_streaming_cache(self, left_context: int, device: torch.device) -> None:
  55. """Initialize/Reset encoder streaming cache.
  56. Args:
  57. left_context: Number of left frames during chunk-by-chunk inference.
  58. device: Device to use for cache tensor.
  59. """
  60. for idx in range(self.num_blocks):
  61. self.blocks[idx].reset_streaming_cache(left_context, device)
  62. def forward(
  63. self,
  64. x: torch.Tensor,
  65. pos_enc: torch.Tensor,
  66. mask: torch.Tensor,
  67. chunk_mask: Optional[torch.Tensor] = None,
  68. ) -> torch.Tensor:
  69. """Forward each block of the encoder architecture.
  70. Args:
  71. x: MultiBlocks input sequences. (B, T, D_block_1)
  72. pos_enc: Positional embedding sequences.
  73. mask: Source mask. (B, T)
  74. chunk_mask: Chunk mask. (T_2, T_2)
  75. Returns:
  76. x: Output sequences. (B, T, D_block_N)
  77. """
  78. for block_index, block in enumerate(self.blocks):
  79. x, mask, pos_enc = block(x, pos_enc, mask, chunk_mask=chunk_mask)
  80. x = self.norm_blocks(x)
  81. return x
  82. def chunk_forward(
  83. self,
  84. x: torch.Tensor,
  85. pos_enc: torch.Tensor,
  86. mask: torch.Tensor,
  87. chunk_size: int = 0,
  88. left_context: int = 0,
  89. right_context: int = 0,
  90. ) -> torch.Tensor:
  91. """Forward each block of the encoder architecture.
  92. Args:
  93. x: MultiBlocks input sequences. (B, T, D_block_1)
  94. pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_att)
  95. mask: Source mask. (B, T_2)
  96. left_context: Number of frames in left context.
  97. right_context: Number of frames in right context.
  98. Returns:
  99. x: MultiBlocks output sequences. (B, T, D_block_N)
  100. """
  101. for block_idx, block in enumerate(self.blocks):
  102. x, pos_enc = block.chunk_forward(
  103. x,
  104. pos_enc,
  105. mask,
  106. chunk_size=chunk_size,
  107. left_context=left_context,
  108. right_context=right_context,
  109. )
  110. x = self.norm_blocks(x)
  111. return x