lm_inference_launch.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409
  1. # -*- encoding: utf-8 -*-
  2. #!/usr/bin/env python3
  3. # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
  4. # MIT License (https://opensource.org/licenses/MIT)
  5. import argparse
  6. import logging
  7. import os
  8. import sys
  9. from typing import Union, Dict, Any
  10. from funasr.utils import config_argparse
  11. from funasr.utils.cli_utils import get_commandline_args
  12. from funasr.utils.types import str2bool
  13. from funasr.utils.types import str2triple_str
  14. from funasr.utils.types import str_or_none
  15. from funasr.utils.types import float_or_none
  16. import argparse
  17. import logging
  18. from pathlib import Path
  19. import sys
  20. import os
  21. from typing import Optional
  22. from typing import Sequence
  23. from typing import Tuple
  24. from typing import Union
  25. from typing import Dict
  26. from typing import Any
  27. from typing import List
  28. import numpy as np
  29. import torch
  30. from torch.nn.parallel import data_parallel
  31. from typeguard import check_argument_types
  32. from funasr.tasks.lm import LMTask
  33. from funasr.datasets.preprocessor import LMPreprocessor
  34. from funasr.utils.cli_utils import get_commandline_args
  35. from funasr.fileio.datadir_writer import DatadirWriter
  36. from funasr.torch_utils.device_funcs import to_device
  37. from funasr.torch_utils.forward_adaptor import ForwardAdaptor
  38. from funasr.torch_utils.set_all_random_seed import set_all_random_seed
  39. from funasr.utils import config_argparse
  40. from funasr.utils.types import float_or_none
  41. from funasr.utils.types import str2bool
  42. from funasr.utils.types import str2triple_str
  43. from funasr.utils.types import str_or_none
  44. def inference_lm(
  45. batch_size: int,
  46. dtype: str,
  47. ngpu: int,
  48. seed: int,
  49. num_workers: int,
  50. log_level: Union[int, str],
  51. key_file: Optional[str],
  52. train_config: Optional[str],
  53. model_file: Optional[str],
  54. log_base: Optional[float] = 10,
  55. allow_variable_data_keys: bool = False,
  56. split_with_space: Optional[bool] = False,
  57. seg_dict_file: Optional[str] = None,
  58. output_dir: Optional[str] = None,
  59. param_dict: dict = None,
  60. **kwargs,
  61. ):
  62. assert check_argument_types()
  63. ncpu = kwargs.get("ncpu", 1)
  64. torch.set_num_threads(ncpu)
  65. if ngpu >= 1 and torch.cuda.is_available():
  66. device = "cuda"
  67. else:
  68. device = "cpu"
  69. # 1. Set random-seed
  70. set_all_random_seed(seed)
  71. # 2. Build Model
  72. model, train_args = LMTask.build_model_from_file(
  73. train_config, model_file, device)
  74. wrapped_model = ForwardAdaptor(model, "nll")
  75. wrapped_model.to(dtype=getattr(torch, dtype)).to(device=device).eval()
  76. logging.info(f"Model:\n{model}")
  77. preprocessor = LMPreprocessor(
  78. train=False,
  79. token_type=train_args.token_type,
  80. token_list=train_args.token_list,
  81. bpemodel=train_args.bpemodel,
  82. text_cleaner=train_args.cleaner,
  83. g2p_type=train_args.g2p,
  84. text_name="text",
  85. non_linguistic_symbols=train_args.non_linguistic_symbols,
  86. split_with_space=split_with_space,
  87. seg_dict_file=seg_dict_file
  88. )
  89. def _forward(
  90. data_path_and_name_and_type,
  91. raw_inputs: Union[List[Any], bytes, str] = None,
  92. output_dir_v2: Optional[str] = None,
  93. param_dict: dict = None,
  94. ):
  95. results = []
  96. output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
  97. if output_path is not None:
  98. writer = DatadirWriter(output_path)
  99. else:
  100. writer = None
  101. if raw_inputs != None:
  102. line = raw_inputs.strip()
  103. key = "lm demo"
  104. if line == "":
  105. item = {'key': key, 'value': ""}
  106. results.append(item)
  107. return results
  108. batch = {}
  109. batch['text'] = line
  110. if preprocessor != None:
  111. batch = preprocessor(key, batch)
  112. # Force data-precision
  113. for name in batch:
  114. value = batch[name]
  115. if not isinstance(value, np.ndarray):
  116. raise RuntimeError(
  117. f"All values must be converted to np.ndarray object "
  118. f'by preprocessing, but "{name}" is still {type(value)}.'
  119. )
  120. # Cast to desired type
  121. if value.dtype.kind == "f":
  122. value = value.astype("float32")
  123. elif value.dtype.kind == "i":
  124. value = value.astype("long")
  125. else:
  126. raise NotImplementedError(f"Not supported dtype: {value.dtype}")
  127. batch[name] = value
  128. batch["text_lengths"] = torch.from_numpy(
  129. np.array([len(batch["text"])], dtype='int32'))
  130. batch["text"] = np.expand_dims(batch["text"], axis=0)
  131. with torch.no_grad():
  132. batch = to_device(batch, device)
  133. if ngpu <= 1:
  134. nll, lengths = wrapped_model(**batch)
  135. else:
  136. nll, lengths = data_parallel(
  137. wrapped_model, (), range(ngpu), module_kwargs=batch
  138. )
  139. ## compute ppl
  140. ppl_out_batch = ""
  141. ids2tokens = preprocessor.token_id_converter.ids2tokens
  142. for sent_ids, sent_nll in zip(batch['text'], nll):
  143. pre_word = "<s>"
  144. cur_word = None
  145. sent_lst = ids2tokens(sent_ids) + ['</s>']
  146. ppl_out = " ".join(sent_lst) + "\n"
  147. for word, word_nll in zip(sent_lst, sent_nll):
  148. cur_word = word
  149. word_nll = -word_nll.cpu()
  150. if log_base is None:
  151. word_prob = np.exp(word_nll)
  152. else:
  153. word_prob = log_base ** (word_nll / np.log(log_base))
  154. ppl_out += ' p( {cur} | {pre} ) = {prob} [ {word_nll} ]\n'.format(
  155. cur=cur_word,
  156. pre=pre_word,
  157. prob=round(word_prob.item(), 8),
  158. word_nll=round(word_nll.item(), 8)
  159. )
  160. pre_word = cur_word
  161. sent_nll_mean = sent_nll.mean().cpu().numpy()
  162. sent_nll_sum = sent_nll.sum().cpu().numpy()
  163. if log_base is None:
  164. sent_ppl = np.exp(sent_nll_mean)
  165. else:
  166. sent_ppl = log_base ** (sent_nll_mean / np.log(log_base))
  167. ppl_out += 'logprob= {sent_nll} ppl= {sent_ppl}\n\n'.format(
  168. sent_nll=round(-sent_nll_sum.item(), 4),
  169. sent_ppl=round(sent_ppl.item(), 4)
  170. )
  171. ppl_out_batch += ppl_out
  172. item = {'key': key, 'value': ppl_out}
  173. if writer is not None:
  174. writer["ppl"][key + ":\n"] = ppl_out
  175. results.append(item)
  176. return results
  177. # 3. Build data-iterator
  178. loader = LMTask.build_streaming_iterator(
  179. data_path_and_name_and_type,
  180. dtype=dtype,
  181. batch_size=batch_size,
  182. key_file=key_file,
  183. num_workers=num_workers,
  184. preprocess_fn=preprocessor,
  185. collate_fn=LMTask.build_collate_fn(train_args, False),
  186. allow_variable_data_keys=allow_variable_data_keys,
  187. inference=True,
  188. )
  189. # 4. Start for-loop
  190. total_nll = 0.0
  191. total_ntokens = 0
  192. ppl_out_all = ""
  193. for keys, batch in loader:
  194. assert isinstance(batch, dict), type(batch)
  195. assert all(isinstance(s, str) for s in keys), keys
  196. _bs = len(next(iter(batch.values())))
  197. assert len(keys) == _bs, f"{len(keys)} != {_bs}"
  198. ppl_out_batch = ""
  199. with torch.no_grad():
  200. batch = to_device(batch, device)
  201. if ngpu <= 1:
  202. # NOTE(kamo): data_parallel also should work with ngpu=1,
  203. # but for debuggability it's better to keep this block.
  204. nll, lengths = wrapped_model(**batch)
  205. else:
  206. nll, lengths = data_parallel(
  207. wrapped_model, (), range(ngpu), module_kwargs=batch
  208. )
  209. ## print ppl
  210. ids2tokens = preprocessor.token_id_converter.ids2tokens
  211. for key, sent_ids, sent_nll in zip(keys, batch['text'], nll):
  212. pre_word = "<s>"
  213. cur_word = None
  214. sent_lst = ids2tokens(sent_ids) + ['</s>']
  215. ppl_out = " ".join(sent_lst) + "\n"
  216. for word, word_nll in zip(sent_lst, sent_nll):
  217. cur_word = word
  218. word_nll = -word_nll.cpu()
  219. if log_base is None:
  220. word_prob = np.exp(word_nll)
  221. else:
  222. word_prob = log_base ** (word_nll / np.log(log_base))
  223. ppl_out += ' p( {cur} | {pre} ) = {prob} [ {word_nll} ]\n'.format(
  224. cur=cur_word,
  225. pre=pre_word,
  226. prob=round(word_prob.item(), 8),
  227. word_nll=round(word_nll.item(), 8)
  228. )
  229. pre_word = cur_word
  230. sent_nll_mean = sent_nll.mean().cpu().numpy()
  231. sent_nll_sum = sent_nll.sum().cpu().numpy()
  232. if log_base is None:
  233. sent_ppl = np.exp(sent_nll_mean)
  234. else:
  235. sent_ppl = log_base ** (sent_nll_mean / np.log(log_base))
  236. ppl_out += 'logprob= {sent_nll} ppl= {sent_ppl}\n\n'.format(
  237. sent_nll=round(-sent_nll_sum.item(), 4),
  238. sent_ppl=round(sent_ppl.item(), 4)
  239. )
  240. ppl_out_batch += ppl_out
  241. utt2nll = round(-sent_nll_sum.item(), 5)
  242. item = {'key': key, 'value': ppl_out}
  243. if writer is not None:
  244. writer["ppl"][key + ":\n"] = ppl_out
  245. writer["utt2nll"][key] = str(utt2nll)
  246. results.append(item)
  247. ppl_out_all += ppl_out_batch
  248. assert _bs == len(nll) == len(lengths), (_bs, len(nll), len(lengths))
  249. # nll: (B, L) -> (B,)
  250. nll = nll.detach().cpu().numpy().sum(1)
  251. # lengths: (B,)
  252. lengths = lengths.detach().cpu().numpy()
  253. total_nll += nll.sum()
  254. total_ntokens += lengths.sum()
  255. if log_base is None:
  256. ppl = np.exp(total_nll / total_ntokens)
  257. else:
  258. ppl = log_base ** (total_nll / total_ntokens / np.log(log_base))
  259. avg_ppl = 'logprob= {total_nll} ppl= {total_ppl}\n'.format(
  260. total_nll=round(-total_nll.item(), 4),
  261. total_ppl=round(ppl.item(), 4)
  262. )
  263. item = {'key': 'AVG PPL', 'value': avg_ppl}
  264. ppl_out_all += avg_ppl
  265. if writer is not None:
  266. writer["ppl"]["AVG PPL : "] = avg_ppl
  267. results.append(item)
  268. return results
  269. return _forward
  270. def inference_launch(mode, **kwargs):
  271. if mode == "transformer":
  272. return inference_lm(**kwargs)
  273. else:
  274. logging.info("Unknown decoding mode: {}".format(mode))
  275. return None
  276. def get_parser():
  277. parser = config_argparse.ArgumentParser(
  278. description="Calc perplexity",
  279. formatter_class=argparse.ArgumentDefaultsHelpFormatter,
  280. )
  281. parser.add_argument(
  282. "--log_level",
  283. type=lambda x: x.upper(),
  284. default="INFO",
  285. choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
  286. help="The verbose level of logging",
  287. )
  288. parser.add_argument("--output_dir", type=str, required=True)
  289. parser.add_argument("--gpuid_list", type=str, required=True)
  290. parser.add_argument(
  291. "--ngpu",
  292. type=int,
  293. default=0,
  294. help="The number of gpus. 0 indicates CPU mode",
  295. )
  296. parser.add_argument("--seed", type=int, default=0, help="Random seed")
  297. parser.add_argument("--njob", type=int, default=1, help="Random seed")
  298. parser.add_argument(
  299. "--dtype",
  300. default="float32",
  301. choices=["float16", "float32", "float64"],
  302. help="Data type",
  303. )
  304. parser.add_argument(
  305. "--num_workers",
  306. type=int,
  307. default=1,
  308. help="The number of workers used for DataLoader",
  309. )
  310. parser.add_argument(
  311. "--batch_size",
  312. type=int,
  313. default=1,
  314. help="The batch size for inference",
  315. )
  316. parser.add_argument(
  317. "--log_base",
  318. type=float_or_none,
  319. default=10,
  320. help="The base of logarithm for Perplexity. "
  321. "If None, napier's constant is used.",
  322. required=False
  323. )
  324. group = parser.add_argument_group("Input data related")
  325. group.add_argument(
  326. "--data_path_and_name_and_type",
  327. type=str2triple_str,
  328. action="append",
  329. required=False
  330. )
  331. group.add_argument(
  332. "--raw_inputs",
  333. type=str,
  334. required=False
  335. )
  336. group.add_argument("--key_file", type=str_or_none)
  337. group.add_argument("--allow_variable_data_keys", type=str2bool, default=False)
  338. group.add_argument("--split_with_space", type=str2bool, default=False)
  339. group.add_argument("--seg_dict_file", type=str_or_none)
  340. group = parser.add_argument_group("The model configuration related")
  341. group.add_argument("--train_config", type=str)
  342. group.add_argument("--model_file", type=str)
  343. group.add_argument("--mode", type=str, default="lm")
  344. return parser
  345. def main(cmd=None):
  346. print(get_commandline_args(), file=sys.stderr)
  347. parser = get_parser()
  348. args = parser.parse_args(cmd)
  349. kwargs = vars(args)
  350. kwargs.pop("config", None)
  351. # set logging messages
  352. logging.basicConfig(
  353. level=args.log_level,
  354. format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
  355. )
  356. logging.info("Decoding args: {}".format(kwargs))
  357. # gpu setting
  358. if args.ngpu > 0:
  359. jobid = int(args.output_dir.split(".")[-1])
  360. gpuid = args.gpuid_list.split(",")[(jobid - 1) // args.njob]
  361. os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
  362. os.environ["CUDA_VISIBLE_DEVICES"] = gpuid
  363. kwargs.pop("gpuid_list", None)
  364. kwargs.pop("njob", None)
  365. results = inference_launch(**kwargs)
  366. if __name__ == "__main__":
  367. main()