build_dataloader.py 2.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455
  1. import logging
  2. import yaml
  3. from torch.utils.data import DataLoader
  4. from funasr.datasets.large_datasets.dataset import Dataset
  5. from funasr.iterators.abs_iter_factory import AbsIterFactory
  6. def read_symbol_table(symbol_table_file):
  7. if isinstance(symbol_table_file, str):
  8. symbol_table = {}
  9. with open(symbol_table_file, "r", encoding="utf8") as fin:
  10. for i, line in enumerate(fin):
  11. char = line.strip()
  12. symbol_table[char] = i
  13. else:
  14. assert isinstance(symbol_table_file, list)
  15. symbol_table = {}
  16. for i, char in enumerate(symbol_table_file):
  17. symbol_table[char] = i
  18. return symbol_table
  19. def load_seg_dict(seg_dict_file):
  20. seg_dict = {}
  21. assert isinstance(seg_dict_file, str)
  22. with open(seg_dict_file, "r", encoding="utf8") as f:
  23. lines = f.readlines()
  24. for line in lines:
  25. s = line.strip().split()
  26. key = s[0]
  27. value = s[1:]
  28. seg_dict[key] = " ".join(value)
  29. return seg_dict
  30. class ArkDataLoader(AbsIterFactory):
  31. def __init__(self, data_list, dict_file, dataset_conf, seg_dict_file=None, mode="train"):
  32. symbol_table = read_symbol_table(dict_file) if dict_file is not None else None
  33. if seg_dict_file is not None:
  34. seg_dict = load_seg_dict(seg_dict_file)
  35. else:
  36. seg_dict = None
  37. self.dataset_conf = dataset_conf
  38. logging.info("dataloader config: {}".format(self.dataset_conf))
  39. batch_mode = self.dataset_conf.get("batch_mode", "padding")
  40. self.dataset = Dataset(data_list, symbol_table, seg_dict,
  41. self.dataset_conf, mode=mode, batch_mode=batch_mode)
  42. def build_iter(self, epoch, shuffle=True):
  43. self.dataset.set_epoch(epoch)
  44. data_loader = DataLoader(self.dataset,
  45. batch_size=None,
  46. pin_memory=True,
  47. num_workers=self.dataset_conf.get("num_workers", 8))
  48. return data_loader