iterable_dataset.py 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237
  1. """Iterable dataset module."""
  2. import copy
  3. from io import StringIO
  4. from pathlib import Path
  5. from typing import Callable
  6. from typing import Collection
  7. from typing import Dict
  8. from typing import Iterator
  9. from typing import Tuple
  10. from typing import Union
  11. import kaldiio
  12. import numpy as np
  13. import soundfile
  14. import torch
  15. from torch.utils.data.dataset import IterableDataset
  16. from typeguard import check_argument_types
  17. from funasr.datasets.dataset import ESPnetDataset
  18. def load_kaldi(input):
  19. retval = kaldiio.load_mat(input)
  20. if isinstance(retval, tuple):
  21. assert len(retval) == 2, len(retval)
  22. if isinstance(retval[0], int) and isinstance(retval[1], np.ndarray):
  23. # sound scp case
  24. rate, array = retval
  25. elif isinstance(retval[1], int) and isinstance(retval[0], np.ndarray):
  26. # Extended ark format case
  27. array, rate = retval
  28. else:
  29. raise RuntimeError(f"Unexpected type: {type(retval[0])}, {type(retval[1])}")
  30. # Multichannel wave fie
  31. # array: (NSample, Channel) or (Nsample)
  32. else:
  33. # Normal ark case
  34. assert isinstance(retval, np.ndarray), type(retval)
  35. array = retval
  36. return array
  37. DATA_TYPES = {
  38. "sound": lambda x: soundfile.read(x)[0],
  39. "kaldi_ark": load_kaldi,
  40. "npy": np.load,
  41. "text_int": lambda x: np.loadtxt(
  42. StringIO(x), ndmin=1, dtype=np.long, delimiter=" "
  43. ),
  44. "csv_int": lambda x: np.loadtxt(StringIO(x), ndmin=1, dtype=np.long, delimiter=","),
  45. "text_float": lambda x: np.loadtxt(
  46. StringIO(x), ndmin=1, dtype=np.float32, delimiter=" "
  47. ),
  48. "csv_float": lambda x: np.loadtxt(
  49. StringIO(x), ndmin=1, dtype=np.float32, delimiter=","
  50. ),
  51. "text": lambda x: x,
  52. }
  53. class IterableESPnetDataset(IterableDataset):
  54. """Pytorch Dataset class for ESPNet.
  55. Examples:
  56. >>> dataset = IterableESPnetDataset([('wav.scp', 'input', 'sound'),
  57. ... ('token_int', 'output', 'text_int')],
  58. ... )
  59. >>> for uid, data in dataset:
  60. ... data
  61. {'input': per_utt_array, 'output': per_utt_array}
  62. """
  63. def __init__(
  64. self,
  65. path_name_type_list: Collection[Tuple[str, str, str]],
  66. preprocess: Callable[
  67. [str, Dict[str, np.ndarray]], Dict[str, np.ndarray]
  68. ] = None,
  69. float_dtype: str = "float32",
  70. int_dtype: str = "long",
  71. key_file: str = None,
  72. ):
  73. assert check_argument_types()
  74. if len(path_name_type_list) == 0:
  75. raise ValueError(
  76. '1 or more elements are required for "path_name_type_list"'
  77. )
  78. path_name_type_list = copy.deepcopy(path_name_type_list)
  79. self.preprocess = preprocess
  80. self.float_dtype = float_dtype
  81. self.int_dtype = int_dtype
  82. self.key_file = key_file
  83. self.debug_info = {}
  84. non_iterable_list = []
  85. self.path_name_type_list = []
  86. for path, name, _type in path_name_type_list:
  87. if name in self.debug_info:
  88. raise RuntimeError(f'"{name}" is duplicated for data-key')
  89. self.debug_info[name] = path, _type
  90. if _type not in DATA_TYPES:
  91. non_iterable_list.append((path, name, _type))
  92. else:
  93. self.path_name_type_list.append((path, name, _type))
  94. if len(non_iterable_list) != 0:
  95. # Some types doesn't support iterable mode
  96. self.non_iterable_dataset = ESPnetDataset(
  97. path_name_type_list=non_iterable_list,
  98. preprocess=preprocess,
  99. float_dtype=float_dtype,
  100. int_dtype=int_dtype,
  101. )
  102. else:
  103. self.non_iterable_dataset = None
  104. if Path(Path(path_name_type_list[0][0]).parent, "utt2category").exists():
  105. self.apply_utt2category = True
  106. else:
  107. self.apply_utt2category = False
  108. def has_name(self, name) -> bool:
  109. return name in self.debug_info
  110. def names(self) -> Tuple[str, ...]:
  111. return tuple(self.debug_info)
  112. def __repr__(self):
  113. _mes = self.__class__.__name__
  114. _mes += "("
  115. for name, (path, _type) in self.debug_info.items():
  116. _mes += f'\n {name}: {{"path": "{path}", "type": "{_type}"}}'
  117. _mes += f"\n preprocess: {self.preprocess})"
  118. return _mes
  119. def __iter__(self) -> Iterator[Tuple[Union[str, int], Dict[str, np.ndarray]]]:
  120. if self.key_file is not None:
  121. uid_iter = (
  122. line.rstrip().split(maxsplit=1)[0]
  123. for line in open(self.key_file, encoding="utf-8")
  124. )
  125. elif len(self.path_name_type_list) != 0:
  126. uid_iter = (
  127. line.rstrip().split(maxsplit=1)[0]
  128. for line in open(self.path_name_type_list[0][0], encoding="utf-8")
  129. )
  130. else:
  131. uid_iter = iter(self.non_iterable_dataset)
  132. files = [open(lis[0], encoding="utf-8") for lis in self.path_name_type_list]
  133. worker_info = torch.utils.data.get_worker_info()
  134. linenum = 0
  135. count = 0
  136. for count, uid in enumerate(uid_iter, 1):
  137. # If num_workers>=1, split keys
  138. if worker_info is not None:
  139. if (count - 1) % worker_info.num_workers != worker_info.id:
  140. continue
  141. # 1. Read a line from each file
  142. while True:
  143. keys = []
  144. values = []
  145. for f in files:
  146. linenum += 1
  147. try:
  148. line = next(f)
  149. except StopIteration:
  150. raise RuntimeError(f"{uid} is not found in the files")
  151. sps = line.rstrip().split(maxsplit=1)
  152. if len(sps) != 2:
  153. raise RuntimeError(
  154. f"This line doesn't include a space:"
  155. f" {f}:L{linenum}: {line})"
  156. )
  157. key, value = sps
  158. keys.append(key)
  159. values.append(value)
  160. for k_idx, k in enumerate(keys):
  161. if k != keys[0]:
  162. raise RuntimeError(
  163. f"Keys are mismatched. Text files (idx={k_idx}) is "
  164. f"not sorted or not having same keys at L{linenum}"
  165. )
  166. # If the key is matched, break the loop
  167. if len(keys) == 0 or keys[0] == uid:
  168. break
  169. # 2. Load the entry from each line and create a dict
  170. data = {}
  171. # 2.a. Load data streamingly
  172. for value, (path, name, _type) in zip(values, self.path_name_type_list):
  173. func = DATA_TYPES[_type]
  174. # Load entry
  175. array = func(value)
  176. data[name] = array
  177. if self.non_iterable_dataset is not None:
  178. # 2.b. Load data from non-iterable dataset
  179. _, from_non_iterable = self.non_iterable_dataset[uid]
  180. data.update(from_non_iterable)
  181. # 3. [Option] Apply preprocessing
  182. # e.g. funasr.train.preprocessor:CommonPreprocessor
  183. if self.preprocess is not None:
  184. data = self.preprocess(uid, data)
  185. # 4. Force data-precision
  186. for name in data:
  187. value = data[name]
  188. if not isinstance(value, np.ndarray):
  189. raise RuntimeError(
  190. f"All values must be converted to np.ndarray object "
  191. f'by preprocessing, but "{name}" is still {type(value)}.'
  192. )
  193. # Cast to desired type
  194. if value.dtype.kind == "f":
  195. value = value.astype(self.float_dtype)
  196. elif value.dtype.kind == "i":
  197. value = value.astype(self.int_dtype)
  198. else:
  199. raise NotImplementedError(f"Not supported dtype: {value.dtype}")
  200. data[name] = value
  201. yield uid, data
  202. if count == 0:
  203. raise RuntimeError("No iteration")