token_id_converter.py 2.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960
  1. from pathlib import Path
  2. from typing import Dict
  3. from typing import Iterable
  4. from typing import List
  5. from typing import Union
  6. import numpy as np
  7. from typeguard import check_argument_types
  8. class TokenIDConverter:
  9. def __init__(
  10. self,
  11. token_list: Union[Path, str, Iterable[str]],
  12. unk_symbol: str = "<unk>",
  13. ):
  14. assert check_argument_types()
  15. if isinstance(token_list, (Path, str)):
  16. token_list = Path(token_list)
  17. self.token_list_repr = str(token_list)
  18. self.token_list: List[str] = []
  19. with token_list.open("r", encoding="utf-8") as f:
  20. for idx, line in enumerate(f):
  21. line = line.rstrip()
  22. self.token_list.append(line)
  23. else:
  24. self.token_list: List[str] = list(token_list)
  25. self.token_list_repr = ""
  26. for i, t in enumerate(self.token_list):
  27. if i == 3:
  28. break
  29. self.token_list_repr += f"{t}, "
  30. self.token_list_repr += f"... (NVocab={(len(self.token_list))})"
  31. self.token2id: Dict[str, int] = {}
  32. for i, t in enumerate(self.token_list):
  33. if t in self.token2id:
  34. raise RuntimeError(f'Symbol "{t}" is duplicated')
  35. self.token2id[t] = i
  36. self.unk_symbol = unk_symbol
  37. if self.unk_symbol not in self.token2id:
  38. raise RuntimeError(
  39. f"Unknown symbol '{unk_symbol}' doesn't exist in the token_list"
  40. )
  41. self.unk_id = self.token2id[self.unk_symbol]
  42. def get_num_vocabulary_size(self) -> int:
  43. return len(self.token_list)
  44. def ids2tokens(self, integers: Union[np.ndarray, Iterable[int]]) -> List[str]:
  45. if isinstance(integers, np.ndarray) and integers.ndim != 1:
  46. raise ValueError(f"Must be 1 dim ndarray, but got {integers.ndim}")
  47. return [self.token_list[i] for i in integers]
  48. def tokens2ids(self, tokens: Iterable[str]) -> List[int]:
  49. return [self.token2id.get(i, self.unk_id) for i in tokens]