iterable_dataset.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393
  1. """Iterable dataset module."""
  2. import copy
  3. from io import StringIO
  4. from pathlib import Path
  5. from typing import Callable
  6. from typing import Collection
  7. from typing import Dict
  8. from typing import Iterator
  9. from typing import Tuple
  10. from typing import Union
  11. import kaldiio
  12. import numpy as np
  13. import torch
  14. import torchaudio
  15. from torch.utils.data.dataset import IterableDataset
  16. from typeguard import check_argument_types
  17. import os.path
  18. from funasr.datasets.dataset import ESPnetDataset
  19. SUPPORT_AUDIO_TYPE_SETS = ['flac', 'mp3', 'ogg', 'opus', 'wav', 'pcm']
  20. def load_kaldi(input):
  21. retval = kaldiio.load_mat(input)
  22. if isinstance(retval, tuple):
  23. assert len(retval) == 2, len(retval)
  24. if isinstance(retval[0], int) and isinstance(retval[1], np.ndarray):
  25. # sound scp case
  26. rate, array = retval
  27. elif isinstance(retval[1], int) and isinstance(retval[0], np.ndarray):
  28. # Extended ark format case
  29. array, rate = retval
  30. else:
  31. raise RuntimeError(f"Unexpected type: {type(retval[0])}, {type(retval[1])}")
  32. # Multichannel wave fie
  33. # array: (NSample, Channel) or (Nsample)
  34. else:
  35. # Normal ark case
  36. assert isinstance(retval, np.ndarray), type(retval)
  37. array = retval
  38. return array
  39. def load_bytes(input):
  40. middle_data = np.frombuffer(input, dtype=np.int16)
  41. middle_data = np.asarray(middle_data)
  42. if middle_data.dtype.kind not in 'iu':
  43. raise TypeError("'middle_data' must be an array of integers")
  44. dtype = np.dtype('float32')
  45. if dtype.kind != 'f':
  46. raise TypeError("'dtype' must be a floating point type")
  47. i = np.iinfo(middle_data.dtype)
  48. abs_max = 2 ** (i.bits - 1)
  49. offset = i.min + abs_max
  50. array = np.frombuffer((middle_data.astype(dtype) - offset) / abs_max, dtype=np.float32)
  51. return array
  52. def load_pcm(input):
  53. with open(input,"rb") as f:
  54. bytes = f.read()
  55. return load_bytes(bytes)
  56. DATA_TYPES = {
  57. "sound": lambda x: torchaudio.load(x)[0].numpy(),
  58. "pcm": load_pcm,
  59. "kaldi_ark": load_kaldi,
  60. "bytes": load_bytes,
  61. "waveform": lambda x: x,
  62. "npy": np.load,
  63. "text_int": lambda x: np.loadtxt(
  64. StringIO(x), ndmin=1, dtype=np.long, delimiter=" "
  65. ),
  66. "csv_int": lambda x: np.loadtxt(StringIO(x), ndmin=1, dtype=np.long, delimiter=","),
  67. "text_float": lambda x: np.loadtxt(
  68. StringIO(x), ndmin=1, dtype=np.float32, delimiter=" "
  69. ),
  70. "csv_float": lambda x: np.loadtxt(
  71. StringIO(x), ndmin=1, dtype=np.float32, delimiter=","
  72. ),
  73. "text": lambda x: x,
  74. }
  75. class IterableESPnetDataset(IterableDataset):
  76. """Pytorch Dataset class for ESPNet.
  77. Examples:
  78. >>> dataset = IterableESPnetDataset([('wav.scp', 'input', 'sound'),
  79. ... ('token_int', 'output', 'text_int')],
  80. ... )
  81. >>> for uid, data in dataset:
  82. ... data
  83. {'input': per_utt_array, 'output': per_utt_array}
  84. """
  85. def __init__(
  86. self,
  87. path_name_type_list: Collection[Tuple[any, str, str]],
  88. preprocess: Callable[
  89. [str, Dict[str, np.ndarray]], Dict[str, np.ndarray]
  90. ] = None,
  91. float_dtype: str = "float32",
  92. fs: dict = None,
  93. mc: bool = False,
  94. int_dtype: str = "long",
  95. key_file: str = None,
  96. ):
  97. assert check_argument_types()
  98. if len(path_name_type_list) == 0:
  99. raise ValueError(
  100. '1 or more elements are required for "path_name_type_list"'
  101. )
  102. path_name_type_list = copy.deepcopy(path_name_type_list)
  103. self.preprocess = preprocess
  104. self.float_dtype = float_dtype
  105. self.int_dtype = int_dtype
  106. self.key_file = key_file
  107. self.fs = fs
  108. self.mc = mc
  109. self.debug_info = {}
  110. non_iterable_list = []
  111. self.path_name_type_list = []
  112. if not isinstance(path_name_type_list[0], Tuple):
  113. path = path_name_type_list[0]
  114. name = path_name_type_list[1]
  115. _type = path_name_type_list[2]
  116. self.debug_info[name] = path, _type
  117. if _type not in DATA_TYPES:
  118. non_iterable_list.append((path, name, _type))
  119. else:
  120. self.path_name_type_list.append((path, name, _type))
  121. else:
  122. for path, name, _type in path_name_type_list:
  123. self.debug_info[name] = path, _type
  124. if _type not in DATA_TYPES:
  125. non_iterable_list.append((path, name, _type))
  126. else:
  127. self.path_name_type_list.append((path, name, _type))
  128. if len(non_iterable_list) != 0:
  129. # Some types doesn't support iterable mode
  130. self.non_iterable_dataset = ESPnetDataset(
  131. path_name_type_list=non_iterable_list,
  132. preprocess=preprocess,
  133. float_dtype=float_dtype,
  134. int_dtype=int_dtype,
  135. )
  136. else:
  137. self.non_iterable_dataset = None
  138. self.apply_utt2category = False
  139. def has_name(self, name) -> bool:
  140. return name in self.debug_info
  141. def names(self) -> Tuple[str, ...]:
  142. return tuple(self.debug_info)
  143. def __repr__(self):
  144. _mes = self.__class__.__name__
  145. _mes += "("
  146. for name, (path, _type) in self.debug_info.items():
  147. _mes += f'\n {name}: {{"path": "{path}", "type": "{_type}"}}'
  148. _mes += f"\n preprocess: {self.preprocess})"
  149. return _mes
  150. def __iter__(self) -> Iterator[Tuple[Union[str, int], Dict[str, np.ndarray]]]:
  151. count = 0
  152. if len(self.path_name_type_list) != 0 and (self.path_name_type_list[0][2] == "bytes" or self.path_name_type_list[0][2] == "waveform"):
  153. linenum = len(self.path_name_type_list)
  154. data = {}
  155. for i in range(linenum):
  156. value = self.path_name_type_list[i][0]
  157. uid = 'utt_id'
  158. name = self.path_name_type_list[i][1]
  159. _type = self.path_name_type_list[i][2]
  160. func = DATA_TYPES[_type]
  161. array = func(value)
  162. if self.fs is not None and (name == "speech" or name == "ref_speech"):
  163. audio_fs = self.fs["audio_fs"]
  164. model_fs = self.fs["model_fs"]
  165. if audio_fs is not None and model_fs is not None:
  166. array = torch.from_numpy(array)
  167. array = array.unsqueeze(0)
  168. array = torchaudio.transforms.Resample(orig_freq=audio_fs,
  169. new_freq=model_fs)(array)
  170. array = array.squeeze(0).numpy()
  171. data[name] = array
  172. if self.preprocess is not None:
  173. data = self.preprocess(uid, data)
  174. for name in data:
  175. count += 1
  176. value = data[name]
  177. if not isinstance(value, np.ndarray):
  178. raise RuntimeError(
  179. f'All values must be converted to np.ndarray object '
  180. f'by preprocessing, but "{name}" is still {type(value)}.')
  181. # Cast to desired type
  182. if value.dtype.kind == 'f':
  183. value = value.astype(self.float_dtype)
  184. elif value.dtype.kind == 'i':
  185. value = value.astype(self.int_dtype)
  186. else:
  187. raise NotImplementedError(
  188. f'Not supported dtype: {value.dtype}')
  189. data[name] = value
  190. yield uid, data
  191. elif len(self.path_name_type_list) != 0 and self.path_name_type_list[0][2] == "sound" and not self.path_name_type_list[0][0].lower().endswith(".scp"):
  192. linenum = len(self.path_name_type_list)
  193. data = {}
  194. for i in range(linenum):
  195. value = self.path_name_type_list[i][0]
  196. uid = os.path.basename(self.path_name_type_list[i][0]).split(".")[0]
  197. name = self.path_name_type_list[i][1]
  198. _type = self.path_name_type_list[i][2]
  199. if _type == "sound":
  200. audio_type = os.path.basename(value).split(".")[-1].lower()
  201. if audio_type not in SUPPORT_AUDIO_TYPE_SETS:
  202. raise NotImplementedError(
  203. f'Not supported audio type: {audio_type}')
  204. if audio_type == "pcm":
  205. _type = "pcm"
  206. func = DATA_TYPES[_type]
  207. array = func(value)
  208. if self.fs is not None and (name == "speech" or name == "ref_speech"):
  209. audio_fs = self.fs["audio_fs"]
  210. model_fs = self.fs["model_fs"]
  211. if audio_fs is not None and model_fs is not None:
  212. array = torch.from_numpy(array)
  213. array = torchaudio.transforms.Resample(orig_freq=audio_fs,
  214. new_freq=model_fs)(array)
  215. array = array.numpy()
  216. if _type == "sound":
  217. if self.mc:
  218. data[name] = array.transpose((1, 0))
  219. else:
  220. data[name] = array[0]
  221. else:
  222. data[name] = array
  223. if self.preprocess is not None:
  224. data = self.preprocess(uid, data)
  225. for name in data:
  226. count += 1
  227. value = data[name]
  228. if not isinstance(value, np.ndarray):
  229. raise RuntimeError(
  230. f'All values must be converted to np.ndarray object '
  231. f'by preprocessing, but "{name}" is still {type(value)}.')
  232. # Cast to desired type
  233. if value.dtype.kind == 'f':
  234. value = value.astype(self.float_dtype)
  235. elif value.dtype.kind == 'i':
  236. value = value.astype(self.int_dtype)
  237. else:
  238. raise NotImplementedError(
  239. f'Not supported dtype: {value.dtype}')
  240. data[name] = value
  241. yield uid, data
  242. else:
  243. if self.key_file is not None:
  244. uid_iter = (
  245. line.rstrip().split(maxsplit=1)[0]
  246. for line in open(self.key_file, encoding="utf-8")
  247. )
  248. elif len(self.path_name_type_list) != 0:
  249. uid_iter = (
  250. line.rstrip().split(maxsplit=1)[0]
  251. for line in open(self.path_name_type_list[0][0], encoding="utf-8")
  252. )
  253. else:
  254. uid_iter = iter(self.non_iterable_dataset)
  255. files = [open(lis[0], encoding="utf-8") for lis in self.path_name_type_list]
  256. worker_info = torch.utils.data.get_worker_info()
  257. linenum = 0
  258. for count, uid in enumerate(uid_iter, 1):
  259. # If num_workers>=1, split keys
  260. if worker_info is not None:
  261. if (count - 1) % worker_info.num_workers != worker_info.id:
  262. continue
  263. # 1. Read a line from each file
  264. while True:
  265. keys = []
  266. values = []
  267. for f in files:
  268. linenum += 1
  269. try:
  270. line = next(f)
  271. except StopIteration:
  272. raise RuntimeError(f"{uid} is not found in the files")
  273. sps = line.rstrip().split(maxsplit=1)
  274. if len(sps) != 2:
  275. raise RuntimeError(
  276. f"This line doesn't include a space:"
  277. f" {f}:L{linenum}: {line})"
  278. )
  279. key, value = sps
  280. keys.append(key)
  281. values.append(value)
  282. for k_idx, k in enumerate(keys):
  283. if k != keys[0]:
  284. raise RuntimeError(
  285. f"Keys are mismatched. Text files (idx={k_idx}) is "
  286. f"not sorted or not having same keys at L{linenum}"
  287. )
  288. # If the key is matched, break the loop
  289. if len(keys) == 0 or keys[0] == uid:
  290. break
  291. # 2. Load the entry from each line and create a dict
  292. data = {}
  293. # 2.a. Load data streamingly
  294. for value, (path, name, _type) in zip(values, self.path_name_type_list):
  295. if _type == "sound":
  296. audio_type = os.path.basename(value).split(".")[-1].lower()
  297. if audio_type not in SUPPORT_AUDIO_TYPE_SETS:
  298. raise NotImplementedError(
  299. f'Not supported audio type: {audio_type}')
  300. if audio_type == "pcm":
  301. _type = "pcm"
  302. func = DATA_TYPES[_type]
  303. # Load entry
  304. array = func(value)
  305. if self.fs is not None and name == "speech":
  306. audio_fs = self.fs["audio_fs"]
  307. model_fs = self.fs["model_fs"]
  308. if audio_fs is not None and model_fs is not None:
  309. array = torch.from_numpy(array)
  310. array = torchaudio.transforms.Resample(orig_freq=audio_fs,
  311. new_freq=model_fs)(array)
  312. array = array.numpy()
  313. if _type == "sound":
  314. if self.mc:
  315. data[name] = array.transpose((1, 0))
  316. else:
  317. data[name] = array[0]
  318. else:
  319. data[name] = array
  320. if self.non_iterable_dataset is not None:
  321. # 2.b. Load data from non-iterable dataset
  322. _, from_non_iterable = self.non_iterable_dataset[uid]
  323. data.update(from_non_iterable)
  324. # 3. [Option] Apply preprocessing
  325. # e.g. funasr.train.preprocessor:CommonPreprocessor
  326. if self.preprocess is not None:
  327. data = self.preprocess(uid, data)
  328. # 4. Force data-precision
  329. for name in data:
  330. value = data[name]
  331. if not isinstance(value, np.ndarray):
  332. raise RuntimeError(
  333. f"All values must be converted to np.ndarray object "
  334. f'by preprocessing, but "{name}" is still {type(value)}.'
  335. )
  336. # Cast to desired type
  337. if value.dtype.kind == "f":
  338. value = value.astype(self.float_dtype)
  339. elif value.dtype.kind == "i":
  340. value = value.astype(self.int_dtype)
  341. else:
  342. raise NotImplementedError(f"Not supported dtype: {value.dtype}")
  343. data[name] = value
  344. yield uid, data
  345. if count == 0:
  346. raise RuntimeError("No iteration")