iterable_dataset_modelscope.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. # Part of the implementation is borrowed from espnet/espnet.
  3. """Iterable dataset module."""
  4. import copy
  5. from io import StringIO
  6. from pathlib import Path
  7. from typing import Callable, Collection, Dict, Iterator, Tuple, Union
  8. import kaldiio
  9. import numpy as np
  10. import soundfile
  11. import torch
  12. from funasr.datasets.dataset import ESPnetDataset
  13. from torch.utils.data.dataset import IterableDataset
  14. from typeguard import check_argument_types
  15. from funasr.utils import wav_utils
  16. def load_kaldi(input):
  17. retval = kaldiio.load_mat(input)
  18. if isinstance(retval, tuple):
  19. assert len(retval) == 2, len(retval)
  20. if isinstance(retval[0], int) and isinstance(retval[1], np.ndarray):
  21. # sound scp case
  22. rate, array = retval
  23. elif isinstance(retval[1], int) and isinstance(retval[0], np.ndarray):
  24. # Extended ark format case
  25. array, rate = retval
  26. else:
  27. raise RuntimeError(
  28. f'Unexpected type: {type(retval[0])}, {type(retval[1])}')
  29. # Multichannel wave fie
  30. # array: (NSample, Channel) or (Nsample)
  31. else:
  32. # Normal ark case
  33. assert isinstance(retval, np.ndarray), type(retval)
  34. array = retval
  35. return array
  36. DATA_TYPES = {
  37. 'sound':
  38. lambda x: soundfile.read(x)[0],
  39. 'kaldi_ark':
  40. load_kaldi,
  41. 'npy':
  42. np.load,
  43. 'text_int':
  44. lambda x: np.loadtxt(StringIO(x), ndmin=1, dtype=np.long, delimiter=' '),
  45. 'csv_int':
  46. lambda x: np.loadtxt(StringIO(x), ndmin=1, dtype=np.long, delimiter=','),
  47. 'text_float':
  48. lambda x: np.loadtxt(StringIO(x), ndmin=1, dtype=np.float32, delimiter=' '
  49. ),
  50. 'csv_float':
  51. lambda x: np.loadtxt(StringIO(x), ndmin=1, dtype=np.float32, delimiter=','
  52. ),
  53. 'text':
  54. lambda x: x,
  55. }
  56. class IterableESPnetDatasetModelScope(IterableDataset):
  57. """Pytorch Dataset class for ESPNet.
  58. Examples:
  59. >>> dataset = IterableESPnetDataset([('wav.scp', 'input', 'sound'),
  60. ... ('token_int', 'output', 'text_int')],
  61. ... )
  62. >>> for uid, data in dataset:
  63. ... data
  64. {'input': per_utt_array, 'output': per_utt_array}
  65. """
  66. def __init__(self,
  67. path_name_type_list: Collection[Tuple[any, str, str]],
  68. preprocess: Callable[[str, Dict[str, np.ndarray]],
  69. Dict[str, np.ndarray]] = None,
  70. float_dtype: str = 'float32',
  71. int_dtype: str = 'long',
  72. key_file: str = None,
  73. sample_rate: Union[dict, int] = 16000):
  74. assert check_argument_types()
  75. if len(path_name_type_list) == 0:
  76. raise ValueError(
  77. '1 or more elements are required for "path_name_type_list"')
  78. self.preprocess = preprocess
  79. self.float_dtype = float_dtype
  80. self.int_dtype = int_dtype
  81. self.key_file = key_file
  82. self.sample_rate = sample_rate
  83. self.debug_info = {}
  84. non_iterable_list = []
  85. self.path_name_type_list = []
  86. path_list = path_name_type_list[0]
  87. name = path_name_type_list[1]
  88. _type = path_name_type_list[2]
  89. if name in self.debug_info:
  90. raise RuntimeError(f'"{name}" is duplicated for data-key')
  91. self.debug_info[name] = path_list, _type
  92. # for path, name, _type in path_name_type_list:
  93. for path in path_list:
  94. self.path_name_type_list.append((path, name, _type))
  95. if len(non_iterable_list) != 0:
  96. # Some types doesn't support iterable mode
  97. self.non_iterable_dataset = ESPnetDataset(
  98. path_name_type_list=non_iterable_list,
  99. preprocess=preprocess,
  100. float_dtype=float_dtype,
  101. int_dtype=int_dtype,
  102. )
  103. else:
  104. self.non_iterable_dataset = None
  105. self.apply_utt2category = False
  106. def has_name(self, name) -> bool:
  107. return name in self.debug_info
  108. def names(self) -> Tuple[str, ...]:
  109. return tuple(self.debug_info)
  110. def __repr__(self):
  111. _mes = self.__class__.__name__
  112. _mes += '('
  113. for name, (path, _type) in self.debug_info.items():
  114. _mes += f'\n {name}: {{"path": "{path}", "type": "{_type}"}}'
  115. _mes += f'\n preprocess: {self.preprocess})'
  116. return _mes
  117. def __iter__(
  118. self) -> Iterator[Tuple[Union[str, int], Dict[str, np.ndarray]]]:
  119. torch.set_printoptions(profile='default')
  120. count = len(self.path_name_type_list)
  121. for idx in range(count):
  122. # 2. Load the entry from each line and create a dict
  123. data = {}
  124. # 2.a. Load data streamingly
  125. # value: /home/fsc/code/MaaS/MaaS-lib-nls-asr/data/test/audios/asr_example.wav
  126. value = self.path_name_type_list[idx][0]['file']
  127. uid = self.path_name_type_list[idx][0]['key']
  128. # name: speech
  129. name = self.path_name_type_list[idx][1]
  130. _type = self.path_name_type_list[idx][2]
  131. func = DATA_TYPES[_type]
  132. array = func(value)
  133. # 2.b. audio resample
  134. if _type == 'sound':
  135. audio_sr: int = 16000
  136. model_sr: int = 16000
  137. if isinstance(self.sample_rate, int):
  138. model_sr = self.sample_rate
  139. else:
  140. if 'audio_sr' in self.sample_rate:
  141. audio_sr = self.sample_rate['audio_sr']
  142. if 'model_sr' in self.sample_rate:
  143. model_sr = self.sample_rate['model_sr']
  144. array = wav_utils.torch_resample(array, audio_sr, model_sr)
  145. # array: [ 1.25122070e-03 ... ]
  146. data[name] = array
  147. # 3. [Option] Apply preprocessing
  148. # e.g. espnet2.train.preprocessor:CommonPreprocessor
  149. if self.preprocess is not None:
  150. data = self.preprocess(uid, data)
  151. # data: {'speech': array([ 1.25122070e-03 ... 6.10351562e-03])}
  152. # 4. Force data-precision
  153. for name in data:
  154. # value is np.ndarray data
  155. value = data[name]
  156. if not isinstance(value, np.ndarray):
  157. raise RuntimeError(
  158. f'All values must be converted to np.ndarray object '
  159. f'by preprocessing, but "{name}" is still {type(value)}.'
  160. )
  161. # Cast to desired type
  162. if value.dtype.kind == 'f':
  163. value = value.astype(self.float_dtype)
  164. elif value.dtype.kind == 'i':
  165. value = value.astype(self.int_dtype)
  166. else:
  167. raise NotImplementedError(
  168. f'Not supported dtype: {value.dtype}')
  169. data[name] = value
  170. yield uid, data
  171. if count == 0:
  172. raise RuntimeError('No iteration')
  173. class IterableESPnetBytesModelScope(IterableDataset):
  174. """Pytorch audio bytes class for ESPNet.
  175. Examples:
  176. >>> dataset = IterableESPnetBytes([('audio bytes', 'input', 'sound'),
  177. ... ('token_int', 'output', 'text_int')],
  178. ... )
  179. >>> for uid, data in dataset:
  180. ... data
  181. {'input': per_utt_array, 'output': per_utt_array}
  182. """
  183. def __init__(self,
  184. path_name_type_list: Collection[Tuple[any, str, str]],
  185. preprocess: Callable[[str, Dict[str, np.ndarray]],
  186. Dict[str, np.ndarray]] = None,
  187. float_dtype: str = 'float32',
  188. int_dtype: str = 'long',
  189. key_file: str = None,
  190. sample_rate: Union[dict, int] = 16000):
  191. assert check_argument_types()
  192. if len(path_name_type_list) == 0:
  193. raise ValueError(
  194. '1 or more elements are required for "path_name_type_list"')
  195. self.preprocess = preprocess
  196. self.float_dtype = float_dtype
  197. self.int_dtype = int_dtype
  198. self.key_file = key_file
  199. self.sample_rate = sample_rate
  200. self.debug_info = {}
  201. non_iterable_list = []
  202. self.path_name_type_list = []
  203. audio_data = path_name_type_list[0]
  204. name = path_name_type_list[1]
  205. _type = path_name_type_list[2]
  206. if name in self.debug_info:
  207. raise RuntimeError(f'"{name}" is duplicated for data-key')
  208. self.debug_info[name] = audio_data, _type
  209. self.path_name_type_list.append((audio_data, name, _type))
  210. if len(non_iterable_list) != 0:
  211. # Some types doesn't support iterable mode
  212. self.non_iterable_dataset = ESPnetDataset(
  213. path_name_type_list=non_iterable_list,
  214. preprocess=preprocess,
  215. float_dtype=float_dtype,
  216. int_dtype=int_dtype,
  217. )
  218. else:
  219. self.non_iterable_dataset = None
  220. self.apply_utt2category = False
  221. if float_dtype == 'float32':
  222. self.np_dtype = np.float32
  223. def has_name(self, name) -> bool:
  224. return name in self.debug_info
  225. def names(self) -> Tuple[str, ...]:
  226. return tuple(self.debug_info)
  227. def __repr__(self):
  228. _mes = self.__class__.__name__
  229. _mes += '('
  230. for name, (path, _type) in self.debug_info.items():
  231. _mes += f'\n {name}: {{"path": "{path}", "type": "{_type}"}}'
  232. _mes += f'\n preprocess: {self.preprocess})'
  233. return _mes
  234. def __iter__(
  235. self) -> Iterator[Tuple[Union[str, int], Dict[str, np.ndarray]]]:
  236. torch.set_printoptions(profile='default')
  237. # 2. Load the entry from each line and create a dict
  238. data = {}
  239. # 2.a. Load data streamingly
  240. value = self.path_name_type_list[0][0]
  241. uid = 'pcm_data'
  242. # name: speech
  243. name = self.path_name_type_list[0][1]
  244. _type = self.path_name_type_list[0][2]
  245. func = DATA_TYPES[_type]
  246. # array: [ 1.25122070e-03 ... ]
  247. # data[name] = np.frombuffer(value, dtype=self.np_dtype)
  248. # 2.b. byte(PCM16) to float32
  249. middle_data = np.frombuffer(value, dtype=np.int16)
  250. middle_data = np.asarray(middle_data)
  251. if middle_data.dtype.kind not in 'iu':
  252. raise TypeError("'middle_data' must be an array of integers")
  253. dtype = np.dtype('float32')
  254. if dtype.kind != 'f':
  255. raise TypeError("'dtype' must be a floating point type")
  256. i = np.iinfo(middle_data.dtype)
  257. abs_max = 2**(i.bits - 1)
  258. offset = i.min + abs_max
  259. array = np.frombuffer((middle_data.astype(dtype) - offset) / abs_max,
  260. dtype=self.np_dtype)
  261. # 2.c. audio resample
  262. if _type == 'sound':
  263. audio_sr: int = 16000
  264. model_sr: int = 16000
  265. if isinstance(self.sample_rate, int):
  266. model_sr = self.sample_rate
  267. else:
  268. if 'audio_sr' in self.sample_rate:
  269. audio_sr = self.sample_rate['audio_sr']
  270. if 'model_sr' in self.sample_rate:
  271. model_sr = self.sample_rate['model_sr']
  272. array = wav_utils.torch_resample(array, audio_sr, model_sr)
  273. data[name] = array
  274. # 3. [Option] Apply preprocessing
  275. # e.g. espnet2.train.preprocessor:CommonPreprocessor
  276. if self.preprocess is not None:
  277. data = self.preprocess(uid, data)
  278. # data: {'speech': array([ 1.25122070e-03 ... 6.10351562e-03])}
  279. # 4. Force data-precision
  280. for name in data:
  281. # value is np.ndarray data
  282. value = data[name]
  283. if not isinstance(value, np.ndarray):
  284. raise RuntimeError(
  285. f'All values must be converted to np.ndarray object '
  286. f'by preprocessing, but "{name}" is still {type(value)}.')
  287. # Cast to desired type
  288. if value.dtype.kind == 'f':
  289. value = value.astype(self.float_dtype)
  290. elif value.dtype.kind == 'i':
  291. value = value.astype(self.int_dtype)
  292. else:
  293. raise NotImplementedError(
  294. f'Not supported dtype: {value.dtype}')
  295. data[name] = value
  296. yield uid, data