speaker_utils.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. """ Some implementations are adapted from https://github.com/yuyq96/D-TDNN
  3. """
  4. import torch
  5. import torch.nn.functional as F
  6. import torch.utils.checkpoint as cp
  7. from torch import nn
  8. import io
  9. import os
  10. from typing import Any, Dict, List, Union
  11. import numpy as np
  12. import librosa as sf
  13. import torch
  14. import torchaudio
  15. import logging
  16. from funasr.utils.modelscope_file import File
  17. from collections import OrderedDict
  18. import torchaudio.compliance.kaldi as Kaldi
  19. def check_audio_list(audio: list):
  20. audio_dur = 0
  21. for i in range(len(audio)):
  22. seg = audio[i]
  23. assert seg[1] >= seg[0], 'modelscope error: Wrong time stamps.'
  24. assert isinstance(seg[2], np.ndarray), 'modelscope error: Wrong data type.'
  25. assert int(seg[1] * 16000) - int(
  26. seg[0] * 16000
  27. ) == seg[2].shape[
  28. 0], 'modelscope error: audio data in list is inconsistent with time length.'
  29. if i > 0:
  30. assert seg[0] >= audio[
  31. i - 1][1], 'modelscope error: Wrong time stamps.'
  32. audio_dur += seg[1] - seg[0]
  33. assert audio_dur > 5, 'modelscope error: The effective audio duration is too short.'
  34. def sv_preprocess(inputs: Union[np.ndarray, list]):
  35. output = []
  36. for i in range(len(inputs)):
  37. if isinstance(inputs[i], str):
  38. file_bytes = File.read(inputs[i])
  39. data, fs = sf.load(io.BytesIO(file_bytes), dtype='float32')
  40. if len(data.shape) == 2:
  41. data = data[:, 0]
  42. data = torch.from_numpy(data).unsqueeze(0)
  43. data = data.squeeze(0)
  44. elif isinstance(inputs[i], np.ndarray):
  45. assert len(
  46. inputs[i].shape
  47. ) == 1, 'modelscope error: Input array should be [N, T]'
  48. data = inputs[i]
  49. if data.dtype in ['int16', 'int32', 'int64']:
  50. data = (data / (1 << 15)).astype('float32')
  51. else:
  52. data = data.astype('float32')
  53. data = torch.from_numpy(data)
  54. else:
  55. raise ValueError(
  56. 'modelscope error: The input type is restricted to audio address and nump array.'
  57. )
  58. output.append(data)
  59. return output
  60. def sv_chunk(vad_segments: list, fs = 16000) -> list:
  61. config = {
  62. 'seg_dur': 1.5,
  63. 'seg_shift': 0.75,
  64. }
  65. def seg_chunk(seg_data):
  66. seg_st = seg_data[0]
  67. data = seg_data[2]
  68. chunk_len = int(config['seg_dur'] * fs)
  69. chunk_shift = int(config['seg_shift'] * fs)
  70. last_chunk_ed = 0
  71. seg_res = []
  72. for chunk_st in range(0, data.shape[0], chunk_shift):
  73. chunk_ed = min(chunk_st + chunk_len, data.shape[0])
  74. if chunk_ed <= last_chunk_ed:
  75. break
  76. last_chunk_ed = chunk_ed
  77. chunk_st = max(0, chunk_ed - chunk_len)
  78. chunk_data = data[chunk_st:chunk_ed]
  79. if chunk_data.shape[0] < chunk_len:
  80. chunk_data = np.pad(chunk_data,
  81. (0, chunk_len - chunk_data.shape[0]),
  82. 'constant')
  83. seg_res.append([
  84. chunk_st / fs + seg_st, chunk_ed / fs + seg_st,
  85. chunk_data
  86. ])
  87. return seg_res
  88. segs = []
  89. for i, s in enumerate(vad_segments):
  90. segs.extend(seg_chunk(s))
  91. return segs
  92. class BasicResBlock(nn.Module):
  93. expansion = 1
  94. def __init__(self, in_planes, planes, stride=1):
  95. super(BasicResBlock, self).__init__()
  96. self.conv1 = nn.Conv2d(
  97. in_planes,
  98. planes,
  99. kernel_size=3,
  100. stride=(stride, 1),
  101. padding=1,
  102. bias=False)
  103. self.bn1 = nn.BatchNorm2d(planes)
  104. self.conv2 = nn.Conv2d(
  105. planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
  106. self.bn2 = nn.BatchNorm2d(planes)
  107. self.shortcut = nn.Sequential()
  108. if stride != 1 or in_planes != self.expansion * planes:
  109. self.shortcut = nn.Sequential(
  110. nn.Conv2d(
  111. in_planes,
  112. self.expansion * planes,
  113. kernel_size=1,
  114. stride=(stride, 1),
  115. bias=False), nn.BatchNorm2d(self.expansion * planes))
  116. def forward(self, x):
  117. out = F.relu(self.bn1(self.conv1(x)))
  118. out = self.bn2(self.conv2(out))
  119. out += self.shortcut(x)
  120. out = F.relu(out)
  121. return out
  122. class FCM(nn.Module):
  123. def __init__(self,
  124. block=BasicResBlock,
  125. num_blocks=[2, 2],
  126. m_channels=32,
  127. feat_dim=80):
  128. super(FCM, self).__init__()
  129. self.in_planes = m_channels
  130. self.conv1 = nn.Conv2d(
  131. 1, m_channels, kernel_size=3, stride=1, padding=1, bias=False)
  132. self.bn1 = nn.BatchNorm2d(m_channels)
  133. self.layer1 = self._make_layer(
  134. block, m_channels, num_blocks[0], stride=2)
  135. self.layer2 = self._make_layer(
  136. block, m_channels, num_blocks[0], stride=2)
  137. self.conv2 = nn.Conv2d(
  138. m_channels,
  139. m_channels,
  140. kernel_size=3,
  141. stride=(2, 1),
  142. padding=1,
  143. bias=False)
  144. self.bn2 = nn.BatchNorm2d(m_channels)
  145. self.out_channels = m_channels * (feat_dim // 8)
  146. def _make_layer(self, block, planes, num_blocks, stride):
  147. strides = [stride] + [1] * (num_blocks - 1)
  148. layers = []
  149. for stride in strides:
  150. layers.append(block(self.in_planes, planes, stride))
  151. self.in_planes = planes * block.expansion
  152. return nn.Sequential(*layers)
  153. def forward(self, x):
  154. x = x.unsqueeze(1)
  155. out = F.relu(self.bn1(self.conv1(x)))
  156. out = self.layer1(out)
  157. out = self.layer2(out)
  158. out = F.relu(self.bn2(self.conv2(out)))
  159. shape = out.shape
  160. out = out.reshape(shape[0], shape[1] * shape[2], shape[3])
  161. return out
  162. class CAMPPlus(nn.Module):
  163. def __init__(self,
  164. feat_dim=80,
  165. embedding_size=192,
  166. growth_rate=32,
  167. bn_size=4,
  168. init_channels=128,
  169. config_str='batchnorm-relu',
  170. memory_efficient=True,
  171. output_level='segment'):
  172. super(CAMPPlus, self).__init__()
  173. self.head = FCM(feat_dim=feat_dim)
  174. channels = self.head.out_channels
  175. self.output_level = output_level
  176. self.xvector = nn.Sequential(
  177. OrderedDict([
  178. ('tdnn',
  179. TDNNLayer(
  180. channels,
  181. init_channels,
  182. 5,
  183. stride=2,
  184. dilation=1,
  185. padding=-1,
  186. config_str=config_str)),
  187. ]))
  188. channels = init_channels
  189. for i, (num_layers, kernel_size, dilation) in enumerate(
  190. zip((12, 24, 16), (3, 3, 3), (1, 2, 2))):
  191. block = CAMDenseTDNNBlock(
  192. num_layers=num_layers,
  193. in_channels=channels,
  194. out_channels=growth_rate,
  195. bn_channels=bn_size * growth_rate,
  196. kernel_size=kernel_size,
  197. dilation=dilation,
  198. config_str=config_str,
  199. memory_efficient=memory_efficient)
  200. self.xvector.add_module('block%d' % (i + 1), block)
  201. channels = channels + num_layers * growth_rate
  202. self.xvector.add_module(
  203. 'transit%d' % (i + 1),
  204. TransitLayer(
  205. channels, channels // 2, bias=False,
  206. config_str=config_str))
  207. channels //= 2
  208. self.xvector.add_module('out_nonlinear',
  209. get_nonlinear(config_str, channels))
  210. if self.output_level == 'segment':
  211. self.xvector.add_module('stats', StatsPool())
  212. self.xvector.add_module(
  213. 'dense',
  214. DenseLayer(
  215. channels * 2, embedding_size, config_str='batchnorm_'))
  216. else:
  217. assert self.output_level == 'frame', '`output_level` should be set to \'segment\' or \'frame\'. '
  218. for m in self.modules():
  219. if isinstance(m, (nn.Conv1d, nn.Linear)):
  220. nn.init.kaiming_normal_(m.weight.data)
  221. if m.bias is not None:
  222. nn.init.zeros_(m.bias)
  223. def forward(self, x):
  224. x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T)
  225. x = self.head(x)
  226. x = self.xvector(x)
  227. if self.output_level == 'frame':
  228. x = x.transpose(1, 2)
  229. return x
  230. def get_nonlinear(config_str, channels):
  231. nonlinear = nn.Sequential()
  232. for name in config_str.split('-'):
  233. if name == 'relu':
  234. nonlinear.add_module('relu', nn.ReLU(inplace=True))
  235. elif name == 'prelu':
  236. nonlinear.add_module('prelu', nn.PReLU(channels))
  237. elif name == 'batchnorm':
  238. nonlinear.add_module('batchnorm', nn.BatchNorm1d(channels))
  239. elif name == 'batchnorm_':
  240. nonlinear.add_module('batchnorm',
  241. nn.BatchNorm1d(channels, affine=False))
  242. else:
  243. raise ValueError('Unexpected module ({}).'.format(name))
  244. return nonlinear
  245. def statistics_pooling(x, dim=-1, keepdim=False, unbiased=True, eps=1e-2):
  246. mean = x.mean(dim=dim)
  247. std = x.std(dim=dim, unbiased=unbiased)
  248. stats = torch.cat([mean, std], dim=-1)
  249. if keepdim:
  250. stats = stats.unsqueeze(dim=dim)
  251. return stats
  252. class StatsPool(nn.Module):
  253. def forward(self, x):
  254. return statistics_pooling(x)
  255. class TDNNLayer(nn.Module):
  256. def __init__(self,
  257. in_channels,
  258. out_channels,
  259. kernel_size,
  260. stride=1,
  261. padding=0,
  262. dilation=1,
  263. bias=False,
  264. config_str='batchnorm-relu'):
  265. super(TDNNLayer, self).__init__()
  266. if padding < 0:
  267. assert kernel_size % 2 == 1, 'Expect equal paddings, but got even kernel size ({})'.format(
  268. kernel_size)
  269. padding = (kernel_size - 1) // 2 * dilation
  270. self.linear = nn.Conv1d(
  271. in_channels,
  272. out_channels,
  273. kernel_size,
  274. stride=stride,
  275. padding=padding,
  276. dilation=dilation,
  277. bias=bias)
  278. self.nonlinear = get_nonlinear(config_str, out_channels)
  279. def forward(self, x):
  280. x = self.linear(x)
  281. x = self.nonlinear(x)
  282. return x
  283. def extract_feature(audio):
  284. features = []
  285. for au in audio:
  286. feature = Kaldi.fbank(
  287. au.unsqueeze(0), num_mel_bins=80)
  288. feature = feature - feature.mean(dim=0, keepdim=True)
  289. features.append(feature.unsqueeze(0))
  290. features = torch.cat(features)
  291. return features
  292. class CAMLayer(nn.Module):
  293. def __init__(self,
  294. bn_channels,
  295. out_channels,
  296. kernel_size,
  297. stride,
  298. padding,
  299. dilation,
  300. bias,
  301. reduction=2):
  302. super(CAMLayer, self).__init__()
  303. self.linear_local = nn.Conv1d(
  304. bn_channels,
  305. out_channels,
  306. kernel_size,
  307. stride=stride,
  308. padding=padding,
  309. dilation=dilation,
  310. bias=bias)
  311. self.linear1 = nn.Conv1d(bn_channels, bn_channels // reduction, 1)
  312. self.relu = nn.ReLU(inplace=True)
  313. self.linear2 = nn.Conv1d(bn_channels // reduction, out_channels, 1)
  314. self.sigmoid = nn.Sigmoid()
  315. def forward(self, x):
  316. y = self.linear_local(x)
  317. context = x.mean(-1, keepdim=True) + self.seg_pooling(x)
  318. context = self.relu(self.linear1(context))
  319. m = self.sigmoid(self.linear2(context))
  320. return y * m
  321. def seg_pooling(self, x, seg_len=100, stype='avg'):
  322. if stype == 'avg':
  323. seg = F.avg_pool1d(
  324. x, kernel_size=seg_len, stride=seg_len, ceil_mode=True)
  325. elif stype == 'max':
  326. seg = F.max_pool1d(
  327. x, kernel_size=seg_len, stride=seg_len, ceil_mode=True)
  328. else:
  329. raise ValueError('Wrong segment pooling type.')
  330. shape = seg.shape
  331. seg = seg.unsqueeze(-1).expand(*shape,
  332. seg_len).reshape(*shape[:-1], -1)
  333. seg = seg[..., :x.shape[-1]]
  334. return seg
  335. class CAMDenseTDNNLayer(nn.Module):
  336. def __init__(self,
  337. in_channels,
  338. out_channels,
  339. bn_channels,
  340. kernel_size,
  341. stride=1,
  342. dilation=1,
  343. bias=False,
  344. config_str='batchnorm-relu',
  345. memory_efficient=False):
  346. super(CAMDenseTDNNLayer, self).__init__()
  347. assert kernel_size % 2 == 1, 'Expect equal paddings, but got even kernel size ({})'.format(
  348. kernel_size)
  349. padding = (kernel_size - 1) // 2 * dilation
  350. self.memory_efficient = memory_efficient
  351. self.nonlinear1 = get_nonlinear(config_str, in_channels)
  352. self.linear1 = nn.Conv1d(in_channels, bn_channels, 1, bias=False)
  353. self.nonlinear2 = get_nonlinear(config_str, bn_channels)
  354. self.cam_layer = CAMLayer(
  355. bn_channels,
  356. out_channels,
  357. kernel_size,
  358. stride=stride,
  359. padding=padding,
  360. dilation=dilation,
  361. bias=bias)
  362. def bn_function(self, x):
  363. return self.linear1(self.nonlinear1(x))
  364. def forward(self, x):
  365. if self.training and self.memory_efficient:
  366. x = cp.checkpoint(self.bn_function, x)
  367. else:
  368. x = self.bn_function(x)
  369. x = self.cam_layer(self.nonlinear2(x))
  370. return x
  371. class CAMDenseTDNNBlock(nn.ModuleList):
  372. def __init__(self,
  373. num_layers,
  374. in_channels,
  375. out_channels,
  376. bn_channels,
  377. kernel_size,
  378. stride=1,
  379. dilation=1,
  380. bias=False,
  381. config_str='batchnorm-relu',
  382. memory_efficient=False):
  383. super(CAMDenseTDNNBlock, self).__init__()
  384. for i in range(num_layers):
  385. layer = CAMDenseTDNNLayer(
  386. in_channels=in_channels + i * out_channels,
  387. out_channels=out_channels,
  388. bn_channels=bn_channels,
  389. kernel_size=kernel_size,
  390. stride=stride,
  391. dilation=dilation,
  392. bias=bias,
  393. config_str=config_str,
  394. memory_efficient=memory_efficient)
  395. self.add_module('tdnnd%d' % (i + 1), layer)
  396. def forward(self, x):
  397. for layer in self:
  398. x = torch.cat([x, layer(x)], dim=1)
  399. return x
  400. class TransitLayer(nn.Module):
  401. def __init__(self,
  402. in_channels,
  403. out_channels,
  404. bias=True,
  405. config_str='batchnorm-relu'):
  406. super(TransitLayer, self).__init__()
  407. self.nonlinear = get_nonlinear(config_str, in_channels)
  408. self.linear = nn.Conv1d(in_channels, out_channels, 1, bias=bias)
  409. def forward(self, x):
  410. x = self.nonlinear(x)
  411. x = self.linear(x)
  412. return x
  413. class DenseLayer(nn.Module):
  414. def __init__(self,
  415. in_channels,
  416. out_channels,
  417. bias=False,
  418. config_str='batchnorm-relu'):
  419. super(DenseLayer, self).__init__()
  420. self.linear = nn.Conv1d(in_channels, out_channels, 1, bias=bias)
  421. self.nonlinear = get_nonlinear(config_str, out_channels)
  422. def forward(self, x):
  423. if len(x.shape) == 2:
  424. x = self.linear(x.unsqueeze(dim=-1)).squeeze(dim=-1)
  425. else:
  426. x = self.linear(x)
  427. x = self.nonlinear(x)
  428. return x
  429. def postprocess(segments: list, vad_segments: list,
  430. labels: np.ndarray, embeddings: np.ndarray) -> list:
  431. assert len(segments) == len(labels)
  432. labels = correct_labels(labels)
  433. distribute_res = []
  434. for i in range(len(segments)):
  435. distribute_res.append([segments[i][0], segments[i][1], labels[i]])
  436. # merge the same speakers chronologically
  437. distribute_res = merge_seque(distribute_res)
  438. # accquire speaker center
  439. spk_embs = []
  440. for i in range(labels.max() + 1):
  441. spk_emb = embeddings[labels == i].mean(0)
  442. spk_embs.append(spk_emb)
  443. spk_embs = np.stack(spk_embs)
  444. def is_overlapped(t1, t2):
  445. if t1 > t2 + 1e-4:
  446. return True
  447. return False
  448. # distribute the overlap region
  449. for i in range(1, len(distribute_res)):
  450. if is_overlapped(distribute_res[i - 1][1], distribute_res[i][0]):
  451. p = (distribute_res[i][0] + distribute_res[i - 1][1]) / 2
  452. distribute_res[i][0] = p
  453. distribute_res[i - 1][1] = p
  454. # smooth the result
  455. distribute_res = smooth(distribute_res)
  456. return distribute_res
  457. def correct_labels(labels):
  458. labels_id = 0
  459. id2id = {}
  460. new_labels = []
  461. for i in labels:
  462. if i not in id2id:
  463. id2id[i] = labels_id
  464. labels_id += 1
  465. new_labels.append(id2id[i])
  466. return np.array(new_labels)
  467. def merge_seque(distribute_res):
  468. res = [distribute_res[0]]
  469. for i in range(1, len(distribute_res)):
  470. if distribute_res[i][2] != res[-1][2] or distribute_res[i][
  471. 0] > res[-1][1]:
  472. res.append(distribute_res[i])
  473. else:
  474. res[-1][1] = distribute_res[i][1]
  475. return res
  476. def smooth(res, mindur=1):
  477. # short segments are assigned to nearest speakers.
  478. for i in range(len(res)):
  479. res[i][0] = round(res[i][0], 2)
  480. res[i][1] = round(res[i][1], 2)
  481. if res[i][1] - res[i][0] < mindur:
  482. if i == 0:
  483. res[i][2] = res[i + 1][2]
  484. elif i == len(res) - 1:
  485. res[i][2] = res[i - 1][2]
  486. elif res[i][0] - res[i - 1][1] <= res[i + 1][0] - res[i][1]:
  487. res[i][2] = res[i - 1][2]
  488. else:
  489. res[i][2] = res[i + 1][2]
  490. # merge the speakers
  491. res = merge_seque(res)
  492. return res
  493. def distribute_spk(sentence_list, sd_time_list):
  494. sd_sentence_list = []
  495. for d in sentence_list:
  496. sentence_start = d['ts_list'][0][0]
  497. sentence_end = d['ts_list'][-1][1]
  498. sentence_spk = 0
  499. max_overlap = 0
  500. for sd_time in sd_time_list:
  501. spk_st, spk_ed, spk = sd_time
  502. spk_st = spk_st*1000
  503. spk_ed = spk_ed*1000
  504. overlap = max(
  505. min(sentence_end, spk_ed) - max(sentence_start, spk_st), 0)
  506. if overlap > max_overlap:
  507. max_overlap = overlap
  508. sentence_spk = spk
  509. d['spk'] = sentence_spk
  510. sd_sentence_list.append(d)
  511. return sd_sentence_list