| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960 |
- from pathlib import Path
- from typing import Dict
- from typing import Iterable
- from typing import List
- from typing import Union
- import numpy as np
- from typeguard import check_argument_types
- class TokenIDConverter:
- def __init__(
- self,
- token_list: Union[Path, str, Iterable[str]],
- unk_symbol: str = "<unk>",
- ):
- assert check_argument_types()
- if isinstance(token_list, (Path, str)):
- token_list = Path(token_list)
- self.token_list_repr = str(token_list)
- self.token_list: List[str] = []
- with token_list.open("r", encoding="utf-8") as f:
- for idx, line in enumerate(f):
- line = line.rstrip()
- self.token_list.append(line)
- else:
- self.token_list: List[str] = list(token_list)
- self.token_list_repr = ""
- for i, t in enumerate(self.token_list):
- if i == 3:
- break
- self.token_list_repr += f"{t}, "
- self.token_list_repr += f"... (NVocab={(len(self.token_list))})"
- self.token2id: Dict[str, int] = {}
- for i, t in enumerate(self.token_list):
- if t in self.token2id:
- raise RuntimeError(f'Symbol "{t}" is duplicated')
- self.token2id[t] = i
- self.unk_symbol = unk_symbol
- if self.unk_symbol not in self.token2id:
- raise RuntimeError(
- f"Unknown symbol '{unk_symbol}' doesn't exist in the token_list"
- )
- self.unk_id = self.token2id[self.unk_symbol]
- def get_num_vocabulary_size(self) -> int:
- return len(self.token_list)
- def ids2tokens(self, integers: Union[np.ndarray, Iterable[int]]) -> List[str]:
- if isinstance(integers, np.ndarray) and integers.ndim != 1:
- raise ValueError(f"Must be 1 dim ndarray, but got {integers.ndim}")
- return [self.token_list[i] for i in integers]
- def tokens2ids(self, tokens: Iterable[str]) -> List[int]:
- return [self.token2id.get(i, self.unk_id) for i in tokens]
|