e2e_diar_eend_ola.py 9.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221
  1. from contextlib import contextmanager
  2. from distutils.version import LooseVersion
  3. from typing import Dict, List, Tuple, Optional
  4. import numpy as np
  5. import torch
  6. import torch.nn as nn
  7. import torch.nn.functional as F
  8. from funasr.models.base_model import FunASRModel
  9. from funasr.models.frontend.wav_frontend import WavFrontendMel23
  10. from funasr.modules.eend_ola.encoder import EENDOLATransformerEncoder
  11. from funasr.modules.eend_ola.encoder_decoder_attractor import EncoderDecoderAttractor
  12. from funasr.modules.eend_ola.utils.losses import standard_loss, cal_power_loss, fast_batch_pit_n_speaker_loss
  13. from funasr.modules.eend_ola.utils.power import create_powerlabel
  14. from funasr.modules.eend_ola.utils.power import generate_mapping_dict
  15. from funasr.torch_utils.device_funcs import force_gatherable
  16. if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
  17. pass
  18. else:
  19. # Nothing to do if torch<1.6.0
  20. @contextmanager
  21. def autocast(enabled=True):
  22. yield
  23. def pad_attractor(att, max_n_speakers):
  24. C, D = att.shape
  25. if C < max_n_speakers:
  26. att = torch.cat([att, torch.zeros(max_n_speakers - C, D).to(torch.float32).to(att.device)], dim=0)
  27. return att
  28. def pad_labels(ts, out_size):
  29. for i, t in enumerate(ts):
  30. if t.shape[1] < out_size:
  31. ts[i] = F.pad(
  32. t,
  33. (0, out_size - t.shape[1], 0, 0),
  34. mode='constant',
  35. value=0.
  36. )
  37. return ts
  38. def pad_results(ys, out_size):
  39. ys_padded = []
  40. for i, y in enumerate(ys):
  41. if y.shape[1] < out_size:
  42. ys_padded.append(
  43. torch.cat([y, torch.zeros(y.shape[0], out_size - y.shape[1]).to(torch.float32).to(y.device)], dim=1))
  44. else:
  45. ys_padded.append(y)
  46. return ys_padded
  47. class DiarEENDOLAModel(FunASRModel):
  48. """EEND-OLA diarization model"""
  49. def __init__(
  50. self,
  51. frontend: Optional[WavFrontendMel23],
  52. encoder: EENDOLATransformerEncoder,
  53. encoder_decoder_attractor: EncoderDecoderAttractor,
  54. n_units: int = 256,
  55. max_n_speaker: int = 8,
  56. attractor_loss_weight: float = 1.0,
  57. mapping_dict=None,
  58. **kwargs,
  59. ):
  60. super().__init__()
  61. self.frontend = frontend
  62. self.enc = encoder
  63. self.encoder_decoder_attractor = encoder_decoder_attractor
  64. self.attractor_loss_weight = attractor_loss_weight
  65. self.max_n_speaker = max_n_speaker
  66. if mapping_dict is None:
  67. mapping_dict = generate_mapping_dict(max_speaker_num=self.max_n_speaker)
  68. self.mapping_dict = mapping_dict
  69. # PostNet
  70. self.postnet = nn.LSTM(self.max_n_speaker, n_units, 1, batch_first=True)
  71. self.output_layer = nn.Linear(n_units, mapping_dict['oov'] + 1)
  72. def forward_encoder(self, xs, ilens):
  73. xs = nn.utils.rnn.pad_sequence(xs, batch_first=True, padding_value=-1)
  74. pad_shape = xs.shape
  75. xs_mask = [torch.ones(ilen).to(xs.device) for ilen in ilens]
  76. xs_mask = torch.nn.utils.rnn.pad_sequence(xs_mask, batch_first=True, padding_value=0).unsqueeze(-2)
  77. emb = self.enc(xs, xs_mask)
  78. emb = torch.split(emb.view(pad_shape[0], pad_shape[1], -1), 1, dim=0)
  79. emb = [e[0][:ilen] for e, ilen in zip(emb, ilens)]
  80. return emb
  81. def forward_post_net(self, logits, ilens):
  82. maxlen = torch.max(ilens).to(torch.int).item()
  83. logits = nn.utils.rnn.pad_sequence(logits, batch_first=True, padding_value=-1)
  84. logits = nn.utils.rnn.pack_padded_sequence(logits, ilens.cpu().to(torch.int64), batch_first=True,
  85. enforce_sorted=False)
  86. outputs, (_, _) = self.postnet(logits)
  87. outputs = nn.utils.rnn.pad_packed_sequence(outputs, batch_first=True, padding_value=-1, total_length=maxlen)[0]
  88. outputs = [output[:ilens[i].to(torch.int).item()] for i, output in enumerate(outputs)]
  89. outputs = [self.output_layer(output) for output in outputs]
  90. return outputs
  91. def forward(
  92. self,
  93. speech: List[torch.Tensor],
  94. speaker_labels: List[torch.Tensor],
  95. orders: torch.Tensor,
  96. ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
  97. # Check that batch_size is unified
  98. assert (len(speech) == len(speaker_labels)), (len(speech), len(speaker_labels))
  99. speech_lengths = torch.tensor([len(sph) for sph in speech]).to(torch.int64)
  100. speaker_labels_lengths = torch.tensor([spk.shape[-1] for spk in speaker_labels]).to(torch.int64)
  101. batch_size = len(speech)
  102. # Encoder
  103. encoder_out = self.forward_encoder(speech, speech_lengths)
  104. # Encoder-decoder attractor
  105. attractor_loss, attractors = self.encoder_decoder_attractor([e[order] for e, order in zip(encoder_out, orders)],
  106. speaker_labels_lengths)
  107. speaker_logits = [torch.matmul(e, att.permute(1, 0)) for e, att in zip(encoder_out, attractors)]
  108. # pit loss
  109. pit_speaker_labels = fast_batch_pit_n_speaker_loss(speaker_logits, speaker_labels)
  110. pit_loss = standard_loss(speaker_logits, pit_speaker_labels)
  111. # pse loss
  112. with torch.no_grad():
  113. power_ts = [create_powerlabel(label.cpu().numpy(), self.mapping_dict, self.max_n_speaker).
  114. to(encoder_out[0].device, non_blocking=True) for label in pit_speaker_labels]
  115. pad_attractors = [pad_attractor(att, self.max_n_speaker) for att in attractors]
  116. pse_speaker_logits = [torch.matmul(e, pad_att.permute(1, 0)) for e, pad_att in zip(encoder_out, pad_attractors)]
  117. pse_speaker_logits = self.forward_post_net(pse_speaker_logits, speech_lengths)
  118. pse_loss = cal_power_loss(pse_speaker_logits, power_ts)
  119. loss = pse_loss + pit_loss + self.attractor_loss_weight * attractor_loss
  120. stats = dict()
  121. stats["pse_loss"] = pse_loss.detach()
  122. stats["pit_loss"] = pit_loss.detach()
  123. stats["attractor_loss"] = attractor_loss.detach()
  124. stats["batch_size"] = batch_size
  125. # Collect total loss stats
  126. stats["loss"] = torch.clone(loss.detach())
  127. # force_gatherable: to-device and to-tensor if scalar for DataParallel
  128. loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
  129. return loss, stats, weight
  130. def estimate_sequential(self,
  131. speech: torch.Tensor,
  132. n_speakers: int = None,
  133. shuffle: bool = True,
  134. threshold: float = 0.5,
  135. **kwargs):
  136. speech_lengths = torch.tensor([len(sph) for sph in speech]).to(torch.int64)
  137. emb = self.forward_encoder(speech, speech_lengths)
  138. if shuffle:
  139. orders = [np.arange(e.shape[0]) for e in emb]
  140. for order in orders:
  141. np.random.shuffle(order)
  142. attractors, probs = self.encoder_decoder_attractor.estimate(
  143. [e[torch.from_numpy(order).to(torch.long).to(speech[0].device)] for e, order in zip(emb, orders)])
  144. else:
  145. attractors, probs = self.encoder_decoder_attractor.estimate(emb)
  146. attractors_active = []
  147. for p, att, e in zip(probs, attractors, emb):
  148. if n_speakers and n_speakers >= 0:
  149. att = att[:n_speakers, ]
  150. attractors_active.append(att)
  151. elif threshold is not None:
  152. silence = torch.nonzero(p < threshold)[0]
  153. n_spk = silence[0] if silence.size else None
  154. att = att[:n_spk, ]
  155. attractors_active.append(att)
  156. else:
  157. NotImplementedError('n_speakers or threshold has to be given.')
  158. raw_n_speakers = [att.shape[0] for att in attractors_active]
  159. attractors = [
  160. pad_attractor(att, self.max_n_speaker) if att.shape[0] <= self.max_n_speaker else att[:self.max_n_speaker]
  161. for att in attractors_active]
  162. ys = [torch.matmul(e, att.permute(1, 0)) for e, att in zip(emb, attractors)]
  163. logits = self.forward_post_net(ys, speech_lengths)
  164. ys = [self.recover_y_from_powerlabel(logit, raw_n_speaker) for logit, raw_n_speaker in
  165. zip(logits, raw_n_speakers)]
  166. return ys, emb, attractors, raw_n_speakers
  167. def recover_y_from_powerlabel(self, logit, n_speaker):
  168. pred = torch.argmax(torch.softmax(logit, dim=-1), dim=-1)
  169. oov_index = torch.where(pred == self.mapping_dict['oov'])[0]
  170. for i in oov_index:
  171. if i > 0:
  172. pred[i] = pred[i - 1]
  173. else:
  174. pred[i] = 0
  175. pred = [self.inv_mapping_func(i) for i in pred]
  176. decisions = [bin(num)[2:].zfill(self.max_n_speaker)[::-1] for num in pred]
  177. decisions = torch.from_numpy(
  178. np.stack([np.array([int(i) for i in dec]) for dec in decisions], axis=0)).to(logit.device).to(
  179. torch.float32)
  180. decisions = decisions[:, :n_speaker]
  181. return decisions
  182. def inv_mapping_func(self, label):
  183. if not isinstance(label, int):
  184. label = int(label)
  185. if label in self.mapping_dict['label2dec'].keys():
  186. num = self.mapping_dict['label2dec'][label]
  187. else:
  188. num = -1
  189. return num
  190. def collect_feats(self, **batch: torch.Tensor) -> Dict[str, torch.Tensor]:
  191. pass