sond_inference.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555
  1. #!/usr/bin/env python3
  2. # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
  3. # MIT License (https://opensource.org/licenses/MIT)
  4. import argparse
  5. import logging
  6. import os
  7. import sys
  8. from pathlib import Path
  9. from typing import Any
  10. from typing import List
  11. from typing import Optional
  12. from typing import Sequence
  13. from typing import Tuple
  14. from typing import Union
  15. from collections import OrderedDict
  16. import numpy as np
  17. import soundfile
  18. import torch
  19. from torch.nn import functional as F
  20. from typeguard import check_argument_types
  21. from typeguard import check_return_type
  22. from funasr.utils.cli_utils import get_commandline_args
  23. from funasr.tasks.diar import DiarTask
  24. from funasr.tasks.asr import ASRTask
  25. from funasr.torch_utils.device_funcs import to_device
  26. from funasr.torch_utils.set_all_random_seed import set_all_random_seed
  27. from funasr.utils import config_argparse
  28. from funasr.utils.types import str2bool
  29. from funasr.utils.types import str2triple_str
  30. from funasr.utils.types import str_or_none
  31. from scipy.ndimage import median_filter
  32. from funasr.utils.misc import statistic_model_parameters
  33. from funasr.datasets.iterable_dataset import load_bytes
  34. class Speech2Diarization:
  35. """Speech2Xvector class
  36. Examples:
  37. >>> import soundfile
  38. >>> import numpy as np
  39. >>> speech2diar = Speech2Diarization("diar_sond_config.yml", "diar_sond.pb")
  40. >>> profile = np.load("profiles.npy")
  41. >>> audio, rate = soundfile.read("speech.wav")
  42. >>> speech2diar(audio, profile)
  43. {"spk1": [(int, int), ...], ...}
  44. """
  45. def __init__(
  46. self,
  47. diar_train_config: Union[Path, str] = None,
  48. diar_model_file: Union[Path, str] = None,
  49. device: str = "cpu",
  50. batch_size: int = 1,
  51. dtype: str = "float32",
  52. streaming: bool = False,
  53. smooth_size: int = 83,
  54. dur_threshold: float = 10,
  55. ):
  56. assert check_argument_types()
  57. # TODO: 1. Build Diarization model
  58. diar_model, diar_train_args = DiarTask.build_model_from_file(
  59. config_file=diar_train_config,
  60. model_file=diar_model_file,
  61. device=device
  62. )
  63. logging.info("diar_model: {}".format(diar_model))
  64. logging.info("model parameter number: {}".format(statistic_model_parameters(diar_model)))
  65. logging.info("diar_train_args: {}".format(diar_train_args))
  66. diar_model.to(dtype=getattr(torch, dtype)).eval()
  67. self.diar_model = diar_model
  68. self.diar_train_args = diar_train_args
  69. self.token_list = diar_train_args.token_list
  70. self.smooth_size = smooth_size
  71. self.dur_threshold = dur_threshold
  72. self.device = device
  73. self.dtype = dtype
  74. def smooth_multi_labels(self, multi_label):
  75. multi_label = median_filter(multi_label, (self.smooth_size, 1), mode="constant", cval=0.0).astype(int)
  76. return multi_label
  77. @staticmethod
  78. def calc_spk_turns(label_arr, spk_list):
  79. turn_list = []
  80. length = label_arr.shape[0]
  81. n_spk = label_arr.shape[1]
  82. for k in range(n_spk):
  83. if spk_list[k] == "None":
  84. continue
  85. in_utt = False
  86. start = 0
  87. for i in range(length):
  88. if label_arr[i, k] == 1 and in_utt is False:
  89. start = i
  90. in_utt = True
  91. if label_arr[i, k] == 0 and in_utt is True:
  92. turn_list.append([spk_list[k], start, i - start])
  93. in_utt = False
  94. if in_utt:
  95. turn_list.append([spk_list[k], start, length - start])
  96. return turn_list
  97. @staticmethod
  98. def seq2arr(seq, vec_dim=8):
  99. def int2vec(x, vec_dim=8, dtype=np.int):
  100. b = ('{:0' + str(vec_dim) + 'b}').format(x)
  101. # little-endian order: lower bit first
  102. return (np.array(list(b)[::-1]) == '1').astype(dtype)
  103. return np.row_stack([int2vec(int(x), vec_dim) for x in seq])
  104. def post_processing(self, raw_logits: torch.Tensor, spk_num: int):
  105. logits_idx = raw_logits.argmax(-1) # B, T, vocab_size -> B, T
  106. # upsampling outputs to match inputs
  107. ut = logits_idx.shape[1] * self.diar_model.encoder.time_ds_ratio
  108. logits_idx = F.upsample(
  109. logits_idx.unsqueeze(1).float(),
  110. size=(ut, ),
  111. mode="nearest",
  112. ).squeeze(1).long()
  113. logits_idx = logits_idx[0].tolist()
  114. pse_labels = [self.token_list[x] for x in logits_idx]
  115. multi_labels = self.seq2arr(pse_labels, spk_num)[:, :spk_num] # remove padding speakers
  116. multi_labels = self.smooth_multi_labels(multi_labels)
  117. spk_list = ["spk{}".format(i + 1) for i in range(spk_num)]
  118. spk_turns = self.calc_spk_turns(multi_labels, spk_list)
  119. results = OrderedDict()
  120. for spk, st, dur in spk_turns:
  121. if spk not in results:
  122. results[spk] = []
  123. if dur > self.dur_threshold:
  124. results[spk].append((st, st+dur))
  125. # sort segments in start time ascending
  126. for spk in results:
  127. results[spk] = sorted(results[spk], key=lambda x: x[0])
  128. return results, pse_labels
  129. @torch.no_grad()
  130. def __call__(
  131. self,
  132. speech: Union[torch.Tensor, np.ndarray],
  133. profile: Union[torch.Tensor, np.ndarray],
  134. ):
  135. """Inference
  136. Args:
  137. speech: Input speech data
  138. profile: Speaker profiles
  139. Returns:
  140. diarization results for each speaker
  141. """
  142. assert check_argument_types()
  143. # Input as audio signal
  144. if isinstance(speech, np.ndarray):
  145. speech = torch.tensor(speech)
  146. if isinstance(profile, np.ndarray):
  147. profile = torch.tensor(profile)
  148. # data: (Nsamples,) -> (1, Nsamples)
  149. speech = speech.unsqueeze(0).to(getattr(torch, self.dtype))
  150. profile = profile.unsqueeze(0).to(getattr(torch, self.dtype))
  151. # lengths: (1,)
  152. speech_lengths = speech.new_full([1], dtype=torch.long, fill_value=speech.size(1))
  153. profile_lengths = profile.new_full([1], dtype=torch.long, fill_value=profile.size(1))
  154. batch = {"speech": speech, "speech_lengths": speech_lengths,
  155. "profile": profile, "profile_lengths": profile_lengths}
  156. # a. To device
  157. batch = to_device(batch, device=self.device)
  158. logits = self.diar_model.prediction_forward(**batch)
  159. results, pse_labels = self.post_processing(logits, profile.shape[1])
  160. return results, pse_labels
  161. @staticmethod
  162. def from_pretrained(
  163. model_tag: Optional[str] = None,
  164. **kwargs: Optional[Any],
  165. ):
  166. """Build Speech2Xvector instance from the pretrained model.
  167. Args:
  168. model_tag (Optional[str]): Model tag of the pretrained models.
  169. Currently, the tags of espnet_model_zoo are supported.
  170. Returns:
  171. Speech2Xvector: Speech2Xvector instance.
  172. """
  173. if model_tag is not None:
  174. try:
  175. from espnet_model_zoo.downloader import ModelDownloader
  176. except ImportError:
  177. logging.error(
  178. "`espnet_model_zoo` is not installed. "
  179. "Please install via `pip install -U espnet_model_zoo`."
  180. )
  181. raise
  182. d = ModelDownloader()
  183. kwargs.update(**d.download_and_unpack(model_tag))
  184. return Speech2Diarization(**kwargs)
  185. def inference_modelscope(
  186. diar_train_config: str,
  187. diar_model_file: str,
  188. output_dir: Optional[str] = None,
  189. batch_size: int = 1,
  190. dtype: str = "float32",
  191. ngpu: int = 0,
  192. seed: int = 0,
  193. num_workers: int = 0,
  194. log_level: Union[int, str] = "INFO",
  195. key_file: Optional[str] = None,
  196. model_tag: Optional[str] = None,
  197. allow_variable_data_keys: bool = True,
  198. streaming: bool = False,
  199. smooth_size: int = 83,
  200. dur_threshold: int = 10,
  201. out_format: str = "vad",
  202. param_dict: Optional[dict] = None,
  203. mode: str = "sond",
  204. **kwargs,
  205. ):
  206. assert check_argument_types()
  207. if batch_size > 1:
  208. raise NotImplementedError("batch decoding is not implemented")
  209. if ngpu > 1:
  210. raise NotImplementedError("only single GPU decoding is supported")
  211. logging.basicConfig(
  212. level=log_level,
  213. format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
  214. )
  215. logging.info("param_dict: {}".format(param_dict))
  216. if ngpu >= 1 and torch.cuda.is_available():
  217. device = "cuda"
  218. else:
  219. device = "cpu"
  220. # 1. Set random-seed
  221. set_all_random_seed(seed)
  222. # 2a. Build speech2xvec [Optional]
  223. if mode == "sond_demo" and param_dict is not None and "extract_profile" in param_dict and param_dict["extract_profile"]:
  224. assert "sv_train_config" in param_dict, "sv_train_config must be provided param_dict."
  225. assert "sv_model_file" in param_dict, "sv_model_file must be provided in param_dict."
  226. sv_train_config = param_dict["sv_train_config"]
  227. sv_model_file = param_dict["sv_model_file"]
  228. if "model_dir" in param_dict:
  229. sv_train_config = os.path.join(param_dict["model_dir"], sv_train_config)
  230. sv_model_file = os.path.join(param_dict["model_dir"], sv_model_file)
  231. from funasr.bin.sv_inference import Speech2Xvector
  232. speech2xvector_kwargs = dict(
  233. sv_train_config=sv_train_config,
  234. sv_model_file=sv_model_file,
  235. device=device,
  236. dtype=dtype,
  237. streaming=streaming,
  238. embedding_node="resnet1_dense"
  239. )
  240. logging.info("speech2xvector_kwargs: {}".format(speech2xvector_kwargs))
  241. speech2xvector = Speech2Xvector.from_pretrained(
  242. model_tag=model_tag,
  243. **speech2xvector_kwargs,
  244. )
  245. speech2xvector.sv_model.eval()
  246. # 2b. Build speech2diar
  247. speech2diar_kwargs = dict(
  248. diar_train_config=diar_train_config,
  249. diar_model_file=diar_model_file,
  250. device=device,
  251. dtype=dtype,
  252. streaming=streaming,
  253. smooth_size=smooth_size,
  254. dur_threshold=dur_threshold,
  255. )
  256. logging.info("speech2diarization_kwargs: {}".format(speech2diar_kwargs))
  257. speech2diar = Speech2Diarization.from_pretrained(
  258. model_tag=model_tag,
  259. **speech2diar_kwargs,
  260. )
  261. speech2diar.diar_model.eval()
  262. def output_results_str(results: dict, uttid: str):
  263. rst = []
  264. mid = uttid.rsplit("-", 1)[0]
  265. for key in results:
  266. results[key] = [(x[0]/100, x[1]/100) for x in results[key]]
  267. if out_format == "vad":
  268. for spk, segs in results.items():
  269. rst.append("{} {}".format(spk, segs))
  270. else:
  271. template = "SPEAKER {} 0 {:.2f} {:.2f} <NA> <NA> {} <NA> <NA>"
  272. for spk, segs in results.items():
  273. rst.extend([template.format(mid, st, ed, spk) for st, ed in segs])
  274. return "\n".join(rst)
  275. def _forward(
  276. data_path_and_name_and_type: Sequence[Tuple[str, str, str]] = None,
  277. raw_inputs: List[List[Union[np.ndarray, torch.Tensor, str, bytes]]] = None,
  278. output_dir_v2: Optional[str] = None,
  279. param_dict: Optional[dict] = None,
  280. ):
  281. logging.info("param_dict: {}".format(param_dict))
  282. if data_path_and_name_and_type is None and raw_inputs is not None:
  283. if isinstance(raw_inputs, (list, tuple)):
  284. if not isinstance(raw_inputs[0], List):
  285. raw_inputs = [raw_inputs]
  286. assert all([len(example) >= 2 for example in raw_inputs]), \
  287. "The length of test case in raw_inputs must larger than 1 (>=2)."
  288. def prepare_dataset():
  289. for idx, example in enumerate(raw_inputs):
  290. # read waveform file
  291. example = [load_bytes(x) if isinstance(x, bytes) else x
  292. for x in example]
  293. example = [soundfile.read(x)[0] if isinstance(x, str) else x
  294. for x in example]
  295. # convert torch tensor to numpy array
  296. example = [x.numpy() if isinstance(example[0], torch.Tensor) else x
  297. for x in example]
  298. speech = example[0]
  299. logging.info("Extracting profiles for {} waveforms".format(len(example)-1))
  300. profile = [speech2xvector.calculate_embedding(x) for x in example[1:]]
  301. profile = torch.cat(profile, dim=0)
  302. yield ["test{}".format(idx)], {"speech": [speech], "profile": [profile]}
  303. loader = prepare_dataset()
  304. else:
  305. raise TypeError("raw_inputs must be a list or tuple in [speech, profile1, profile2, ...] ")
  306. else:
  307. # 3. Build data-iterator
  308. loader = ASRTask.build_streaming_iterator(
  309. data_path_and_name_and_type,
  310. dtype=dtype,
  311. batch_size=batch_size,
  312. key_file=key_file,
  313. num_workers=num_workers,
  314. preprocess_fn=None,
  315. collate_fn=None,
  316. allow_variable_data_keys=allow_variable_data_keys,
  317. inference=True,
  318. )
  319. # 7. Start for-loop
  320. output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
  321. if output_path is not None:
  322. os.makedirs(output_path, exist_ok=True)
  323. output_writer = open("{}/result.txt".format(output_path), "w")
  324. pse_label_writer = open("{}/labels.txt".format(output_path), "w")
  325. logging.info("Start to diarize...")
  326. result_list = []
  327. for keys, batch in loader:
  328. assert isinstance(batch, dict), type(batch)
  329. assert all(isinstance(s, str) for s in keys), keys
  330. _bs = len(next(iter(batch.values())))
  331. assert len(keys) == _bs, f"{len(keys)} != {_bs}"
  332. batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")}
  333. results, pse_labels = speech2diar(**batch)
  334. # Only supporting batch_size==1
  335. key, value = keys[0], output_results_str(results, keys[0])
  336. item = {"key": key, "value": value}
  337. result_list.append(item)
  338. if output_path is not None:
  339. output_writer.write(value)
  340. output_writer.flush()
  341. pse_label_writer.write("{} {}\n".format(key, " ".join(pse_labels)))
  342. pse_label_writer.flush()
  343. if output_path is not None:
  344. output_writer.close()
  345. pse_label_writer.close()
  346. return result_list
  347. return _forward
  348. def inference(
  349. data_path_and_name_and_type: Sequence[Tuple[str, str, str]],
  350. diar_train_config: Optional[str],
  351. diar_model_file: Optional[str],
  352. output_dir: Optional[str] = None,
  353. batch_size: int = 1,
  354. dtype: str = "float32",
  355. ngpu: int = 0,
  356. seed: int = 0,
  357. num_workers: int = 1,
  358. log_level: Union[int, str] = "INFO",
  359. key_file: Optional[str] = None,
  360. model_tag: Optional[str] = None,
  361. allow_variable_data_keys: bool = True,
  362. streaming: bool = False,
  363. smooth_size: int = 83,
  364. dur_threshold: int = 10,
  365. out_format: str = "vad",
  366. **kwargs,
  367. ):
  368. inference_pipeline = inference_modelscope(
  369. diar_train_config=diar_train_config,
  370. diar_model_file=diar_model_file,
  371. output_dir=output_dir,
  372. batch_size=batch_size,
  373. dtype=dtype,
  374. ngpu=ngpu,
  375. seed=seed,
  376. num_workers=num_workers,
  377. log_level=log_level,
  378. key_file=key_file,
  379. model_tag=model_tag,
  380. allow_variable_data_keys=allow_variable_data_keys,
  381. streaming=streaming,
  382. smooth_size=smooth_size,
  383. dur_threshold=dur_threshold,
  384. out_format=out_format,
  385. **kwargs,
  386. )
  387. return inference_pipeline(data_path_and_name_and_type, raw_inputs=None)
  388. def get_parser():
  389. parser = config_argparse.ArgumentParser(
  390. description="Speaker verification/x-vector extraction",
  391. formatter_class=argparse.ArgumentDefaultsHelpFormatter,
  392. )
  393. # Note(kamo): Use '_' instead of '-' as separator.
  394. # '-' is confusing if written in yaml.
  395. parser.add_argument(
  396. "--log_level",
  397. type=lambda x: x.upper(),
  398. default="INFO",
  399. choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
  400. help="The verbose level of logging",
  401. )
  402. parser.add_argument("--output_dir", type=str, required=False)
  403. parser.add_argument(
  404. "--ngpu",
  405. type=int,
  406. default=0,
  407. help="The number of gpus. 0 indicates CPU mode",
  408. )
  409. parser.add_argument(
  410. "--gpuid_list",
  411. type=str,
  412. default="",
  413. help="The visible gpus",
  414. )
  415. parser.add_argument("--seed", type=int, default=0, help="Random seed")
  416. parser.add_argument(
  417. "--dtype",
  418. default="float32",
  419. choices=["float16", "float32", "float64"],
  420. help="Data type",
  421. )
  422. parser.add_argument(
  423. "--num_workers",
  424. type=int,
  425. default=1,
  426. help="The number of workers used for DataLoader",
  427. )
  428. group = parser.add_argument_group("Input data related")
  429. group.add_argument(
  430. "--data_path_and_name_and_type",
  431. type=str2triple_str,
  432. required=False,
  433. action="append",
  434. )
  435. group.add_argument("--key_file", type=str_or_none)
  436. group.add_argument("--allow_variable_data_keys", type=str2bool, default=False)
  437. group = parser.add_argument_group("The model configuration related")
  438. group.add_argument(
  439. "--diar_train_config",
  440. type=str,
  441. help="diarization training configuration",
  442. )
  443. group.add_argument(
  444. "--diar_model_file",
  445. type=str,
  446. help="diarization model parameter file",
  447. )
  448. group.add_argument(
  449. "--dur_threshold",
  450. type=int,
  451. default=10,
  452. help="The threshold for short segments in number frames"
  453. )
  454. parser.add_argument(
  455. "--smooth_size",
  456. type=int,
  457. default=83,
  458. help="The smoothing window length in number frames"
  459. )
  460. group.add_argument(
  461. "--model_tag",
  462. type=str,
  463. help="Pretrained model tag. If specify this option, *_train_config and "
  464. "*_file will be overwritten",
  465. )
  466. parser.add_argument(
  467. "--batch_size",
  468. type=int,
  469. default=1,
  470. help="The batch size for inference",
  471. )
  472. parser.add_argument("--streaming", type=str2bool, default=False)
  473. return parser
  474. def main(cmd=None):
  475. print(get_commandline_args(), file=sys.stderr)
  476. parser = get_parser()
  477. args = parser.parse_args(cmd)
  478. kwargs = vars(args)
  479. kwargs.pop("config", None)
  480. logging.info("args: {}".format(kwargs))
  481. if args.output_dir is None:
  482. jobid, n_gpu = 1, 1
  483. gpuid = args.gpuid_list.split(",")[jobid-1]
  484. else:
  485. jobid = int(args.output_dir.split(".")[-1])
  486. n_gpu = len(args.gpuid_list.split(","))
  487. gpuid = args.gpuid_list.split(",")[(jobid - 1) % n_gpu]
  488. os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
  489. os.environ["CUDA_VISIBLE_DEVICES"] = gpuid
  490. results_list = inference(**kwargs)
  491. for results in results_list:
  492. print("{} {}".format(results["key"], results["value"]))
  493. if __name__ == "__main__":
  494. main()