dataset.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190
  1. import os
  2. import random
  3. import soundfile
  4. from functools import partial
  5. import torch
  6. import torch.distributed as dist
  7. from kaldiio import ReadHelper
  8. from torch.utils.data import IterableDataset
  9. from funasr.datasets.large_datasets.datapipes.batch import MaxTokenBucketizerIterDataPipe
  10. from funasr.datasets.large_datasets.datapipes.filter import FilterIterDataPipe
  11. from funasr.datasets.large_datasets.datapipes.map import MapperIterDataPipe
  12. from funasr.datasets.large_datasets.utils.filter import filter
  13. from funasr.datasets.large_datasets.utils.padding import padding
  14. from funasr.datasets.large_datasets.utils.clipping import clipping
  15. from funasr.datasets.large_datasets.utils.tokenize import tokenize
  16. def read_lists(list_file):
  17. lists = []
  18. with open(list_file, 'r', encoding='utf8') as fin:
  19. for line in fin:
  20. parts = line.strip()
  21. lists.append(parts)
  22. return lists
  23. class AudioDataset(IterableDataset):
  24. def __init__(self, scp_lists, data_names, data_types, shuffle=True, mode="train"):
  25. self.scp_lists = scp_lists
  26. self.data_names = data_names
  27. self.data_types = data_types
  28. self.shuffle = shuffle
  29. self.mode = mode
  30. self.epoch = -1
  31. self.rank = 0
  32. self.world_size = 1
  33. self.worker_id = 0
  34. self.num_workers = 1
  35. def set_epoch(self, epoch):
  36. self.epoch = epoch
  37. def get_rank_data_list(self, data_index):
  38. assert dist.is_available()
  39. if dist.is_initialized():
  40. self.rank = dist.get_rank()
  41. self.world_size = dist.get_world_size()
  42. else:
  43. self.rank = 0
  44. self.world_size = 1
  45. if self.mode == "train":
  46. if self.shuffle:
  47. random.seed(self.epoch)
  48. random.shuffle(data_index)
  49. return data_index[self.rank::self.world_size]
  50. return data_index
  51. def get_worker_data_list(self, rank_data_index):
  52. worker_info = torch.utils.data.get_worker_info()
  53. if worker_info is None:
  54. self.worker_id = 0
  55. self.num_workers = 1
  56. else:
  57. self.worker_id = worker_info.id
  58. self.num_workers = worker_info.num_workers
  59. return rank_data_index[self.worker_id::self.num_workers]
  60. def close_reader(self, reader_list):
  61. for reader in reader_list:
  62. reader.close()
  63. def __iter__(self):
  64. data_index = list(range(len(self.scp_lists)))
  65. rank_data_index = self.get_rank_data_list(data_index)
  66. worker_data_index = self.get_worker_data_list(rank_data_index)
  67. for index in worker_data_index:
  68. data = dict(scp=self.scp_lists[index])
  69. assert 'scp' in data
  70. scp = data['scp']
  71. data_file_list = scp.strip().split()
  72. data_name_list = self.data_names.split(",")
  73. data_type_list = self.data_types.split(",")
  74. for file in data_file_list:
  75. assert os.path.exists(file), "{} not exists".format(file)
  76. assert len(data_file_list) == len(data_name_list) == len(data_type_list), \
  77. "The item number of data, data_names, data_types must be the same "
  78. reader_list = []
  79. for data_file, data_type in zip(data_file_list, data_type_list):
  80. if data_type == "kaldi_ark":
  81. ark_reader = ReadHelper('ark:{}'.format(data_file))
  82. reader_list.append(ark_reader)
  83. elif data_type == "text" or data_type == "sound":
  84. text_reader = open(data_file, "r")
  85. reader_list.append(text_reader)
  86. else:
  87. raise TypeError("Data type {} is not supported".format(data_type))
  88. for items in zip(*reader_list):
  89. sample_dict = {}
  90. for item, (data_name, data_type) in zip(items, zip(data_name_list, data_type_list)):
  91. if data_type == "kaldi_ark":
  92. key, mat = item
  93. sample_dict[data_name] = mat
  94. if data_name == "speech":
  95. sample_dict["key"] = key
  96. elif data_type == "sound":
  97. key, path = item.strip().split()
  98. mat, sampling_rate = soundfile.read(path)
  99. sample_dict[data_name] = mat
  100. sample_dict["sampling_rate"] = sampling_rate
  101. if data_name == "speech":
  102. sample_dict["key"] = key
  103. else:
  104. text = item
  105. sample_dict[data_name] = text.strip().split()[1:]
  106. yield sample_dict
  107. self.close_reader(reader_list)
  108. def len_fn_example(data):
  109. return len(data)
  110. def len_fn_token(data):
  111. assert "speech" in data
  112. if "sampling_rate" in data:
  113. return (data["speech"].shape[0] / data["sampling_rate"]) * 1000.
  114. else:
  115. return data["speech"].shape[0]
  116. def Dataset(data_list_file,
  117. dict,
  118. seg_dict,
  119. conf,
  120. mode="train",
  121. batch_mode="padding"):
  122. scp_lists = read_lists(data_list_file)
  123. shuffle = conf.get('shuffle', True)
  124. data_names = conf.get("data_names", "speech,text")
  125. data_types = conf.get("data_types", "kaldi_ark,text")
  126. dataset = AudioDataset(scp_lists, data_names, data_types, shuffle=shuffle, mode=mode)
  127. filter_conf = conf.get('filter_conf', {})
  128. filter_fn = partial(filter, **filter_conf)
  129. dataset = FilterIterDataPipe(dataset, fn=filter_fn)
  130. vocab = {'vocab': dict, 'seg_dict': seg_dict}
  131. tokenize_fn = partial(tokenize, **vocab)
  132. dataset = MapperIterDataPipe(dataset, fn=tokenize_fn)
  133. if shuffle:
  134. buffer_conf = conf.get('shuffle_conf', {})
  135. buffer_size = buffer_conf['shuffle_size']
  136. sort_size = buffer_conf['sort_size']
  137. else:
  138. buffer_size = 0
  139. sort_size = 1
  140. batch_conf = conf.get('batch_conf', {})
  141. batch_size = batch_conf['batch_size']
  142. batch_type = batch_conf['batch_type']
  143. assert batch_type in ["example", "token"]
  144. if batch_type == 'example':
  145. len_fn = len_fn_example
  146. else:
  147. len_fn = len_fn_token
  148. dataset = MaxTokenBucketizerIterDataPipe(dataset,
  149. batch_size=batch_size,
  150. len_fn=len_fn,
  151. buffer_size=buffer_size,
  152. sort_size=sort_size,
  153. batch_mode=batch_mode)
  154. dataset = MapperIterDataPipe(dataset, fn=padding if batch_mode == "padding" else clipping)
  155. return dataset