tokenizer.py 9.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339
  1. import os
  2. from dataclasses import dataclass
  3. from functools import lru_cache
  4. from typing import List, Optional, Tuple, Union
  5. import numpy as np
  6. import torch
  7. try:
  8. from transformers import GPT2TokenizerFast
  9. except ImportError:
  10. raise ImportError(
  11. "transformers was not installed. Please install transformers first."
  12. )
  13. LANGUAGES = {
  14. "en": "english",
  15. "zh": "chinese",
  16. "de": "german",
  17. "es": "spanish",
  18. "ru": "russian",
  19. "ko": "korean",
  20. "fr": "french",
  21. "ja": "japanese",
  22. "pt": "portuguese",
  23. "tr": "turkish",
  24. "pl": "polish",
  25. "ca": "catalan",
  26. "nl": "dutch",
  27. "ar": "arabic",
  28. "sv": "swedish",
  29. "it": "italian",
  30. "id": "indonesian",
  31. "hi": "hindi",
  32. "fi": "finnish",
  33. "vi": "vietnamese",
  34. "he": "hebrew",
  35. "uk": "ukrainian",
  36. "el": "greek",
  37. "ms": "malay",
  38. "cs": "czech",
  39. "ro": "romanian",
  40. "da": "danish",
  41. "hu": "hungarian",
  42. "ta": "tamil",
  43. "no": "norwegian",
  44. "th": "thai",
  45. "ur": "urdu",
  46. "hr": "croatian",
  47. "bg": "bulgarian",
  48. "lt": "lithuanian",
  49. "la": "latin",
  50. "mi": "maori",
  51. "ml": "malayalam",
  52. "cy": "welsh",
  53. "sk": "slovak",
  54. "te": "telugu",
  55. "fa": "persian",
  56. "lv": "latvian",
  57. "bn": "bengali",
  58. "sr": "serbian",
  59. "az": "azerbaijani",
  60. "sl": "slovenian",
  61. "kn": "kannada",
  62. "et": "estonian",
  63. "mk": "macedonian",
  64. "br": "breton",
  65. "eu": "basque",
  66. "is": "icelandic",
  67. "hy": "armenian",
  68. "ne": "nepali",
  69. "mn": "mongolian",
  70. "bs": "bosnian",
  71. "kk": "kazakh",
  72. "sq": "albanian",
  73. "sw": "swahili",
  74. "gl": "galician",
  75. "mr": "marathi",
  76. "pa": "punjabi",
  77. "si": "sinhala",
  78. "km": "khmer",
  79. "sn": "shona",
  80. "yo": "yoruba",
  81. "so": "somali",
  82. "af": "afrikaans",
  83. "oc": "occitan",
  84. "ka": "georgian",
  85. "be": "belarusian",
  86. "tg": "tajik",
  87. "sd": "sindhi",
  88. "gu": "gujarati",
  89. "am": "amharic",
  90. "yi": "yiddish",
  91. "lo": "lao",
  92. "uz": "uzbek",
  93. "fo": "faroese",
  94. "ht": "haitian creole",
  95. "ps": "pashto",
  96. "tk": "turkmen",
  97. "nn": "nynorsk",
  98. "mt": "maltese",
  99. "sa": "sanskrit",
  100. "lb": "luxembourgish",
  101. "my": "myanmar",
  102. "bo": "tibetan",
  103. "tl": "tagalog",
  104. "mg": "malagasy",
  105. "as": "assamese",
  106. "tt": "tatar",
  107. "haw": "hawaiian",
  108. "ln": "lingala",
  109. "ha": "hausa",
  110. "ba": "bashkir",
  111. "jw": "javanese",
  112. "su": "sundanese",
  113. }
  114. # language code lookup by name, with a few language aliases
  115. TO_LANGUAGE_CODE = {
  116. **{language: code for code, language in LANGUAGES.items()},
  117. "burmese": "my",
  118. "valencian": "ca",
  119. "flemish": "nl",
  120. "haitian": "ht",
  121. "letzeburgesch": "lb",
  122. "pushto": "ps",
  123. "panjabi": "pa",
  124. "moldavian": "ro",
  125. "moldovan": "ro",
  126. "sinhalese": "si",
  127. "castilian": "es",
  128. }
  129. @dataclass(frozen=True)
  130. class Tokenizer:
  131. """A thin wrapper around `GPT2TokenizerFast` providing quick access to special tokens"""
  132. tokenizer: "GPT2TokenizerFast"
  133. language: Optional[str]
  134. sot_sequence: Tuple[int]
  135. def encode(self, text, **kwargs):
  136. return self.tokenizer.encode(text, **kwargs)
  137. def decode(self, token_ids: Union[int, List[int], np.ndarray, torch.Tensor], **kwargs):
  138. return self.tokenizer.decode(token_ids, **kwargs)
  139. def decode_with_timestamps(self, tokens) -> str:
  140. """
  141. Timestamp tokens are above the special tokens' id range and are ignored by `decode()`.
  142. This method decodes given tokens with timestamps tokens annotated, e.g. "<|1.08|>".
  143. """
  144. outputs = [[]]
  145. for token in tokens:
  146. if token >= self.timestamp_begin:
  147. timestamp = f"<|{(token - self.timestamp_begin) * 0.02:.2f}|>"
  148. outputs.append(timestamp)
  149. outputs.append([])
  150. else:
  151. outputs[-1].append(token)
  152. outputs = [s if isinstance(s, str) else self.tokenizer.decode(s) for s in outputs]
  153. return "".join(outputs)
  154. @property
  155. @lru_cache()
  156. def eot(self) -> int:
  157. return self.tokenizer.eos_token_id
  158. @property
  159. @lru_cache()
  160. def sot(self) -> int:
  161. return self._get_single_token_id("<|startoftranscript|>")
  162. @property
  163. @lru_cache()
  164. def sot_lm(self) -> int:
  165. return self._get_single_token_id("<|startoflm|>")
  166. @property
  167. @lru_cache()
  168. def sot_prev(self) -> int:
  169. return self._get_single_token_id("<|startofprev|>")
  170. @property
  171. @lru_cache()
  172. def no_speech(self) -> int:
  173. return self._get_single_token_id("<|nospeech|>")
  174. @property
  175. @lru_cache()
  176. def no_timestamps(self) -> int:
  177. return self._get_single_token_id("<|notimestamps|>")
  178. @property
  179. @lru_cache()
  180. def timestamp_begin(self) -> int:
  181. return self.tokenizer.all_special_ids[-1] + 1
  182. @property
  183. @lru_cache()
  184. def language_token(self) -> int:
  185. """Returns the token id corresponding to the value of the `language` field"""
  186. if self.language is None:
  187. raise ValueError(f"This tokenizer does not have language token configured")
  188. additional_tokens = dict(
  189. zip(
  190. self.tokenizer.additional_special_tokens,
  191. self.tokenizer.additional_special_tokens_ids,
  192. )
  193. )
  194. candidate = f"<|{self.language}|>"
  195. if candidate in additional_tokens:
  196. return additional_tokens[candidate]
  197. raise KeyError(f"Language {self.language} not found in tokenizer.")
  198. @property
  199. @lru_cache()
  200. def all_language_tokens(self) -> Tuple[int]:
  201. result = []
  202. for token, token_id in zip(
  203. self.tokenizer.additional_special_tokens,
  204. self.tokenizer.additional_special_tokens_ids,
  205. ):
  206. if token.strip("<|>") in LANGUAGES:
  207. result.append(token_id)
  208. return tuple(result)
  209. @property
  210. @lru_cache()
  211. def all_language_codes(self) -> Tuple[str]:
  212. return tuple(self.decode([l]).strip("<|>") for l in self.all_language_tokens)
  213. @property
  214. @lru_cache()
  215. def sot_sequence_including_notimestamps(self) -> Tuple[int]:
  216. return tuple(list(self.sot_sequence) + [self.no_timestamps])
  217. @property
  218. @lru_cache()
  219. def non_speech_tokens(self) -> Tuple[int]:
  220. """
  221. Returns the list of tokens to suppress in order to avoid any speaker tags or non-speech
  222. annotations, to prevent sampling texts that are not actually spoken in the audio, e.g.
  223. - ♪♪♪
  224. - ( SPEAKING FOREIGN LANGUAGE )
  225. - [DAVID] Hey there,
  226. keeping basic punctuations like commas, periods, question marks, exclamation points, etc.
  227. """
  228. symbols = list("\"#()*+/:;<=>@[\\]^_`{|}~「」『』")
  229. symbols += "<< >> <<< >>> -- --- -( -[ (' (\" (( )) ((( ))) [[ ]] {{ }} ♪♪ ♪♪♪".split()
  230. # symbols that may be a single token or multiple tokens depending on the tokenizer.
  231. # In case they're multiple tokens, suppress the first token, which is safe because:
  232. # These are between U+2640 and U+267F miscellaneous symbols that are okay to suppress
  233. # in generations, and in the 3-byte UTF-8 representation they share the first two bytes.
  234. miscellaneous = set("♩♪♫♬♭♮♯")
  235. assert all(0x2640 <= ord(c) <= 0x267F for c in miscellaneous)
  236. # allow hyphens "-" and single quotes "'" between words, but not at the beginning of a word
  237. result = {self.tokenizer.encode(" -")[0], self.tokenizer.encode(" '")[0]}
  238. for symbol in symbols + list(miscellaneous):
  239. for tokens in [self.tokenizer.encode(symbol), self.tokenizer.encode(" " + symbol)]:
  240. if len(tokens) == 1 or symbol in miscellaneous:
  241. result.add(tokens[0])
  242. return tuple(sorted(result))
  243. def _get_single_token_id(self, text) -> int:
  244. tokens = self.tokenizer.encode(text)
  245. assert len(tokens) == 1, f"{text} is not encoded as a single token"
  246. return tokens[0]
  247. @lru_cache(maxsize=None)
  248. def build_tokenizer(name: str = "gpt2", resource_path: str = None):
  249. os.environ["TOKENIZERS_PARALLELISM"] = "false"
  250. if resource_path is not None:
  251. path = os.path.join(resource_path, name)
  252. else:
  253. path = os.path.join(os.path.dirname(__file__), "assets", name)
  254. tokenizer = GPT2TokenizerFast.from_pretrained(path)
  255. specials = [
  256. "<|startoftranscript|>",
  257. *[f"<|{lang}|>" for lang in LANGUAGES.keys()],
  258. "<|translate|>",
  259. "<|transcribe|>",
  260. "<|startoflm|>",
  261. "<|startofprev|>",
  262. "<|nospeech|>",
  263. "<|notimestamps|>",
  264. ]
  265. tokenizer.add_special_tokens(dict(additional_special_tokens=specials))
  266. return tokenizer
  267. @lru_cache(maxsize=None)
  268. def get_tokenizer(
  269. multilingual: bool,
  270. *,
  271. task: Optional[str] = None, # Literal["transcribe", "translate", None]
  272. language: Optional[str] = None,
  273. ) -> Tokenizer:
  274. if language is not None:
  275. language = language.lower()
  276. if language not in LANGUAGES:
  277. if language in TO_LANGUAGE_CODE:
  278. language = TO_LANGUAGE_CODE[language]
  279. else:
  280. raise ValueError(f"Unsupported language: {language}")
  281. if multilingual:
  282. tokenizer_name = "multilingual"
  283. task = task or "transcribe"
  284. language = language or "en"
  285. else:
  286. tokenizer_name = "gpt2"
  287. task = None
  288. language = None
  289. tokenizer = build_tokenizer(name=tokenizer_name)
  290. all_special_ids: List[int] = tokenizer.all_special_ids
  291. sot: int = all_special_ids[1]
  292. translate: int = all_special_ids[-6]
  293. transcribe: int = all_special_ids[-5]
  294. langs = tuple(LANGUAGES.keys())
  295. sot_sequence = [sot]
  296. if language is not None:
  297. sot_sequence.append(sot + 1 + langs.index(language))
  298. if task is not None:
  299. sot_sequence.append(transcribe if task == "transcribe" else translate)
  300. return Tokenizer(tokenizer=tokenizer, language=language, sot_sequence=tuple(sot_sequence))