e2e_diar_eend_ola.py 9.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220
  1. from contextlib import contextmanager
  2. from distutils.version import LooseVersion
  3. from typing import Dict, List, Tuple, Optional
  4. import numpy as np
  5. import torch
  6. import torch.nn as nn
  7. import torch.nn.functional as F
  8. from funasr.frontends.wav_frontend import WavFrontendMel23
  9. from funasr.models.eend.encoder import EENDOLATransformerEncoder
  10. from funasr.models.eend.encoder_decoder_attractor import EncoderDecoderAttractor
  11. from funasr.models.eend.utils.losses import standard_loss, cal_power_loss, fast_batch_pit_n_speaker_loss
  12. from funasr.models.eend.utils.power import create_powerlabel
  13. from funasr.models.eend.utils.power import generate_mapping_dict
  14. from funasr.train_utils.device_funcs import force_gatherable
  15. if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
  16. pass
  17. else:
  18. # Nothing to do if torch<1.6.0
  19. @contextmanager
  20. def autocast(enabled=True):
  21. yield
  22. def pad_attractor(att, max_n_speakers):
  23. C, D = att.shape
  24. if C < max_n_speakers:
  25. att = torch.cat([att, torch.zeros(max_n_speakers - C, D).to(torch.float32).to(att.device)], dim=0)
  26. return att
  27. def pad_labels(ts, out_size):
  28. for i, t in enumerate(ts):
  29. if t.shape[1] < out_size:
  30. ts[i] = F.pad(
  31. t,
  32. (0, out_size - t.shape[1], 0, 0),
  33. mode='constant',
  34. value=0.
  35. )
  36. return ts
  37. def pad_results(ys, out_size):
  38. ys_padded = []
  39. for i, y in enumerate(ys):
  40. if y.shape[1] < out_size:
  41. ys_padded.append(
  42. torch.cat([y, torch.zeros(y.shape[0], out_size - y.shape[1]).to(torch.float32).to(y.device)], dim=1))
  43. else:
  44. ys_padded.append(y)
  45. return ys_padded
  46. class DiarEENDOLAModel(nn.Module):
  47. """EEND-OLA diarization model"""
  48. def __init__(
  49. self,
  50. frontend: Optional[WavFrontendMel23],
  51. encoder: EENDOLATransformerEncoder,
  52. encoder_decoder_attractor: EncoderDecoderAttractor,
  53. n_units: int = 256,
  54. max_n_speaker: int = 8,
  55. attractor_loss_weight: float = 1.0,
  56. mapping_dict=None,
  57. **kwargs,
  58. ):
  59. super().__init__()
  60. self.frontend = frontend
  61. self.enc = encoder
  62. self.encoder_decoder_attractor = encoder_decoder_attractor
  63. self.attractor_loss_weight = attractor_loss_weight
  64. self.max_n_speaker = max_n_speaker
  65. if mapping_dict is None:
  66. mapping_dict = generate_mapping_dict(max_speaker_num=self.max_n_speaker)
  67. self.mapping_dict = mapping_dict
  68. # PostNet
  69. self.postnet = nn.LSTM(self.max_n_speaker, n_units, 1, batch_first=True)
  70. self.output_layer = nn.Linear(n_units, mapping_dict['oov'] + 1)
  71. def forward_encoder(self, xs, ilens):
  72. xs = nn.utils.rnn.pad_sequence(xs, batch_first=True, padding_value=-1)
  73. pad_shape = xs.shape
  74. xs_mask = [torch.ones(ilen).to(xs.device) for ilen in ilens]
  75. xs_mask = torch.nn.utils.rnn.pad_sequence(xs_mask, batch_first=True, padding_value=0).unsqueeze(-2)
  76. emb = self.enc(xs, xs_mask)
  77. emb = torch.split(emb.view(pad_shape[0], pad_shape[1], -1), 1, dim=0)
  78. emb = [e[0][:ilen] for e, ilen in zip(emb, ilens)]
  79. return emb
  80. def forward_post_net(self, logits, ilens):
  81. maxlen = torch.max(ilens).to(torch.int).item()
  82. logits = nn.utils.rnn.pad_sequence(logits, batch_first=True, padding_value=-1)
  83. logits = nn.utils.rnn.pack_padded_sequence(logits, ilens.cpu().to(torch.int64), batch_first=True,
  84. enforce_sorted=False)
  85. outputs, (_, _) = self.postnet(logits)
  86. outputs = nn.utils.rnn.pad_packed_sequence(outputs, batch_first=True, padding_value=-1, total_length=maxlen)[0]
  87. outputs = [output[:ilens[i].to(torch.int).item()] for i, output in enumerate(outputs)]
  88. outputs = [self.output_layer(output) for output in outputs]
  89. return outputs
  90. def forward(
  91. self,
  92. speech: List[torch.Tensor],
  93. speaker_labels: List[torch.Tensor],
  94. orders: torch.Tensor,
  95. ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
  96. # Check that batch_size is unified
  97. assert (len(speech) == len(speaker_labels)), (len(speech), len(speaker_labels))
  98. speech_lengths = torch.tensor([len(sph) for sph in speech]).to(torch.int64)
  99. speaker_labels_lengths = torch.tensor([spk.shape[-1] for spk in speaker_labels]).to(torch.int64)
  100. batch_size = len(speech)
  101. # Encoder
  102. encoder_out = self.forward_encoder(speech, speech_lengths)
  103. # Encoder-decoder attractor
  104. attractor_loss, attractors = self.encoder_decoder_attractor([e[order] for e, order in zip(encoder_out, orders)],
  105. speaker_labels_lengths)
  106. speaker_logits = [torch.matmul(e, att.permute(1, 0)) for e, att in zip(encoder_out, attractors)]
  107. # pit loss
  108. pit_speaker_labels = fast_batch_pit_n_speaker_loss(speaker_logits, speaker_labels)
  109. pit_loss = standard_loss(speaker_logits, pit_speaker_labels)
  110. # pse loss
  111. with torch.no_grad():
  112. power_ts = [create_powerlabel(label.cpu().numpy(), self.mapping_dict, self.max_n_speaker).
  113. to(encoder_out[0].device, non_blocking=True) for label in pit_speaker_labels]
  114. pad_attractors = [pad_attractor(att, self.max_n_speaker) for att in attractors]
  115. pse_speaker_logits = [torch.matmul(e, pad_att.permute(1, 0)) for e, pad_att in zip(encoder_out, pad_attractors)]
  116. pse_speaker_logits = self.forward_post_net(pse_speaker_logits, speech_lengths)
  117. pse_loss = cal_power_loss(pse_speaker_logits, power_ts)
  118. loss = pse_loss + pit_loss + self.attractor_loss_weight * attractor_loss
  119. stats = dict()
  120. stats["pse_loss"] = pse_loss.detach()
  121. stats["pit_loss"] = pit_loss.detach()
  122. stats["attractor_loss"] = attractor_loss.detach()
  123. stats["batch_size"] = batch_size
  124. # Collect total loss stats
  125. stats["loss"] = torch.clone(loss.detach())
  126. # force_gatherable: to-device and to-tensor if scalar for DataParallel
  127. loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
  128. return loss, stats, weight
  129. def estimate_sequential(self,
  130. speech: torch.Tensor,
  131. n_speakers: int = None,
  132. shuffle: bool = True,
  133. threshold: float = 0.5,
  134. **kwargs):
  135. speech_lengths = torch.tensor([len(sph) for sph in speech]).to(torch.int64)
  136. emb = self.forward_encoder(speech, speech_lengths)
  137. if shuffle:
  138. orders = [np.arange(e.shape[0]) for e in emb]
  139. for order in orders:
  140. np.random.shuffle(order)
  141. attractors, probs = self.encoder_decoder_attractor.estimate(
  142. [e[torch.from_numpy(order).to(torch.long).to(speech[0].device)] for e, order in zip(emb, orders)])
  143. else:
  144. attractors, probs = self.encoder_decoder_attractor.estimate(emb)
  145. attractors_active = []
  146. for p, att, e in zip(probs, attractors, emb):
  147. if n_speakers and n_speakers >= 0:
  148. att = att[:n_speakers, ]
  149. attractors_active.append(att)
  150. elif threshold is not None:
  151. silence = torch.nonzero(p < threshold)[0]
  152. n_spk = silence[0] if silence.size else None
  153. att = att[:n_spk, ]
  154. attractors_active.append(att)
  155. else:
  156. NotImplementedError('n_speakers or threshold has to be given.')
  157. raw_n_speakers = [att.shape[0] for att in attractors_active]
  158. attractors = [
  159. pad_attractor(att, self.max_n_speaker) if att.shape[0] <= self.max_n_speaker else att[:self.max_n_speaker]
  160. for att in attractors_active]
  161. ys = [torch.matmul(e, att.permute(1, 0)) for e, att in zip(emb, attractors)]
  162. logits = self.forward_post_net(ys, speech_lengths)
  163. ys = [self.recover_y_from_powerlabel(logit, raw_n_speaker) for logit, raw_n_speaker in
  164. zip(logits, raw_n_speakers)]
  165. return ys, emb, attractors, raw_n_speakers
  166. def recover_y_from_powerlabel(self, logit, n_speaker):
  167. pred = torch.argmax(torch.softmax(logit, dim=-1), dim=-1)
  168. oov_index = torch.where(pred == self.mapping_dict['oov'])[0]
  169. for i in oov_index:
  170. if i > 0:
  171. pred[i] = pred[i - 1]
  172. else:
  173. pred[i] = 0
  174. pred = [self.inv_mapping_func(i) for i in pred]
  175. decisions = [bin(num)[2:].zfill(self.max_n_speaker)[::-1] for num in pred]
  176. decisions = torch.from_numpy(
  177. np.stack([np.array([int(i) for i in dec]) for dec in decisions], axis=0)).to(logit.device).to(
  178. torch.float32)
  179. decisions = decisions[:, :n_speaker]
  180. return decisions
  181. def inv_mapping_func(self, label):
  182. if not isinstance(label, int):
  183. label = int(label)
  184. if label in self.mapping_dict['label2dec'].keys():
  185. num = self.mapping_dict['label2dec'][label]
  186. else:
  187. num = -1
  188. return num
  189. def collect_feats(self, **batch: torch.Tensor) -> Dict[str, torch.Tensor]:
  190. pass