lm_inference_launch.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392
  1. #!/usr/bin/env python3
  2. # -*- encoding: utf-8 -*-
  3. # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
  4. # MIT License (https://opensource.org/licenses/MIT)
  5. import argparse
  6. import logging
  7. import os
  8. import sys
  9. from typing import Any
  10. from typing import List
  11. from typing import Optional
  12. from typing import Union
  13. import numpy as np
  14. import torch
  15. from torch.nn.parallel import data_parallel
  16. from funasr.build_utils.build_model_from_file import build_model_from_file
  17. from funasr.build_utils.build_streaming_iterator import build_streaming_iterator
  18. from funasr.datasets.preprocessor import LMPreprocessor
  19. from funasr.fileio.datadir_writer import DatadirWriter
  20. from funasr.torch_utils.device_funcs import to_device
  21. from funasr.torch_utils.forward_adaptor import ForwardAdaptor
  22. from funasr.torch_utils.set_all_random_seed import set_all_random_seed
  23. from funasr.utils import config_argparse
  24. from funasr.utils.cli_utils import get_commandline_args
  25. from funasr.utils.types import float_or_none
  26. from funasr.utils.types import str2bool
  27. from funasr.utils.types import str2triple_str
  28. from funasr.utils.types import str_or_none
  29. def inference_lm(
  30. batch_size: int,
  31. dtype: str,
  32. ngpu: int,
  33. seed: int,
  34. num_workers: int,
  35. log_level: Union[int, str],
  36. key_file: Optional[str],
  37. train_config: Optional[str],
  38. model_file: Optional[str],
  39. log_base: Optional[float] = 10,
  40. allow_variable_data_keys: bool = False,
  41. split_with_space: Optional[bool] = False,
  42. seg_dict_file: Optional[str] = None,
  43. output_dir: Optional[str] = None,
  44. param_dict: dict = None,
  45. **kwargs,
  46. ):
  47. ncpu = kwargs.get("ncpu", 1)
  48. torch.set_num_threads(ncpu)
  49. if ngpu >= 1 and torch.cuda.is_available():
  50. device = "cuda"
  51. else:
  52. device = "cpu"
  53. # 1. Set random-seed
  54. set_all_random_seed(seed)
  55. # 2. Build Model
  56. model, train_args = build_model_from_file(
  57. train_config, model_file, None, device, "lm")
  58. wrapped_model = ForwardAdaptor(model, "nll")
  59. wrapped_model.to(dtype=getattr(torch, dtype)).to(device=device).eval()
  60. logging.info(f"Model:\n{model}")
  61. preprocessor = LMPreprocessor(
  62. train=False,
  63. token_type=train_args.token_type,
  64. token_list=train_args.token_list,
  65. bpemodel=train_args.bpemodel,
  66. text_cleaner=train_args.cleaner,
  67. g2p_type=train_args.g2p,
  68. text_name="text",
  69. non_linguistic_symbols=train_args.non_linguistic_symbols,
  70. split_with_space=split_with_space,
  71. seg_dict_file=seg_dict_file
  72. )
  73. def _forward(
  74. data_path_and_name_and_type,
  75. raw_inputs: Union[List[Any], bytes, str] = None,
  76. output_dir_v2: Optional[str] = None,
  77. param_dict: dict = None,
  78. ):
  79. results = []
  80. output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
  81. if output_path is not None:
  82. writer = DatadirWriter(output_path)
  83. else:
  84. writer = None
  85. if raw_inputs != None:
  86. line = raw_inputs.strip()
  87. key = "lm demo"
  88. if line == "":
  89. item = {'key': key, 'value': ""}
  90. results.append(item)
  91. return results
  92. batch = {}
  93. batch['text'] = line
  94. if preprocessor != None:
  95. batch = preprocessor(key, batch)
  96. # Force data-precision
  97. for name in batch:
  98. value = batch[name]
  99. if not isinstance(value, np.ndarray):
  100. raise RuntimeError(
  101. f"All values must be converted to np.ndarray object "
  102. f'by preprocessing, but "{name}" is still {type(value)}.'
  103. )
  104. # Cast to desired type
  105. if value.dtype.kind == "f":
  106. value = value.astype("float32")
  107. elif value.dtype.kind == "i":
  108. value = value.astype("long")
  109. else:
  110. raise NotImplementedError(f"Not supported dtype: {value.dtype}")
  111. batch[name] = value
  112. batch["text_lengths"] = torch.from_numpy(
  113. np.array([len(batch["text"])], dtype='int32'))
  114. batch["text"] = np.expand_dims(batch["text"], axis=0)
  115. with torch.no_grad():
  116. batch = to_device(batch, device)
  117. if ngpu <= 1:
  118. nll, lengths = wrapped_model(**batch)
  119. else:
  120. nll, lengths = data_parallel(
  121. wrapped_model, (), range(ngpu), module_kwargs=batch
  122. )
  123. ## compute ppl
  124. ppl_out_batch = ""
  125. ids2tokens = preprocessor.token_id_converter.ids2tokens
  126. for sent_ids, sent_nll in zip(batch['text'], nll):
  127. pre_word = "<s>"
  128. cur_word = None
  129. sent_lst = ids2tokens(sent_ids) + ['</s>']
  130. ppl_out = " ".join(sent_lst) + "\n"
  131. for word, word_nll in zip(sent_lst, sent_nll):
  132. cur_word = word
  133. word_nll = -word_nll.cpu()
  134. if log_base is None:
  135. word_prob = np.exp(word_nll)
  136. else:
  137. word_prob = log_base ** (word_nll / np.log(log_base))
  138. ppl_out += ' p( {cur} | {pre} ) = {prob} [ {word_nll} ]\n'.format(
  139. cur=cur_word,
  140. pre=pre_word,
  141. prob=round(word_prob.item(), 8),
  142. word_nll=round(word_nll.item(), 8)
  143. )
  144. pre_word = cur_word
  145. sent_nll_mean = sent_nll.mean().cpu().numpy()
  146. sent_nll_sum = sent_nll.sum().cpu().numpy()
  147. if log_base is None:
  148. sent_ppl = np.exp(sent_nll_mean)
  149. else:
  150. sent_ppl = log_base ** (sent_nll_mean / np.log(log_base))
  151. ppl_out += 'logprob= {sent_nll} ppl= {sent_ppl}\n\n'.format(
  152. sent_nll=round(-sent_nll_sum.item(), 4),
  153. sent_ppl=round(sent_ppl.item(), 4)
  154. )
  155. ppl_out_batch += ppl_out
  156. item = {'key': key, 'value': ppl_out}
  157. if writer is not None:
  158. writer["ppl"][key + ":\n"] = ppl_out
  159. results.append(item)
  160. return results
  161. # 3. Build data-iterator
  162. loader = build_streaming_iterator(
  163. task_name="lm",
  164. preprocess_args=train_args,
  165. data_path_and_name_and_type=data_path_and_name_and_type,
  166. dtype=dtype,
  167. batch_size=batch_size,
  168. key_file=key_file,
  169. preprocess_fn=preprocessor,
  170. num_workers=num_workers,
  171. )
  172. # 4. Start for-loop
  173. total_nll = 0.0
  174. total_ntokens = 0
  175. ppl_out_all = ""
  176. for keys, batch in loader:
  177. assert isinstance(batch, dict), type(batch)
  178. assert all(isinstance(s, str) for s in keys), keys
  179. _bs = len(next(iter(batch.values())))
  180. assert len(keys) == _bs, f"{len(keys)} != {_bs}"
  181. ppl_out_batch = ""
  182. with torch.no_grad():
  183. batch = to_device(batch, device)
  184. if ngpu <= 1:
  185. # NOTE(kamo): data_parallel also should work with ngpu=1,
  186. # but for debuggability it's better to keep this block.
  187. nll, lengths = wrapped_model(**batch)
  188. else:
  189. nll, lengths = data_parallel(
  190. wrapped_model, (), range(ngpu), module_kwargs=batch
  191. )
  192. ## print ppl
  193. ids2tokens = preprocessor.token_id_converter.ids2tokens
  194. for key, sent_ids, sent_nll in zip(keys, batch['text'], nll):
  195. pre_word = "<s>"
  196. cur_word = None
  197. sent_lst = ids2tokens(sent_ids) + ['</s>']
  198. ppl_out = " ".join(sent_lst) + "\n"
  199. for word, word_nll in zip(sent_lst, sent_nll):
  200. cur_word = word
  201. word_nll = -word_nll.cpu()
  202. if log_base is None:
  203. word_prob = np.exp(word_nll)
  204. else:
  205. word_prob = log_base ** (word_nll / np.log(log_base))
  206. ppl_out += ' p( {cur} | {pre} ) = {prob} [ {word_nll} ]\n'.format(
  207. cur=cur_word,
  208. pre=pre_word,
  209. prob=round(word_prob.item(), 8),
  210. word_nll=round(word_nll.item(), 8)
  211. )
  212. pre_word = cur_word
  213. sent_nll_mean = sent_nll.mean().cpu().numpy()
  214. sent_nll_sum = sent_nll.sum().cpu().numpy()
  215. if log_base is None:
  216. sent_ppl = np.exp(sent_nll_mean)
  217. else:
  218. sent_ppl = log_base ** (sent_nll_mean / np.log(log_base))
  219. ppl_out += 'logprob= {sent_nll} ppl= {sent_ppl}\n\n'.format(
  220. sent_nll=round(-sent_nll_sum.item(), 4),
  221. sent_ppl=round(sent_ppl.item(), 4)
  222. )
  223. ppl_out_batch += ppl_out
  224. utt2nll = round(-sent_nll_sum.item(), 5)
  225. item = {'key': key, 'value': ppl_out}
  226. if writer is not None:
  227. writer["ppl"][key + ":\n"] = ppl_out
  228. writer["utt2nll"][key] = str(utt2nll)
  229. results.append(item)
  230. ppl_out_all += ppl_out_batch
  231. assert _bs == len(nll) == len(lengths), (_bs, len(nll), len(lengths))
  232. # nll: (B, L) -> (B,)
  233. nll = nll.detach().cpu().numpy().sum(1)
  234. # lengths: (B,)
  235. lengths = lengths.detach().cpu().numpy()
  236. total_nll += nll.sum()
  237. total_ntokens += lengths.sum()
  238. if log_base is None:
  239. ppl = np.exp(total_nll / total_ntokens)
  240. else:
  241. ppl = log_base ** (total_nll / total_ntokens / np.log(log_base))
  242. avg_ppl = 'logprob= {total_nll} ppl= {total_ppl}\n'.format(
  243. total_nll=round(-total_nll.item(), 4),
  244. total_ppl=round(ppl.item(), 4)
  245. )
  246. item = {'key': 'AVG PPL', 'value': avg_ppl}
  247. ppl_out_all += avg_ppl
  248. if writer is not None:
  249. writer["ppl"]["AVG PPL : "] = avg_ppl
  250. results.append(item)
  251. return results
  252. return _forward
  253. def inference_launch(mode, **kwargs):
  254. if mode == "transformer":
  255. return inference_lm(**kwargs)
  256. else:
  257. logging.info("Unknown decoding mode: {}".format(mode))
  258. return None
  259. def get_parser():
  260. parser = config_argparse.ArgumentParser(
  261. description="Calc perplexity",
  262. formatter_class=argparse.ArgumentDefaultsHelpFormatter,
  263. )
  264. parser.add_argument(
  265. "--log_level",
  266. type=lambda x: x.upper(),
  267. default="INFO",
  268. choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
  269. help="The verbose level of logging",
  270. )
  271. parser.add_argument("--output_dir", type=str, required=True)
  272. parser.add_argument("--gpuid_list", type=str, required=True)
  273. parser.add_argument(
  274. "--ngpu",
  275. type=int,
  276. default=0,
  277. help="The number of gpus. 0 indicates CPU mode",
  278. )
  279. parser.add_argument("--seed", type=int, default=0, help="Random seed")
  280. parser.add_argument("--njob", type=int, default=1, help="Random seed")
  281. parser.add_argument(
  282. "--dtype",
  283. default="float32",
  284. choices=["float16", "float32", "float64"],
  285. help="Data type",
  286. )
  287. parser.add_argument(
  288. "--num_workers",
  289. type=int,
  290. default=1,
  291. help="The number of workers used for DataLoader",
  292. )
  293. parser.add_argument(
  294. "--batch_size",
  295. type=int,
  296. default=1,
  297. help="The batch size for inference",
  298. )
  299. parser.add_argument(
  300. "--log_base",
  301. type=float_or_none,
  302. default=10,
  303. help="The base of logarithm for Perplexity. "
  304. "If None, napier's constant is used.",
  305. required=False
  306. )
  307. group = parser.add_argument_group("Input data related")
  308. group.add_argument(
  309. "--data_path_and_name_and_type",
  310. type=str2triple_str,
  311. action="append",
  312. required=False
  313. )
  314. group.add_argument(
  315. "--raw_inputs",
  316. type=str,
  317. required=False
  318. )
  319. group.add_argument("--key_file", type=str_or_none)
  320. group.add_argument("--allow_variable_data_keys", type=str2bool, default=False)
  321. group.add_argument("--split_with_space", type=str2bool, default=False)
  322. group.add_argument("--seg_dict_file", type=str_or_none)
  323. group = parser.add_argument_group("The model configuration related")
  324. group.add_argument("--train_config", type=str)
  325. group.add_argument("--model_file", type=str)
  326. group.add_argument("--mode", type=str, default="lm")
  327. return parser
  328. def main(cmd=None):
  329. print(get_commandline_args(), file=sys.stderr)
  330. parser = get_parser()
  331. args = parser.parse_args(cmd)
  332. kwargs = vars(args)
  333. kwargs.pop("config", None)
  334. # set logging messages
  335. logging.basicConfig(
  336. level=args.log_level,
  337. format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
  338. )
  339. logging.info("Decoding args: {}".format(kwargs))
  340. # gpu setting
  341. if args.ngpu > 0:
  342. jobid = int(args.output_dir.split(".")[-1])
  343. gpuid = args.gpuid_list.split(",")[(jobid - 1) // args.njob]
  344. os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
  345. os.environ["CUDA_VISIBLE_DEVICES"] = gpuid
  346. kwargs.pop("gpuid_list", None)
  347. kwargs.pop("njob", None)
  348. inference_pipeline = inference_launch(**kwargs)
  349. return inference_pipeline(kwargs["data_path_and_name_and_type"])
  350. if __name__ == "__main__":
  351. main()