dataset.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442
  1. # Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved.
  2. # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
  3. from abc import ABC
  4. from abc import abstractmethod
  5. import collections
  6. import copy
  7. import functools
  8. import logging
  9. import numbers
  10. import re
  11. from typing import Any
  12. from typing import Callable
  13. from typing import Collection
  14. from typing import Dict
  15. from typing import Mapping
  16. from typing import Tuple
  17. from typing import Union
  18. import h5py
  19. import humanfriendly
  20. import kaldiio
  21. import numpy as np
  22. import torch
  23. from torch.utils.data.dataset import Dataset
  24. from funasr.fileio.npy_scp import NpyScpReader
  25. from funasr.fileio.rand_gen_dataset import FloatRandomGenerateDataset
  26. from funasr.fileio.rand_gen_dataset import IntRandomGenerateDataset
  27. from funasr.fileio.read_text import load_num_sequence_text
  28. from funasr.fileio.read_text import read_2column_text
  29. from funasr.fileio.sound_scp import SoundScpReader
  30. from funasr.utils.sized_dict import SizedDict
  31. class AdapterForSoundScpReader(collections.abc.Mapping):
  32. def __init__(self, loader, dtype=None):
  33. self.loader = loader
  34. self.dtype = dtype
  35. self.rate = None
  36. def keys(self):
  37. return self.loader.keys()
  38. def __len__(self):
  39. return len(self.loader)
  40. def __iter__(self):
  41. return iter(self.loader)
  42. def __getitem__(self, key: str) -> np.ndarray:
  43. retval = self.loader[key]
  44. if isinstance(retval, tuple):
  45. assert len(retval) == 2, len(retval)
  46. if isinstance(retval[0], int) and isinstance(retval[1], np.ndarray):
  47. # sound scp case
  48. rate, array = retval
  49. elif isinstance(retval[1], int) and isinstance(retval[0], np.ndarray):
  50. # Extended ark format case
  51. array, rate = retval
  52. else:
  53. raise RuntimeError(
  54. f"Unexpected type: {type(retval[0])}, {type(retval[1])}"
  55. )
  56. if self.rate is not None and self.rate != rate:
  57. raise RuntimeError(
  58. f"Sampling rates are mismatched: {self.rate} != {rate}"
  59. )
  60. self.rate = rate
  61. # Multichannel wave fie
  62. # array: (NSample, Channel) or (Nsample)
  63. if self.dtype is not None:
  64. array = array.astype(self.dtype)
  65. else:
  66. # Normal ark case
  67. assert isinstance(retval, np.ndarray), type(retval)
  68. array = retval
  69. if self.dtype is not None:
  70. array = array.astype(self.dtype)
  71. assert isinstance(array, np.ndarray), type(array)
  72. return array
  73. class H5FileWrapper:
  74. def __init__(self, path: str):
  75. self.path = path
  76. self.h5_file = h5py.File(path, "r")
  77. def __repr__(self) -> str:
  78. return str(self.h5_file)
  79. def __len__(self) -> int:
  80. return len(self.h5_file)
  81. def __iter__(self):
  82. return iter(self.h5_file)
  83. def __getitem__(self, key) -> np.ndarray:
  84. value = self.h5_file[key]
  85. return value[()]
  86. def sound_loader(path, dest_sample_rate=16000, float_dtype=None):
  87. # The file is as follows:
  88. # utterance_id_A /some/where/a.wav
  89. # utterance_id_B /some/where/a.flac
  90. # NOTE(kamo): SoundScpReader doesn't support pipe-fashion
  91. # like Kaldi e.g. "cat a.wav |".
  92. # NOTE(kamo): The audio signal is normalized to [-1,1] range.
  93. loader = SoundScpReader(path, normalize=True, always_2d=False, dest_sample_rate = dest_sample_rate)
  94. # SoundScpReader.__getitem__() returns Tuple[int, ndarray],
  95. # but ndarray is desired, so Adapter class is inserted here
  96. return AdapterForSoundScpReader(loader, float_dtype)
  97. def kaldi_loader(path, float_dtype=None, max_cache_fd: int = 0):
  98. loader = kaldiio.load_scp(path, max_cache_fd=max_cache_fd)
  99. return AdapterForSoundScpReader(loader, float_dtype)
  100. def rand_int_loader(filepath, loader_type):
  101. # e.g. rand_int_3_10
  102. try:
  103. low, high = map(int, loader_type[len("rand_int_") :].split("_"))
  104. except ValueError:
  105. raise RuntimeError(f"e.g rand_int_3_10: but got {loader_type}")
  106. return IntRandomGenerateDataset(filepath, low, high)
  107. DATA_TYPES = {
  108. "sound": dict(
  109. func=sound_loader,
  110. kwargs=["dest_sample_rate","float_dtype"],
  111. help="Audio format types which supported by sndfile wav, flac, etc."
  112. "\n\n"
  113. " utterance_id_a a.wav\n"
  114. " utterance_id_b b.wav\n"
  115. " ...",
  116. ),
  117. "kaldi_ark": dict(
  118. func=kaldi_loader,
  119. kwargs=["max_cache_fd"],
  120. help="Kaldi-ark file type."
  121. "\n\n"
  122. " utterance_id_A /some/where/a.ark:123\n"
  123. " utterance_id_B /some/where/a.ark:456\n"
  124. " ...",
  125. ),
  126. "npy": dict(
  127. func=NpyScpReader,
  128. kwargs=[],
  129. help="Npy file format."
  130. "\n\n"
  131. " utterance_id_A /some/where/a.npy\n"
  132. " utterance_id_B /some/where/b.npy\n"
  133. " ...",
  134. ),
  135. "text_int": dict(
  136. func=functools.partial(load_num_sequence_text, loader_type="text_int"),
  137. kwargs=[],
  138. help="A text file in which is written a sequence of interger numbers "
  139. "separated by space."
  140. "\n\n"
  141. " utterance_id_A 12 0 1 3\n"
  142. " utterance_id_B 3 3 1\n"
  143. " ...",
  144. ),
  145. "csv_int": dict(
  146. func=functools.partial(load_num_sequence_text, loader_type="csv_int"),
  147. kwargs=[],
  148. help="A text file in which is written a sequence of interger numbers "
  149. "separated by comma."
  150. "\n\n"
  151. " utterance_id_A 100,80\n"
  152. " utterance_id_B 143,80\n"
  153. " ...",
  154. ),
  155. "text_float": dict(
  156. func=functools.partial(load_num_sequence_text, loader_type="text_float"),
  157. kwargs=[],
  158. help="A text file in which is written a sequence of float numbers "
  159. "separated by space."
  160. "\n\n"
  161. " utterance_id_A 12. 3.1 3.4 4.4\n"
  162. " utterance_id_B 3. 3.12 1.1\n"
  163. " ...",
  164. ),
  165. "csv_float": dict(
  166. func=functools.partial(load_num_sequence_text, loader_type="csv_float"),
  167. kwargs=[],
  168. help="A text file in which is written a sequence of float numbers "
  169. "separated by comma."
  170. "\n\n"
  171. " utterance_id_A 12.,3.1,3.4,4.4\n"
  172. " utterance_id_B 3.,3.12,1.1\n"
  173. " ...",
  174. ),
  175. "text": dict(
  176. func=read_2column_text,
  177. kwargs=[],
  178. help="Return text as is. The text must be converted to ndarray "
  179. "by 'preprocess'."
  180. "\n\n"
  181. " utterance_id_A hello world\n"
  182. " utterance_id_B foo bar\n"
  183. " ...",
  184. ),
  185. "hdf5": dict(
  186. func=H5FileWrapper,
  187. kwargs=[],
  188. help="A HDF5 file which contains arrays at the first level or the second level."
  189. " >>> f = h5py.File('file.h5')\n"
  190. " >>> array1 = f['utterance_id_A']\n"
  191. " >>> array2 = f['utterance_id_B']\n",
  192. ),
  193. "rand_float": dict(
  194. func=FloatRandomGenerateDataset,
  195. kwargs=[],
  196. help="Generate random float-ndarray which has the given shapes "
  197. "in the file."
  198. "\n\n"
  199. " utterance_id_A 3,4\n"
  200. " utterance_id_B 10,4\n"
  201. " ...",
  202. ),
  203. "rand_int_\\d+_\\d+": dict(
  204. func=rand_int_loader,
  205. kwargs=["loader_type"],
  206. help="e.g. 'rand_int_0_10'. Generate random int-ndarray which has the given "
  207. "shapes in the path. "
  208. "Give the lower and upper value by the file type. e.g. "
  209. "rand_int_0_10 -> Generate integers from 0 to 10."
  210. "\n\n"
  211. " utterance_id_A 3,4\n"
  212. " utterance_id_B 10,4\n"
  213. " ...",
  214. ),
  215. }
  216. class AbsDataset(Dataset, ABC):
  217. @abstractmethod
  218. def has_name(self, name) -> bool:
  219. raise NotImplementedError
  220. @abstractmethod
  221. def names(self) -> Tuple[str, ...]:
  222. raise NotImplementedError
  223. @abstractmethod
  224. def __getitem__(self, uid) -> Tuple[Any, Dict[str, np.ndarray]]:
  225. raise NotImplementedError
  226. class ESPnetDataset(AbsDataset):
  227. """Pytorch Dataset class for ESPNet.
  228. Examples:
  229. >>> dataset = ESPnetDataset([('wav.scp', 'input', 'sound'),
  230. ... ('token_int', 'output', 'text_int')],
  231. ... )
  232. ... uttid, data = dataset['uttid']
  233. {'input': per_utt_array, 'output': per_utt_array}
  234. """
  235. def __init__(
  236. self,
  237. path_name_type_list: Collection[Tuple[str, str, str]],
  238. preprocess: Callable[
  239. [str, Dict[str, np.ndarray]], Dict[str, np.ndarray]
  240. ] = None,
  241. float_dtype: str = "float32",
  242. int_dtype: str = "long",
  243. max_cache_size: Union[float, int, str] = 0.0,
  244. max_cache_fd: int = 0,
  245. dest_sample_rate: int = 16000,
  246. ):
  247. if len(path_name_type_list) == 0:
  248. raise ValueError(
  249. '1 or more elements are required for "path_name_type_list"'
  250. )
  251. path_name_type_list = copy.deepcopy(path_name_type_list)
  252. self.preprocess = preprocess
  253. self.float_dtype = float_dtype
  254. self.int_dtype = int_dtype
  255. self.max_cache_fd = max_cache_fd
  256. self.dest_sample_rate = dest_sample_rate
  257. self.loader_dict = {}
  258. self.debug_info = {}
  259. for path, name, _type in path_name_type_list:
  260. if name in self.loader_dict:
  261. raise RuntimeError(f'"{name}" is duplicated for data-key')
  262. loader = self._build_loader(path, _type)
  263. self.loader_dict[name] = loader
  264. self.debug_info[name] = path, _type
  265. if len(self.loader_dict[name]) == 0:
  266. raise RuntimeError(f"{path} has no samples")
  267. # TODO(kamo): Should check consistency of each utt-keys?
  268. if isinstance(max_cache_size, str):
  269. max_cache_size = humanfriendly.parse_size(max_cache_size)
  270. self.max_cache_size = max_cache_size
  271. if max_cache_size > 0:
  272. self.cache = SizedDict(shared=True)
  273. else:
  274. self.cache = None
  275. def _build_loader(
  276. self, path: str, loader_type: str
  277. ) -> Mapping[str, Union[np.ndarray, torch.Tensor, str, numbers.Number]]:
  278. """Helper function to instantiate Loader.
  279. Args:
  280. path: The file path
  281. loader_type: loader_type. sound, npy, text_int, text_float, etc
  282. """
  283. for key, dic in DATA_TYPES.items():
  284. # e.g. loader_type="sound"
  285. # -> return DATA_TYPES["sound"]["func"](path)
  286. if re.match(key, loader_type):
  287. kwargs = {}
  288. for key2 in dic["kwargs"]:
  289. if key2 == "loader_type":
  290. kwargs["loader_type"] = loader_type
  291. elif key2 == "dest_sample_rate" and loader_type=="sound":
  292. kwargs["dest_sample_rate"] = self.dest_sample_rate
  293. elif key2 == "float_dtype":
  294. kwargs["float_dtype"] = self.float_dtype
  295. elif key2 == "int_dtype":
  296. kwargs["int_dtype"] = self.int_dtype
  297. elif key2 == "max_cache_fd":
  298. kwargs["max_cache_fd"] = self.max_cache_fd
  299. else:
  300. raise RuntimeError(f"Not implemented keyword argument: {key2}")
  301. func = dic["func"]
  302. try:
  303. return func(path, **kwargs)
  304. except Exception:
  305. if hasattr(func, "__name__"):
  306. name = func.__name__
  307. else:
  308. name = str(func)
  309. logging.error(f"An error happened with {name}({path})")
  310. raise
  311. else:
  312. raise RuntimeError(f"Not supported: loader_type={loader_type}")
  313. def has_name(self, name) -> bool:
  314. return name in self.loader_dict
  315. def names(self) -> Tuple[str, ...]:
  316. return tuple(self.loader_dict)
  317. def __iter__(self):
  318. return iter(next(iter(self.loader_dict.values())))
  319. def __repr__(self):
  320. _mes = self.__class__.__name__
  321. _mes += "("
  322. for name, (path, _type) in self.debug_info.items():
  323. _mes += f'\n {name}: {{"path": "{path}", "type": "{_type}"}}'
  324. _mes += f"\n preprocess: {self.preprocess})"
  325. return _mes
  326. def __getitem__(self, uid: Union[str, int]) -> Tuple[str, Dict[str, np.ndarray]]:
  327. # Change integer-id to string-id
  328. if isinstance(uid, int):
  329. d = next(iter(self.loader_dict.values()))
  330. uid = list(d)[uid]
  331. if self.cache is not None and uid in self.cache:
  332. data = self.cache[uid]
  333. return uid, data
  334. data = {}
  335. # 1. Load data from each loaders
  336. for name, loader in self.loader_dict.items():
  337. try:
  338. value = loader[uid]
  339. if isinstance(value, (list, tuple)):
  340. value = np.array(value)
  341. if not isinstance(
  342. value, (np.ndarray, torch.Tensor, str, numbers.Number)
  343. ):
  344. raise TypeError(
  345. f"Must be ndarray, torch.Tensor, str or Number: {type(value)}"
  346. )
  347. except Exception:
  348. path, _type = self.debug_info[name]
  349. logging.error(
  350. f"Error happened with path={path}, type={_type}, id={uid}"
  351. )
  352. raise
  353. # torch.Tensor is converted to ndarray
  354. if isinstance(value, torch.Tensor):
  355. value = value.numpy()
  356. elif isinstance(value, numbers.Number):
  357. value = np.array([value])
  358. data[name] = value
  359. # 2. [Option] Apply preprocessing
  360. # e.g. funasr.train.preprocessor:CommonPreprocessor
  361. if self.preprocess is not None:
  362. data = self.preprocess(uid, data)
  363. # 3. Force data-precision
  364. for name in data:
  365. value = data[name]
  366. if not isinstance(value, np.ndarray):
  367. raise RuntimeError(
  368. f"All values must be converted to np.ndarray object "
  369. f'by preprocessing, but "{name}" is still {type(value)}.'
  370. )
  371. # Cast to desired type
  372. if value.dtype.kind == "f":
  373. value = value.astype(self.float_dtype)
  374. elif value.dtype.kind == "i":
  375. value = value.astype(self.int_dtype)
  376. else:
  377. raise NotImplementedError(f"Not supported dtype: {value.dtype}")
  378. data[name] = value
  379. if self.cache is not None and self.cache.size < self.max_cache_size:
  380. self.cache[uid] = data
  381. retval = uid, data
  382. return retval