build_dataloader.py 2.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960
  1. import logging
  2. import yaml
  3. from torch.utils.data import DataLoader
  4. from funasr.datasets.large_datasets.dataset import Dataset
  5. from funasr.iterators.abs_iter_factory import AbsIterFactory
  6. def read_symbol_table(symbol_table_file):
  7. if isinstance(symbol_table_file, str):
  8. symbol_table = {}
  9. with open(symbol_table_file, "r", encoding="utf8") as fin:
  10. for i, line in enumerate(fin):
  11. char = line.strip()
  12. symbol_table[char] = i
  13. else:
  14. assert isinstance(symbol_table_file, list)
  15. symbol_table = {}
  16. for i, char in enumerate(symbol_table_file):
  17. symbol_table[char] = i
  18. return symbol_table
  19. def load_seg_dict(seg_dict_file):
  20. seg_dict = {}
  21. assert isinstance(seg_dict_file, str)
  22. with open(seg_dict_file, "r", encoding="utf8") as f:
  23. lines = f.readlines()
  24. for line in lines:
  25. s = line.strip().split()
  26. key = s[0]
  27. value = s[1:]
  28. seg_dict[key] = " ".join(value)
  29. return seg_dict
  30. class ArkDataLoader(AbsIterFactory):
  31. def __init__(self, data_list, dict_file, dataset_conf, frontend_conf=None, seg_dict_file=None, punc_dict_file=None, mode="train"):
  32. symbol_table = read_symbol_table(dict_file) if dict_file is not None else None
  33. if seg_dict_file is not None:
  34. seg_dict = load_seg_dict(seg_dict_file)
  35. else:
  36. seg_dict = None
  37. if punc_dict_file is not None:
  38. punc_dict = read_symbol_table(punc_dict_file)
  39. else:
  40. punc_dict = None
  41. self.dataset_conf = dataset_conf
  42. self.frontend_conf = frontend_conf
  43. logging.info("dataloader config: {}".format(self.dataset_conf))
  44. batch_mode = self.dataset_conf.get("batch_mode", "padding")
  45. self.dataset = Dataset(data_list, symbol_table, seg_dict, punc_dict,
  46. self.dataset_conf, self.frontend_conf, mode=mode, batch_mode=batch_mode)
  47. def build_iter(self, epoch, shuffle=True):
  48. self.dataset.set_epoch(epoch)
  49. data_loader = DataLoader(self.dataset,
  50. batch_size=None,
  51. pin_memory=True,
  52. num_workers=self.dataset_conf.get("num_workers", 8))
  53. return data_loader