build_dataloader.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. import logging
  2. from pathlib import Path
  3. from typing import Iterable
  4. from typing import List
  5. from typing import Union
  6. import sentencepiece as spm
  7. from torch.utils.data import DataLoader
  8. from typeguard import check_argument_types
  9. from funasr.datasets.large_datasets.dataset import Dataset
  10. from funasr.iterators.abs_iter_factory import AbsIterFactory
  11. from funasr.text.abs_tokenizer import AbsTokenizer
  12. def read_symbol_table(symbol_table_file):
  13. if isinstance(symbol_table_file, str):
  14. symbol_table = {}
  15. with open(symbol_table_file, "r", encoding="utf8") as fin:
  16. for i, line in enumerate(fin):
  17. char = line.strip()
  18. symbol_table[char] = i
  19. else:
  20. assert isinstance(symbol_table_file, list)
  21. symbol_table = {}
  22. for i, char in enumerate(symbol_table_file):
  23. symbol_table[char] = i
  24. return symbol_table
  25. def load_seg_dict(seg_dict_file):
  26. seg_dict = {}
  27. assert isinstance(seg_dict_file, str)
  28. with open(seg_dict_file, "r", encoding="utf8") as f:
  29. lines = f.readlines()
  30. for line in lines:
  31. s = line.strip().split()
  32. key = s[0]
  33. value = s[1:]
  34. seg_dict[key] = " ".join(value)
  35. return seg_dict
  36. class SentencepiecesTokenizer(AbsTokenizer):
  37. def __init__(self, model: Union[Path, str]):
  38. assert check_argument_types()
  39. self.model = str(model)
  40. self.sp = None
  41. def __repr__(self):
  42. return f'{self.__class__.__name__}(model="{self.model}")'
  43. def _build_sentence_piece_processor(self):
  44. if self.sp is None:
  45. self.sp = spm.SentencePieceProcessor()
  46. self.sp.load(self.model)
  47. def text2tokens(self, line: str) -> List[str]:
  48. self._build_sentence_piece_processor()
  49. return self.sp.EncodeAsPieces(line)
  50. def tokens2text(self, tokens: Iterable[str]) -> str:
  51. self._build_sentence_piece_processor()
  52. return self.sp.DecodePieces(list(tokens))
  53. class LargeDataLoader(AbsIterFactory):
  54. def __init__(self, args, mode="train"):
  55. symbol_table = read_symbol_table(args.token_list) if args.token_list is not None else None
  56. seg_dict = load_seg_dict(args.seg_dict_file) if args.seg_dict_file is not None else None
  57. punc_dict = load_seg_dict(args.punc_dict_file) if args.punc_dict_file is not None else None
  58. bpe_tokenizer = load_seg_dict(args.bpemodel_file) if args.bpemodel_file is not None else None
  59. self.dataset_conf = args.dataset_conf
  60. self.frontend_conf = args.frontend_conf
  61. logging.info("dataloader config: {}".format(self.dataset_conf))
  62. batch_mode = self.dataset_conf.get("batch_mode", "padding")
  63. self.dataset = Dataset(args.data_list, symbol_table, seg_dict, punc_dict, bpe_tokenizer,
  64. self.dataset_conf, self.frontend_conf, speed_perturb=args.speed_perturb,
  65. mode=mode, batch_mode=batch_mode)
  66. def build_iter(self, epoch, shuffle=True):
  67. self.dataset.set_epoch(epoch)
  68. data_loader = DataLoader(self.dataset,
  69. batch_size=None,
  70. pin_memory=True,
  71. num_workers=self.dataset_conf.get("num_workers", 8))
  72. return data_loader