abs_tokenizer.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899
  1. from abc import ABC
  2. from abc import abstractmethod
  3. from typing import Iterable
  4. from typing import List
  5. from pathlib import Path
  6. from typing import Dict
  7. from typing import Iterable
  8. from typing import List
  9. from typing import Union
  10. import json
  11. import numpy as np
  12. class AbsTokenizer(ABC):
  13. @abstractmethod
  14. def text2tokens(self, line: str) -> List[str]:
  15. raise NotImplementedError
  16. @abstractmethod
  17. def tokens2text(self, tokens: Iterable[str]) -> str:
  18. raise NotImplementedError
  19. class BaseTokenizer(ABC):
  20. def __init__(self, token_list: Union[Path, str, Iterable[str]]=None,
  21. unk_symbol: str = "<unk>",
  22. **kwargs,
  23. ):
  24. if token_list is not None:
  25. if isinstance(token_list, (Path, str)) and token_list.endswith(".txt"):
  26. token_list = Path(token_list)
  27. self.token_list_repr = str(token_list)
  28. self.token_list: List[str] = []
  29. with token_list.open("r", encoding="utf-8") as f:
  30. for idx, line in enumerate(f):
  31. line = line.rstrip()
  32. self.token_list.append(line)
  33. elif isinstance(token_list, (Path, str)) and token_list.endswith(".json"):
  34. token_list = Path(token_list)
  35. self.token_list_repr = str(token_list)
  36. self.token_list: List[str] = []
  37. with open(token_list, 'r', encoding='utf-8') as f:
  38. self.token_list = json.load(f)
  39. else:
  40. self.token_list: List[str] = list(token_list)
  41. self.token_list_repr = ""
  42. for i, t in enumerate(self.token_list):
  43. if i == 3:
  44. break
  45. self.token_list_repr += f"{t}, "
  46. self.token_list_repr += f"... (NVocab={(len(self.token_list))})"
  47. self.token2id: Dict[str, int] = {}
  48. for i, t in enumerate(self.token_list):
  49. if t in self.token2id:
  50. raise RuntimeError(f'Symbol "{t}" is duplicated')
  51. self.token2id[t] = i
  52. self.unk_symbol = unk_symbol
  53. if self.unk_symbol not in self.token2id:
  54. raise RuntimeError(
  55. f"Unknown symbol '{unk_symbol}' doesn't exist in the token_list"
  56. )
  57. self.unk_id = self.token2id[self.unk_symbol]
  58. def encode(self, text):
  59. tokens = self.text2tokens(text)
  60. text_ints = self.tokens2ids(tokens)
  61. return text_ints
  62. def decode(self, text_ints):
  63. token = self.ids2tokens(text_ints)
  64. text = self.tokens2text(token)
  65. return text
  66. def get_num_vocabulary_size(self) -> int:
  67. return len(self.token_list)
  68. def ids2tokens(self, integers: Union[np.ndarray, Iterable[int]]) -> List[str]:
  69. if isinstance(integers, np.ndarray) and integers.ndim != 1:
  70. raise ValueError(f"Must be 1 dim ndarray, but got {integers.ndim}")
  71. return [self.token_list[i] for i in integers]
  72. def tokens2ids(self, tokens: Iterable[str]) -> List[int]:
  73. return [self.token2id.get(i, self.unk_id) for i in tokens]
  74. @abstractmethod
  75. def text2tokens(self, line: str) -> List[str]:
  76. raise NotImplementedError
  77. @abstractmethod
  78. def tokens2text(self, tokens: Iterable[str]) -> str:
  79. raise NotImplementedError