e2e_diar_eend_ola.py 9.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242
  1. # Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved.
  2. # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
  3. from contextlib import contextmanager
  4. from distutils.version import LooseVersion
  5. from typing import Dict
  6. from typing import Tuple
  7. import numpy as np
  8. import torch
  9. import torch.nn as nn
  10. from typeguard import check_argument_types
  11. from funasr.models.frontend.wav_frontend import WavFrontendMel23
  12. from funasr.modules.eend_ola.encoder import EENDOLATransformerEncoder
  13. from funasr.modules.eend_ola.encoder_decoder_attractor import EncoderDecoderAttractor
  14. from funasr.modules.eend_ola.utils.power import generate_mapping_dict
  15. from funasr.torch_utils.device_funcs import force_gatherable
  16. from funasr.train.abs_espnet_model import AbsESPnetModel
  17. if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
  18. pass
  19. else:
  20. # Nothing to do if torch<1.6.0
  21. @contextmanager
  22. def autocast(enabled=True):
  23. yield
  24. def pad_attractor(att, max_n_speakers):
  25. C, D = att.shape
  26. if C < max_n_speakers:
  27. att = torch.cat([att, torch.zeros(max_n_speakers - C, D).to(torch.float32).to(att.device)], dim=0)
  28. return att
  29. class DiarEENDOLAModel(AbsESPnetModel):
  30. """EEND-OLA diarization model"""
  31. def __init__(
  32. self,
  33. frontend: WavFrontendMel23,
  34. encoder: EENDOLATransformerEncoder,
  35. encoder_decoder_attractor: EncoderDecoderAttractor,
  36. n_units: int = 256,
  37. max_n_speaker: int = 8,
  38. attractor_loss_weight: float = 1.0,
  39. mapping_dict=None,
  40. **kwargs,
  41. ):
  42. assert check_argument_types()
  43. super().__init__()
  44. self.frontend = frontend
  45. self.encoder = encoder
  46. self.encoder_decoder_attractor = encoder_decoder_attractor
  47. self.attractor_loss_weight = attractor_loss_weight
  48. self.max_n_speaker = max_n_speaker
  49. if mapping_dict is None:
  50. mapping_dict = generate_mapping_dict(max_speaker_num=self.max_n_speaker)
  51. self.mapping_dict = mapping_dict
  52. # PostNet
  53. self.PostNet = nn.LSTM(self.max_n_speaker, n_units, 1, batch_first=True)
  54. self.output_layer = nn.Linear(n_units, mapping_dict['oov'] + 1)
  55. def forward_encoder(self, xs, ilens):
  56. xs = nn.utils.rnn.pad_sequence(xs, batch_first=True, padding_value=-1)
  57. pad_shape = xs.shape
  58. xs_mask = [torch.ones(ilen).to(xs.device) for ilen in ilens]
  59. xs_mask = torch.nn.utils.rnn.pad_sequence(xs_mask, batch_first=True, padding_value=0).unsqueeze(-2)
  60. emb = self.encoder(xs, xs_mask)
  61. emb = torch.split(emb.view(pad_shape[0], pad_shape[1], -1), 1, dim=0)
  62. emb = [e[0][:ilen] for e, ilen in zip(emb, ilens)]
  63. return emb
  64. def forward_post_net(self, logits, ilens):
  65. maxlen = torch.max(ilens).to(torch.int).item()
  66. logits = nn.utils.rnn.pad_sequence(logits, batch_first=True, padding_value=-1)
  67. logits = nn.utils.rnn.pack_padded_sequence(logits, ilens, batch_first=True, enforce_sorted=False)
  68. outputs, (_, _) = self.PostNet(logits)
  69. outputs = nn.utils.rnn.pad_packed_sequence(outputs, batch_first=True, padding_value=-1, total_length=maxlen)[0]
  70. outputs = [output[:ilens[i].to(torch.int).item()] for i, output in enumerate(outputs)]
  71. outputs = [self.output_layer(output) for output in outputs]
  72. return outputs
  73. def forward(
  74. self,
  75. speech: torch.Tensor,
  76. speech_lengths: torch.Tensor,
  77. text: torch.Tensor,
  78. text_lengths: torch.Tensor,
  79. ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
  80. """Frontend + Encoder + Decoder + Calc loss
  81. Args:
  82. speech: (Batch, Length, ...)
  83. speech_lengths: (Batch, )
  84. text: (Batch, Length)
  85. text_lengths: (Batch,)
  86. """
  87. assert text_lengths.dim() == 1, text_lengths.shape
  88. # Check that batch_size is unified
  89. assert (
  90. speech.shape[0]
  91. == speech_lengths.shape[0]
  92. == text.shape[0]
  93. == text_lengths.shape[0]
  94. ), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
  95. batch_size = speech.shape[0]
  96. # for data-parallel
  97. text = text[:, : text_lengths.max()]
  98. # 1. Encoder
  99. encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
  100. intermediate_outs = None
  101. if isinstance(encoder_out, tuple):
  102. intermediate_outs = encoder_out[1]
  103. encoder_out = encoder_out[0]
  104. loss_att, acc_att, cer_att, wer_att = None, None, None, None
  105. loss_ctc, cer_ctc = None, None
  106. stats = dict()
  107. # 1. CTC branch
  108. if self.ctc_weight != 0.0:
  109. loss_ctc, cer_ctc = self._calc_ctc_loss(
  110. encoder_out, encoder_out_lens, text, text_lengths
  111. )
  112. # Collect CTC branch stats
  113. stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
  114. stats["cer_ctc"] = cer_ctc
  115. # Intermediate CTC (optional)
  116. loss_interctc = 0.0
  117. if self.interctc_weight != 0.0 and intermediate_outs is not None:
  118. for layer_idx, intermediate_out in intermediate_outs:
  119. # we assume intermediate_out has the same length & padding
  120. # as those of encoder_out
  121. loss_ic, cer_ic = self._calc_ctc_loss(
  122. intermediate_out, encoder_out_lens, text, text_lengths
  123. )
  124. loss_interctc = loss_interctc + loss_ic
  125. # Collect Intermedaite CTC stats
  126. stats["loss_interctc_layer{}".format(layer_idx)] = (
  127. loss_ic.detach() if loss_ic is not None else None
  128. )
  129. stats["cer_interctc_layer{}".format(layer_idx)] = cer_ic
  130. loss_interctc = loss_interctc / len(intermediate_outs)
  131. # calculate whole encoder loss
  132. loss_ctc = (
  133. 1 - self.interctc_weight
  134. ) * loss_ctc + self.interctc_weight * loss_interctc
  135. # 2b. Attention decoder branch
  136. if self.ctc_weight != 1.0:
  137. loss_att, acc_att, cer_att, wer_att = self._calc_att_loss(
  138. encoder_out, encoder_out_lens, text, text_lengths
  139. )
  140. # 3. CTC-Att loss definition
  141. if self.ctc_weight == 0.0:
  142. loss = loss_att
  143. elif self.ctc_weight == 1.0:
  144. loss = loss_ctc
  145. else:
  146. loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att
  147. # Collect Attn branch stats
  148. stats["loss_att"] = loss_att.detach() if loss_att is not None else None
  149. stats["acc"] = acc_att
  150. stats["cer"] = cer_att
  151. stats["wer"] = wer_att
  152. # Collect total loss stats
  153. stats["loss"] = torch.clone(loss.detach())
  154. # force_gatherable: to-device and to-tensor if scalar for DataParallel
  155. loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
  156. return loss, stats, weight
  157. def estimate_sequential(self,
  158. speech: torch.Tensor,
  159. speech_lengths: torch.Tensor,
  160. n_speakers: int = None,
  161. shuffle: bool = True,
  162. threshold: float = 0.5,
  163. **kwargs):
  164. if self.frontend is not None:
  165. speech = self.frontend(speech)
  166. speech = [s[:s_len] for s, s_len in zip(speech, speech_lengths)]
  167. emb = self.forward_encoder(speech, speech_lengths)
  168. if shuffle:
  169. orders = [np.arange(e.shape[0]) for e in emb]
  170. for order in orders:
  171. np.random.shuffle(order)
  172. attractors, probs = self.encoder_decoder_attractor.estimate(
  173. [e[torch.from_numpy(order).to(torch.long).to(speech[0].device)] for e, order in zip(emb, orders)])
  174. else:
  175. attractors, probs = self.encoder_decoder_attractor.estimate(emb)
  176. attractors_active = []
  177. for p, att, e in zip(probs, attractors, emb):
  178. if n_speakers and n_speakers >= 0:
  179. att = att[:n_speakers, ]
  180. attractors_active.append(att)
  181. elif threshold is not None:
  182. silence = torch.nonzero(p < threshold)[0]
  183. n_spk = silence[0] if silence.size else None
  184. att = att[:n_spk, ]
  185. attractors_active.append(att)
  186. else:
  187. NotImplementedError('n_speakers or threshold has to be given.')
  188. raw_n_speakers = [att.shape[0] for att in attractors_active]
  189. attractors = [
  190. pad_attractor(att, self.max_n_speaker) if att.shape[0] <= self.max_n_speaker else att[:self.max_n_speaker]
  191. for att in attractors_active]
  192. ys = [torch.matmul(e, att.permute(1, 0)) for e, att in zip(emb, attractors)]
  193. logits = self.forward_post_net(ys, speech_lengths)
  194. ys = [self.recover_y_from_powerlabel(logit, raw_n_speaker) for logit, raw_n_speaker in
  195. zip(logits, raw_n_speakers)]
  196. return ys, emb, attractors, raw_n_speakers
  197. def recover_y_from_powerlabel(self, logit, n_speaker):
  198. pred = torch.argmax(torch.softmax(logit, dim=-1), dim=-1)
  199. oov_index = torch.where(pred == self.mapping_dict['oov'])[0]
  200. for i in oov_index:
  201. if i > 0:
  202. pred[i] = pred[i - 1]
  203. else:
  204. pred[i] = 0
  205. pred = [self.reporter.inv_mapping_func(i, self.mapping_dict) for i in pred]
  206. decisions = [bin(num)[2:].zfill(self.max_n_speaker)[::-1] for num in pred]
  207. decisions = torch.from_numpy(
  208. np.stack([np.array([int(i) for i in dec]) for dec in decisions], axis=0)).to(logit.device).to(
  209. torch.float32)
  210. decisions = decisions[:, :n_speaker]
  211. return decisions