e2e_diar_sond.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558
  1. #!/usr/bin/env python3
  2. # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
  3. # MIT License (https://opensource.org/licenses/MIT)
  4. import logging
  5. import random
  6. from contextlib import contextmanager
  7. from distutils.version import LooseVersion
  8. from itertools import permutations
  9. from typing import Dict
  10. from typing import Optional
  11. from typing import Tuple, List
  12. import numpy as np
  13. import torch
  14. from torch.nn import functional as F
  15. from funasr.models.transformer.utils.nets_utils import to_device
  16. from funasr.models.transformer.utils.nets_utils import make_pad_mask
  17. from funasr.models.decoder.abs_decoder import AbsDecoder
  18. from funasr.models.encoder.abs_encoder import AbsEncoder
  19. from funasr.frontends.abs_frontend import AbsFrontend
  20. from funasr.models.specaug.abs_specaug import AbsSpecAug
  21. from funasr.models.specaug.abs_profileaug import AbsProfileAug
  22. from funasr.layers.abs_normalize import AbsNormalize
  23. from funasr.train_utils.device_funcs import force_gatherable
  24. from funasr.models.base_model import FunASRModel
  25. from funasr.losses.label_smoothing_loss import LabelSmoothingLoss, SequenceBinaryCrossEntropy
  26. from funasr.utils.misc import int2vec
  27. from funasr.utils.hinter import hint_once
  28. if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
  29. from torch.cuda.amp import autocast
  30. else:
  31. # Nothing to do if torch<1.6.0
  32. @contextmanager
  33. def autocast(enabled=True):
  34. yield
  35. class DiarSondModel(FunASRModel):
  36. """Speaker overlap-aware neural diarization model
  37. reference: https://arxiv.org/abs/2211.10243
  38. """
  39. def __init__(
  40. self,
  41. vocab_size: int,
  42. frontend: Optional[AbsFrontend],
  43. specaug: Optional[AbsSpecAug],
  44. profileaug: Optional[AbsProfileAug],
  45. normalize: Optional[AbsNormalize],
  46. encoder: torch.nn.Module,
  47. speaker_encoder: Optional[torch.nn.Module],
  48. ci_scorer: torch.nn.Module,
  49. cd_scorer: Optional[torch.nn.Module],
  50. decoder: torch.nn.Module,
  51. token_list: list,
  52. lsm_weight: float = 0.1,
  53. length_normalized_loss: bool = False,
  54. max_spk_num: int = 16,
  55. label_aggregator: Optional[torch.nn.Module] = None,
  56. normalize_speech_speaker: bool = False,
  57. ignore_id: int = -1,
  58. speaker_discrimination_loss_weight: float = 1.0,
  59. inter_score_loss_weight: float = 0.0,
  60. inputs_type: str = "raw",
  61. model_regularizer_weight: float = 0.0,
  62. freeze_encoder: bool = False,
  63. onfly_shuffle_speaker: bool = True,
  64. ):
  65. super().__init__()
  66. self.encoder = encoder
  67. self.speaker_encoder = speaker_encoder
  68. self.ci_scorer = ci_scorer
  69. self.cd_scorer = cd_scorer
  70. self.normalize = normalize
  71. self.frontend = frontend
  72. self.specaug = specaug
  73. self.profileaug = profileaug
  74. self.label_aggregator = label_aggregator
  75. self.decoder = decoder
  76. self.token_list = token_list
  77. self.max_spk_num = max_spk_num
  78. self.normalize_speech_speaker = normalize_speech_speaker
  79. self.ignore_id = ignore_id
  80. self.model_regularizer_weight = model_regularizer_weight
  81. self.freeze_encoder = freeze_encoder
  82. self.onfly_shuffle_speaker = onfly_shuffle_speaker
  83. self.criterion_diar = LabelSmoothingLoss(
  84. size=vocab_size,
  85. padding_idx=ignore_id,
  86. smoothing=lsm_weight,
  87. normalize_length=length_normalized_loss,
  88. )
  89. self.criterion_bce = SequenceBinaryCrossEntropy(normalize_length=length_normalized_loss)
  90. self.pse_embedding = self.generate_pse_embedding()
  91. self.power_weight = torch.from_numpy(2 ** np.arange(max_spk_num)[np.newaxis, np.newaxis, :]).float()
  92. self.int_token_arr = torch.from_numpy(np.array(self.token_list).astype(int)[np.newaxis, np.newaxis, :]).int()
  93. self.speaker_discrimination_loss_weight = speaker_discrimination_loss_weight
  94. self.inter_score_loss_weight = inter_score_loss_weight
  95. self.forward_steps = 0
  96. self.inputs_type = inputs_type
  97. self.to_regularize_parameters = None
  98. def get_regularize_parameters(self):
  99. to_regularize_parameters, normal_parameters = [], []
  100. for name, param in self.named_parameters():
  101. if ("encoder" in name and "weight" in name and "bn" not in name and
  102. ("conv2" in name or "conv1" in name or "conv_sc" in name or "dense" in name)
  103. ):
  104. to_regularize_parameters.append((name, param))
  105. else:
  106. normal_parameters.append((name, param))
  107. self.to_regularize_parameters = to_regularize_parameters
  108. return to_regularize_parameters, normal_parameters
  109. def generate_pse_embedding(self):
  110. embedding = np.zeros((len(self.token_list), self.max_spk_num), dtype=np.float32)
  111. for idx, pse_label in enumerate(self.token_list):
  112. emb = int2vec(int(pse_label), vec_dim=self.max_spk_num, dtype=np.float32)
  113. embedding[idx] = emb
  114. return torch.from_numpy(embedding)
  115. def rand_permute_speaker(self, raw_profile, raw_binary_labels):
  116. """
  117. raw_profile: B, N, D
  118. raw_binary_labels: B, T, N
  119. """
  120. assert raw_profile.shape[1] == raw_binary_labels.shape[2], \
  121. "Num profile: {}, Num label: {}".format(raw_profile.shape[1], raw_binary_labels.shape[-1])
  122. profile = torch.clone(raw_profile)
  123. binary_labels = torch.clone(raw_binary_labels)
  124. bsz, num_spk = profile.shape[0], profile.shape[1]
  125. for i in range(bsz):
  126. idx = list(range(num_spk))
  127. random.shuffle(idx)
  128. profile[i] = profile[i][idx, :]
  129. binary_labels[i] = binary_labels[i][:, idx]
  130. return profile, binary_labels
  131. def forward(
  132. self,
  133. speech: torch.Tensor,
  134. speech_lengths: torch.Tensor = None,
  135. profile: torch.Tensor = None,
  136. profile_lengths: torch.Tensor = None,
  137. binary_labels: torch.Tensor = None,
  138. binary_labels_lengths: torch.Tensor = None,
  139. ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
  140. """Frontend + Encoder + Speaker Encoder + CI Scorer + CD Scorer + Decoder + Calc loss
  141. Args:
  142. speech: (Batch, samples) or (Batch, frames, input_size)
  143. speech_lengths: (Batch,) default None for chunk interator,
  144. because the chunk-iterator does not
  145. have the speech_lengths returned.
  146. see in
  147. espnet2/iterators/chunk_iter_factory.py
  148. profile: (Batch, N_spk, dim)
  149. profile_lengths: (Batch,)
  150. binary_labels: (Batch, frames, max_spk_num)
  151. binary_labels_lengths: (Batch,)
  152. """
  153. assert speech.shape[0] <= binary_labels.shape[0], (speech.shape, binary_labels.shape)
  154. batch_size = speech.shape[0]
  155. if self.freeze_encoder:
  156. hint_once("Freeze encoder", "freeze_encoder", rank=0)
  157. self.encoder.eval()
  158. self.forward_steps = self.forward_steps + 1
  159. if self.pse_embedding.device != speech.device:
  160. self.pse_embedding = self.pse_embedding.to(speech.device)
  161. self.power_weight = self.power_weight.to(speech.device)
  162. self.int_token_arr = self.int_token_arr.to(speech.device)
  163. if self.onfly_shuffle_speaker:
  164. hint_once("On-the-fly shuffle speaker permutation.", "onfly_shuffle_speaker", rank=0)
  165. profile, binary_labels = self.rand_permute_speaker(profile, binary_labels)
  166. # 0a. Aggregate time-domain labels to match forward outputs
  167. if self.label_aggregator is not None:
  168. binary_labels, binary_labels_lengths = self.label_aggregator(
  169. binary_labels, binary_labels_lengths
  170. )
  171. # 0b. augment profiles
  172. if self.profileaug is not None and self.training:
  173. speech, profile, binary_labels = self.profileaug(
  174. speech, speech_lengths,
  175. profile, profile_lengths,
  176. binary_labels, binary_labels_lengths
  177. )
  178. # 1. Calculate power-set encoding (PSE) labels
  179. pad_bin_labels = F.pad(binary_labels, (0, self.max_spk_num - binary_labels.shape[2]), "constant", 0.0)
  180. raw_pse_labels = torch.sum(pad_bin_labels * self.power_weight, dim=2, keepdim=True)
  181. pse_labels = torch.argmax((raw_pse_labels.int() == self.int_token_arr).float(), dim=2)
  182. # 2. Network forward
  183. pred, inter_outputs = self.prediction_forward(
  184. speech, speech_lengths,
  185. profile, profile_lengths,
  186. return_inter_outputs=True
  187. )
  188. (speech, speech_lengths), (profile, profile_lengths), (ci_score, cd_score) = inter_outputs
  189. # If encoder uses conv* as input_layer (i.e., subsampling),
  190. # the sequence length of 'pred' might be slightly less than the
  191. # length of 'spk_labels'. Here we force them to be equal.
  192. length_diff_tolerance = 2
  193. length_diff = abs(pse_labels.shape[1] - pred.shape[1])
  194. if length_diff <= length_diff_tolerance:
  195. min_len = min(pred.shape[1], pse_labels.shape[1])
  196. pse_labels = pse_labels[:, :min_len]
  197. pred = pred[:, :min_len]
  198. cd_score = cd_score[:, :min_len]
  199. ci_score = ci_score[:, :min_len]
  200. loss_diar = self.classification_loss(pred, pse_labels, binary_labels_lengths)
  201. loss_spk_dis = self.speaker_discrimination_loss(profile, profile_lengths)
  202. loss_inter_ci, loss_inter_cd = self.internal_score_loss(cd_score, ci_score, pse_labels, binary_labels_lengths)
  203. regularizer_loss = None
  204. if self.model_regularizer_weight > 0 and self.to_regularize_parameters is not None:
  205. regularizer_loss = self.calculate_regularizer_loss()
  206. label_mask = make_pad_mask(binary_labels_lengths, maxlen=pse_labels.shape[1]).to(pse_labels.device)
  207. loss = (loss_diar + self.speaker_discrimination_loss_weight * loss_spk_dis
  208. + self.inter_score_loss_weight * (loss_inter_ci + loss_inter_cd))
  209. # if regularizer_loss is not None:
  210. # loss = loss + regularizer_loss * self.model_regularizer_weight
  211. (
  212. correct,
  213. num_frames,
  214. speech_scored,
  215. speech_miss,
  216. speech_falarm,
  217. speaker_scored,
  218. speaker_miss,
  219. speaker_falarm,
  220. speaker_error,
  221. ) = self.calc_diarization_error(
  222. pred=F.embedding(pred.argmax(dim=2) * (~label_mask), self.pse_embedding),
  223. label=F.embedding(pse_labels * (~label_mask), self.pse_embedding),
  224. length=binary_labels_lengths
  225. )
  226. if speech_scored > 0 and num_frames > 0:
  227. sad_mr, sad_fr, mi, fa, cf, acc, der = (
  228. speech_miss / speech_scored,
  229. speech_falarm / speech_scored,
  230. speaker_miss / speaker_scored,
  231. speaker_falarm / speaker_scored,
  232. speaker_error / speaker_scored,
  233. correct / num_frames,
  234. (speaker_miss + speaker_falarm + speaker_error) / speaker_scored,
  235. )
  236. else:
  237. sad_mr, sad_fr, mi, fa, cf, acc, der = 0, 0, 0, 0, 0, 0, 0
  238. stats = dict(
  239. loss=loss.detach(),
  240. loss_diar=loss_diar.detach() if loss_diar is not None else None,
  241. loss_spk_dis=loss_spk_dis.detach() if loss_spk_dis is not None else None,
  242. loss_inter_ci=loss_inter_ci.detach() if loss_inter_ci is not None else None,
  243. loss_inter_cd=loss_inter_cd.detach() if loss_inter_cd is not None else None,
  244. regularizer_loss=regularizer_loss.detach() if regularizer_loss is not None else None,
  245. sad_mr=sad_mr,
  246. sad_fr=sad_fr,
  247. mi=mi,
  248. fa=fa,
  249. cf=cf,
  250. acc=acc,
  251. der=der,
  252. forward_steps=self.forward_steps,
  253. )
  254. loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
  255. return loss, stats, weight
  256. def calculate_regularizer_loss(self):
  257. regularizer_loss = 0.0
  258. for name, param in self.to_regularize_parameters:
  259. regularizer_loss = regularizer_loss + torch.norm(param, p=2)
  260. return regularizer_loss
  261. def classification_loss(
  262. self,
  263. predictions: torch.Tensor,
  264. labels: torch.Tensor,
  265. prediction_lengths: torch.Tensor
  266. ) -> torch.Tensor:
  267. mask = make_pad_mask(prediction_lengths, maxlen=labels.shape[1])
  268. pad_labels = labels.masked_fill(
  269. mask.to(predictions.device),
  270. value=self.ignore_id
  271. )
  272. loss = self.criterion_diar(predictions.contiguous(), pad_labels)
  273. return loss
  274. def speaker_discrimination_loss(
  275. self,
  276. profile: torch.Tensor,
  277. profile_lengths: torch.Tensor
  278. ) -> torch.Tensor:
  279. profile_mask = (torch.linalg.norm(profile, ord=2, dim=2, keepdim=True) > 0).float() # (B, N, 1)
  280. mask = torch.matmul(profile_mask, profile_mask.transpose(1, 2)) # (B, N, N)
  281. mask = mask * (1.0 - torch.eye(self.max_spk_num).unsqueeze(0).to(mask))
  282. eps = 1e-12
  283. coding_norm = torch.linalg.norm(
  284. profile * profile_mask + (1 - profile_mask) * eps,
  285. dim=2, keepdim=True
  286. ) * profile_mask
  287. # profile: Batch, N, dim
  288. cos_theta = F.cosine_similarity(profile.unsqueeze(2), profile.unsqueeze(1), dim=-1, eps=eps) * mask
  289. cos_theta = torch.clip(cos_theta, -1 + eps, 1 - eps)
  290. loss = (F.relu(mask * coding_norm * (cos_theta - 0.0))).sum() / mask.sum()
  291. return loss
  292. def calculate_multi_labels(self, pse_labels, pse_labels_lengths):
  293. mask = make_pad_mask(pse_labels_lengths, maxlen=pse_labels.shape[1])
  294. padding_labels = pse_labels.masked_fill(
  295. mask.to(pse_labels.device),
  296. value=0
  297. ).to(pse_labels)
  298. multi_labels = F.embedding(padding_labels, self.pse_embedding)
  299. return multi_labels
  300. def internal_score_loss(
  301. self,
  302. cd_score: torch.Tensor,
  303. ci_score: torch.Tensor,
  304. pse_labels: torch.Tensor,
  305. pse_labels_lengths: torch.Tensor
  306. ) -> Tuple[torch.Tensor, torch.Tensor]:
  307. multi_labels = self.calculate_multi_labels(pse_labels, pse_labels_lengths)
  308. ci_loss = self.criterion_bce(ci_score, multi_labels, pse_labels_lengths)
  309. cd_loss = self.criterion_bce(cd_score, multi_labels, pse_labels_lengths)
  310. return ci_loss, cd_loss
  311. def collect_feats(
  312. self,
  313. speech: torch.Tensor,
  314. speech_lengths: torch.Tensor,
  315. profile: torch.Tensor = None,
  316. profile_lengths: torch.Tensor = None,
  317. binary_labels: torch.Tensor = None,
  318. binary_labels_lengths: torch.Tensor = None,
  319. ) -> Dict[str, torch.Tensor]:
  320. feats, feats_lengths = self._extract_feats(speech, speech_lengths)
  321. return {"feats": feats, "feats_lengths": feats_lengths}
  322. def encode_speaker(
  323. self,
  324. profile: torch.Tensor,
  325. profile_lengths: torch.Tensor,
  326. ) -> Tuple[torch.Tensor, torch.Tensor]:
  327. with autocast(False):
  328. if profile.shape[1] < self.max_spk_num:
  329. profile = F.pad(profile, [0, 0, 0, self.max_spk_num-profile.shape[1], 0, 0], "constant", 0.0)
  330. profile_mask = (torch.linalg.norm(profile, ord=2, dim=2, keepdim=True) > 0).float()
  331. profile = F.normalize(profile, dim=2)
  332. if self.speaker_encoder is not None:
  333. profile = self.speaker_encoder(profile, profile_lengths)[0]
  334. return profile * profile_mask, profile_lengths
  335. else:
  336. return profile, profile_lengths
  337. def encode_speech(
  338. self,
  339. speech: torch.Tensor,
  340. speech_lengths: torch.Tensor,
  341. ) -> Tuple[torch.Tensor, torch.Tensor]:
  342. if self.encoder is not None and self.inputs_type == "raw":
  343. speech, speech_lengths = self.encode(speech, speech_lengths)
  344. speech_mask = ~make_pad_mask(speech_lengths, maxlen=speech.shape[1])
  345. speech_mask = speech_mask.to(speech.device).unsqueeze(-1).float()
  346. return speech * speech_mask, speech_lengths
  347. else:
  348. return speech, speech_lengths
  349. @staticmethod
  350. def concate_speech_ivc(
  351. speech: torch.Tensor,
  352. ivc: torch.Tensor
  353. ) -> torch.Tensor:
  354. nn, tt = ivc.shape[1], speech.shape[1]
  355. speech = speech.unsqueeze(dim=1) # B x 1 x T x D
  356. speech = speech.expand(-1, nn, -1, -1) # B x N x T x D
  357. ivc = ivc.unsqueeze(dim=2) # B x N x 1 x D
  358. ivc = ivc.expand(-1, -1, tt, -1) # B x N x T x D
  359. sd_in = torch.cat([speech, ivc], dim=3) # B x N x T x 2D
  360. return sd_in
  361. def calc_similarity(
  362. self,
  363. speech_encoder_outputs: torch.Tensor,
  364. speaker_encoder_outputs: torch.Tensor,
  365. seq_len: torch.Tensor = None,
  366. spk_len: torch.Tensor = None,
  367. ) -> Tuple[torch.Tensor, torch.Tensor]:
  368. bb, tt = speech_encoder_outputs.shape[0], speech_encoder_outputs.shape[1]
  369. d_sph, d_spk = speech_encoder_outputs.shape[2], speaker_encoder_outputs.shape[2]
  370. if self.normalize_speech_speaker:
  371. speech_encoder_outputs = F.normalize(speech_encoder_outputs, dim=2)
  372. speaker_encoder_outputs = F.normalize(speaker_encoder_outputs, dim=2)
  373. ge_in = self.concate_speech_ivc(speech_encoder_outputs, speaker_encoder_outputs)
  374. ge_in = torch.reshape(ge_in, [bb * self.max_spk_num, tt, d_sph + d_spk])
  375. ge_len = seq_len.unsqueeze(1).expand(-1, self.max_spk_num)
  376. ge_len = torch.reshape(ge_len, [bb * self.max_spk_num])
  377. cd_simi = self.cd_scorer(ge_in, ge_len)[0]
  378. cd_simi = torch.reshape(cd_simi, [bb, self.max_spk_num, tt, 1])
  379. cd_simi = cd_simi.squeeze(dim=3).permute([0, 2, 1])
  380. if isinstance(self.ci_scorer, AbsEncoder):
  381. ci_simi = self.ci_scorer(ge_in, ge_len)[0]
  382. ci_simi = torch.reshape(ci_simi, [bb, self.max_spk_num, tt]).permute([0, 2, 1])
  383. else:
  384. ci_simi = self.ci_scorer(speech_encoder_outputs, speaker_encoder_outputs)
  385. return ci_simi, cd_simi
  386. def post_net_forward(self, simi, seq_len):
  387. logits = self.decoder(simi, seq_len)[0]
  388. return logits
  389. def prediction_forward(
  390. self,
  391. speech: torch.Tensor,
  392. speech_lengths: torch.Tensor,
  393. profile: torch.Tensor,
  394. profile_lengths: torch.Tensor,
  395. return_inter_outputs: bool = False,
  396. ) -> [torch.Tensor, Optional[list]]:
  397. # speech encoding
  398. speech, speech_lengths = self.encode_speech(speech, speech_lengths)
  399. # speaker encoding
  400. profile, profile_lengths = self.encode_speaker(profile, profile_lengths)
  401. # calculating similarity
  402. ci_simi, cd_simi = self.calc_similarity(speech, profile, speech_lengths, profile_lengths)
  403. similarity = torch.cat([cd_simi, ci_simi], dim=2)
  404. # post net forward
  405. logits = self.post_net_forward(similarity, speech_lengths)
  406. if return_inter_outputs:
  407. return logits, [(speech, speech_lengths), (profile, profile_lengths), (ci_simi, cd_simi)]
  408. return logits
  409. def encode(
  410. self, speech: torch.Tensor, speech_lengths: torch.Tensor
  411. ) -> Tuple[torch.Tensor, torch.Tensor]:
  412. """Frontend + Encoder
  413. Args:
  414. speech: (Batch, Length, ...)
  415. speech_lengths: (Batch,)
  416. """
  417. with autocast(False):
  418. # 1. Extract feats
  419. feats, feats_lengths = self._extract_feats(speech, speech_lengths)
  420. # 2. Data augmentation
  421. if self.specaug is not None and self.training:
  422. feats, feats_lengths = self.specaug(feats, feats_lengths)
  423. # 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
  424. if self.normalize is not None:
  425. feats, feats_lengths = self.normalize(feats, feats_lengths)
  426. # 4. Forward encoder
  427. # feats: (Batch, Length, Dim)
  428. # -> encoder_out: (Batch, Length2, Dim)
  429. encoder_outputs = self.encoder(feats, feats_lengths)
  430. encoder_out, encoder_out_lens = encoder_outputs[:2]
  431. assert encoder_out.size(0) == speech.size(0), (
  432. encoder_out.size(),
  433. speech.size(0),
  434. )
  435. assert encoder_out.size(1) <= encoder_out_lens.max(), (
  436. encoder_out.size(),
  437. encoder_out_lens.max(),
  438. )
  439. return encoder_out, encoder_out_lens
  440. def _extract_feats(
  441. self, speech: torch.Tensor, speech_lengths: torch.Tensor
  442. ) -> Tuple[torch.Tensor, torch.Tensor]:
  443. batch_size = speech.shape[0]
  444. speech_lengths = (
  445. speech_lengths
  446. if speech_lengths is not None
  447. else torch.ones(batch_size).int() * speech.shape[1]
  448. )
  449. assert speech_lengths.dim() == 1, speech_lengths.shape
  450. # for data-parallel
  451. speech = speech[:, : speech_lengths.max()]
  452. if self.frontend is not None:
  453. # Frontend
  454. # e.g. STFT and Feature extract
  455. # data_loader may send time-domain signal in this case
  456. # speech (Batch, NSamples) -> feats: (Batch, NFrames, Dim)
  457. feats, feats_lengths = self.frontend(speech, speech_lengths)
  458. else:
  459. # No frontend and no feature extract
  460. feats, feats_lengths = speech, speech_lengths
  461. return feats, feats_lengths
  462. @staticmethod
  463. def calc_diarization_error(pred, label, length):
  464. # Note (jiatong): Credit to https://github.com/hitachi-speech/EEND
  465. (batch_size, max_len, num_output) = label.size()
  466. # mask the padding part
  467. mask = ~make_pad_mask(length, maxlen=label.shape[1]).unsqueeze(-1).numpy()
  468. # pred and label have the shape (batch_size, max_len, num_output)
  469. label_np = label.data.cpu().numpy().astype(int)
  470. pred_np = (pred.data.cpu().numpy() > 0).astype(int)
  471. label_np = label_np * mask
  472. pred_np = pred_np * mask
  473. length = length.data.cpu().numpy()
  474. # compute speech activity detection error
  475. n_ref = np.sum(label_np, axis=2)
  476. n_sys = np.sum(pred_np, axis=2)
  477. speech_scored = float(np.sum(n_ref > 0))
  478. speech_miss = float(np.sum(np.logical_and(n_ref > 0, n_sys == 0)))
  479. speech_falarm = float(np.sum(np.logical_and(n_ref == 0, n_sys > 0)))
  480. # compute speaker diarization error
  481. speaker_scored = float(np.sum(n_ref))
  482. speaker_miss = float(np.sum(np.maximum(n_ref - n_sys, 0)))
  483. speaker_falarm = float(np.sum(np.maximum(n_sys - n_ref, 0)))
  484. n_map = np.sum(np.logical_and(label_np == 1, pred_np == 1), axis=2)
  485. speaker_error = float(np.sum(np.minimum(n_ref, n_sys) - n_map))
  486. correct = float(1.0 * np.sum((label_np == pred_np) * mask) / num_output)
  487. num_frames = np.sum(length)
  488. return (
  489. correct,
  490. num_frames,
  491. speech_scored,
  492. speech_miss,
  493. speech_falarm,
  494. speaker_scored,
  495. speaker_miss,
  496. speaker_falarm,
  497. speaker_error,
  498. )