|
|
@@ -24,21 +24,11 @@ class TokenIDConverter():
|
|
|
):
|
|
|
check_argument_types()
|
|
|
|
|
|
- # self.token_list = self.load_token(token_path)
|
|
|
self.token_list = token_list
|
|
|
self.unk_symbol = token_list[-1]
|
|
|
+ self.token2id = {v: i for i, v in enumerate(self.token_list)}
|
|
|
+ self.unk_id = self.token2id[self.unk_symbol]
|
|
|
|
|
|
- # @staticmethod
|
|
|
- # def load_token(file_path: Union[Path, str]) -> List:
|
|
|
- # if not Path(file_path).exists():
|
|
|
- # raise TokenIDConverterError(f'The {file_path} does not exist.')
|
|
|
- #
|
|
|
- # with open(str(file_path), 'rb') as f:
|
|
|
- # token_list = pickle.load(f)
|
|
|
- #
|
|
|
- # if len(token_list) != len(set(token_list)):
|
|
|
- # raise TokenIDConverterError('The Token exists duplicated symbol.')
|
|
|
- # return token_list
|
|
|
|
|
|
def get_num_vocabulary_size(self) -> int:
|
|
|
return len(self.token_list)
|
|
|
@@ -51,13 +41,8 @@ class TokenIDConverter():
|
|
|
return [self.token_list[i] for i in integers]
|
|
|
|
|
|
def tokens2ids(self, tokens: Iterable[str]) -> List[int]:
|
|
|
- token2id = {v: i for i, v in enumerate(self.token_list)}
|
|
|
- if self.unk_symbol not in token2id:
|
|
|
- raise TokenIDConverterError(
|
|
|
- f"Unknown symbol '{self.unk_symbol}' doesn't exist in the token_list"
|
|
|
- )
|
|
|
- unk_id = token2id[self.unk_symbol]
|
|
|
- return [token2id.get(i, unk_id) for i in tokens]
|
|
|
+
|
|
|
+ return [self.token2id.get(i, self.unk_id) for i in tokens]
|
|
|
|
|
|
|
|
|
class CharTokenizer():
|