token_id_converter.py 2.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758
  1. from pathlib import Path
  2. from typing import Dict
  3. from typing import Iterable
  4. from typing import List
  5. from typing import Union
  6. import numpy as np
  7. class TokenIDConverter:
  8. def __init__(
  9. self,
  10. token_list: Union[Path, str, Iterable[str]],
  11. unk_symbol: str = "<unk>",
  12. ):
  13. if isinstance(token_list, (Path, str)):
  14. token_list = Path(token_list)
  15. self.token_list_repr = str(token_list)
  16. self.token_list: List[str] = []
  17. with token_list.open("r", encoding="utf-8") as f:
  18. for idx, line in enumerate(f):
  19. line = line.rstrip()
  20. self.token_list.append(line)
  21. else:
  22. self.token_list: List[str] = list(token_list)
  23. self.token_list_repr = ""
  24. for i, t in enumerate(self.token_list):
  25. if i == 3:
  26. break
  27. self.token_list_repr += f"{t}, "
  28. self.token_list_repr += f"... (NVocab={(len(self.token_list))})"
  29. self.token2id: Dict[str, int] = {}
  30. for i, t in enumerate(self.token_list):
  31. if t in self.token2id:
  32. raise RuntimeError(f'Symbol "{t}" is duplicated')
  33. self.token2id[t] = i
  34. self.unk_symbol = unk_symbol
  35. if self.unk_symbol not in self.token2id:
  36. raise RuntimeError(
  37. f"Unknown symbol '{unk_symbol}' doesn't exist in the token_list"
  38. )
  39. self.unk_id = self.token2id[self.unk_symbol]
  40. def get_num_vocabulary_size(self) -> int:
  41. return len(self.token_list)
  42. def ids2tokens(self, integers: Union[np.ndarray, Iterable[int]]) -> List[str]:
  43. if isinstance(integers, np.ndarray) and integers.ndim != 1:
  44. raise ValueError(f"Must be 1 dim ndarray, but got {integers.ndim}")
  45. return [self.token_list[i] for i in integers]
  46. def tokens2ids(self, tokens: Iterable[str]) -> List[int]:
  47. return [self.token2id.get(i, self.unk_id) for i in tokens]