tokenize_text.py 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283
  1. #!/usr/bin/env python3
  2. import argparse
  3. from collections import Counter
  4. import logging
  5. from pathlib import Path
  6. import sys
  7. from typing import List
  8. from typing import Optional
  9. from typeguard import check_argument_types
  10. from funasr.utils.cli_utils import get_commandline_args
  11. from funasr.text.build_tokenizer import build_tokenizer
  12. from funasr.text.cleaner import TextCleaner
  13. from funasr.text.phoneme_tokenizer import g2p_choices
  14. from funasr.utils.types import str2bool
  15. from funasr.utils.types import str_or_none
  16. def field2slice(field: Optional[str]) -> slice:
  17. """Convert field string to slice
  18. Note that field string accepts 1-based integer.
  19. Examples:
  20. >>> field2slice("1-")
  21. slice(0, None, None)
  22. >>> field2slice("1-3")
  23. slice(0, 3, None)
  24. >>> field2slice("-3")
  25. slice(None, 3, None)
  26. """
  27. field = field.strip()
  28. try:
  29. if "-" in field:
  30. # e.g. "2-" or "2-5" or "-7"
  31. s1, s2 = field.split("-", maxsplit=1)
  32. if s1.strip() == "":
  33. s1 = None
  34. else:
  35. s1 = int(s1)
  36. if s1 == 0:
  37. raise ValueError("1-based string")
  38. if s2.strip() == "":
  39. s2 = None
  40. else:
  41. s2 = int(s2)
  42. else:
  43. # e.g. "2"
  44. s1 = int(field)
  45. s2 = s1 + 1
  46. if s1 == 0:
  47. raise ValueError("must be 1 or more value")
  48. except ValueError:
  49. raise RuntimeError(f"Format error: e.g. '2-', '2-5', or '-5': {field}")
  50. if s1 is None:
  51. slic = slice(None, s2)
  52. else:
  53. # -1 because of 1-based integer following "cut" command
  54. # e.g "1-3" -> slice(0, 3)
  55. slic = slice(s1 - 1, s2)
  56. return slic
  57. def tokenize(
  58. input: str,
  59. output: str,
  60. field: Optional[str],
  61. delimiter: Optional[str],
  62. token_type: str,
  63. space_symbol: str,
  64. non_linguistic_symbols: Optional[str],
  65. bpemodel: Optional[str],
  66. log_level: str,
  67. write_vocabulary: bool,
  68. vocabulary_size: int,
  69. remove_non_linguistic_symbols: bool,
  70. cutoff: int,
  71. add_symbol: List[str],
  72. cleaner: Optional[str],
  73. g2p: Optional[str],
  74. ):
  75. assert check_argument_types()
  76. logging.basicConfig(
  77. level=log_level,
  78. format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
  79. )
  80. if input == "-":
  81. fin = sys.stdin
  82. else:
  83. fin = Path(input).open("r", encoding="utf-8")
  84. if output == "-":
  85. fout = sys.stdout
  86. else:
  87. p = Path(output)
  88. p.parent.mkdir(parents=True, exist_ok=True)
  89. fout = p.open("w", encoding="utf-8")
  90. cleaner = TextCleaner(cleaner)
  91. tokenizer = build_tokenizer(
  92. token_type=token_type,
  93. bpemodel=bpemodel,
  94. delimiter=delimiter,
  95. space_symbol=space_symbol,
  96. non_linguistic_symbols=non_linguistic_symbols,
  97. remove_non_linguistic_symbols=remove_non_linguistic_symbols,
  98. g2p_type=g2p,
  99. )
  100. counter = Counter()
  101. if field is not None:
  102. field = field2slice(field)
  103. for line in fin:
  104. line = line.rstrip()
  105. if field is not None:
  106. # e.g. field="2-"
  107. # uttidA hello world!! -> hello world!!
  108. tokens = line.split(delimiter)
  109. tokens = tokens[field]
  110. if delimiter is None:
  111. line = " ".join(tokens)
  112. else:
  113. line = delimiter.join(tokens)
  114. line = cleaner(line)
  115. tokens = tokenizer.text2tokens(line)
  116. if not write_vocabulary:
  117. fout.write(" ".join(tokens) + "\n")
  118. else:
  119. for t in tokens:
  120. counter[t] += 1
  121. if not write_vocabulary:
  122. return
  123. ## FIXME
  124. ## del duplicate add_symbols in counter
  125. for symbol_and_id in add_symbol:
  126. # e.g symbol="<blank>:0"
  127. try:
  128. symbol, idx = symbol_and_id.split(":")
  129. except ValueError:
  130. raise RuntimeError(f"Format error: e.g. '<blank>:0': {symbol_and_id}")
  131. symbol = symbol.strip()
  132. if symbol in counter:
  133. del counter[symbol]
  134. # ======= write_vocabulary mode from here =======
  135. # Sort by the number of occurrences in descending order
  136. # and filter lower frequency words than cutoff value
  137. words_and_counts = list(
  138. filter(lambda x: x[1] > cutoff, sorted(counter.items(), key=lambda x: -x[1]))
  139. )
  140. # Restrict the vocabulary size
  141. if vocabulary_size > 0:
  142. if vocabulary_size < len(add_symbol):
  143. raise RuntimeError(f"vocabulary_size is too small: {vocabulary_size}")
  144. words_and_counts = words_and_counts[: vocabulary_size - len(add_symbol)]
  145. # Parse the values of --add_symbol
  146. for symbol_and_id in add_symbol:
  147. # e.g symbol="<blank>:0"
  148. try:
  149. symbol, idx = symbol_and_id.split(":")
  150. idx = int(idx)
  151. except ValueError:
  152. raise RuntimeError(f"Format error: e.g. '<blank>:0': {symbol_and_id}")
  153. symbol = symbol.strip()
  154. # e.g. idx=0 -> append as the first symbol
  155. # e.g. idx=-1 -> append as the last symbol
  156. if idx < 0:
  157. idx = len(words_and_counts) + 1 + idx
  158. words_and_counts.insert(idx, (symbol, None))
  159. # Write words
  160. for w, c in words_and_counts:
  161. fout.write(w + "\n")
  162. # Logging
  163. total_count = sum(counter.values())
  164. invocab_count = sum(c for w, c in words_and_counts if c is not None)
  165. logging.info(f"OOV rate = {(total_count - invocab_count) / total_count * 100} %")
  166. def get_parser() -> argparse.ArgumentParser:
  167. parser = argparse.ArgumentParser(
  168. description="Tokenize texts",
  169. formatter_class=argparse.ArgumentDefaultsHelpFormatter,
  170. )
  171. parser.add_argument(
  172. "--log_level",
  173. type=lambda x: x.upper(),
  174. default="INFO",
  175. choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
  176. help="The verbose level of logging",
  177. )
  178. parser.add_argument(
  179. "--input", "-i", required=True, help="Input text. - indicates sys.stdin"
  180. )
  181. parser.add_argument(
  182. "--output", "-o", required=True, help="Output text. - indicates sys.stdout"
  183. )
  184. parser.add_argument(
  185. "--field",
  186. "-f",
  187. help="The target columns of the input text as 1-based integer. e.g 2-",
  188. )
  189. parser.add_argument(
  190. "--token_type",
  191. "-t",
  192. default="char",
  193. choices=["char", "bpe", "word", "phn"],
  194. help="Token type",
  195. )
  196. parser.add_argument("--delimiter", "-d", default=None, help="The delimiter")
  197. parser.add_argument("--space_symbol", default="<space>", help="The space symbol")
  198. parser.add_argument("--bpemodel", default=None, help="The bpemodel file path")
  199. parser.add_argument(
  200. "--non_linguistic_symbols",
  201. type=str_or_none,
  202. help="non_linguistic_symbols file path",
  203. )
  204. parser.add_argument(
  205. "--remove_non_linguistic_symbols",
  206. type=str2bool,
  207. default=False,
  208. help="Remove non-language-symbols from tokens",
  209. )
  210. parser.add_argument(
  211. "--cleaner",
  212. type=str_or_none,
  213. choices=[None, "tacotron", "jaconv", "vietnamese", "korean_cleaner"],
  214. default=None,
  215. help="Apply text cleaning",
  216. )
  217. parser.add_argument(
  218. "--g2p",
  219. type=str_or_none,
  220. choices=g2p_choices,
  221. default=None,
  222. help="Specify g2p method if --token_type=phn",
  223. )
  224. group = parser.add_argument_group("write_vocabulary mode related")
  225. group.add_argument(
  226. "--write_vocabulary",
  227. type=str2bool,
  228. default=False,
  229. help="Write tokens list instead of tokenized text per line",
  230. )
  231. group.add_argument("--vocabulary_size", type=int, default=0, help="Vocabulary size")
  232. group.add_argument(
  233. "--cutoff",
  234. default=0,
  235. type=int,
  236. help="cut-off frequency used for write-vocabulary mode",
  237. )
  238. group.add_argument(
  239. "--add_symbol",
  240. type=str,
  241. default=[],
  242. action="append",
  243. help="Append symbol e.g. --add_symbol '<blank>:0' --add_symbol '<unk>:1'",
  244. )
  245. return parser
  246. def main(cmd=None):
  247. print(get_commandline_args(), file=sys.stderr)
  248. parser = get_parser()
  249. args = parser.parse_args(cmd)
  250. kwargs = vars(args)
  251. tokenize(**kwargs)
  252. if __name__ == "__main__":
  253. main()