sentencepiece_tokenizer.py 1.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748
  1. from pathlib import Path
  2. from typing import Iterable
  3. from typing import List
  4. from typing import Union
  5. import sentencepiece as spm
  6. from funasr.tokenizer.abs_tokenizer import BaseTokenizer
  7. from funasr.register import tables
  8. @tables.register("tokenizer_classes", "SentencepiecesTokenizer")
  9. class SentencepiecesTokenizer(BaseTokenizer):
  10. def __init__(self, bpemodel: Union[Path, str],
  11. **kwargs
  12. ):
  13. super().__init__(**kwargs)
  14. self.bpemodel = str(bpemodel)
  15. # NOTE(kamo):
  16. # Don't build SentencePieceProcessor in __init__()
  17. # because it's not picklable and it may cause following error,
  18. # "TypeError: can't pickle SwigPyObject objects",
  19. # when giving it as argument of "multiprocessing.Process()".
  20. self.sp = None
  21. def __repr__(self):
  22. return f'{self.__class__.__name__}(model="{self.bpemodel}")'
  23. def _build_sentence_piece_processor(self):
  24. # Build SentencePieceProcessor lazily.
  25. if self.sp is None:
  26. self.sp = spm.SentencePieceProcessor()
  27. self.sp.load(self.bpemodel)
  28. def text2tokens(self, line: str) -> List[str]:
  29. self._build_sentence_piece_processor()
  30. return self.sp.EncodeAsPieces(line)
  31. def tokens2text(self, tokens: Iterable[str]) -> str:
  32. self._build_sentence_piece_processor()
  33. return self.sp.DecodePieces(list(tokens))
  34. def encode(self, line: str) -> List[int]:
  35. self._build_sentence_piece_processor()
  36. return self.sp.EncodeAsIds(line)
  37. def decode(self, line: List[int]):
  38. self._build_sentence_piece_processor()
  39. return self.sp.DecodeIds(line)