build_dataloader.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596
  1. import logging
  2. from pathlib import Path
  3. from typing import Iterable
  4. from typing import List
  5. from typing import Union
  6. import sentencepiece as spm
  7. from torch.utils.data import DataLoader
  8. from funasr.datasets.large_datasets.dataset import Dataset
  9. from funasr.iterators.abs_iter_factory import AbsIterFactory
  10. from funasr.tokenizer.abs_tokenizer import AbsTokenizer
  11. def read_symbol_table(symbol_table_file):
  12. if isinstance(symbol_table_file, str):
  13. symbol_table = {}
  14. with open(symbol_table_file, "r", encoding="utf8") as fin:
  15. for i, line in enumerate(fin):
  16. char = line.strip()
  17. symbol_table[char] = i
  18. else:
  19. assert isinstance(symbol_table_file, list)
  20. symbol_table = {}
  21. for i, char in enumerate(symbol_table_file):
  22. symbol_table[char] = i
  23. return symbol_table
  24. def load_seg_dict(seg_dict_file):
  25. seg_dict = {}
  26. assert isinstance(seg_dict_file, str)
  27. with open(seg_dict_file, "r", encoding="utf8") as f:
  28. lines = f.readlines()
  29. for line in lines:
  30. s = line.strip().split()
  31. key = s[0]
  32. value = s[1:]
  33. seg_dict[key] = " ".join(value)
  34. return seg_dict
  35. class SentencepiecesTokenizer(AbsTokenizer):
  36. def __init__(self, model: Union[Path, str]):
  37. self.model = str(model)
  38. self.sp = None
  39. def __repr__(self):
  40. return f'{self.__class__.__name__}(model="{self.model}")'
  41. def _build_sentence_piece_processor(self):
  42. if self.sp is None:
  43. self.sp = spm.SentencePieceProcessor()
  44. self.sp.load(self.model)
  45. def text2tokens(self, line: str) -> List[str]:
  46. self._build_sentence_piece_processor()
  47. return self.sp.EncodeAsPieces(line)
  48. def tokens2text(self, tokens: Iterable[str]) -> str:
  49. self._build_sentence_piece_processor()
  50. return self.sp.DecodePieces(list(tokens))
  51. class LargeDataLoader(AbsIterFactory):
  52. def __init__(self, args, mode="train"):
  53. symbol_table, seg_dict, punc_dict, bpe_tokenizer = None, None, None, None
  54. if hasattr(args, "token_list") and args.token_list is not None:
  55. symbol_table = read_symbol_table(args.token_list)
  56. if hasattr(args, "seg_dict_file") and args.seg_dict_file is not None:
  57. seg_dict = load_seg_dict(args.seg_dict_file)
  58. if hasattr(args, "punc_list") and args.punc_list is not None:
  59. punc_dict = read_symbol_table(args.punc_list)
  60. if hasattr(args, "bpemodel") and args.bpemodel is not None:
  61. bpe_tokenizer = SentencepiecesTokenizer(args.bpemodel)
  62. self.dataset_conf = args.dataset_conf
  63. if "frontend_conf" not in args:
  64. self.frontend_conf = None
  65. else:
  66. self.frontend_conf = args.frontend_conf
  67. self.speed_perturb = args.speed_perturb if hasattr(args, "speed_perturb") else None
  68. logging.info("dataloader config: {}".format(self.dataset_conf))
  69. batch_mode = self.dataset_conf.get("batch_mode", "padding")
  70. data_list = args.train_data_file if mode == "train" else args.valid_data_file
  71. self.dataset = Dataset(data_list, symbol_table, seg_dict, punc_dict, bpe_tokenizer,
  72. self.dataset_conf, self.frontend_conf,
  73. speed_perturb=self.speed_perturb if mode == "train" else None,
  74. mode=mode, batch_mode=batch_mode)
  75. def build_iter(self, epoch, shuffle=True):
  76. self.dataset.set_epoch(epoch)
  77. data_loader = DataLoader(self.dataset,
  78. batch_size=None,
  79. pin_memory=True,
  80. num_workers=self.dataset_conf.get("num_workers", 8))
  81. return data_loader