dataset.py 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211
  1. import os
  2. import random
  3. import numpy
  4. from functools import partial
  5. import torch
  6. import torchaudio
  7. import torch.distributed as dist
  8. from kaldiio import ReadHelper
  9. from torch.utils.data import IterableDataset
  10. from funasr.datasets.large_datasets.datapipes.batch import MaxTokenBucketizerIterDataPipe
  11. from funasr.datasets.large_datasets.datapipes.filter import FilterIterDataPipe
  12. from funasr.datasets.large_datasets.datapipes.map import MapperIterDataPipe
  13. from funasr.datasets.large_datasets.utils.filter import filter
  14. from funasr.datasets.large_datasets.utils.padding import padding
  15. from funasr.datasets.large_datasets.utils.clipping import clipping
  16. from funasr.datasets.large_datasets.utils.tokenize import tokenize
  17. def read_lists(list_file):
  18. lists = []
  19. with open(list_file, 'r', encoding='utf8') as fin:
  20. for line in fin:
  21. parts = line.strip()
  22. lists.append(parts)
  23. return lists
  24. class AudioDataset(IterableDataset):
  25. def __init__(self, scp_lists, data_names, data_types, frontend_conf=None, shuffle=True, mode="train"):
  26. self.scp_lists = scp_lists
  27. self.data_names = data_names
  28. self.data_types = data_types
  29. self.frontend_conf = frontend_conf
  30. self.shuffle = shuffle
  31. self.mode = mode
  32. self.epoch = -1
  33. self.rank = 0
  34. self.world_size = 1
  35. self.worker_id = 0
  36. self.num_workers = 1
  37. def set_epoch(self, epoch):
  38. self.epoch = epoch
  39. def get_rank_data_list(self, data_index):
  40. assert dist.is_available()
  41. if dist.is_initialized():
  42. self.rank = dist.get_rank()
  43. self.world_size = dist.get_world_size()
  44. else:
  45. self.rank = 0
  46. self.world_size = 1
  47. if self.mode == "train":
  48. if self.shuffle:
  49. random.seed(self.epoch)
  50. random.shuffle(data_index)
  51. return data_index[self.rank::self.world_size]
  52. return data_index
  53. def get_worker_data_list(self, rank_data_index):
  54. worker_info = torch.utils.data.get_worker_info()
  55. if worker_info is None:
  56. self.worker_id = 0
  57. self.num_workers = 1
  58. else:
  59. self.worker_id = worker_info.id
  60. self.num_workers = worker_info.num_workers
  61. return rank_data_index[self.worker_id::self.num_workers]
  62. def close_reader(self, reader_list):
  63. for reader in reader_list:
  64. reader.close()
  65. def __iter__(self):
  66. data_index = list(range(len(self.scp_lists)))
  67. rank_data_index = self.get_rank_data_list(data_index)
  68. worker_data_index = self.get_worker_data_list(rank_data_index)
  69. for index in worker_data_index:
  70. data = dict(scp=self.scp_lists[index])
  71. assert 'scp' in data
  72. scp = data['scp']
  73. data_file_list = scp.strip().split()
  74. data_name_list = self.data_names.split(",")
  75. data_type_list = self.data_types.split(",")
  76. for file in data_file_list:
  77. assert os.path.exists(file), "{} not exists".format(file)
  78. assert len(data_file_list) == len(data_name_list) == len(data_type_list), \
  79. "The item number of data, data_names, data_types must be the same "
  80. reader_list = []
  81. for data_file, data_type in zip(data_file_list, data_type_list):
  82. if data_type == "kaldi_ark":
  83. ark_reader = ReadHelper('ark:{}'.format(data_file))
  84. reader_list.append(ark_reader)
  85. elif data_type == "text" or data_type == "sound":
  86. text_reader = open(data_file, "r")
  87. reader_list.append(text_reader)
  88. elif data_type == "none":
  89. continue
  90. else:
  91. raise TypeError("Data type {} is not supported".format(data_type))
  92. for items in zip(*reader_list):
  93. sample_dict = {}
  94. for item, (data_name, data_type) in zip(items, zip(data_name_list, data_type_list)):
  95. if data_type == "kaldi_ark":
  96. key, mat = item
  97. sample_dict[data_name] = mat
  98. if data_name == "speech":
  99. sample_dict["key"] = key
  100. elif data_type == "sound":
  101. key, path = item.strip().split()
  102. waveform, sampling_rate = torchaudio.load(path)
  103. if self.frontend_conf is not None:
  104. if sampling_rate != self.frontend_conf["fs"]:
  105. waveform = torchaudio.transforms.Resample(orig_freq=sampling_rate,
  106. new_freq=self.frontend_conf["fs"])(waveform)
  107. sampling_rate = self.frontend_conf["fs"]
  108. waveform = waveform.numpy()
  109. mat = waveform[0]
  110. sample_dict[data_name] = mat
  111. sample_dict["sampling_rate"] = sampling_rate
  112. if data_name == "speech":
  113. sample_dict["key"] = key
  114. else:
  115. text = item
  116. segs = text.strip().split()
  117. sample_dict[data_name] = segs[1:]
  118. if "key" not in sample_dict:
  119. sample_dict["key"] = segs[0]
  120. yield sample_dict
  121. self.close_reader(reader_list)
  122. def len_fn_example(data):
  123. return 1
  124. def len_fn_token(data):
  125. assert "speech" in data
  126. if "sampling_rate" in data:
  127. return (data["speech"].shape[0] / data["sampling_rate"]) * 1000.
  128. else:
  129. return data["speech"].shape[0]
  130. def Dataset(data_list_file,
  131. dict,
  132. seg_dict,
  133. punc_dict,
  134. conf,
  135. frontend_conf,
  136. mode="train",
  137. batch_mode="padding"):
  138. scp_lists = read_lists(data_list_file)
  139. shuffle = conf.get('shuffle', True)
  140. data_names = conf.get("data_names", "speech,text")
  141. data_types = conf.get("data_types", "kaldi_ark,text")
  142. dataset = AudioDataset(scp_lists, data_names, data_types, frontend_conf=frontend_conf, shuffle=shuffle, mode=mode)
  143. filter_conf = conf.get('filter_conf', {})
  144. filter_fn = partial(filter, **filter_conf)
  145. dataset = FilterIterDataPipe(dataset, fn=filter_fn)
  146. if "text" in data_names:
  147. vocab = {'vocab': dict, 'seg_dict': seg_dict, 'punc_dict': punc_dict}
  148. tokenize_fn = partial(tokenize, **vocab)
  149. dataset = MapperIterDataPipe(dataset, fn=tokenize_fn)
  150. if shuffle:
  151. buffer_conf = conf.get('shuffle_conf', {})
  152. buffer_size = buffer_conf['shuffle_size']
  153. sort_size = buffer_conf['sort_size']
  154. else:
  155. buffer_size = 0
  156. sort_size = 1
  157. batch_conf = conf.get('batch_conf', {})
  158. batch_size = batch_conf['batch_size']
  159. batch_type = batch_conf['batch_type']
  160. assert batch_type in ["example", "token"]
  161. if batch_type == 'example':
  162. len_fn = len_fn_example
  163. else:
  164. len_fn = len_fn_token
  165. dataset = MaxTokenBucketizerIterDataPipe(dataset,
  166. batch_size=batch_size,
  167. len_fn=len_fn,
  168. buffer_size=buffer_size,
  169. sort_size=sort_size,
  170. batch_mode=batch_mode)
  171. int_pad_value = conf.get("int_pad_value", -1)
  172. float_pad_value = conf.get("float_pad_value", 0.0)
  173. padding_conf = {"int_pad_value": int_pad_value, "float_pad_value": float_pad_value}
  174. padding_fn = partial(padding, **padding_conf)
  175. dataset = MapperIterDataPipe(dataset, fn=padding_fn if batch_mode == "padding" else clipping)
  176. return dataset