abs_tokenizer.py 3.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788
  1. from abc import ABC
  2. from abc import abstractmethod
  3. from typing import Iterable
  4. from typing import List
  5. from pathlib import Path
  6. from typing import Dict
  7. from typing import Iterable
  8. from typing import List
  9. from typing import Union
  10. import numpy as np
  11. class AbsTokenizer(ABC):
  12. @abstractmethod
  13. def text2tokens(self, line: str) -> List[str]:
  14. raise NotImplementedError
  15. @abstractmethod
  16. def tokens2text(self, tokens: Iterable[str]) -> str:
  17. raise NotImplementedError
  18. class BaseTokenizer(ABC):
  19. def __init__(self, token_list: Union[Path, str, Iterable[str]]=None,
  20. unk_symbol: str = "<unk>",
  21. **kwargs,
  22. ):
  23. if token_list is not None:
  24. if isinstance(token_list, (Path, str)):
  25. token_list = Path(token_list)
  26. self.token_list_repr = str(token_list)
  27. self.token_list: List[str] = []
  28. with token_list.open("r", encoding="utf-8") as f:
  29. for idx, line in enumerate(f):
  30. line = line.rstrip()
  31. self.token_list.append(line)
  32. else:
  33. self.token_list: List[str] = list(token_list)
  34. self.token_list_repr = ""
  35. for i, t in enumerate(self.token_list):
  36. if i == 3:
  37. break
  38. self.token_list_repr += f"{t}, "
  39. self.token_list_repr += f"... (NVocab={(len(self.token_list))})"
  40. self.token2id: Dict[str, int] = {}
  41. for i, t in enumerate(self.token_list):
  42. if t in self.token2id:
  43. raise RuntimeError(f'Symbol "{t}" is duplicated')
  44. self.token2id[t] = i
  45. self.unk_symbol = unk_symbol
  46. if self.unk_symbol not in self.token2id:
  47. raise RuntimeError(
  48. f"Unknown symbol '{unk_symbol}' doesn't exist in the token_list"
  49. )
  50. self.unk_id = self.token2id[self.unk_symbol]
  51. def encode(self, text):
  52. tokens = self.text2tokens(text)
  53. text_ints = self.tokens2ids(tokens)
  54. return text_ints
  55. def decode(self, text_ints):
  56. return self.ids2tokens(text_ints)
  57. def get_num_vocabulary_size(self) -> int:
  58. return len(self.token_list)
  59. def ids2tokens(self, integers: Union[np.ndarray, Iterable[int]]) -> List[str]:
  60. if isinstance(integers, np.ndarray) and integers.ndim != 1:
  61. raise ValueError(f"Must be 1 dim ndarray, but got {integers.ndim}")
  62. return [self.token_list[i] for i in integers]
  63. def tokens2ids(self, tokens: Iterable[str]) -> List[int]:
  64. return [self.token2id.get(i, self.unk_id) for i in tokens]
  65. @abstractmethod
  66. def text2tokens(self, line: str) -> List[str]:
  67. raise NotImplementedError
  68. @abstractmethod
  69. def tokens2text(self, tokens: Iterable[str]) -> str:
  70. raise NotImplementedError