dataset.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270
  1. import logging
  2. import os
  3. import random
  4. from functools import partial
  5. import torch
  6. import torch.distributed as dist
  7. import torchaudio
  8. import numpy as np
  9. import soundfile
  10. from kaldiio import ReadHelper
  11. from torch.utils.data import IterableDataset
  12. from funasr.datasets.large_datasets.datapipes.batch import MaxTokenBucketizerIterDataPipe
  13. from funasr.datasets.large_datasets.datapipes.filter import FilterIterDataPipe
  14. from funasr.datasets.large_datasets.datapipes.map import MapperIterDataPipe
  15. from funasr.datasets.large_datasets.utils.clipping import clipping
  16. from funasr.datasets.large_datasets.utils.filter import filter
  17. from funasr.datasets.large_datasets.utils.padding import padding
  18. from funasr.datasets.large_datasets.utils.tokenize import tokenize
  19. def read_lists(list_file):
  20. lists = []
  21. with open(list_file, 'r', encoding='utf8') as fin:
  22. for line in fin:
  23. parts = line.strip()
  24. lists.append(parts)
  25. return lists
  26. class AudioDataset(IterableDataset):
  27. def __init__(self, scp_lists, data_names, data_types, frontend_conf=None, shuffle=True, speed_perturb=None,
  28. mode="train"):
  29. self.scp_lists = scp_lists
  30. self.data_names = data_names
  31. self.data_types = data_types
  32. self.frontend_conf = frontend_conf
  33. self.shuffle = shuffle
  34. self.mode = mode
  35. self.epoch = -1
  36. self.rank = 0
  37. self.world_size = 1
  38. self.worker_id = 0
  39. self.num_workers = 1
  40. self.speed_perturb = speed_perturb
  41. if self.speed_perturb is not None:
  42. logging.info("Using speed_perturb: {}".format(speed_perturb))
  43. def set_epoch(self, epoch):
  44. self.epoch = epoch
  45. def get_rank_data_list(self, data_index):
  46. assert dist.is_available()
  47. if dist.is_initialized():
  48. self.rank = dist.get_rank()
  49. self.world_size = dist.get_world_size()
  50. else:
  51. self.rank = 0
  52. self.world_size = 1
  53. if self.mode == "train":
  54. if self.shuffle:
  55. random.seed(self.epoch)
  56. random.shuffle(data_index)
  57. return data_index[self.rank::self.world_size]
  58. return data_index
  59. def get_worker_data_list(self, rank_data_index):
  60. worker_info = torch.utils.data.get_worker_info()
  61. if worker_info is None:
  62. self.worker_id = 0
  63. self.num_workers = 1
  64. else:
  65. self.worker_id = worker_info.id
  66. self.num_workers = worker_info.num_workers
  67. return rank_data_index[self.worker_id::self.num_workers]
  68. def close_reader(self, reader_list):
  69. for reader in reader_list:
  70. reader.close()
  71. def __iter__(self):
  72. data_index = list(range(len(self.scp_lists)))
  73. rank_data_index = self.get_rank_data_list(data_index)
  74. worker_data_index = self.get_worker_data_list(rank_data_index)
  75. for index in worker_data_index:
  76. data = dict(scp=self.scp_lists[index])
  77. assert 'scp' in data
  78. scp = data['scp']
  79. data_file_list = scp.strip().split()
  80. data_name_list = self.data_names.split(",")
  81. data_type_list = self.data_types.split(",")
  82. for file in data_file_list:
  83. assert os.path.exists(file), "{} not exists".format(file)
  84. assert len(data_file_list) == len(data_name_list) == len(data_type_list), \
  85. "The item number of data, data_names, data_types must be the same "
  86. reader_list = []
  87. for data_file, data_type in zip(data_file_list, data_type_list):
  88. if data_type == "kaldi_ark":
  89. ark_reader = ReadHelper('ark:{}'.format(data_file))
  90. reader_list.append(ark_reader)
  91. elif data_type == "text" or data_type == "sound" or data_type == 'text_hotword':
  92. text_reader = open(data_file, "r")
  93. reader_list.append(text_reader)
  94. elif data_type == "none":
  95. continue
  96. else:
  97. raise TypeError("Data type {} is not supported".format(data_type))
  98. for items in zip(*reader_list):
  99. sample_dict = {}
  100. for item, (data_name, data_type) in zip(items, zip(data_name_list, data_type_list)):
  101. if data_type == "kaldi_ark":
  102. key, mat = item
  103. sample_dict[data_name] = mat
  104. if data_name == "speech":
  105. sample_dict["key"] = key
  106. elif data_type == "sound":
  107. key, path = item.strip().split()
  108. try:
  109. waveform, sampling_rate = torchaudio.load(path)
  110. except:
  111. waveform, sampling_rate = soundfile.read(path, dtype='float32')
  112. if waveform.ndim == 2:
  113. waveform = waveform[:, 0]
  114. waveform = np.expand_dims(waveform, axis=0)
  115. waveform = torch.tensor(waveform)
  116. if self.frontend_conf is not None:
  117. if sampling_rate != self.frontend_conf["fs"]:
  118. waveform = torchaudio.transforms.Resample(orig_freq=sampling_rate,
  119. new_freq=self.frontend_conf["fs"])(waveform)
  120. sampling_rate = self.frontend_conf["fs"]
  121. waveform = waveform.numpy()
  122. mat = waveform[0]
  123. if self.speed_perturb is not None:
  124. speed = random.choice(self.speed_perturb)
  125. if speed != 1.0:
  126. mat, _ = torchaudio.sox_effects.apply_effects_tensor(
  127. torch.tensor(mat).view(1, -1), sampling_rate, [['speed', str(speed)], ['rate', str(sampling_rate)]])
  128. mat = mat.view(-1).numpy()
  129. sample_dict[data_name] = mat
  130. sample_dict["sampling_rate"] = sampling_rate
  131. if data_name == "speech":
  132. sample_dict["key"] = key
  133. elif data_type == "text_hotword":
  134. text = item
  135. segs = text.strip().split()
  136. sample_dict[data_name] = segs[1:]
  137. if "key" not in sample_dict:
  138. sample_dict["key"] = segs[0]
  139. sample_dict['hw_tag'] = 1
  140. elif data_type == "text_nospace":
  141. text = item
  142. segs = text.strip().split(maxsplit=1)
  143. sample_dict[data_name] = [x for x in segs[1]]
  144. if "key" not in sample_dict:
  145. sample_dict["key"] = segs[0]
  146. else:
  147. text = item
  148. segs = text.strip().split()
  149. sample_dict[data_name] = segs[1:]
  150. if "key" not in sample_dict:
  151. sample_dict["key"] = segs[0]
  152. yield sample_dict
  153. self.close_reader(reader_list)
  154. def len_fn_example(data):
  155. return 1
  156. def len_fn_token(data):
  157. assert "speech" in data
  158. if "sampling_rate" in data:
  159. return (data["speech"].shape[0] / data["sampling_rate"]) * 1000.
  160. else:
  161. return data["speech"].shape[0]
  162. def Dataset(data_list_file,
  163. dict,
  164. seg_dict,
  165. punc_dict,
  166. bpe_tokenizer,
  167. conf,
  168. frontend_conf,
  169. speed_perturb=None,
  170. mode="train",
  171. batch_mode="padding"):
  172. scp_lists = read_lists(data_list_file)
  173. shuffle = conf.get('shuffle', True)
  174. data_names = conf.get("data_names", "speech,text")
  175. data_types = conf.get("data_types", "kaldi_ark,text")
  176. pre_hwfile = conf.get("pre_hwlist", None)
  177. pre_prob = conf.get("pre_prob", 0) # unused yet
  178. hw_config = {"sample_rate": conf.get("sample_rate", 0.6),
  179. "double_rate": conf.get("double_rate", 0.1),
  180. "hotword_min_length": conf.get("hotword_min_length", 2),
  181. "hotword_max_length": conf.get("hotword_max_length", 8),
  182. "pre_prob": conf.get("pre_prob", 0.0)}
  183. if pre_hwfile is not None:
  184. pre_hwlist = []
  185. with open(pre_hwfile, 'r') as fin:
  186. for line in fin.readlines():
  187. pre_hwlist.append(line.strip())
  188. else:
  189. pre_hwlist = None
  190. dataset = AudioDataset(scp_lists,
  191. data_names,
  192. data_types,
  193. frontend_conf=frontend_conf,
  194. shuffle=shuffle,
  195. speed_perturb=speed_perturb,
  196. mode=mode,
  197. )
  198. filter_conf = conf.get('filter_conf', {})
  199. filter_fn = partial(filter, **filter_conf)
  200. dataset = FilterIterDataPipe(dataset, fn=filter_fn)
  201. if "text" in data_names:
  202. vocab = {'vocab': dict, 'seg_dict': seg_dict, 'punc_dict': punc_dict, 'bpe_tokenizer': bpe_tokenizer, 'hw_config': hw_config}
  203. tokenize_fn = partial(tokenize, **vocab)
  204. dataset = MapperIterDataPipe(dataset, fn=tokenize_fn)
  205. if shuffle:
  206. buffer_conf = conf.get('shuffle_conf', {})
  207. buffer_size = buffer_conf['shuffle_size']
  208. sort_size = buffer_conf['sort_size']
  209. else:
  210. buffer_size = 0
  211. sort_size = 1
  212. batch_conf = conf.get('batch_conf', {})
  213. batch_size = batch_conf['batch_size']
  214. batch_type = batch_conf['batch_type']
  215. assert batch_type in ["example", "token"]
  216. if batch_type == 'example':
  217. len_fn = len_fn_example
  218. else:
  219. len_fn = len_fn_token
  220. dataset = MaxTokenBucketizerIterDataPipe(dataset,
  221. batch_size=batch_size,
  222. len_fn=len_fn,
  223. buffer_size=buffer_size,
  224. sort_size=sort_size,
  225. batch_mode=batch_mode)
  226. int_pad_value = conf.get("int_pad_value", -1)
  227. float_pad_value = conf.get("float_pad_value", 0.0)
  228. padding_conf = {"int_pad_value": int_pad_value, "float_pad_value": float_pad_value}
  229. padding_fn = partial(padding, **padding_conf)
  230. dataset = MapperIterDataPipe(dataset, fn=padding_fn if batch_mode == "padding" else clipping)
  231. return dataset