abs_tokenizer.py 3.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394
  1. import json
  2. import numpy as np
  3. from abc import ABC
  4. from pathlib import Path
  5. from abc import abstractmethod
  6. from typing import Union, Iterable, List, Dict
  7. class AbsTokenizer(ABC):
  8. @abstractmethod
  9. def text2tokens(self, line: str) -> List[str]:
  10. raise NotImplementedError
  11. @abstractmethod
  12. def tokens2text(self, tokens: Iterable[str]) -> str:
  13. raise NotImplementedError
  14. class BaseTokenizer(ABC):
  15. def __init__(self, token_list: Union[Path, str, Iterable[str]] = None,
  16. unk_symbol: str = "<unk>",
  17. **kwargs,
  18. ):
  19. if token_list is not None:
  20. if isinstance(token_list, (Path, str)) and token_list.endswith(".txt"):
  21. token_list = Path(token_list)
  22. self.token_list_repr = str(token_list)
  23. self.token_list: List[str] = []
  24. with token_list.open("r", encoding="utf-8") as f:
  25. for idx, line in enumerate(f):
  26. line = line.rstrip()
  27. self.token_list.append(line)
  28. elif isinstance(token_list, (Path, str)) and token_list.endswith(".json"):
  29. token_list = Path(token_list)
  30. self.token_list_repr = str(token_list)
  31. self.token_list: List[str] = []
  32. with open(token_list, 'r', encoding='utf-8') as f:
  33. self.token_list = json.load(f)
  34. else:
  35. self.token_list: List[str] = list(token_list)
  36. self.token_list_repr = ""
  37. for i, t in enumerate(self.token_list):
  38. if i == 3:
  39. break
  40. self.token_list_repr += f"{t}, "
  41. self.token_list_repr += f"... (NVocab={(len(self.token_list))})"
  42. self.token2id: Dict[str, int] = {}
  43. for i, t in enumerate(self.token_list):
  44. if t in self.token2id:
  45. raise RuntimeError(f'Symbol "{t}" is duplicated')
  46. self.token2id[t] = i
  47. self.unk_symbol = unk_symbol
  48. if self.unk_symbol not in self.token2id:
  49. raise RuntimeError(
  50. f"Unknown symbol '{unk_symbol}' doesn't exist in the token_list"
  51. )
  52. self.unk_id = self.token2id[self.unk_symbol]
  53. def encode(self, text):
  54. tokens = self.text2tokens(text)
  55. text_ints = self.tokens2ids(tokens)
  56. return text_ints
  57. def decode(self, text_ints):
  58. token = self.ids2tokens(text_ints)
  59. text = self.tokens2text(token)
  60. return text
  61. def get_num_vocabulary_size(self) -> int:
  62. return len(self.token_list)
  63. def ids2tokens(self, integers: Union[np.ndarray, Iterable[int]]) -> List[str]:
  64. if isinstance(integers, np.ndarray) and integers.ndim != 1:
  65. raise ValueError(f"Must be 1 dim ndarray, but got {integers.ndim}")
  66. return [self.token_list[i] for i in integers]
  67. def tokens2ids(self, tokens: Iterable[str]) -> List[int]:
  68. return [self.token2id.get(i, self.unk_id) for i in tokens]
  69. @abstractmethod
  70. def text2tokens(self, line: str) -> List[str]:
  71. raise NotImplementedError
  72. @abstractmethod
  73. def tokens2text(self, tokens: Iterable[str]) -> str:
  74. raise NotImplementedError