|
|
@@ -21,42 +21,43 @@ class AbsTokenizer(ABC):
|
|
|
|
|
|
|
|
|
class BaseTokenizer(ABC):
|
|
|
- def __init__(self, token_list: Union[Path, str, Iterable[str]],
|
|
|
+ def __init__(self, token_list: Union[Path, str, Iterable[str]]=None,
|
|
|
unk_symbol: str = "<unk>",
|
|
|
**kwargs,
|
|
|
):
|
|
|
|
|
|
- if isinstance(token_list, (Path, str)):
|
|
|
- token_list = Path(token_list)
|
|
|
- self.token_list_repr = str(token_list)
|
|
|
- self.token_list: List[str] = []
|
|
|
+ if token_list is not None:
|
|
|
+ if isinstance(token_list, (Path, str)):
|
|
|
+ token_list = Path(token_list)
|
|
|
+ self.token_list_repr = str(token_list)
|
|
|
+ self.token_list: List[str] = []
|
|
|
+
|
|
|
+ with token_list.open("r", encoding="utf-8") as f:
|
|
|
+ for idx, line in enumerate(f):
|
|
|
+ line = line.rstrip()
|
|
|
+ self.token_list.append(line)
|
|
|
|
|
|
- with token_list.open("r", encoding="utf-8") as f:
|
|
|
- for idx, line in enumerate(f):
|
|
|
- line = line.rstrip()
|
|
|
- self.token_list.append(line)
|
|
|
-
|
|
|
- else:
|
|
|
- self.token_list: List[str] = list(token_list)
|
|
|
- self.token_list_repr = ""
|
|
|
+ else:
|
|
|
+ self.token_list: List[str] = list(token_list)
|
|
|
+ self.token_list_repr = ""
|
|
|
+ for i, t in enumerate(self.token_list):
|
|
|
+ if i == 3:
|
|
|
+ break
|
|
|
+ self.token_list_repr += f"{t}, "
|
|
|
+ self.token_list_repr += f"... (NVocab={(len(self.token_list))})"
|
|
|
+
|
|
|
+ self.token2id: Dict[str, int] = {}
|
|
|
for i, t in enumerate(self.token_list):
|
|
|
- if i == 3:
|
|
|
- break
|
|
|
- self.token_list_repr += f"{t}, "
|
|
|
- self.token_list_repr += f"... (NVocab={(len(self.token_list))})"
|
|
|
-
|
|
|
- self.token2id: Dict[str, int] = {}
|
|
|
- for i, t in enumerate(self.token_list):
|
|
|
- if t in self.token2id:
|
|
|
- raise RuntimeError(f'Symbol "{t}" is duplicated')
|
|
|
- self.token2id[t] = i
|
|
|
-
|
|
|
- self.unk_symbol = unk_symbol
|
|
|
- if self.unk_symbol not in self.token2id:
|
|
|
- raise RuntimeError(
|
|
|
- f"Unknown symbol '{unk_symbol}' doesn't exist in the token_list"
|
|
|
- )
|
|
|
- self.unk_id = self.token2id[self.unk_symbol]
|
|
|
+ if t in self.token2id:
|
|
|
+ raise RuntimeError(f'Symbol "{t}" is duplicated')
|
|
|
+ self.token2id[t] = i
|
|
|
+
|
|
|
+ self.unk_symbol = unk_symbol
|
|
|
+ if self.unk_symbol not in self.token2id:
|
|
|
+ raise RuntimeError(
|
|
|
+ f"Unknown symbol '{unk_symbol}' doesn't exist in the token_list"
|
|
|
+ )
|
|
|
+ self.unk_id = self.token2id[self.unk_symbol]
|
|
|
|
|
|
def encode(self, text):
|
|
|
tokens = self.text2tokens(text)
|