sv_inference.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437
  1. #!/usr/bin/env python3
  2. import argparse
  3. import logging
  4. import os
  5. import sys
  6. from pathlib import Path
  7. from typing import Any
  8. from typing import List
  9. from typing import Optional
  10. from typing import Sequence
  11. from typing import Tuple
  12. from typing import Union
  13. import numpy as np
  14. import torch
  15. from kaldiio import WriteHelper
  16. from typeguard import check_argument_types
  17. from typeguard import check_return_type
  18. from funasr.utils.cli_utils import get_commandline_args
  19. from funasr.tasks.sv import SVTask
  20. from funasr.tasks.asr import ASRTask
  21. from funasr.torch_utils.device_funcs import to_device
  22. from funasr.torch_utils.set_all_random_seed import set_all_random_seed
  23. from funasr.utils import config_argparse
  24. from funasr.utils.types import str2bool
  25. from funasr.utils.types import str2triple_str
  26. from funasr.utils.types import str_or_none
  27. class Speech2Xvector:
  28. """Speech2Xvector class
  29. Examples:
  30. >>> import soundfile
  31. >>> speech2xvector = Speech2Xvector("sv_config.yml", "sv.pth")
  32. >>> audio, rate = soundfile.read("speech.wav")
  33. >>> speech2xvector(audio)
  34. [(text, token, token_int, hypothesis object), ...]
  35. """
  36. def __init__(
  37. self,
  38. sv_train_config: Union[Path, str] = None,
  39. sv_model_file: Union[Path, str] = None,
  40. device: str = "cpu",
  41. batch_size: int = 1,
  42. dtype: str = "float32",
  43. streaming: bool = False,
  44. embedding_node: str = "resnet1_dense",
  45. ):
  46. assert check_argument_types()
  47. # TODO: 1. Build SV model
  48. sv_model, sv_train_args = SVTask.build_model_from_file(
  49. config_file=sv_train_config,
  50. model_file=sv_model_file,
  51. device=device
  52. )
  53. logging.info("sv_model: {}".format(sv_model))
  54. logging.info("sv_train_args: {}".format(sv_train_args))
  55. sv_model.to(dtype=getattr(torch, dtype)).eval()
  56. self.sv_model = sv_model
  57. self.sv_train_args = sv_train_args
  58. self.device = device
  59. self.dtype = dtype
  60. self.embedding_node = embedding_node
  61. @torch.no_grad()
  62. def calculate_embedding(self, speech: Union[torch.Tensor, np.ndarray]) -> torch.Tensor:
  63. # Input as audio signal
  64. if isinstance(speech, np.ndarray):
  65. speech = torch.tensor(speech)
  66. # data: (Nsamples,) -> (1, Nsamples)
  67. speech = speech.unsqueeze(0).to(getattr(torch, self.dtype))
  68. # lengths: (1,)
  69. lengths = speech.new_full([1], dtype=torch.long, fill_value=speech.size(1))
  70. batch = {"speech": speech, "speech_lengths": lengths}
  71. # a. To device
  72. batch = to_device(batch, device=self.device)
  73. # b. Forward Encoder
  74. enc, ilens = self.sv_model.encode(**batch)
  75. # c. Forward Pooling
  76. pooling = self.sv_model.pooling_layer(enc)
  77. # d. Forward Decoder
  78. outputs, embeddings = self.sv_model.decoder(pooling)
  79. if self.embedding_node not in embeddings:
  80. raise ValueError("Required embedding node {} not in {}".format(
  81. self.embedding_node, embeddings.keys()))
  82. return embeddings[self.embedding_node]
  83. @torch.no_grad()
  84. def __call__(
  85. self, speech: Union[torch.Tensor, np.ndarray],
  86. ref_speech: Optional[Union[torch.Tensor, np.ndarray]] = None,
  87. ) -> Tuple[torch.Tensor, Union[torch.Tensor, None], Union[torch.Tensor, None]]:
  88. """Inference
  89. Args:
  90. speech: Input speech data
  91. ref_speech: Reference speech to compare
  92. Returns:
  93. embedding, ref_embedding, similarity_score
  94. """
  95. assert check_argument_types()
  96. self.sv_model.eval()
  97. embedding = self.calculate_embedding(speech)
  98. ref_emb, score = None, None
  99. if ref_speech is not None:
  100. ref_emb = self.calculate_embedding(ref_speech)
  101. score = torch.cosine_similarity(embedding, ref_emb)
  102. results = (embedding, ref_emb, score)
  103. assert check_return_type(results)
  104. return results
  105. @staticmethod
  106. def from_pretrained(
  107. model_tag: Optional[str] = None,
  108. **kwargs: Optional[Any],
  109. ):
  110. """Build Speech2Xvector instance from the pretrained model.
  111. Args:
  112. model_tag (Optional[str]): Model tag of the pretrained models.
  113. Currently, the tags of espnet_model_zoo are supported.
  114. Returns:
  115. Speech2Xvector: Speech2Xvector instance.
  116. """
  117. if model_tag is not None:
  118. try:
  119. from espnet_model_zoo.downloader import ModelDownloader
  120. except ImportError:
  121. logging.error(
  122. "`espnet_model_zoo` is not installed. "
  123. "Please install via `pip install -U espnet_model_zoo`."
  124. )
  125. raise
  126. d = ModelDownloader()
  127. kwargs.update(**d.download_and_unpack(model_tag))
  128. return Speech2Xvector(**kwargs)
  129. def inference_modelscope(
  130. output_dir: Optional[str],
  131. batch_size: int,
  132. dtype: str,
  133. ngpu: int,
  134. seed: int,
  135. num_workers: int,
  136. log_level: Union[int, str],
  137. key_file: Optional[str],
  138. sv_train_config: Optional[str],
  139. sv_model_file: Optional[str],
  140. model_tag: Optional[str],
  141. allow_variable_data_keys: bool = True,
  142. streaming: bool = False,
  143. embedding_node: str = "resnet1_dense",
  144. sv_threshold: float = 0.9465,
  145. param_dict: Optional[dict] = None,
  146. **kwargs,
  147. ):
  148. assert check_argument_types()
  149. if batch_size > 1:
  150. raise NotImplementedError("batch decoding is not implemented")
  151. if ngpu > 1:
  152. raise NotImplementedError("only single GPU decoding is supported")
  153. logging.basicConfig(
  154. level=log_level,
  155. format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
  156. )
  157. logging.info("param_dict: {}".format(param_dict))
  158. if ngpu >= 1 and torch.cuda.is_available():
  159. device = "cuda"
  160. else:
  161. device = "cpu"
  162. # 1. Set random-seed
  163. set_all_random_seed(seed)
  164. # 2. Build speech2xvector
  165. speech2xvector_kwargs = dict(
  166. sv_train_config=sv_train_config,
  167. sv_model_file=sv_model_file,
  168. device=device,
  169. dtype=dtype,
  170. streaming=streaming,
  171. embedding_node=embedding_node
  172. )
  173. logging.info("speech2xvector_kwargs: {}".format(speech2xvector_kwargs))
  174. speech2xvector = Speech2Xvector.from_pretrained(
  175. model_tag=model_tag,
  176. **speech2xvector_kwargs,
  177. )
  178. speech2xvector.sv_model.eval()
  179. def _forward(
  180. data_path_and_name_and_type: Sequence[Tuple[str, str, str]] = None,
  181. raw_inputs: Union[np.ndarray, torch.Tensor] = None,
  182. output_dir_v2: Optional[str] = None,
  183. fs: dict = None,
  184. param_dict: Optional[dict] = None,
  185. ):
  186. logging.info("param_dict: {}".format(param_dict))
  187. if data_path_and_name_and_type is None and raw_inputs is not None:
  188. if isinstance(raw_inputs, torch.Tensor):
  189. raw_inputs = raw_inputs.numpy()
  190. data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
  191. # 3. Build data-iterator
  192. loader = ASRTask.build_streaming_iterator(
  193. data_path_and_name_and_type,
  194. dtype=dtype,
  195. batch_size=batch_size,
  196. key_file=key_file,
  197. num_workers=num_workers,
  198. preprocess_fn=None,
  199. collate_fn=None,
  200. allow_variable_data_keys=allow_variable_data_keys,
  201. inference=True,
  202. )
  203. # 7 .Start for-loop
  204. output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
  205. embd_writer, ref_embd_writer, score_writer = None, None, None
  206. if output_path is not None:
  207. os.makedirs(output_path, exist_ok=True)
  208. embd_writer = WriteHelper("ark,scp:{}/xvector.ark,{}/xvector.scp".format(output_path, output_path))
  209. sv_result_list = []
  210. for keys, batch in loader:
  211. assert isinstance(batch, dict), type(batch)
  212. assert all(isinstance(s, str) for s in keys), keys
  213. _bs = len(next(iter(batch.values())))
  214. assert len(keys) == _bs, f"{len(keys)} != {_bs}"
  215. batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")}
  216. embedding, ref_embedding, score = speech2xvector(**batch)
  217. # Only supporting batch_size==1
  218. key = keys[0]
  219. normalized_score = 0.0
  220. if score is not None:
  221. score = score.item()
  222. normalized_score = max(score - sv_threshold, 0.0) / (1.0 - sv_threshold) * 100.0
  223. item = {"key": key, "value": normalized_score}
  224. else:
  225. item = {"key": key, "value": embedding.squeeze(0).cpu().numpy()}
  226. sv_result_list.append(item)
  227. if output_path is not None:
  228. embd_writer(key, embedding[0].cpu().numpy())
  229. if ref_embedding is not None:
  230. if ref_embd_writer is None:
  231. ref_embd_writer = WriteHelper(
  232. "ark,scp:{}/ref_xvector.ark,{}/ref_xvector.scp".format(output_path, output_path)
  233. )
  234. score_writer = open(os.path.join(output_path, "score.txt"), "w")
  235. ref_embd_writer(key, ref_embedding[0].cpu().numpy())
  236. score_writer.write("{} {:.6f}\n".format(key, normalized_score))
  237. if output_path is not None:
  238. embd_writer.close()
  239. if ref_embd_writer is not None:
  240. ref_embd_writer.close()
  241. score_writer.close()
  242. return sv_result_list
  243. return _forward
  244. def inference(
  245. output_dir: Optional[str],
  246. batch_size: int,
  247. dtype: str,
  248. ngpu: int,
  249. seed: int,
  250. num_workers: int,
  251. log_level: Union[int, str],
  252. data_path_and_name_and_type: Sequence[Tuple[str, str, str]],
  253. key_file: Optional[str],
  254. sv_train_config: Optional[str],
  255. sv_model_file: Optional[str],
  256. model_tag: Optional[str],
  257. allow_variable_data_keys: bool = True,
  258. streaming: bool = False,
  259. embedding_node: str = "resnet1_dense",
  260. sv_threshold: float = 0.9465,
  261. **kwargs,
  262. ):
  263. inference_pipeline = inference_modelscope(
  264. output_dir=output_dir,
  265. batch_size=batch_size,
  266. dtype=dtype,
  267. ngpu=ngpu,
  268. seed=seed,
  269. num_workers=num_workers,
  270. log_level=log_level,
  271. key_file=key_file,
  272. sv_train_config=sv_train_config,
  273. sv_model_file=sv_model_file,
  274. model_tag=model_tag,
  275. allow_variable_data_keys=allow_variable_data_keys,
  276. streaming=streaming,
  277. embedding_node=embedding_node,
  278. sv_threshold=sv_threshold,
  279. **kwargs,
  280. )
  281. return inference_pipeline(data_path_and_name_and_type, raw_inputs=None)
  282. def get_parser():
  283. parser = config_argparse.ArgumentParser(
  284. description="Speaker verification/x-vector extraction",
  285. formatter_class=argparse.ArgumentDefaultsHelpFormatter,
  286. )
  287. # Note(kamo): Use '_' instead of '-' as separator.
  288. # '-' is confusing if written in yaml.
  289. parser.add_argument(
  290. "--log_level",
  291. type=lambda x: x.upper(),
  292. default="INFO",
  293. choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
  294. help="The verbose level of logging",
  295. )
  296. parser.add_argument("--output_dir", type=str, required=False)
  297. parser.add_argument(
  298. "--ngpu",
  299. type=int,
  300. default=0,
  301. help="The number of gpus. 0 indicates CPU mode",
  302. )
  303. parser.add_argument(
  304. "--gpuid_list",
  305. type=str,
  306. default="",
  307. help="The visible gpus",
  308. )
  309. parser.add_argument("--seed", type=int, default=0, help="Random seed")
  310. parser.add_argument(
  311. "--dtype",
  312. default="float32",
  313. choices=["float16", "float32", "float64"],
  314. help="Data type",
  315. )
  316. parser.add_argument(
  317. "--num_workers",
  318. type=int,
  319. default=1,
  320. help="The number of workers used for DataLoader",
  321. )
  322. group = parser.add_argument_group("Input data related")
  323. group.add_argument(
  324. "--data_path_and_name_and_type",
  325. type=str2triple_str,
  326. required=False,
  327. action="append",
  328. )
  329. group.add_argument("--key_file", type=str_or_none)
  330. group.add_argument("--allow_variable_data_keys", type=str2bool, default=False)
  331. group = parser.add_argument_group("The model configuration related")
  332. group.add_argument(
  333. "--sv_train_config",
  334. type=str,
  335. help="SV training configuration",
  336. )
  337. group.add_argument(
  338. "--sv_model_file",
  339. type=str,
  340. help="SV model parameter file",
  341. )
  342. group.add_argument(
  343. "--sv_threshold",
  344. type=float,
  345. default=0.9465,
  346. help="The threshold for verification"
  347. )
  348. group.add_argument(
  349. "--model_tag",
  350. type=str,
  351. help="Pretrained model tag. If specify this option, *_train_config and "
  352. "*_file will be overwritten",
  353. )
  354. parser.add_argument(
  355. "--batch_size",
  356. type=int,
  357. default=1,
  358. help="The batch size for inference",
  359. )
  360. parser.add_argument("--streaming", type=str2bool, default=False)
  361. parser.add_argument("--embedding_node", type=str, default="resnet1_dense")
  362. return parser
  363. def main(cmd=None):
  364. print(get_commandline_args(), file=sys.stderr)
  365. parser = get_parser()
  366. args = parser.parse_args(cmd)
  367. kwargs = vars(args)
  368. kwargs.pop("config", None)
  369. logging.info("args: {}".format(kwargs))
  370. if args.output_dir is None:
  371. jobid, n_gpu = 1, 1
  372. gpuid = args.gpuid_list.split(",")[jobid-1]
  373. else:
  374. jobid = int(args.output_dir.split(".")[-1])
  375. n_gpu = len(args.gpuid_list.split(","))
  376. gpuid = args.gpuid_list.split(",")[(jobid - 1) % n_gpu]
  377. os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
  378. os.environ["CUDA_VISIBLE_DEVICES"] = gpuid
  379. results_list = inference(**kwargs)
  380. for results in results_list:
  381. print("{} {}".format(results["key"], results["value"]))
  382. if __name__ == "__main__":
  383. main()