transformer_decoder_sa_asr.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291
  1. from typing import Any
  2. from typing import List
  3. from typing import Sequence
  4. from typing import Tuple
  5. import torch
  6. from typeguard import check_argument_types
  7. from funasr.modules.nets_utils import make_pad_mask
  8. from funasr.modules.attention import MultiHeadedAttention
  9. from funasr.modules.attention import CosineDistanceAttention
  10. from funasr.models.decoder.transformer_decoder import DecoderLayer
  11. from funasr.models.decoder.decoder_layer_sa_asr import SpeakerAttributeAsrDecoderFirstLayer
  12. from funasr.models.decoder.decoder_layer_sa_asr import SpeakerAttributeSpkDecoderFirstLayer
  13. from funasr.modules.dynamic_conv import DynamicConvolution
  14. from funasr.modules.dynamic_conv2d import DynamicConvolution2D
  15. from funasr.modules.embedding import PositionalEncoding
  16. from funasr.modules.layer_norm import LayerNorm
  17. from funasr.modules.lightconv import LightweightConvolution
  18. from funasr.modules.lightconv2d import LightweightConvolution2D
  19. from funasr.modules.mask import subsequent_mask
  20. from funasr.modules.positionwise_feed_forward import (
  21. PositionwiseFeedForward, # noqa: H301
  22. )
  23. from funasr.modules.repeat import repeat
  24. from funasr.modules.scorers.scorer_interface import BatchScorerInterface
  25. from funasr.models.decoder.abs_decoder import AbsDecoder
  26. class BaseSAAsrTransformerDecoder(AbsDecoder, BatchScorerInterface):
  27. def __init__(
  28. self,
  29. vocab_size: int,
  30. encoder_output_size: int,
  31. spker_embedding_dim: int = 256,
  32. dropout_rate: float = 0.1,
  33. positional_dropout_rate: float = 0.1,
  34. input_layer: str = "embed",
  35. use_asr_output_layer: bool = True,
  36. use_spk_output_layer: bool = True,
  37. pos_enc_class=PositionalEncoding,
  38. normalize_before: bool = True,
  39. ):
  40. assert check_argument_types()
  41. super().__init__()
  42. attention_dim = encoder_output_size
  43. if input_layer == "embed":
  44. self.embed = torch.nn.Sequential(
  45. torch.nn.Embedding(vocab_size, attention_dim),
  46. pos_enc_class(attention_dim, positional_dropout_rate),
  47. )
  48. elif input_layer == "linear":
  49. self.embed = torch.nn.Sequential(
  50. torch.nn.Linear(vocab_size, attention_dim),
  51. torch.nn.LayerNorm(attention_dim),
  52. torch.nn.Dropout(dropout_rate),
  53. torch.nn.ReLU(),
  54. pos_enc_class(attention_dim, positional_dropout_rate),
  55. )
  56. else:
  57. raise ValueError(f"only 'embed' or 'linear' is supported: {input_layer}")
  58. self.normalize_before = normalize_before
  59. if self.normalize_before:
  60. self.after_norm = LayerNorm(attention_dim)
  61. if use_asr_output_layer:
  62. self.asr_output_layer = torch.nn.Linear(attention_dim, vocab_size)
  63. else:
  64. self.asr_output_layer = None
  65. if use_spk_output_layer:
  66. self.spk_output_layer = torch.nn.Linear(attention_dim, spker_embedding_dim)
  67. else:
  68. self.spk_output_layer = None
  69. self.cos_distance_att = CosineDistanceAttention()
  70. self.decoder1 = None
  71. self.decoder2 = None
  72. self.decoder3 = None
  73. self.decoder4 = None
  74. def forward(
  75. self,
  76. asr_hs_pad: torch.Tensor,
  77. spk_hs_pad: torch.Tensor,
  78. hlens: torch.Tensor,
  79. ys_in_pad: torch.Tensor,
  80. ys_in_lens: torch.Tensor,
  81. profile: torch.Tensor,
  82. profile_lens: torch.Tensor,
  83. ) -> Tuple[torch.Tensor, torch.Tensor]:
  84. tgt = ys_in_pad
  85. # tgt_mask: (B, 1, L)
  86. tgt_mask = (~make_pad_mask(ys_in_lens)[:, None, :]).to(tgt.device)
  87. # m: (1, L, L)
  88. m = subsequent_mask(tgt_mask.size(-1), device=tgt_mask.device).unsqueeze(0)
  89. # tgt_mask: (B, L, L)
  90. tgt_mask = tgt_mask & m
  91. asr_memory = asr_hs_pad
  92. spk_memory = spk_hs_pad
  93. memory_mask = (~make_pad_mask(hlens))[:, None, :].to(asr_memory.device)
  94. # Spk decoder
  95. x = self.embed(tgt)
  96. x, tgt_mask, asr_memory, spk_memory, memory_mask, z = self.decoder1(
  97. x, tgt_mask, asr_memory, spk_memory, memory_mask
  98. )
  99. x, tgt_mask, spk_memory, memory_mask = self.decoder2(
  100. x, tgt_mask, spk_memory, memory_mask
  101. )
  102. if self.normalize_before:
  103. x = self.after_norm(x)
  104. if self.spk_output_layer is not None:
  105. x = self.spk_output_layer(x)
  106. dn, weights = self.cos_distance_att(x, profile, profile_lens)
  107. # Asr decoder
  108. x, tgt_mask, asr_memory, memory_mask = self.decoder3(
  109. z, tgt_mask, asr_memory, memory_mask, dn
  110. )
  111. x, tgt_mask, asr_memory, memory_mask = self.decoder4(
  112. x, tgt_mask, asr_memory, memory_mask
  113. )
  114. if self.normalize_before:
  115. x = self.after_norm(x)
  116. if self.asr_output_layer is not None:
  117. x = self.asr_output_layer(x)
  118. olens = tgt_mask.sum(1)
  119. return x, weights, olens
  120. def forward_one_step(
  121. self,
  122. tgt: torch.Tensor,
  123. tgt_mask: torch.Tensor,
  124. asr_memory: torch.Tensor,
  125. spk_memory: torch.Tensor,
  126. profile: torch.Tensor,
  127. cache: List[torch.Tensor] = None,
  128. ) -> Tuple[torch.Tensor, List[torch.Tensor]]:
  129. x = self.embed(tgt)
  130. if cache is None:
  131. cache = [None] * (2 + len(self.decoder2) + len(self.decoder4))
  132. new_cache = []
  133. x, tgt_mask, asr_memory, spk_memory, _, z = self.decoder1(
  134. x, tgt_mask, asr_memory, spk_memory, None, cache=cache[0]
  135. )
  136. new_cache.append(x)
  137. for c, decoder in zip(cache[1: len(self.decoder2) + 1], self.decoder2):
  138. x, tgt_mask, spk_memory, _ = decoder(
  139. x, tgt_mask, spk_memory, None, cache=c
  140. )
  141. new_cache.append(x)
  142. if self.normalize_before:
  143. x = self.after_norm(x)
  144. else:
  145. x = x
  146. if self.spk_output_layer is not None:
  147. x = self.spk_output_layer(x)
  148. dn, weights = self.cos_distance_att(x, profile, None)
  149. x, tgt_mask, asr_memory, _ = self.decoder3(
  150. z, tgt_mask, asr_memory, None, dn, cache=cache[len(self.decoder2) + 1]
  151. )
  152. new_cache.append(x)
  153. for c, decoder in zip(cache[len(self.decoder2) + 2: ], self.decoder4):
  154. x, tgt_mask, asr_memory, _ = decoder(
  155. x, tgt_mask, asr_memory, None, cache=c
  156. )
  157. new_cache.append(x)
  158. if self.normalize_before:
  159. y = self.after_norm(x[:, -1])
  160. else:
  161. y = x[:, -1]
  162. if self.asr_output_layer is not None:
  163. y = torch.log_softmax(self.asr_output_layer(y), dim=-1)
  164. return y, weights, new_cache
  165. def score(self, ys, state, asr_enc, spk_enc, profile):
  166. """Score."""
  167. ys_mask = subsequent_mask(len(ys), device=ys.device).unsqueeze(0)
  168. logp, weights, state = self.forward_one_step(
  169. ys.unsqueeze(0), ys_mask, asr_enc.unsqueeze(0), spk_enc.unsqueeze(0), profile.unsqueeze(0), cache=state
  170. )
  171. return logp.squeeze(0), weights.squeeze(), state
  172. class SAAsrTransformerDecoder(BaseSAAsrTransformerDecoder):
  173. def __init__(
  174. self,
  175. vocab_size: int,
  176. encoder_output_size: int,
  177. spker_embedding_dim: int = 256,
  178. attention_heads: int = 4,
  179. linear_units: int = 2048,
  180. asr_num_blocks: int = 6,
  181. spk_num_blocks: int = 3,
  182. dropout_rate: float = 0.1,
  183. positional_dropout_rate: float = 0.1,
  184. self_attention_dropout_rate: float = 0.0,
  185. src_attention_dropout_rate: float = 0.0,
  186. input_layer: str = "embed",
  187. use_asr_output_layer: bool = True,
  188. use_spk_output_layer: bool = True,
  189. pos_enc_class=PositionalEncoding,
  190. normalize_before: bool = True,
  191. concat_after: bool = False,
  192. ):
  193. assert check_argument_types()
  194. super().__init__(
  195. vocab_size=vocab_size,
  196. encoder_output_size=encoder_output_size,
  197. spker_embedding_dim=spker_embedding_dim,
  198. dropout_rate=dropout_rate,
  199. positional_dropout_rate=positional_dropout_rate,
  200. input_layer=input_layer,
  201. use_asr_output_layer=use_asr_output_layer,
  202. use_spk_output_layer=use_spk_output_layer,
  203. pos_enc_class=pos_enc_class,
  204. normalize_before=normalize_before,
  205. )
  206. attention_dim = encoder_output_size
  207. self.decoder1 = SpeakerAttributeSpkDecoderFirstLayer(
  208. attention_dim,
  209. MultiHeadedAttention(
  210. attention_heads, attention_dim, self_attention_dropout_rate
  211. ),
  212. MultiHeadedAttention(
  213. attention_heads, attention_dim, src_attention_dropout_rate
  214. ),
  215. PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
  216. dropout_rate,
  217. normalize_before,
  218. concat_after,
  219. )
  220. self.decoder2 = repeat(
  221. spk_num_blocks - 1,
  222. lambda lnum: DecoderLayer(
  223. attention_dim,
  224. MultiHeadedAttention(
  225. attention_heads, attention_dim, self_attention_dropout_rate
  226. ),
  227. MultiHeadedAttention(
  228. attention_heads, attention_dim, src_attention_dropout_rate
  229. ),
  230. PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
  231. dropout_rate,
  232. normalize_before,
  233. concat_after,
  234. ),
  235. )
  236. self.decoder3 = SpeakerAttributeAsrDecoderFirstLayer(
  237. attention_dim,
  238. spker_embedding_dim,
  239. MultiHeadedAttention(
  240. attention_heads, attention_dim, src_attention_dropout_rate
  241. ),
  242. PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
  243. dropout_rate,
  244. normalize_before,
  245. concat_after,
  246. )
  247. self.decoder4 = repeat(
  248. asr_num_blocks - 1,
  249. lambda lnum: DecoderLayer(
  250. attention_dim,
  251. MultiHeadedAttention(
  252. attention_heads, attention_dim, self_attention_dropout_rate
  253. ),
  254. MultiHeadedAttention(
  255. attention_heads, attention_dim, src_attention_dropout_rate
  256. ),
  257. PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
  258. dropout_rate,
  259. normalize_before,
  260. concat_after,
  261. ),
  262. )