lm_inference.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407
  1. #!/usr/bin/env python3
  2. import argparse
  3. import logging
  4. from pathlib import Path
  5. import sys
  6. import os
  7. from typing import Optional
  8. from typing import Sequence
  9. from typing import Tuple
  10. from typing import Union
  11. from typing import Dict
  12. from typing import Any
  13. from typing import List
  14. import numpy as np
  15. import torch
  16. from torch.nn.parallel import data_parallel
  17. from typeguard import check_argument_types
  18. from funasr.tasks.lm import LMTask
  19. from funasr.datasets.preprocessor import LMPreprocessor
  20. from funasr.utils.cli_utils import get_commandline_args
  21. from funasr.fileio.datadir_writer import DatadirWriter
  22. from funasr.torch_utils.device_funcs import to_device
  23. from funasr.torch_utils.forward_adaptor import ForwardAdaptor
  24. from funasr.torch_utils.set_all_random_seed import set_all_random_seed
  25. from funasr.utils import config_argparse
  26. from funasr.utils.types import float_or_none
  27. from funasr.utils.types import str2bool
  28. from funasr.utils.types import str2triple_str
  29. from funasr.utils.types import str_or_none
  30. def inference(
  31. output_dir: str,
  32. batch_size: int,
  33. dtype: str,
  34. ngpu: int,
  35. seed: int,
  36. num_workers: int,
  37. log_level: Union[int, str],
  38. train_config: Optional[str],
  39. model_file: Optional[str],
  40. log_base: Optional[float],
  41. key_file: Optional[str] = None,
  42. allow_variable_data_keys: bool = False,
  43. split_with_space: Optional[bool] = False,
  44. seg_dict_file: Optional[str] = None,
  45. data_path_and_name_and_type: Sequence[Tuple[str, str, str]] = None,
  46. raw_inputs: Union[List[Any], bytes, str] = None,
  47. **kwargs,
  48. ):
  49. inference_pipeline = inference_modelscope(
  50. output_dir=output_dir,
  51. raw_inputs=raw_inputs,
  52. batch_size=batch_size,
  53. dtype=dtype,
  54. ngpu=ngpu,
  55. seed=seed,
  56. num_workers=num_workers,
  57. log_level=log_level,
  58. key_file=key_file,
  59. train_config=train_config,
  60. model_file=model_file,
  61. log_base = log_base,
  62. allow_variable_data_keys = allow_variable_data_keys,
  63. split_with_space=split_with_space,
  64. seg_dict_file=seg_dict_file,
  65. **kwargs,
  66. )
  67. return inference_pipeline(data_path_and_name_and_type, raw_inputs)
  68. def inference_modelscope(
  69. batch_size: int,
  70. dtype: str,
  71. ngpu: int,
  72. seed: int,
  73. num_workers: int,
  74. log_level: Union[int, str],
  75. key_file: Optional[str],
  76. train_config: Optional[str],
  77. model_file: Optional[str],
  78. log_base: Optional[float] = 10,
  79. allow_variable_data_keys: bool = False,
  80. split_with_space: Optional[bool] = False,
  81. seg_dict_file: Optional[str] = None,
  82. output_dir: Optional[str] = None,
  83. param_dict: dict = None,
  84. **kwargs,
  85. ):
  86. assert check_argument_types()
  87. logging.basicConfig(
  88. level=log_level,
  89. format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
  90. )
  91. if ngpu >= 1 and torch.cuda.is_available():
  92. device = "cuda"
  93. else:
  94. device = "cpu"
  95. # 1. Set random-seed
  96. set_all_random_seed(seed)
  97. # 2. Build Model
  98. model, train_args = LMTask.build_model_from_file(
  99. train_config, model_file, device)
  100. wrapped_model = ForwardAdaptor(model, "nll")
  101. wrapped_model.to(dtype=getattr(torch, dtype)).to(device=device).eval()
  102. logging.info(f"Model:\n{model}")
  103. preprocessor = LMPreprocessor(
  104. train=False,
  105. token_type=train_args.token_type,
  106. token_list=train_args.token_list,
  107. bpemodel=train_args.bpemodel,
  108. text_cleaner=train_args.cleaner,
  109. g2p_type=train_args.g2p,
  110. text_name="text",
  111. non_linguistic_symbols=train_args.non_linguistic_symbols,
  112. split_with_space=split_with_space,
  113. seg_dict_file=seg_dict_file
  114. )
  115. def _forward(
  116. data_path_and_name_and_type,
  117. raw_inputs: Union[List[Any], bytes, str] = None,
  118. output_dir_v2: Optional[str] = None,
  119. param_dict: dict = None,
  120. ):
  121. results = []
  122. output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
  123. if output_path is not None:
  124. writer = DatadirWriter(output_path)
  125. else:
  126. writer = None
  127. if raw_inputs != None:
  128. line = raw_inputs.strip()
  129. key = "lm demo"
  130. if line=="":
  131. item = {'key': key, 'value': ""}
  132. results.append(item)
  133. return results
  134. batch = {}
  135. batch['text'] = line
  136. if preprocessor != None:
  137. batch = preprocessor(key, batch)
  138. # Force data-precision
  139. for name in batch:
  140. value = batch[name]
  141. if not isinstance(value, np.ndarray):
  142. raise RuntimeError(
  143. f"All values must be converted to np.ndarray object "
  144. f'by preprocessing, but "{name}" is still {type(value)}.'
  145. )
  146. # Cast to desired type
  147. if value.dtype.kind == "f":
  148. value = value.astype("float32")
  149. elif value.dtype.kind == "i":
  150. value = value.astype("long")
  151. else:
  152. raise NotImplementedError(f"Not supported dtype: {value.dtype}")
  153. batch[name] = value
  154. batch["text_lengths"] = torch.from_numpy(
  155. np.array([len(batch["text"])], dtype='int32'))
  156. batch["text"] = np.expand_dims(batch["text"], axis=0)
  157. with torch.no_grad():
  158. batch = to_device(batch, device)
  159. if ngpu <= 1:
  160. nll, lengths = wrapped_model(**batch)
  161. else:
  162. nll, lengths = data_parallel(
  163. wrapped_model, (), range(ngpu), module_kwargs=batch
  164. )
  165. ## compute ppl
  166. ppl_out_batch = ""
  167. ids2tokens = preprocessor.token_id_converter.ids2tokens
  168. for sent_ids, sent_nll in zip(batch['text'], nll):
  169. pre_word = "<s>"
  170. cur_word = None
  171. sent_lst = ids2tokens(sent_ids) + ['</s>']
  172. ppl_out = " ".join(sent_lst) + "\n"
  173. for word, word_nll in zip(sent_lst, sent_nll):
  174. cur_word = word
  175. word_nll = -word_nll.cpu()
  176. if log_base is None:
  177. word_prob = np.exp(word_nll)
  178. else:
  179. word_prob = log_base ** (word_nll / np.log(log_base))
  180. ppl_out += ' p( {cur} | {pre} ) = {prob} [ {word_nll} ]\n'.format(
  181. cur=cur_word,
  182. pre=pre_word,
  183. prob=round(word_prob.item(), 8),
  184. word_nll=round(word_nll.item(), 8)
  185. )
  186. pre_word = cur_word
  187. sent_nll_mean = sent_nll.mean().cpu().numpy()
  188. sent_nll_sum = sent_nll.sum().cpu().numpy()
  189. if log_base is None:
  190. sent_ppl = np.exp(sent_nll_mean)
  191. else:
  192. sent_ppl = log_base ** (sent_nll_mean / np.log(log_base))
  193. ppl_out += 'logprob= {sent_nll} ppl= {sent_ppl}\n\n'.format(
  194. sent_nll=round(-sent_nll_sum.item(), 4),
  195. sent_ppl=round(sent_ppl.item(), 4)
  196. )
  197. ppl_out_batch += ppl_out
  198. item = {'key': key, 'value': ppl_out}
  199. if writer is not None:
  200. writer["ppl"][key+":\n"] = ppl_out
  201. results.append(item)
  202. return results
  203. # 3. Build data-iterator
  204. loader = LMTask.build_streaming_iterator(
  205. data_path_and_name_and_type,
  206. dtype=dtype,
  207. batch_size=batch_size,
  208. key_file=key_file,
  209. num_workers=num_workers,
  210. preprocess_fn=preprocessor,
  211. collate_fn=LMTask.build_collate_fn(train_args, False),
  212. allow_variable_data_keys=allow_variable_data_keys,
  213. inference=True,
  214. )
  215. # 4. Start for-loop
  216. total_nll = 0.0
  217. total_ntokens = 0
  218. ppl_out_all = ""
  219. for keys, batch in loader:
  220. assert isinstance(batch, dict), type(batch)
  221. assert all(isinstance(s, str) for s in keys), keys
  222. _bs = len(next(iter(batch.values())))
  223. assert len(keys) == _bs, f"{len(keys)} != {_bs}"
  224. ppl_out_batch = ""
  225. with torch.no_grad():
  226. batch = to_device(batch, device)
  227. if ngpu <= 1:
  228. # NOTE(kamo): data_parallel also should work with ngpu=1,
  229. # but for debuggability it's better to keep this block.
  230. nll, lengths = wrapped_model(**batch)
  231. else:
  232. nll, lengths = data_parallel(
  233. wrapped_model, (), range(ngpu), module_kwargs=batch
  234. )
  235. ## print ppl
  236. ids2tokens = preprocessor.token_id_converter.ids2tokens
  237. for key, sent_ids, sent_nll in zip(keys, batch['text'], nll):
  238. pre_word = "<s>"
  239. cur_word = None
  240. sent_lst = ids2tokens(sent_ids) + ['</s>']
  241. ppl_out = " ".join(sent_lst) + "\n"
  242. for word, word_nll in zip(sent_lst, sent_nll):
  243. cur_word = word
  244. word_nll = -word_nll.cpu()
  245. if log_base is None:
  246. word_prob = np.exp(word_nll)
  247. else:
  248. word_prob = log_base ** (word_nll / np.log(log_base))
  249. ppl_out += ' p( {cur} | {pre} ) = {prob} [ {word_nll} ]\n'.format(
  250. cur=cur_word,
  251. pre=pre_word,
  252. prob=round(word_prob.item(), 8),
  253. word_nll=round(word_nll.item(), 8)
  254. )
  255. pre_word = cur_word
  256. sent_nll_mean = sent_nll.mean().cpu().numpy()
  257. sent_nll_sum = sent_nll.sum().cpu().numpy()
  258. if log_base is None:
  259. sent_ppl = np.exp(sent_nll_mean)
  260. else:
  261. sent_ppl = log_base ** (sent_nll_mean / np.log(log_base))
  262. ppl_out += 'logprob= {sent_nll} ppl= {sent_ppl}\n\n'.format(
  263. sent_nll=round(-sent_nll_sum.item(), 4),
  264. sent_ppl=round(sent_ppl.item(), 4)
  265. )
  266. ppl_out_batch += ppl_out
  267. utt2nll = round(-sent_nll_sum.item(), 5)
  268. item = {'key': key, 'value': ppl_out}
  269. if writer is not None:
  270. writer["ppl"][key+":\n"] = ppl_out
  271. writer["utt2nll"][key] = str(utt2nll)
  272. results.append(item)
  273. ppl_out_all += ppl_out_batch
  274. assert _bs == len(nll) == len(lengths), (_bs, len(nll), len(lengths))
  275. # nll: (B, L) -> (B,)
  276. nll = nll.detach().cpu().numpy().sum(1)
  277. # lengths: (B,)
  278. lengths = lengths.detach().cpu().numpy()
  279. total_nll += nll.sum()
  280. total_ntokens += lengths.sum()
  281. if log_base is None:
  282. ppl = np.exp(total_nll / total_ntokens)
  283. else:
  284. ppl = log_base ** (total_nll / total_ntokens / np.log(log_base))
  285. avg_ppl = 'logprob= {total_nll} ppl= {total_ppl}\n'.format(
  286. total_nll=round(-total_nll.item(), 4),
  287. total_ppl=round(ppl.item(), 4)
  288. )
  289. item = {'key': 'AVG PPL', 'value': avg_ppl}
  290. ppl_out_all += avg_ppl
  291. if writer is not None:
  292. writer["ppl"]["AVG PPL : "] = avg_ppl
  293. results.append(item)
  294. return results
  295. return _forward
  296. def get_parser():
  297. parser = config_argparse.ArgumentParser(
  298. description="Calc perplexity",
  299. formatter_class=argparse.ArgumentDefaultsHelpFormatter,
  300. )
  301. parser.add_argument(
  302. "--log_level",
  303. type=lambda x: x.upper(),
  304. default="INFO",
  305. choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
  306. help="The verbose level of logging",
  307. )
  308. parser.add_argument("--output_dir", type=str, required=False)
  309. parser.add_argument(
  310. "--ngpu",
  311. type=int,
  312. default=0,
  313. help="The number of gpus. 0 indicates CPU mode",
  314. )
  315. parser.add_argument("--seed", type=int, default=0, help="Random seed")
  316. parser.add_argument(
  317. "--dtype",
  318. default="float32",
  319. choices=["float16", "float32", "float64"],
  320. help="Data type",
  321. )
  322. parser.add_argument(
  323. "--num_workers",
  324. type=int,
  325. default=1,
  326. help="The number of workers used for DataLoader",
  327. )
  328. parser.add_argument(
  329. "--batch_size",
  330. type=int,
  331. default=1,
  332. help="The batch size for inference",
  333. )
  334. parser.add_argument(
  335. "--log_base",
  336. type=float_or_none,
  337. default=10,
  338. help="The base of logarithm for Perplexity. "
  339. "If None, napier's constant is used.",
  340. required=False
  341. )
  342. group = parser.add_argument_group("Input data related")
  343. group.add_argument(
  344. "--data_path_and_name_and_type",
  345. type=str2triple_str,
  346. action="append",
  347. required=False
  348. )
  349. group.add_argument(
  350. "--raw_inputs",
  351. type=str,
  352. required=False
  353. )
  354. group.add_argument("--key_file", type=str_or_none)
  355. group.add_argument("--allow_variable_data_keys", type=str2bool, default=False)
  356. group.add_argument("--split_with_space", type=str2bool, default=False)
  357. group.add_argument("--seg_dict_file", type=str_or_none)
  358. group = parser.add_argument_group("The model configuration related")
  359. group.add_argument("--train_config", type=str)
  360. group.add_argument("--model_file", type=str)
  361. return parser
  362. def main(cmd=None):
  363. print(get_commandline_args(), file=sys.stderr)
  364. parser = get_parser()
  365. args = parser.parse_args(cmd)
  366. kwargs = vars(args)
  367. inference(**kwargs)
  368. if __name__ == "__main__":
  369. main()