dataset.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444
  1. # Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved.
  2. # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
  3. from abc import ABC
  4. from abc import abstractmethod
  5. import collections
  6. import copy
  7. import functools
  8. import logging
  9. import numbers
  10. import re
  11. from typing import Any
  12. from typing import Callable
  13. from typing import Collection
  14. from typing import Dict
  15. from typing import Mapping
  16. from typing import Tuple
  17. from typing import Union
  18. import h5py
  19. import humanfriendly
  20. import kaldiio
  21. import numpy as np
  22. import torch
  23. from torch.utils.data.dataset import Dataset
  24. from typeguard import check_argument_types
  25. from typeguard import check_return_type
  26. from funasr.fileio.npy_scp import NpyScpReader
  27. from funasr.fileio.rand_gen_dataset import FloatRandomGenerateDataset
  28. from funasr.fileio.rand_gen_dataset import IntRandomGenerateDataset
  29. from funasr.fileio.read_text import load_num_sequence_text
  30. from funasr.fileio.read_text import read_2column_text
  31. from funasr.fileio.sound_scp import SoundScpReader
  32. from funasr.utils.sized_dict import SizedDict
  33. class AdapterForSoundScpReader(collections.abc.Mapping):
  34. def __init__(self, loader, dtype=None):
  35. assert check_argument_types()
  36. self.loader = loader
  37. self.dtype = dtype
  38. self.rate = None
  39. def keys(self):
  40. return self.loader.keys()
  41. def __len__(self):
  42. return len(self.loader)
  43. def __iter__(self):
  44. return iter(self.loader)
  45. def __getitem__(self, key: str) -> np.ndarray:
  46. retval = self.loader[key]
  47. if isinstance(retval, tuple):
  48. assert len(retval) == 2, len(retval)
  49. if isinstance(retval[0], int) and isinstance(retval[1], np.ndarray):
  50. # sound scp case
  51. rate, array = retval
  52. elif isinstance(retval[1], int) and isinstance(retval[0], np.ndarray):
  53. # Extended ark format case
  54. array, rate = retval
  55. else:
  56. raise RuntimeError(
  57. f"Unexpected type: {type(retval[0])}, {type(retval[1])}"
  58. )
  59. if self.rate is not None and self.rate != rate:
  60. raise RuntimeError(
  61. f"Sampling rates are mismatched: {self.rate} != {rate}"
  62. )
  63. self.rate = rate
  64. # Multichannel wave fie
  65. # array: (NSample, Channel) or (Nsample)
  66. if self.dtype is not None:
  67. array = array.astype(self.dtype)
  68. else:
  69. # Normal ark case
  70. assert isinstance(retval, np.ndarray), type(retval)
  71. array = retval
  72. if self.dtype is not None:
  73. array = array.astype(self.dtype)
  74. assert isinstance(array, np.ndarray), type(array)
  75. return array
  76. class H5FileWrapper:
  77. def __init__(self, path: str):
  78. self.path = path
  79. self.h5_file = h5py.File(path, "r")
  80. def __repr__(self) -> str:
  81. return str(self.h5_file)
  82. def __len__(self) -> int:
  83. return len(self.h5_file)
  84. def __iter__(self):
  85. return iter(self.h5_file)
  86. def __getitem__(self, key) -> np.ndarray:
  87. value = self.h5_file[key]
  88. return value[()]
  89. def sound_loader(path, float_dtype=None):
  90. # The file is as follows:
  91. # utterance_id_A /some/where/a.wav
  92. # utterance_id_B /some/where/a.flac
  93. # NOTE(kamo): SoundScpReader doesn't support pipe-fashion
  94. # like Kaldi e.g. "cat a.wav |".
  95. # NOTE(kamo): The audio signal is normalized to [-1,1] range.
  96. loader = SoundScpReader(path, normalize=True, always_2d=False)
  97. # SoundScpReader.__getitem__() returns Tuple[int, ndarray],
  98. # but ndarray is desired, so Adapter class is inserted here
  99. return AdapterForSoundScpReader(loader, float_dtype)
  100. def kaldi_loader(path, float_dtype=None, max_cache_fd: int = 0):
  101. loader = kaldiio.load_scp(path, max_cache_fd=max_cache_fd)
  102. return AdapterForSoundScpReader(loader, float_dtype)
  103. def rand_int_loader(filepath, loader_type):
  104. # e.g. rand_int_3_10
  105. try:
  106. low, high = map(int, loader_type[len("rand_int_") :].split("_"))
  107. except ValueError:
  108. raise RuntimeError(f"e.g rand_int_3_10: but got {loader_type}")
  109. return IntRandomGenerateDataset(filepath, low, high)
  110. DATA_TYPES = {
  111. "sound": dict(
  112. func=sound_loader,
  113. kwargs=["float_dtype"],
  114. help="Audio format types which supported by sndfile wav, flac, etc."
  115. "\n\n"
  116. " utterance_id_a a.wav\n"
  117. " utterance_id_b b.wav\n"
  118. " ...",
  119. ),
  120. "kaldi_ark": dict(
  121. func=kaldi_loader,
  122. kwargs=["max_cache_fd"],
  123. help="Kaldi-ark file type."
  124. "\n\n"
  125. " utterance_id_A /some/where/a.ark:123\n"
  126. " utterance_id_B /some/where/a.ark:456\n"
  127. " ...",
  128. ),
  129. "npy": dict(
  130. func=NpyScpReader,
  131. kwargs=[],
  132. help="Npy file format."
  133. "\n\n"
  134. " utterance_id_A /some/where/a.npy\n"
  135. " utterance_id_B /some/where/b.npy\n"
  136. " ...",
  137. ),
  138. "text_int": dict(
  139. func=functools.partial(load_num_sequence_text, loader_type="text_int"),
  140. kwargs=[],
  141. help="A text file in which is written a sequence of interger numbers "
  142. "separated by space."
  143. "\n\n"
  144. " utterance_id_A 12 0 1 3\n"
  145. " utterance_id_B 3 3 1\n"
  146. " ...",
  147. ),
  148. "csv_int": dict(
  149. func=functools.partial(load_num_sequence_text, loader_type="csv_int"),
  150. kwargs=[],
  151. help="A text file in which is written a sequence of interger numbers "
  152. "separated by comma."
  153. "\n\n"
  154. " utterance_id_A 100,80\n"
  155. " utterance_id_B 143,80\n"
  156. " ...",
  157. ),
  158. "text_float": dict(
  159. func=functools.partial(load_num_sequence_text, loader_type="text_float"),
  160. kwargs=[],
  161. help="A text file in which is written a sequence of float numbers "
  162. "separated by space."
  163. "\n\n"
  164. " utterance_id_A 12. 3.1 3.4 4.4\n"
  165. " utterance_id_B 3. 3.12 1.1\n"
  166. " ...",
  167. ),
  168. "csv_float": dict(
  169. func=functools.partial(load_num_sequence_text, loader_type="csv_float"),
  170. kwargs=[],
  171. help="A text file in which is written a sequence of float numbers "
  172. "separated by comma."
  173. "\n\n"
  174. " utterance_id_A 12.,3.1,3.4,4.4\n"
  175. " utterance_id_B 3.,3.12,1.1\n"
  176. " ...",
  177. ),
  178. "text": dict(
  179. func=read_2column_text,
  180. kwargs=[],
  181. help="Return text as is. The text must be converted to ndarray "
  182. "by 'preprocess'."
  183. "\n\n"
  184. " utterance_id_A hello world\n"
  185. " utterance_id_B foo bar\n"
  186. " ...",
  187. ),
  188. "hdf5": dict(
  189. func=H5FileWrapper,
  190. kwargs=[],
  191. help="A HDF5 file which contains arrays at the first level or the second level."
  192. " >>> f = h5py.File('file.h5')\n"
  193. " >>> array1 = f['utterance_id_A']\n"
  194. " >>> array2 = f['utterance_id_B']\n",
  195. ),
  196. "rand_float": dict(
  197. func=FloatRandomGenerateDataset,
  198. kwargs=[],
  199. help="Generate random float-ndarray which has the given shapes "
  200. "in the file."
  201. "\n\n"
  202. " utterance_id_A 3,4\n"
  203. " utterance_id_B 10,4\n"
  204. " ...",
  205. ),
  206. "rand_int_\\d+_\\d+": dict(
  207. func=rand_int_loader,
  208. kwargs=["loader_type"],
  209. help="e.g. 'rand_int_0_10'. Generate random int-ndarray which has the given "
  210. "shapes in the path. "
  211. "Give the lower and upper value by the file type. e.g. "
  212. "rand_int_0_10 -> Generate integers from 0 to 10."
  213. "\n\n"
  214. " utterance_id_A 3,4\n"
  215. " utterance_id_B 10,4\n"
  216. " ...",
  217. ),
  218. }
  219. class AbsDataset(Dataset, ABC):
  220. @abstractmethod
  221. def has_name(self, name) -> bool:
  222. raise NotImplementedError
  223. @abstractmethod
  224. def names(self) -> Tuple[str, ...]:
  225. raise NotImplementedError
  226. @abstractmethod
  227. def __getitem__(self, uid) -> Tuple[Any, Dict[str, np.ndarray]]:
  228. raise NotImplementedError
  229. class ESPnetDataset(AbsDataset):
  230. """Pytorch Dataset class for ESPNet.
  231. Examples:
  232. >>> dataset = ESPnetDataset([('wav.scp', 'input', 'sound'),
  233. ... ('token_int', 'output', 'text_int')],
  234. ... )
  235. ... uttid, data = dataset['uttid']
  236. {'input': per_utt_array, 'output': per_utt_array}
  237. """
  238. def __init__(
  239. self,
  240. path_name_type_list: Collection[Tuple[str, str, str]],
  241. preprocess: Callable[
  242. [str, Dict[str, np.ndarray]], Dict[str, np.ndarray]
  243. ] = None,
  244. float_dtype: str = "float32",
  245. int_dtype: str = "long",
  246. max_cache_size: Union[float, int, str] = 0.0,
  247. max_cache_fd: int = 0,
  248. ):
  249. assert check_argument_types()
  250. if len(path_name_type_list) == 0:
  251. raise ValueError(
  252. '1 or more elements are required for "path_name_type_list"'
  253. )
  254. path_name_type_list = copy.deepcopy(path_name_type_list)
  255. self.preprocess = preprocess
  256. self.float_dtype = float_dtype
  257. self.int_dtype = int_dtype
  258. self.max_cache_fd = max_cache_fd
  259. self.loader_dict = {}
  260. self.debug_info = {}
  261. for path, name, _type in path_name_type_list:
  262. if name in self.loader_dict:
  263. raise RuntimeError(f'"{name}" is duplicated for data-key')
  264. loader = self._build_loader(path, _type)
  265. self.loader_dict[name] = loader
  266. self.debug_info[name] = path, _type
  267. if len(self.loader_dict[name]) == 0:
  268. raise RuntimeError(f"{path} has no samples")
  269. # TODO(kamo): Should check consistency of each utt-keys?
  270. if isinstance(max_cache_size, str):
  271. max_cache_size = humanfriendly.parse_size(max_cache_size)
  272. self.max_cache_size = max_cache_size
  273. if max_cache_size > 0:
  274. self.cache = SizedDict(shared=True)
  275. else:
  276. self.cache = None
  277. def _build_loader(
  278. self, path: str, loader_type: str
  279. ) -> Mapping[str, Union[np.ndarray, torch.Tensor, str, numbers.Number]]:
  280. """Helper function to instantiate Loader.
  281. Args:
  282. path: The file path
  283. loader_type: loader_type. sound, npy, text_int, text_float, etc
  284. """
  285. for key, dic in DATA_TYPES.items():
  286. # e.g. loader_type="sound"
  287. # -> return DATA_TYPES["sound"]["func"](path)
  288. if re.match(key, loader_type):
  289. kwargs = {}
  290. for key2 in dic["kwargs"]:
  291. if key2 == "loader_type":
  292. kwargs["loader_type"] = loader_type
  293. elif key2 == "float_dtype":
  294. kwargs["float_dtype"] = self.float_dtype
  295. elif key2 == "int_dtype":
  296. kwargs["int_dtype"] = self.int_dtype
  297. elif key2 == "max_cache_fd":
  298. kwargs["max_cache_fd"] = self.max_cache_fd
  299. else:
  300. raise RuntimeError(f"Not implemented keyword argument: {key2}")
  301. func = dic["func"]
  302. try:
  303. return func(path, **kwargs)
  304. except Exception:
  305. if hasattr(func, "__name__"):
  306. name = func.__name__
  307. else:
  308. name = str(func)
  309. logging.error(f"An error happened with {name}({path})")
  310. raise
  311. else:
  312. raise RuntimeError(f"Not supported: loader_type={loader_type}")
  313. def has_name(self, name) -> bool:
  314. return name in self.loader_dict
  315. def names(self) -> Tuple[str, ...]:
  316. return tuple(self.loader_dict)
  317. def __iter__(self):
  318. return iter(next(iter(self.loader_dict.values())))
  319. def __repr__(self):
  320. _mes = self.__class__.__name__
  321. _mes += "("
  322. for name, (path, _type) in self.debug_info.items():
  323. _mes += f'\n {name}: {{"path": "{path}", "type": "{_type}"}}'
  324. _mes += f"\n preprocess: {self.preprocess})"
  325. return _mes
  326. def __getitem__(self, uid: Union[str, int]) -> Tuple[str, Dict[str, np.ndarray]]:
  327. assert check_argument_types()
  328. # Change integer-id to string-id
  329. if isinstance(uid, int):
  330. d = next(iter(self.loader_dict.values()))
  331. uid = list(d)[uid]
  332. if self.cache is not None and uid in self.cache:
  333. data = self.cache[uid]
  334. return uid, data
  335. data = {}
  336. # 1. Load data from each loaders
  337. for name, loader in self.loader_dict.items():
  338. try:
  339. value = loader[uid]
  340. if isinstance(value, (list, tuple)):
  341. value = np.array(value)
  342. if not isinstance(
  343. value, (np.ndarray, torch.Tensor, str, numbers.Number)
  344. ):
  345. raise TypeError(
  346. f"Must be ndarray, torch.Tensor, str or Number: {type(value)}"
  347. )
  348. except Exception:
  349. path, _type = self.debug_info[name]
  350. logging.error(
  351. f"Error happened with path={path}, type={_type}, id={uid}"
  352. )
  353. raise
  354. # torch.Tensor is converted to ndarray
  355. if isinstance(value, torch.Tensor):
  356. value = value.numpy()
  357. elif isinstance(value, numbers.Number):
  358. value = np.array([value])
  359. data[name] = value
  360. # 2. [Option] Apply preprocessing
  361. # e.g. funasr.train.preprocessor:CommonPreprocessor
  362. if self.preprocess is not None:
  363. data = self.preprocess(uid, data)
  364. # 3. Force data-precision
  365. for name in data:
  366. value = data[name]
  367. if not isinstance(value, np.ndarray):
  368. raise RuntimeError(
  369. f"All values must be converted to np.ndarray object "
  370. f'by preprocessing, but "{name}" is still {type(value)}.'
  371. )
  372. # Cast to desired type
  373. if value.dtype.kind == "f":
  374. value = value.astype(self.float_dtype)
  375. elif value.dtype.kind == "i":
  376. value = value.astype(self.int_dtype)
  377. else:
  378. raise NotImplementedError(f"Not supported dtype: {value.dtype}")
  379. data[name] = value
  380. if self.cache is not None and self.cache.size < self.max_cache_size:
  381. self.cache[uid] = data
  382. retval = uid, data
  383. assert check_return_type(retval)
  384. return retval