iterable_dataset.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396
  1. """Iterable dataset module."""
  2. import copy
  3. from io import StringIO
  4. from pathlib import Path
  5. from typing import Callable
  6. from typing import Collection
  7. from typing import Dict
  8. from typing import Iterator
  9. from typing import Tuple
  10. from typing import Union
  11. from typing import List
  12. import kaldiio
  13. import numpy as np
  14. import torch
  15. import torchaudio
  16. import soundfile
  17. from torch.utils.data.dataset import IterableDataset
  18. import os.path
  19. from funasr.datasets.dataset import ESPnetDataset
  20. SUPPORT_AUDIO_TYPE_SETS = ['flac', 'mp3', 'ogg', 'opus', 'wav', 'pcm']
  21. def load_kaldi(input):
  22. retval = kaldiio.load_mat(input)
  23. if isinstance(retval, tuple):
  24. assert len(retval) == 2, len(retval)
  25. if isinstance(retval[0], int) and isinstance(retval[1], np.ndarray):
  26. # sound scp case
  27. rate, array = retval
  28. elif isinstance(retval[1], int) and isinstance(retval[0], np.ndarray):
  29. # Extended ark format case
  30. array, rate = retval
  31. else:
  32. raise RuntimeError(f"Unexpected type: {type(retval[0])}, {type(retval[1])}")
  33. # Multichannel wave fie
  34. # array: (NSample, Channel) or (Nsample)
  35. else:
  36. # Normal ark case
  37. assert isinstance(retval, np.ndarray), type(retval)
  38. array = retval
  39. return array
  40. def load_bytes(input):
  41. middle_data = np.frombuffer(input, dtype=np.int16)
  42. middle_data = np.asarray(middle_data)
  43. if middle_data.dtype.kind not in 'iu':
  44. raise TypeError("'middle_data' must be an array of integers")
  45. dtype = np.dtype('float32')
  46. if dtype.kind != 'f':
  47. raise TypeError("'dtype' must be a floating point type")
  48. i = np.iinfo(middle_data.dtype)
  49. abs_max = 2 ** (i.bits - 1)
  50. offset = i.min + abs_max
  51. array = np.frombuffer((middle_data.astype(dtype) - offset) / abs_max, dtype=np.float32)
  52. return array
  53. def load_pcm(input):
  54. with open(input,"rb") as f:
  55. bytes = f.read()
  56. return load_bytes(bytes)
  57. def load_wav(input):
  58. try:
  59. return torchaudio.load(input)[0].numpy()
  60. except:
  61. waveform, _ = soundfile.read(input, dtype='float32')
  62. if waveform.ndim == 2:
  63. waveform = waveform[:, 0]
  64. return np.expand_dims(waveform, axis=0)
  65. DATA_TYPES = {
  66. "sound": load_wav,
  67. "pcm": load_pcm,
  68. "kaldi_ark": load_kaldi,
  69. "bytes": load_bytes,
  70. "waveform": lambda x: x,
  71. "npy": np.load,
  72. "text_int": lambda x: np.loadtxt(
  73. StringIO(x), ndmin=1, dtype=np.long, delimiter=" "
  74. ),
  75. "csv_int": lambda x: np.loadtxt(StringIO(x), ndmin=1, dtype=np.long, delimiter=","),
  76. "text_float": lambda x: np.loadtxt(
  77. StringIO(x), ndmin=1, dtype=np.float32, delimiter=" "
  78. ),
  79. "csv_float": lambda x: np.loadtxt(
  80. StringIO(x), ndmin=1, dtype=np.float32, delimiter=","
  81. ),
  82. "text": lambda x: x,
  83. }
  84. class IterableESPnetDataset(IterableDataset):
  85. """Pytorch Dataset class for ESPNet.
  86. Examples:
  87. >>> dataset = IterableESPnetDataset([('wav.scp', 'input', 'sound'),
  88. ... ('token_int', 'output', 'text_int')],
  89. ... )
  90. >>> for uid, data in dataset:
  91. ... data
  92. {'input': per_utt_array, 'output': per_utt_array}
  93. """
  94. def __init__(
  95. self,
  96. path_name_type_list: Collection[Tuple[any, str, str]],
  97. preprocess: Callable[
  98. [str, Dict[str, np.ndarray]], Dict[str, np.ndarray]
  99. ] = None,
  100. float_dtype: str = "float32",
  101. fs: dict = None,
  102. mc: bool = False,
  103. int_dtype: str = "long",
  104. key_file: str = None,
  105. ):
  106. if len(path_name_type_list) == 0:
  107. raise ValueError(
  108. '1 or more elements are required for "path_name_type_list"'
  109. )
  110. path_name_type_list = copy.deepcopy(path_name_type_list)
  111. self.preprocess = preprocess
  112. self.float_dtype = float_dtype
  113. self.int_dtype = int_dtype
  114. self.key_file = key_file
  115. self.fs = fs
  116. self.mc = mc
  117. self.debug_info = {}
  118. non_iterable_list = []
  119. self.path_name_type_list = []
  120. if not isinstance(path_name_type_list[0], (Tuple, List)):
  121. path = path_name_type_list[0]
  122. name = path_name_type_list[1]
  123. _type = path_name_type_list[2]
  124. self.debug_info[name] = path, _type
  125. if _type not in DATA_TYPES:
  126. non_iterable_list.append((path, name, _type))
  127. else:
  128. self.path_name_type_list.append((path, name, _type))
  129. else:
  130. for path, name, _type in path_name_type_list:
  131. self.debug_info[name] = path, _type
  132. if _type not in DATA_TYPES:
  133. non_iterable_list.append((path, name, _type))
  134. else:
  135. self.path_name_type_list.append((path, name, _type))
  136. if len(non_iterable_list) != 0:
  137. # Some types doesn't support iterable mode
  138. self.non_iterable_dataset = ESPnetDataset(
  139. path_name_type_list=non_iterable_list,
  140. preprocess=preprocess,
  141. float_dtype=float_dtype,
  142. int_dtype=int_dtype,
  143. )
  144. else:
  145. self.non_iterable_dataset = None
  146. self.apply_utt2category = False
  147. def has_name(self, name) -> bool:
  148. return name in self.debug_info
  149. def names(self) -> Tuple[str, ...]:
  150. return tuple(self.debug_info)
  151. def __repr__(self):
  152. _mes = self.__class__.__name__
  153. _mes += "("
  154. for name, (path, _type) in self.debug_info.items():
  155. _mes += f'\n {name}: {{"path": "{path}", "type": "{_type}"}}'
  156. _mes += f"\n preprocess: {self.preprocess})"
  157. return _mes
  158. def __iter__(self) -> Iterator[Tuple[Union[str, int], Dict[str, np.ndarray]]]:
  159. count = 0
  160. if len(self.path_name_type_list) != 0 and (self.path_name_type_list[0][2] == "bytes" or self.path_name_type_list[0][2] == "waveform"):
  161. linenum = len(self.path_name_type_list)
  162. data = {}
  163. for i in range(linenum):
  164. value = self.path_name_type_list[i][0]
  165. uid = 'utt_id'
  166. name = self.path_name_type_list[i][1]
  167. _type = self.path_name_type_list[i][2]
  168. func = DATA_TYPES[_type]
  169. array = func(value)
  170. if self.fs is not None and (name == "speech" or name == "ref_speech"):
  171. audio_fs = self.fs["audio_fs"]
  172. model_fs = self.fs["model_fs"]
  173. if audio_fs is not None and model_fs is not None:
  174. array = torch.from_numpy(array)
  175. array = array.unsqueeze(0)
  176. array = torchaudio.transforms.Resample(orig_freq=audio_fs,
  177. new_freq=model_fs)(array)
  178. array = array.squeeze(0).numpy()
  179. data[name] = array
  180. if self.preprocess is not None:
  181. data = self.preprocess(uid, data)
  182. for name in data:
  183. count += 1
  184. value = data[name]
  185. if not isinstance(value, np.ndarray):
  186. raise RuntimeError(
  187. f'All values must be converted to np.ndarray object '
  188. f'by preprocessing, but "{name}" is still {type(value)}.')
  189. # Cast to desired type
  190. if value.dtype.kind == 'f':
  191. value = value.astype(self.float_dtype)
  192. elif value.dtype.kind == 'i':
  193. value = value.astype(self.int_dtype)
  194. else:
  195. raise NotImplementedError(
  196. f'Not supported dtype: {value.dtype}')
  197. data[name] = value
  198. yield uid, data
  199. elif len(self.path_name_type_list) != 0 and self.path_name_type_list[0][2] == "sound" and not self.path_name_type_list[0][0].lower().endswith(".scp"):
  200. linenum = len(self.path_name_type_list)
  201. data = {}
  202. for i in range(linenum):
  203. value = self.path_name_type_list[i][0]
  204. uid = os.path.basename(self.path_name_type_list[i][0]).split(".")[0]
  205. name = self.path_name_type_list[i][1]
  206. _type = self.path_name_type_list[i][2]
  207. if _type == "sound":
  208. audio_type = os.path.basename(value).lower()
  209. if audio_type.rfind(".pcm") >= 0:
  210. _type = "pcm"
  211. func = DATA_TYPES[_type]
  212. array = func(value)
  213. if self.fs is not None and (name == "speech" or name == "ref_speech"):
  214. audio_fs = self.fs["audio_fs"]
  215. model_fs = self.fs["model_fs"]
  216. if audio_fs is not None and model_fs is not None:
  217. array = torch.from_numpy(array)
  218. array = torchaudio.transforms.Resample(orig_freq=audio_fs,
  219. new_freq=model_fs)(array)
  220. array = array.numpy()
  221. if _type == "sound":
  222. if self.mc:
  223. data[name] = array.transpose((1, 0))
  224. else:
  225. data[name] = array[0]
  226. else:
  227. data[name] = array
  228. if self.preprocess is not None:
  229. data = self.preprocess(uid, data)
  230. for name in data:
  231. count += 1
  232. value = data[name]
  233. if not isinstance(value, np.ndarray):
  234. raise RuntimeError(
  235. f'All values must be converted to np.ndarray object '
  236. f'by preprocessing, but "{name}" is still {type(value)}.')
  237. # Cast to desired type
  238. if value.dtype.kind == 'f':
  239. value = value.astype(self.float_dtype)
  240. elif value.dtype.kind == 'i':
  241. value = value.astype(self.int_dtype)
  242. else:
  243. raise NotImplementedError(
  244. f'Not supported dtype: {value.dtype}')
  245. data[name] = value
  246. yield uid, data
  247. else:
  248. if self.key_file is not None:
  249. uid_iter = (
  250. line.rstrip().split(maxsplit=1)[0]
  251. for line in open(self.key_file, encoding="utf-8")
  252. )
  253. elif len(self.path_name_type_list) != 0:
  254. uid_iter = (
  255. line.rstrip().split(maxsplit=1)[0]
  256. for line in open(self.path_name_type_list[0][0], encoding="utf-8")
  257. )
  258. else:
  259. uid_iter = iter(self.non_iterable_dataset)
  260. files = [open(lis[0], encoding="utf-8") for lis in self.path_name_type_list]
  261. worker_info = torch.utils.data.get_worker_info()
  262. linenum = 0
  263. for count, uid in enumerate(uid_iter, 1):
  264. # If num_workers>=1, split keys
  265. if worker_info is not None:
  266. if (count - 1) % worker_info.num_workers != worker_info.id:
  267. continue
  268. # 1. Read a line from each file
  269. while True:
  270. keys = []
  271. values = []
  272. for f in files:
  273. linenum += 1
  274. try:
  275. line = next(f)
  276. except StopIteration:
  277. raise RuntimeError(f"{uid} is not found in the files")
  278. sps = line.rstrip().split(maxsplit=1)
  279. if len(sps) != 2:
  280. raise RuntimeError(
  281. f"This line doesn't include a space:"
  282. f" {f}:L{linenum}: {line})"
  283. )
  284. key, value = sps
  285. keys.append(key)
  286. values.append(value)
  287. for k_idx, k in enumerate(keys):
  288. if k != keys[0]:
  289. raise RuntimeError(
  290. f"Keys are mismatched. Text files (idx={k_idx}) is "
  291. f"not sorted or not having same keys at L{linenum}"
  292. )
  293. # If the key is matched, break the loop
  294. if len(keys) == 0 or keys[0] == uid:
  295. break
  296. # 2. Load the entry from each line and create a dict
  297. data = {}
  298. # 2.a. Load data streamingly
  299. for value, (path, name, _type) in zip(values, self.path_name_type_list):
  300. if _type == "sound":
  301. audio_type = os.path.basename(value).lower()
  302. if audio_type.rfind(".pcm") >= 0:
  303. _type = "pcm"
  304. func = DATA_TYPES[_type]
  305. # Load entry
  306. array = func(value)
  307. if self.fs is not None and name == "speech":
  308. audio_fs = self.fs["audio_fs"]
  309. model_fs = self.fs["model_fs"]
  310. if audio_fs is not None and model_fs is not None:
  311. array = torch.from_numpy(array)
  312. array = torchaudio.transforms.Resample(orig_freq=audio_fs,
  313. new_freq=model_fs)(array)
  314. array = array.numpy()
  315. if _type == "sound":
  316. if self.mc:
  317. data[name] = array.transpose((1, 0))
  318. else:
  319. data[name] = array[0]
  320. else:
  321. data[name] = array
  322. if self.non_iterable_dataset is not None:
  323. # 2.b. Load data from non-iterable dataset
  324. _, from_non_iterable = self.non_iterable_dataset[uid]
  325. data.update(from_non_iterable)
  326. # 3. [Option] Apply preprocessing
  327. # e.g. funasr.train.preprocessor:CommonPreprocessor
  328. if self.preprocess is not None:
  329. data = self.preprocess(uid, data)
  330. # 4. Force data-precision
  331. for name in data:
  332. value = data[name]
  333. if not isinstance(value, np.ndarray):
  334. raise RuntimeError(
  335. f"All values must be converted to np.ndarray object "
  336. f'by preprocessing, but "{name}" is still {type(value)}.'
  337. )
  338. # Cast to desired type
  339. if value.dtype.kind == "f":
  340. value = value.astype(self.float_dtype)
  341. elif value.dtype.kind == "i":
  342. value = value.astype(self.int_dtype)
  343. else:
  344. raise NotImplementedError(f"Not supported dtype: {value.dtype}")
  345. data[name] = value
  346. yield uid, data
  347. if count == 0:
  348. raise RuntimeError("No iteration")