dataset.py 9.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269
  1. # Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved.
  2. # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
  3. import collections
  4. import copy
  5. import logging
  6. import numbers
  7. from typing import Callable
  8. from typing import Collection
  9. from typing import Dict
  10. from typing import Mapping
  11. from typing import Union, List, Tuple
  12. import kaldiio
  13. import numpy as np
  14. import torch
  15. from torch.utils.data.dataset import Dataset
  16. from typeguard import check_argument_types
  17. from typeguard import check_return_type
  18. from funasr.fileio.npy_scp import NpyScpReader
  19. from funasr.fileio.sound_scp import SoundScpReader
  20. class AdapterForSoundScpReader(collections.abc.Mapping):
  21. def __init__(self, loader, dtype=None):
  22. assert check_argument_types()
  23. self.loader = loader
  24. self.dtype = dtype
  25. self.rate = None
  26. def keys(self):
  27. return self.loader.keys()
  28. def __len__(self):
  29. return len(self.loader)
  30. def __iter__(self):
  31. return iter(self.loader)
  32. def __getitem__(self, key: str) -> np.ndarray:
  33. retval = self.loader[key]
  34. if isinstance(retval, tuple):
  35. assert len(retval) == 2, len(retval)
  36. if isinstance(retval[0], int) and isinstance(retval[1], np.ndarray):
  37. # sound scp case
  38. rate, array = retval
  39. elif isinstance(retval[1], int) and isinstance(retval[0], np.ndarray):
  40. # Extended ark format case
  41. array, rate = retval
  42. else:
  43. raise RuntimeError(
  44. f"Unexpected type: {type(retval[0])}, {type(retval[1])}"
  45. )
  46. if self.rate is not None and self.rate != rate:
  47. raise RuntimeError(
  48. f"Sampling rates are mismatched: {self.rate} != {rate}"
  49. )
  50. self.rate = rate
  51. # Multichannel wave fie
  52. # array: (NSample, Channel) or (Nsample)
  53. if self.dtype is not None:
  54. array = array.astype(self.dtype)
  55. else:
  56. # Normal ark case
  57. assert isinstance(retval, np.ndarray), type(retval)
  58. array = retval
  59. if self.dtype is not None:
  60. array = array.astype(self.dtype)
  61. assert isinstance(array, np.ndarray), type(array)
  62. return array
  63. def sound_loader(path, dest_sample_rate=16000, float_dtype=None):
  64. # The file is as follows:
  65. # utterance_id_A /some/where/a.wav
  66. # utterance_id_B /some/where/a.flac
  67. # NOTE(kamo): SoundScpReader doesn't support pipe-fashion
  68. # like Kaldi e.g. "cat a.wav |".
  69. # NOTE(kamo): The audio signal is normalized to [-1,1] range.
  70. loader = SoundScpReader(path, dest_sample_rate, normalize=True, always_2d=False)
  71. # SoundScpReader.__getitem__() returns Tuple[int, ndarray],
  72. # but ndarray is desired, so Adapter class is inserted here
  73. return AdapterForSoundScpReader(loader, float_dtype)
  74. def kaldi_loader(path, float_dtype=None, max_cache_fd: int = 0):
  75. loader = kaldiio.load_scp(path, max_cache_fd=max_cache_fd)
  76. return AdapterForSoundScpReader(loader, float_dtype)
  77. class ESPnetDataset(Dataset):
  78. """
  79. Pytorch Dataset class for FunASR, modified from ESPnet
  80. """
  81. def __init__(
  82. self,
  83. path_name_type_list: Collection[Tuple[str, str, str]],
  84. preprocess: Callable[
  85. [str, Dict[str, np.ndarray]], Dict[str, np.ndarray]
  86. ] = None,
  87. float_dtype: str = "float32",
  88. int_dtype: str = "long",
  89. dest_sample_rate: int = 16000,
  90. speed_perturb: Union[list, tuple] = None,
  91. mode: str = "train",
  92. ):
  93. assert check_argument_types()
  94. if len(path_name_type_list) == 0:
  95. raise ValueError(
  96. '1 or more elements are required for "path_name_type_list"'
  97. )
  98. path_name_type_list = copy.deepcopy(path_name_type_list)
  99. self.preprocess = preprocess
  100. self.float_dtype = float_dtype
  101. self.int_dtype = int_dtype
  102. self.dest_sample_rate = dest_sample_rate
  103. self.speed_perturb = speed_perturb
  104. self.mode = mode
  105. if self.speed_perturb is not None:
  106. logging.info("Using speed_perturb: {}".format(speed_perturb))
  107. self.loader_dict = {}
  108. self.debug_info = {}
  109. for path, name, _type in path_name_type_list:
  110. if name in self.loader_dict:
  111. raise RuntimeError(f'"{name}" is duplicated for data-key')
  112. loader = self._build_loader(path, _type)
  113. self.loader_dict[name] = loader
  114. self.debug_info[name] = path, _type
  115. if len(self.loader_dict[name]) == 0:
  116. raise RuntimeError(f"{path} has no samples")
  117. def _build_loader(
  118. self, path: str, loader_type: str
  119. ) -> Mapping[str, Union[np.ndarray, torch.Tensor, str, List[int], numbers.Number]]:
  120. """Helper function to instantiate Loader.
  121. Args:
  122. path: The file path
  123. loader_type: loader_type. sound, npy, text, etc
  124. """
  125. if loader_type == "sound":
  126. speed_perturb = self.speed_perturb if self.mode == "train" else None
  127. loader = SoundScpReader(path, self.dest_sample_rate, normalize=True, always_2d=False,
  128. speed_perturb=speed_perturb)
  129. return AdapterForSoundScpReader(loader, self.float_dtype)
  130. elif loader_type == "kaldi_ark":
  131. loader = kaldiio.load_scp(path)
  132. return AdapterForSoundScpReader(loader, self.float_dtype)
  133. elif loader_type == "npy":
  134. return NpyScpReader(path)
  135. elif loader_type == "text":
  136. text_loader = {}
  137. with open(path, "r", encoding="utf-8") as f:
  138. for linenum, line in enumerate(f, 1):
  139. sps = line.rstrip().split(maxsplit=1)
  140. if len(sps) == 1:
  141. k, v = sps[0], ""
  142. else:
  143. k, v = sps
  144. if k in text_loader:
  145. raise RuntimeError(f"{k} is duplicated ({path}:{linenum})")
  146. text_loader[k] = v
  147. return text_loader
  148. elif loader_type == "text_int":
  149. text_int_loader = {}
  150. with open(path, "r", encoding="utf-8") as f:
  151. for linenum, line in enumerate(f, 1):
  152. sps = line.rstrip().split(maxsplit=1)
  153. if len(sps) == 1:
  154. k, v = sps[0], ""
  155. else:
  156. k, v = sps
  157. if k in text_int_loader:
  158. raise RuntimeError(f"{k} is duplicated ({path}:{linenum})")
  159. text_int_loader[k] = [int(i) for i in v.split()]
  160. return text_int_loader
  161. else:
  162. raise RuntimeError(f"Not supported: loader_type={loader_type}")
  163. def has_name(self, name) -> bool:
  164. return name in self.loader_dict
  165. def names(self) -> Tuple[str, ...]:
  166. return tuple(self.loader_dict)
  167. def __iter__(self):
  168. return iter(next(iter(self.loader_dict.values())))
  169. def __repr__(self):
  170. _mes = self.__class__.__name__
  171. _mes += "("
  172. for name, (path, _type) in self.debug_info.items():
  173. _mes += f'\n {name}: {{"path": "{path}", "type": "{_type}"}}'
  174. _mes += f"\n preprocess: {self.preprocess})"
  175. return _mes
  176. def __getitem__(self, uid: Union[str, int]) -> Tuple[str, Dict[str, np.ndarray]]:
  177. assert check_argument_types()
  178. # Change integer-id to string-id
  179. if isinstance(uid, int):
  180. d = next(iter(self.loader_dict.values()))
  181. uid = list(d)[uid]
  182. data = {}
  183. # 1. Load data from each loaders
  184. for name, loader in self.loader_dict.items():
  185. try:
  186. value = loader[uid]
  187. if isinstance(value, (list, tuple)):
  188. value = np.array(value)
  189. if not isinstance(
  190. value, (np.ndarray, torch.Tensor, str, numbers.Number)
  191. ):
  192. raise TypeError(
  193. f"Must be ndarray, torch.Tensor, str or Number: {type(value)}"
  194. )
  195. except Exception:
  196. path, _type = self.debug_info[name]
  197. logging.error(
  198. f"Error happened with path={path}, type={_type}, id={uid}"
  199. )
  200. raise
  201. # torch.Tensor is converted to ndarray
  202. if isinstance(value, torch.Tensor):
  203. value = value.numpy()
  204. elif isinstance(value, numbers.Number):
  205. value = np.array([value])
  206. data[name] = value
  207. # 2. [Option] Apply preprocessing
  208. # e.g. funasr.train.preprocessor:CommonPreprocessor
  209. if self.preprocess is not None:
  210. data = self.preprocess(uid, data)
  211. # 3. Force data-precision
  212. for name in data:
  213. value = data[name]
  214. if not isinstance(value, np.ndarray):
  215. raise RuntimeError(
  216. f"All values must be converted to np.ndarray object "
  217. f'by preprocessing, but "{name}" is still {type(value)}.'
  218. )
  219. # Cast to desired type
  220. if value.dtype.kind == "f":
  221. value = value.astype(self.float_dtype)
  222. elif value.dtype.kind == "i":
  223. value = value.astype(self.int_dtype)
  224. else:
  225. raise NotImplementedError(f"Not supported dtype: {value.dtype}")
  226. data[name] = value
  227. retval = uid, data
  228. assert check_return_type(retval)
  229. return retval