| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444 |
- # Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved.
- # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
- from abc import ABC
- from abc import abstractmethod
- import collections
- import copy
- import functools
- import logging
- import numbers
- import re
- from typing import Any
- from typing import Callable
- from typing import Collection
- from typing import Dict
- from typing import Mapping
- from typing import Tuple
- from typing import Union
- import h5py
- import humanfriendly
- import kaldiio
- import numpy as np
- import torch
- from torch.utils.data.dataset import Dataset
- from typeguard import check_argument_types
- from typeguard import check_return_type
- from funasr.fileio.npy_scp import NpyScpReader
- from funasr.fileio.rand_gen_dataset import FloatRandomGenerateDataset
- from funasr.fileio.rand_gen_dataset import IntRandomGenerateDataset
- from funasr.fileio.read_text import load_num_sequence_text
- from funasr.fileio.read_text import read_2column_text
- from funasr.fileio.sound_scp import SoundScpReader
- from funasr.utils.sized_dict import SizedDict
- class AdapterForSoundScpReader(collections.abc.Mapping):
- def __init__(self, loader, dtype=None):
- assert check_argument_types()
- self.loader = loader
- self.dtype = dtype
- self.rate = None
- def keys(self):
- return self.loader.keys()
- def __len__(self):
- return len(self.loader)
- def __iter__(self):
- return iter(self.loader)
- def __getitem__(self, key: str) -> np.ndarray:
- retval = self.loader[key]
- if isinstance(retval, tuple):
- assert len(retval) == 2, len(retval)
- if isinstance(retval[0], int) and isinstance(retval[1], np.ndarray):
- # sound scp case
- rate, array = retval
- elif isinstance(retval[1], int) and isinstance(retval[0], np.ndarray):
- # Extended ark format case
- array, rate = retval
- else:
- raise RuntimeError(
- f"Unexpected type: {type(retval[0])}, {type(retval[1])}"
- )
- if self.rate is not None and self.rate != rate:
- raise RuntimeError(
- f"Sampling rates are mismatched: {self.rate} != {rate}"
- )
- self.rate = rate
- # Multichannel wave fie
- # array: (NSample, Channel) or (Nsample)
- if self.dtype is not None:
- array = array.astype(self.dtype)
- else:
- # Normal ark case
- assert isinstance(retval, np.ndarray), type(retval)
- array = retval
- if self.dtype is not None:
- array = array.astype(self.dtype)
- assert isinstance(array, np.ndarray), type(array)
- return array
- class H5FileWrapper:
- def __init__(self, path: str):
- self.path = path
- self.h5_file = h5py.File(path, "r")
- def __repr__(self) -> str:
- return str(self.h5_file)
- def __len__(self) -> int:
- return len(self.h5_file)
- def __iter__(self):
- return iter(self.h5_file)
- def __getitem__(self, key) -> np.ndarray:
- value = self.h5_file[key]
- return value[()]
- def sound_loader(path, float_dtype=None):
- # The file is as follows:
- # utterance_id_A /some/where/a.wav
- # utterance_id_B /some/where/a.flac
- # NOTE(kamo): SoundScpReader doesn't support pipe-fashion
- # like Kaldi e.g. "cat a.wav |".
- # NOTE(kamo): The audio signal is normalized to [-1,1] range.
- loader = SoundScpReader(path, normalize=True, always_2d=False)
- # SoundScpReader.__getitem__() returns Tuple[int, ndarray],
- # but ndarray is desired, so Adapter class is inserted here
- return AdapterForSoundScpReader(loader, float_dtype)
- def kaldi_loader(path, float_dtype=None, max_cache_fd: int = 0):
- loader = kaldiio.load_scp(path, max_cache_fd=max_cache_fd)
- return AdapterForSoundScpReader(loader, float_dtype)
- def rand_int_loader(filepath, loader_type):
- # e.g. rand_int_3_10
- try:
- low, high = map(int, loader_type[len("rand_int_") :].split("_"))
- except ValueError:
- raise RuntimeError(f"e.g rand_int_3_10: but got {loader_type}")
- return IntRandomGenerateDataset(filepath, low, high)
- DATA_TYPES = {
- "sound": dict(
- func=sound_loader,
- kwargs=["float_dtype"],
- help="Audio format types which supported by sndfile wav, flac, etc."
- "\n\n"
- " utterance_id_a a.wav\n"
- " utterance_id_b b.wav\n"
- " ...",
- ),
- "kaldi_ark": dict(
- func=kaldi_loader,
- kwargs=["max_cache_fd"],
- help="Kaldi-ark file type."
- "\n\n"
- " utterance_id_A /some/where/a.ark:123\n"
- " utterance_id_B /some/where/a.ark:456\n"
- " ...",
- ),
- "npy": dict(
- func=NpyScpReader,
- kwargs=[],
- help="Npy file format."
- "\n\n"
- " utterance_id_A /some/where/a.npy\n"
- " utterance_id_B /some/where/b.npy\n"
- " ...",
- ),
- "text_int": dict(
- func=functools.partial(load_num_sequence_text, loader_type="text_int"),
- kwargs=[],
- help="A text file in which is written a sequence of interger numbers "
- "separated by space."
- "\n\n"
- " utterance_id_A 12 0 1 3\n"
- " utterance_id_B 3 3 1\n"
- " ...",
- ),
- "csv_int": dict(
- func=functools.partial(load_num_sequence_text, loader_type="csv_int"),
- kwargs=[],
- help="A text file in which is written a sequence of interger numbers "
- "separated by comma."
- "\n\n"
- " utterance_id_A 100,80\n"
- " utterance_id_B 143,80\n"
- " ...",
- ),
- "text_float": dict(
- func=functools.partial(load_num_sequence_text, loader_type="text_float"),
- kwargs=[],
- help="A text file in which is written a sequence of float numbers "
- "separated by space."
- "\n\n"
- " utterance_id_A 12. 3.1 3.4 4.4\n"
- " utterance_id_B 3. 3.12 1.1\n"
- " ...",
- ),
- "csv_float": dict(
- func=functools.partial(load_num_sequence_text, loader_type="csv_float"),
- kwargs=[],
- help="A text file in which is written a sequence of float numbers "
- "separated by comma."
- "\n\n"
- " utterance_id_A 12.,3.1,3.4,4.4\n"
- " utterance_id_B 3.,3.12,1.1\n"
- " ...",
- ),
- "text": dict(
- func=read_2column_text,
- kwargs=[],
- help="Return text as is. The text must be converted to ndarray "
- "by 'preprocess'."
- "\n\n"
- " utterance_id_A hello world\n"
- " utterance_id_B foo bar\n"
- " ...",
- ),
- "hdf5": dict(
- func=H5FileWrapper,
- kwargs=[],
- help="A HDF5 file which contains arrays at the first level or the second level."
- " >>> f = h5py.File('file.h5')\n"
- " >>> array1 = f['utterance_id_A']\n"
- " >>> array2 = f['utterance_id_B']\n",
- ),
- "rand_float": dict(
- func=FloatRandomGenerateDataset,
- kwargs=[],
- help="Generate random float-ndarray which has the given shapes "
- "in the file."
- "\n\n"
- " utterance_id_A 3,4\n"
- " utterance_id_B 10,4\n"
- " ...",
- ),
- "rand_int_\\d+_\\d+": dict(
- func=rand_int_loader,
- kwargs=["loader_type"],
- help="e.g. 'rand_int_0_10'. Generate random int-ndarray which has the given "
- "shapes in the path. "
- "Give the lower and upper value by the file type. e.g. "
- "rand_int_0_10 -> Generate integers from 0 to 10."
- "\n\n"
- " utterance_id_A 3,4\n"
- " utterance_id_B 10,4\n"
- " ...",
- ),
- }
- class AbsDataset(Dataset, ABC):
- @abstractmethod
- def has_name(self, name) -> bool:
- raise NotImplementedError
- @abstractmethod
- def names(self) -> Tuple[str, ...]:
- raise NotImplementedError
- @abstractmethod
- def __getitem__(self, uid) -> Tuple[Any, Dict[str, np.ndarray]]:
- raise NotImplementedError
- class ESPnetDataset(AbsDataset):
- """Pytorch Dataset class for ESPNet.
- Examples:
- >>> dataset = ESPnetDataset([('wav.scp', 'input', 'sound'),
- ... ('token_int', 'output', 'text_int')],
- ... )
- ... uttid, data = dataset['uttid']
- {'input': per_utt_array, 'output': per_utt_array}
- """
- def __init__(
- self,
- path_name_type_list: Collection[Tuple[str, str, str]],
- preprocess: Callable[
- [str, Dict[str, np.ndarray]], Dict[str, np.ndarray]
- ] = None,
- float_dtype: str = "float32",
- int_dtype: str = "long",
- max_cache_size: Union[float, int, str] = 0.0,
- max_cache_fd: int = 0,
- ):
- assert check_argument_types()
- if len(path_name_type_list) == 0:
- raise ValueError(
- '1 or more elements are required for "path_name_type_list"'
- )
- path_name_type_list = copy.deepcopy(path_name_type_list)
- self.preprocess = preprocess
- self.float_dtype = float_dtype
- self.int_dtype = int_dtype
- self.max_cache_fd = max_cache_fd
- self.loader_dict = {}
- self.debug_info = {}
- for path, name, _type in path_name_type_list:
- if name in self.loader_dict:
- raise RuntimeError(f'"{name}" is duplicated for data-key')
- loader = self._build_loader(path, _type)
- self.loader_dict[name] = loader
- self.debug_info[name] = path, _type
- if len(self.loader_dict[name]) == 0:
- raise RuntimeError(f"{path} has no samples")
- # TODO(kamo): Should check consistency of each utt-keys?
- if isinstance(max_cache_size, str):
- max_cache_size = humanfriendly.parse_size(max_cache_size)
- self.max_cache_size = max_cache_size
- if max_cache_size > 0:
- self.cache = SizedDict(shared=True)
- else:
- self.cache = None
- def _build_loader(
- self, path: str, loader_type: str
- ) -> Mapping[str, Union[np.ndarray, torch.Tensor, str, numbers.Number]]:
- """Helper function to instantiate Loader.
- Args:
- path: The file path
- loader_type: loader_type. sound, npy, text_int, text_float, etc
- """
- for key, dic in DATA_TYPES.items():
- # e.g. loader_type="sound"
- # -> return DATA_TYPES["sound"]["func"](path)
- if re.match(key, loader_type):
- kwargs = {}
- for key2 in dic["kwargs"]:
- if key2 == "loader_type":
- kwargs["loader_type"] = loader_type
- elif key2 == "float_dtype":
- kwargs["float_dtype"] = self.float_dtype
- elif key2 == "int_dtype":
- kwargs["int_dtype"] = self.int_dtype
- elif key2 == "max_cache_fd":
- kwargs["max_cache_fd"] = self.max_cache_fd
- else:
- raise RuntimeError(f"Not implemented keyword argument: {key2}")
- func = dic["func"]
- try:
- return func(path, **kwargs)
- except Exception:
- if hasattr(func, "__name__"):
- name = func.__name__
- else:
- name = str(func)
- logging.error(f"An error happened with {name}({path})")
- raise
- else:
- raise RuntimeError(f"Not supported: loader_type={loader_type}")
- def has_name(self, name) -> bool:
- return name in self.loader_dict
- def names(self) -> Tuple[str, ...]:
- return tuple(self.loader_dict)
- def __iter__(self):
- return iter(next(iter(self.loader_dict.values())))
- def __repr__(self):
- _mes = self.__class__.__name__
- _mes += "("
- for name, (path, _type) in self.debug_info.items():
- _mes += f'\n {name}: {{"path": "{path}", "type": "{_type}"}}'
- _mes += f"\n preprocess: {self.preprocess})"
- return _mes
- def __getitem__(self, uid: Union[str, int]) -> Tuple[str, Dict[str, np.ndarray]]:
- assert check_argument_types()
- # Change integer-id to string-id
- if isinstance(uid, int):
- d = next(iter(self.loader_dict.values()))
- uid = list(d)[uid]
- if self.cache is not None and uid in self.cache:
- data = self.cache[uid]
- return uid, data
- data = {}
- # 1. Load data from each loaders
- for name, loader in self.loader_dict.items():
- try:
- value = loader[uid]
- if isinstance(value, (list, tuple)):
- value = np.array(value)
- if not isinstance(
- value, (np.ndarray, torch.Tensor, str, numbers.Number)
- ):
- raise TypeError(
- f"Must be ndarray, torch.Tensor, str or Number: {type(value)}"
- )
- except Exception:
- path, _type = self.debug_info[name]
- logging.error(
- f"Error happened with path={path}, type={_type}, id={uid}"
- )
- raise
- # torch.Tensor is converted to ndarray
- if isinstance(value, torch.Tensor):
- value = value.numpy()
- elif isinstance(value, numbers.Number):
- value = np.array([value])
- data[name] = value
- # 2. [Option] Apply preprocessing
- # e.g. funasr.train.preprocessor:CommonPreprocessor
- if self.preprocess is not None:
- data = self.preprocess(uid, data)
- # 3. Force data-precision
- for name in data:
- value = data[name]
- if not isinstance(value, np.ndarray):
- raise RuntimeError(
- f"All values must be converted to np.ndarray object "
- f'by preprocessing, but "{name}" is still {type(value)}.'
- )
- # Cast to desired type
- if value.dtype.kind == "f":
- value = value.astype(self.float_dtype)
- elif value.dtype.kind == "i":
- value = value.astype(self.int_dtype)
- else:
- raise NotImplementedError(f"Not supported dtype: {value.dtype}")
- data[name] = value
- if self.cache is not None and self.cache.size < self.max_cache_size:
- self.cache[uid] = data
- retval = uid, data
- assert check_return_type(retval)
- return retval
|