tokenize_text.py 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281
  1. #!/usr/bin/env python3
  2. import argparse
  3. from collections import Counter
  4. import logging
  5. from pathlib import Path
  6. import sys
  7. from typing import List
  8. from typing import Optional
  9. from funasr.utils.cli_utils import get_commandline_args
  10. from funasr.tokenizer.build_tokenizer import build_tokenizer
  11. from funasr.tokenizer.cleaner import TextCleaner
  12. from funasr.tokenizer.phoneme_tokenizer import g2p_classes
  13. from funasr.utils.types import str2bool
  14. from funasr.utils.types import str_or_none
  15. def field2slice(field: Optional[str]) -> slice:
  16. """Convert field string to slice
  17. Note that field string accepts 1-based integer.
  18. Examples:
  19. >>> field2slice("1-")
  20. slice(0, None, None)
  21. >>> field2slice("1-3")
  22. slice(0, 3, None)
  23. >>> field2slice("-3")
  24. slice(None, 3, None)
  25. """
  26. field = field.strip()
  27. try:
  28. if "-" in field:
  29. # e.g. "2-" or "2-5" or "-7"
  30. s1, s2 = field.split("-", maxsplit=1)
  31. if s1.strip() == "":
  32. s1 = None
  33. else:
  34. s1 = int(s1)
  35. if s1 == 0:
  36. raise ValueError("1-based string")
  37. if s2.strip() == "":
  38. s2 = None
  39. else:
  40. s2 = int(s2)
  41. else:
  42. # e.g. "2"
  43. s1 = int(field)
  44. s2 = s1 + 1
  45. if s1 == 0:
  46. raise ValueError("must be 1 or more value")
  47. except ValueError:
  48. raise RuntimeError(f"Format error: e.g. '2-', '2-5', or '-5': {field}")
  49. if s1 is None:
  50. slic = slice(None, s2)
  51. else:
  52. # -1 because of 1-based integer following "cut" command
  53. # e.g "1-3" -> slice(0, 3)
  54. slic = slice(s1 - 1, s2)
  55. return slic
  56. def tokenize(
  57. input: str,
  58. output: str,
  59. field: Optional[str],
  60. delimiter: Optional[str],
  61. token_type: str,
  62. space_symbol: str,
  63. non_linguistic_symbols: Optional[str],
  64. bpemodel: Optional[str],
  65. log_level: str,
  66. write_vocabulary: bool,
  67. vocabulary_size: int,
  68. remove_non_linguistic_symbols: bool,
  69. cutoff: int,
  70. add_symbol: List[str],
  71. cleaner: Optional[str],
  72. g2p: Optional[str],
  73. ):
  74. logging.basicConfig(
  75. level=log_level,
  76. format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
  77. )
  78. if input == "-":
  79. fin = sys.stdin
  80. else:
  81. fin = Path(input).open("r", encoding="utf-8")
  82. if output == "-":
  83. fout = sys.stdout
  84. else:
  85. p = Path(output)
  86. p.parent.mkdir(parents=True, exist_ok=True)
  87. fout = p.open("w", encoding="utf-8")
  88. cleaner = TextCleaner(cleaner)
  89. tokenizer = build_tokenizer(
  90. token_type=token_type,
  91. bpemodel=bpemodel,
  92. delimiter=delimiter,
  93. space_symbol=space_symbol,
  94. non_linguistic_symbols=non_linguistic_symbols,
  95. remove_non_linguistic_symbols=remove_non_linguistic_symbols,
  96. g2p_type=g2p,
  97. )
  98. counter = Counter()
  99. if field is not None:
  100. field = field2slice(field)
  101. for line in fin:
  102. line = line.rstrip()
  103. if field is not None:
  104. # e.g. field="2-"
  105. # uttidA hello world!! -> hello world!!
  106. tokens = line.split(delimiter)
  107. tokens = tokens[field]
  108. if delimiter is None:
  109. line = " ".join(tokens)
  110. else:
  111. line = delimiter.join(tokens)
  112. line = cleaner(line)
  113. tokens = tokenizer.text2tokens(line)
  114. if not write_vocabulary:
  115. fout.write(" ".join(tokens) + "\n")
  116. else:
  117. for t in tokens:
  118. counter[t] += 1
  119. if not write_vocabulary:
  120. return
  121. ## FIXME
  122. ## del duplicate add_symbols in counter
  123. for symbol_and_id in add_symbol:
  124. # e.g symbol="<blank>:0"
  125. try:
  126. symbol, idx = symbol_and_id.split(":")
  127. except ValueError:
  128. raise RuntimeError(f"Format error: e.g. '<blank>:0': {symbol_and_id}")
  129. symbol = symbol.strip()
  130. if symbol in counter:
  131. del counter[symbol]
  132. # ======= write_vocabulary mode from here =======
  133. # Sort by the number of occurrences in descending order
  134. # and filter lower frequency words than cutoff value
  135. words_and_counts = list(
  136. filter(lambda x: x[1] > cutoff, sorted(counter.items(), key=lambda x: -x[1]))
  137. )
  138. # Restrict the vocabulary size
  139. if vocabulary_size > 0:
  140. if vocabulary_size < len(add_symbol):
  141. raise RuntimeError(f"vocabulary_size is too small: {vocabulary_size}")
  142. words_and_counts = words_and_counts[: vocabulary_size - len(add_symbol)]
  143. # Parse the values of --add_symbol
  144. for symbol_and_id in add_symbol:
  145. # e.g symbol="<blank>:0"
  146. try:
  147. symbol, idx = symbol_and_id.split(":")
  148. idx = int(idx)
  149. except ValueError:
  150. raise RuntimeError(f"Format error: e.g. '<blank>:0': {symbol_and_id}")
  151. symbol = symbol.strip()
  152. # e.g. idx=0 -> append as the first symbol
  153. # e.g. idx=-1 -> append as the last symbol
  154. if idx < 0:
  155. idx = len(words_and_counts) + 1 + idx
  156. words_and_counts.insert(idx, (symbol, None))
  157. # Write words
  158. for w, c in words_and_counts:
  159. fout.write(w + "\n")
  160. # Logging
  161. total_count = sum(counter.values())
  162. invocab_count = sum(c for w, c in words_and_counts if c is not None)
  163. logging.info(f"OOV rate = {(total_count - invocab_count) / total_count * 100} %")
  164. def get_parser() -> argparse.ArgumentParser:
  165. parser = argparse.ArgumentParser(
  166. description="Tokenize texts",
  167. formatter_class=argparse.ArgumentDefaultsHelpFormatter,
  168. )
  169. parser.add_argument(
  170. "--log_level",
  171. type=lambda x: x.upper(),
  172. default="INFO",
  173. choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
  174. help="The verbose level of logging",
  175. )
  176. parser.add_argument(
  177. "--input", "-i", required=True, help="Input text. - indicates sys.stdin"
  178. )
  179. parser.add_argument(
  180. "--output", "-o", required=True, help="Output text. - indicates sys.stdout"
  181. )
  182. parser.add_argument(
  183. "--field",
  184. "-f",
  185. help="The target columns of the input text as 1-based integer. e.g 2-",
  186. )
  187. parser.add_argument(
  188. "--token_type",
  189. "-t",
  190. default="char",
  191. choices=["char", "bpe", "word", "phn"],
  192. help="Token type",
  193. )
  194. parser.add_argument("--delimiter", "-d", default=None, help="The delimiter")
  195. parser.add_argument("--space_symbol", default="<space>", help="The space symbol")
  196. parser.add_argument("--bpemodel", default=None, help="The bpemodel file path")
  197. parser.add_argument(
  198. "--non_linguistic_symbols",
  199. type=str_or_none,
  200. help="non_linguistic_symbols file path",
  201. )
  202. parser.add_argument(
  203. "--remove_non_linguistic_symbols",
  204. type=str2bool,
  205. default=False,
  206. help="Remove non-language-symbols from tokens",
  207. )
  208. parser.add_argument(
  209. "--cleaner",
  210. type=str_or_none,
  211. choices=[None, "tacotron", "jaconv", "vietnamese", "korean_cleaner"],
  212. default=None,
  213. help="Apply text cleaning",
  214. )
  215. parser.add_argument(
  216. "--g2p",
  217. type=str_or_none,
  218. choices=g2p_classes,
  219. default=None,
  220. help="Specify g2p method if --token_type=phn",
  221. )
  222. group = parser.add_argument_group("write_vocabulary mode related")
  223. group.add_argument(
  224. "--write_vocabulary",
  225. type=str2bool,
  226. default=False,
  227. help="Write tokens list instead of tokenized text per line",
  228. )
  229. group.add_argument("--vocabulary_size", type=int, default=0, help="Vocabulary size")
  230. group.add_argument(
  231. "--cutoff",
  232. default=0,
  233. type=int,
  234. help="cut-off frequency used for write-vocabulary mode",
  235. )
  236. group.add_argument(
  237. "--add_symbol",
  238. type=str,
  239. default=[],
  240. action="append",
  241. help="Append symbol e.g. --add_symbol '<blank>:0' --add_symbol '<unk>:1'",
  242. )
  243. return parser
  244. def main(cmd=None):
  245. print(get_commandline_args(), file=sys.stderr)
  246. parser = get_parser()
  247. args = parser.parse_args(cmd)
  248. kwargs = vars(args)
  249. tokenize(**kwargs)
  250. if __name__ == "__main__":
  251. main()