lm_inference.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406
  1. #!/usr/bin/env python3
  2. import argparse
  3. import logging
  4. from pathlib import Path
  5. import sys
  6. import os
  7. from typing import Optional
  8. from typing import Sequence
  9. from typing import Tuple
  10. from typing import Union
  11. from typing import Dict
  12. from typing import Any
  13. from typing import List
  14. import numpy as np
  15. import torch
  16. from torch.nn.parallel import data_parallel
  17. from typeguard import check_argument_types
  18. from funasr.tasks.lm import LMTask
  19. from funasr.datasets.preprocessor import LMPreprocessor
  20. from funasr.utils.cli_utils import get_commandline_args
  21. from funasr.fileio.datadir_writer import DatadirWriter
  22. from funasr.torch_utils.device_funcs import to_device
  23. from funasr.torch_utils.forward_adaptor import ForwardAdaptor
  24. from funasr.torch_utils.set_all_random_seed import set_all_random_seed
  25. from funasr.utils import config_argparse
  26. from funasr.utils.types import float_or_none
  27. from funasr.utils.types import str2bool
  28. from funasr.utils.types import str2triple_str
  29. from funasr.utils.types import str_or_none
  30. def inference(
  31. output_dir: str,
  32. batch_size: int,
  33. dtype: str,
  34. ngpu: int,
  35. seed: int,
  36. num_workers: int,
  37. log_level: Union[int, str],
  38. train_config: Optional[str],
  39. model_file: Optional[str],
  40. log_base: Optional[float],
  41. key_file: Optional[str] = None,
  42. allow_variable_data_keys: bool = False,
  43. split_with_space: Optional[bool] = False,
  44. seg_dict_file: Optional[str] = None,
  45. data_path_and_name_and_type: Sequence[Tuple[str, str, str]] = None,
  46. raw_inputs: Union[List[Any], bytes, str] = None,
  47. **kwargs,
  48. ):
  49. inference_pipeline = inference_modelscope(
  50. output_dir=output_dir,
  51. raw_inputs=raw_inputs,
  52. batch_size=batch_size,
  53. dtype=dtype,
  54. ngpu=ngpu,
  55. seed=seed,
  56. num_workers=num_workers,
  57. log_level=log_level,
  58. key_file=key_file,
  59. train_config=train_config,
  60. model_file=model_file,
  61. log_base = log_base,
  62. allow_variable_data_keys = allow_variable_data_keys,
  63. split_with_space=split_with_space,
  64. seg_dict_file=seg_dict_file,
  65. **kwargs,
  66. )
  67. return inference_pipeline(data_path_and_name_and_type, raw_inputs)
  68. def inference_modelscope(
  69. batch_size: int,
  70. dtype: str,
  71. ngpu: int,
  72. seed: int,
  73. num_workers: int,
  74. log_level: Union[int, str],
  75. key_file: Optional[str],
  76. train_config: Optional[str],
  77. model_file: Optional[str],
  78. log_base: Optional[float] = 10,
  79. allow_variable_data_keys: bool = False,
  80. split_with_space: Optional[bool] = False,
  81. seg_dict_file: Optional[str] = None,
  82. output_dir: Optional[str] = None,
  83. param_dict: dict = None,
  84. **kwargs,
  85. ):
  86. assert check_argument_types()
  87. ncpu = kwargs.get("ncpu", 1)
  88. torch.set_num_threads(ncpu)
  89. if ngpu >= 1 and torch.cuda.is_available():
  90. device = "cuda"
  91. else:
  92. device = "cpu"
  93. # 1. Set random-seed
  94. set_all_random_seed(seed)
  95. # 2. Build Model
  96. model, train_args = LMTask.build_model_from_file(
  97. train_config, model_file, device)
  98. wrapped_model = ForwardAdaptor(model, "nll")
  99. wrapped_model.to(dtype=getattr(torch, dtype)).to(device=device).eval()
  100. logging.info(f"Model:\n{model}")
  101. preprocessor = LMPreprocessor(
  102. train=False,
  103. token_type=train_args.token_type,
  104. token_list=train_args.token_list,
  105. bpemodel=train_args.bpemodel,
  106. text_cleaner=train_args.cleaner,
  107. g2p_type=train_args.g2p,
  108. text_name="text",
  109. non_linguistic_symbols=train_args.non_linguistic_symbols,
  110. split_with_space=split_with_space,
  111. seg_dict_file=seg_dict_file
  112. )
  113. def _forward(
  114. data_path_and_name_and_type,
  115. raw_inputs: Union[List[Any], bytes, str] = None,
  116. output_dir_v2: Optional[str] = None,
  117. param_dict: dict = None,
  118. ):
  119. results = []
  120. output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
  121. if output_path is not None:
  122. writer = DatadirWriter(output_path)
  123. else:
  124. writer = None
  125. if raw_inputs != None:
  126. line = raw_inputs.strip()
  127. key = "lm demo"
  128. if line=="":
  129. item = {'key': key, 'value': ""}
  130. results.append(item)
  131. return results
  132. batch = {}
  133. batch['text'] = line
  134. if preprocessor != None:
  135. batch = preprocessor(key, batch)
  136. # Force data-precision
  137. for name in batch:
  138. value = batch[name]
  139. if not isinstance(value, np.ndarray):
  140. raise RuntimeError(
  141. f"All values must be converted to np.ndarray object "
  142. f'by preprocessing, but "{name}" is still {type(value)}.'
  143. )
  144. # Cast to desired type
  145. if value.dtype.kind == "f":
  146. value = value.astype("float32")
  147. elif value.dtype.kind == "i":
  148. value = value.astype("long")
  149. else:
  150. raise NotImplementedError(f"Not supported dtype: {value.dtype}")
  151. batch[name] = value
  152. batch["text_lengths"] = torch.from_numpy(
  153. np.array([len(batch["text"])], dtype='int32'))
  154. batch["text"] = np.expand_dims(batch["text"], axis=0)
  155. with torch.no_grad():
  156. batch = to_device(batch, device)
  157. if ngpu <= 1:
  158. nll, lengths = wrapped_model(**batch)
  159. else:
  160. nll, lengths = data_parallel(
  161. wrapped_model, (), range(ngpu), module_kwargs=batch
  162. )
  163. ## compute ppl
  164. ppl_out_batch = ""
  165. ids2tokens = preprocessor.token_id_converter.ids2tokens
  166. for sent_ids, sent_nll in zip(batch['text'], nll):
  167. pre_word = "<s>"
  168. cur_word = None
  169. sent_lst = ids2tokens(sent_ids) + ['</s>']
  170. ppl_out = " ".join(sent_lst) + "\n"
  171. for word, word_nll in zip(sent_lst, sent_nll):
  172. cur_word = word
  173. word_nll = -word_nll.cpu()
  174. if log_base is None:
  175. word_prob = np.exp(word_nll)
  176. else:
  177. word_prob = log_base ** (word_nll / np.log(log_base))
  178. ppl_out += ' p( {cur} | {pre} ) = {prob} [ {word_nll} ]\n'.format(
  179. cur=cur_word,
  180. pre=pre_word,
  181. prob=round(word_prob.item(), 8),
  182. word_nll=round(word_nll.item(), 8)
  183. )
  184. pre_word = cur_word
  185. sent_nll_mean = sent_nll.mean().cpu().numpy()
  186. sent_nll_sum = sent_nll.sum().cpu().numpy()
  187. if log_base is None:
  188. sent_ppl = np.exp(sent_nll_mean)
  189. else:
  190. sent_ppl = log_base ** (sent_nll_mean / np.log(log_base))
  191. ppl_out += 'logprob= {sent_nll} ppl= {sent_ppl}\n\n'.format(
  192. sent_nll=round(-sent_nll_sum.item(), 4),
  193. sent_ppl=round(sent_ppl.item(), 4)
  194. )
  195. ppl_out_batch += ppl_out
  196. item = {'key': key, 'value': ppl_out}
  197. if writer is not None:
  198. writer["ppl"][key+":\n"] = ppl_out
  199. results.append(item)
  200. return results
  201. # 3. Build data-iterator
  202. loader = LMTask.build_streaming_iterator(
  203. data_path_and_name_and_type,
  204. dtype=dtype,
  205. batch_size=batch_size,
  206. key_file=key_file,
  207. num_workers=num_workers,
  208. preprocess_fn=preprocessor,
  209. collate_fn=LMTask.build_collate_fn(train_args, False),
  210. allow_variable_data_keys=allow_variable_data_keys,
  211. inference=True,
  212. )
  213. # 4. Start for-loop
  214. total_nll = 0.0
  215. total_ntokens = 0
  216. ppl_out_all = ""
  217. for keys, batch in loader:
  218. assert isinstance(batch, dict), type(batch)
  219. assert all(isinstance(s, str) for s in keys), keys
  220. _bs = len(next(iter(batch.values())))
  221. assert len(keys) == _bs, f"{len(keys)} != {_bs}"
  222. ppl_out_batch = ""
  223. with torch.no_grad():
  224. batch = to_device(batch, device)
  225. if ngpu <= 1:
  226. # NOTE(kamo): data_parallel also should work with ngpu=1,
  227. # but for debuggability it's better to keep this block.
  228. nll, lengths = wrapped_model(**batch)
  229. else:
  230. nll, lengths = data_parallel(
  231. wrapped_model, (), range(ngpu), module_kwargs=batch
  232. )
  233. ## print ppl
  234. ids2tokens = preprocessor.token_id_converter.ids2tokens
  235. for key, sent_ids, sent_nll in zip(keys, batch['text'], nll):
  236. pre_word = "<s>"
  237. cur_word = None
  238. sent_lst = ids2tokens(sent_ids) + ['</s>']
  239. ppl_out = " ".join(sent_lst) + "\n"
  240. for word, word_nll in zip(sent_lst, sent_nll):
  241. cur_word = word
  242. word_nll = -word_nll.cpu()
  243. if log_base is None:
  244. word_prob = np.exp(word_nll)
  245. else:
  246. word_prob = log_base ** (word_nll / np.log(log_base))
  247. ppl_out += ' p( {cur} | {pre} ) = {prob} [ {word_nll} ]\n'.format(
  248. cur=cur_word,
  249. pre=pre_word,
  250. prob=round(word_prob.item(), 8),
  251. word_nll=round(word_nll.item(), 8)
  252. )
  253. pre_word = cur_word
  254. sent_nll_mean = sent_nll.mean().cpu().numpy()
  255. sent_nll_sum = sent_nll.sum().cpu().numpy()
  256. if log_base is None:
  257. sent_ppl = np.exp(sent_nll_mean)
  258. else:
  259. sent_ppl = log_base ** (sent_nll_mean / np.log(log_base))
  260. ppl_out += 'logprob= {sent_nll} ppl= {sent_ppl}\n\n'.format(
  261. sent_nll=round(-sent_nll_sum.item(), 4),
  262. sent_ppl=round(sent_ppl.item(), 4)
  263. )
  264. ppl_out_batch += ppl_out
  265. utt2nll = round(-sent_nll_sum.item(), 5)
  266. item = {'key': key, 'value': ppl_out}
  267. if writer is not None:
  268. writer["ppl"][key+":\n"] = ppl_out
  269. writer["utt2nll"][key] = str(utt2nll)
  270. results.append(item)
  271. ppl_out_all += ppl_out_batch
  272. assert _bs == len(nll) == len(lengths), (_bs, len(nll), len(lengths))
  273. # nll: (B, L) -> (B,)
  274. nll = nll.detach().cpu().numpy().sum(1)
  275. # lengths: (B,)
  276. lengths = lengths.detach().cpu().numpy()
  277. total_nll += nll.sum()
  278. total_ntokens += lengths.sum()
  279. if log_base is None:
  280. ppl = np.exp(total_nll / total_ntokens)
  281. else:
  282. ppl = log_base ** (total_nll / total_ntokens / np.log(log_base))
  283. avg_ppl = 'logprob= {total_nll} ppl= {total_ppl}\n'.format(
  284. total_nll=round(-total_nll.item(), 4),
  285. total_ppl=round(ppl.item(), 4)
  286. )
  287. item = {'key': 'AVG PPL', 'value': avg_ppl}
  288. ppl_out_all += avg_ppl
  289. if writer is not None:
  290. writer["ppl"]["AVG PPL : "] = avg_ppl
  291. results.append(item)
  292. return results
  293. return _forward
  294. def get_parser():
  295. parser = config_argparse.ArgumentParser(
  296. description="Calc perplexity",
  297. formatter_class=argparse.ArgumentDefaultsHelpFormatter,
  298. )
  299. parser.add_argument(
  300. "--log_level",
  301. type=lambda x: x.upper(),
  302. default="INFO",
  303. choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
  304. help="The verbose level of logging",
  305. )
  306. parser.add_argument("--output_dir", type=str, required=False)
  307. parser.add_argument(
  308. "--ngpu",
  309. type=int,
  310. default=0,
  311. help="The number of gpus. 0 indicates CPU mode",
  312. )
  313. parser.add_argument("--seed", type=int, default=0, help="Random seed")
  314. parser.add_argument(
  315. "--dtype",
  316. default="float32",
  317. choices=["float16", "float32", "float64"],
  318. help="Data type",
  319. )
  320. parser.add_argument(
  321. "--num_workers",
  322. type=int,
  323. default=1,
  324. help="The number of workers used for DataLoader",
  325. )
  326. parser.add_argument(
  327. "--batch_size",
  328. type=int,
  329. default=1,
  330. help="The batch size for inference",
  331. )
  332. parser.add_argument(
  333. "--log_base",
  334. type=float_or_none,
  335. default=10,
  336. help="The base of logarithm for Perplexity. "
  337. "If None, napier's constant is used.",
  338. required=False
  339. )
  340. group = parser.add_argument_group("Input data related")
  341. group.add_argument(
  342. "--data_path_and_name_and_type",
  343. type=str2triple_str,
  344. action="append",
  345. required=False
  346. )
  347. group.add_argument(
  348. "--raw_inputs",
  349. type=str,
  350. required=False
  351. )
  352. group.add_argument("--key_file", type=str_or_none)
  353. group.add_argument("--allow_variable_data_keys", type=str2bool, default=False)
  354. group.add_argument("--split_with_space", type=str2bool, default=False)
  355. group.add_argument("--seg_dict_file", type=str_or_none)
  356. group = parser.add_argument_group("The model configuration related")
  357. group.add_argument("--train_config", type=str)
  358. group.add_argument("--model_file", type=str)
  359. return parser
  360. def main(cmd=None):
  361. print(get_commandline_args(), file=sys.stderr)
  362. parser = get_parser()
  363. args = parser.parse_args(cmd)
  364. kwargs = vars(args)
  365. inference(**kwargs)
  366. if __name__ == "__main__":
  367. main()