lm_calc_perplexity.py 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211
  1. #!/usr/bin/env python3
  2. import argparse
  3. import logging
  4. from pathlib import Path
  5. import sys
  6. from typing import Optional
  7. from typing import Sequence
  8. from typing import Tuple
  9. from typing import Union
  10. import numpy as np
  11. import torch
  12. from torch.nn.parallel import data_parallel
  13. from typeguard import check_argument_types
  14. from funasr.utils.cli_utils import get_commandline_args
  15. from funasr.fileio.datadir_writer import DatadirWriter
  16. from funasr.tasks.lm import LMTask
  17. from funasr.torch_utils.device_funcs import to_device
  18. from funasr.torch_utils.forward_adaptor import ForwardAdaptor
  19. from funasr.torch_utils.set_all_random_seed import set_all_random_seed
  20. from funasr.utils import config_argparse
  21. from funasr.utils.types import float_or_none
  22. from funasr.utils.types import str2bool
  23. from funasr.utils.types import str2triple_str
  24. from funasr.utils.types import str_or_none
  25. def calc_perplexity(
  26. output_dir: str,
  27. batch_size: int,
  28. dtype: str,
  29. ngpu: int,
  30. seed: int,
  31. num_workers: int,
  32. log_level: Union[int, str],
  33. data_path_and_name_and_type: Sequence[Tuple[str, str, str]],
  34. key_file: Optional[str],
  35. train_config: Optional[str],
  36. model_file: Optional[str],
  37. log_base: Optional[float],
  38. allow_variable_data_keys: bool,
  39. ):
  40. assert check_argument_types()
  41. logging.basicConfig(
  42. level=log_level,
  43. format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
  44. )
  45. if ngpu >= 1:
  46. device = "cuda"
  47. else:
  48. device = "cpu"
  49. # 1. Set random-seed
  50. set_all_random_seed(seed)
  51. # 2. Build LM
  52. model, train_args = LMTask.build_model_from_file(config_file=train_config, model_file=model_file, device=device)
  53. # Wrape model to make model.nll() data-parallel
  54. wrapped_model = ForwardAdaptor(model, "nll")
  55. wrapped_model.to(dtype=getattr(torch, dtype)).eval()
  56. logging.info(f"Model:\n{model}")
  57. # 3. Build data-iterator
  58. loader = LMTask.build_streaming_iterator(
  59. data_path_and_name_and_type,
  60. dtype=dtype,
  61. batch_size=batch_size,
  62. key_file=key_file,
  63. num_workers=num_workers,
  64. preprocess_fn=LMTask.build_preprocess_fn(train_args, False),
  65. collate_fn=LMTask.build_collate_fn(train_args, False),
  66. allow_variable_data_keys=allow_variable_data_keys,
  67. inference=True,
  68. )
  69. # 4. Start for-loop
  70. with DatadirWriter(output_dir) as writer:
  71. total_nll = 0.0
  72. total_ntokens = 0
  73. for keys, batch in loader:
  74. assert isinstance(batch, dict), type(batch)
  75. assert all(isinstance(s, str) for s in keys), keys
  76. _bs = len(next(iter(batch.values())))
  77. assert len(keys) == _bs, f"{len(keys)} != {_bs}"
  78. with torch.no_grad():
  79. batch = to_device(batch, device)
  80. if ngpu <= 1:
  81. # NOTE(kamo): data_parallel also should work with ngpu=1,
  82. # but for debuggability it's better to keep this block.
  83. nll, lengths = wrapped_model(**batch)
  84. else:
  85. nll, lengths = data_parallel(
  86. wrapped_model, (), range(ngpu), module_kwargs=batch
  87. )
  88. assert _bs == len(nll) == len(lengths), (_bs, len(nll), len(lengths))
  89. # nll: (B, L) -> (B,)
  90. nll = nll.detach().cpu().numpy().sum(1)
  91. # lengths: (B,)
  92. lengths = lengths.detach().cpu().numpy()
  93. total_nll += nll.sum()
  94. total_ntokens += lengths.sum()
  95. for key, _nll, ntoken in zip(keys, nll, lengths):
  96. if log_base is None:
  97. utt_ppl = np.exp(_nll / ntoken)
  98. else:
  99. utt_ppl = log_base ** (_nll / ntoken / np.log(log_base))
  100. # Write PPL of each utts for debugging or analysis
  101. writer["utt2nll"][key] = str(-_nll)
  102. writer["utt2ppl"][key] = str(utt_ppl)
  103. writer["utt2ntokens"][key] = str(ntoken)
  104. if log_base is None:
  105. ppl = np.exp(total_nll / total_ntokens)
  106. else:
  107. ppl = log_base ** (total_nll / total_ntokens / np.log(log_base))
  108. with (Path(output_dir) / "ppl").open("w", encoding="utf-8") as f:
  109. f.write(f"{ppl}\n")
  110. with (Path(output_dir) / "base").open("w", encoding="utf-8") as f:
  111. if log_base is None:
  112. _log_base = np.e
  113. else:
  114. _log_base = log_base
  115. f.write(f"{_log_base}\n")
  116. logging.info(f"PPL={ppl}")
  117. def get_parser():
  118. parser = config_argparse.ArgumentParser(
  119. description="Calc perplexity",
  120. formatter_class=argparse.ArgumentDefaultsHelpFormatter,
  121. )
  122. # Note(kamo): Use '_' instead of '-' as separator.
  123. # '-' is confusing if written in yaml.
  124. parser.add_argument(
  125. "--log_level",
  126. type=lambda x: x.upper(),
  127. default="INFO",
  128. choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
  129. help="The verbose level of logging",
  130. )
  131. parser.add_argument("--output_dir", type=str, required=True)
  132. parser.add_argument(
  133. "--ngpu",
  134. type=int,
  135. default=0,
  136. help="The number of gpus. 0 indicates CPU mode",
  137. )
  138. parser.add_argument("--seed", type=int, default=0, help="Random seed")
  139. parser.add_argument(
  140. "--dtype",
  141. default="float32",
  142. choices=["float16", "float32", "float64"],
  143. help="Data type",
  144. )
  145. parser.add_argument(
  146. "--num_workers",
  147. type=int,
  148. default=1,
  149. help="The number of workers used for DataLoader",
  150. )
  151. parser.add_argument(
  152. "--batch_size",
  153. type=int,
  154. default=1,
  155. help="The batch size for inference",
  156. )
  157. parser.add_argument(
  158. "--log_base",
  159. type=float_or_none,
  160. default=None,
  161. help="The base of logarithm for Perplexity. "
  162. "If None, napier's constant is used.",
  163. )
  164. group = parser.add_argument_group("Input data related")
  165. group.add_argument(
  166. "--data_path_and_name_and_type",
  167. type=str2triple_str,
  168. required=True,
  169. action="append",
  170. )
  171. group.add_argument("--key_file", type=str_or_none)
  172. group.add_argument("--allow_variable_data_keys", type=str2bool, default=False)
  173. group = parser.add_argument_group("The model configuration related")
  174. group.add_argument("--train_config", type=str)
  175. group.add_argument("--model_file", type=str)
  176. return parser
  177. def main(cmd=None):
  178. print(get_commandline_args(), file=sys.stderr)
  179. parser = get_parser()
  180. args = parser.parse_args(cmd)
  181. kwargs = vars(args)
  182. kwargs.pop("config", None)
  183. calc_perplexity(**kwargs)
  184. if __name__ == "__main__":
  185. main()