|
|
@@ -66,7 +66,7 @@ def load_pcm(input):
|
|
|
return load_bytes(bytes)
|
|
|
|
|
|
DATA_TYPES = {
|
|
|
- "sound": lambda x: torchaudio.load(x)[0][0].numpy(),
|
|
|
+ "sound": lambda x: torchaudio.load(x)[0].numpy(),
|
|
|
"pcm": load_pcm,
|
|
|
"kaldi_ark": load_kaldi,
|
|
|
"bytes": load_bytes,
|
|
|
@@ -106,6 +106,7 @@ class IterableESPnetDataset(IterableDataset):
|
|
|
] = None,
|
|
|
float_dtype: str = "float32",
|
|
|
fs: dict = None,
|
|
|
+ mc: bool = False,
|
|
|
int_dtype: str = "long",
|
|
|
key_file: str = None,
|
|
|
):
|
|
|
@@ -122,6 +123,7 @@ class IterableESPnetDataset(IterableDataset):
|
|
|
self.int_dtype = int_dtype
|
|
|
self.key_file = key_file
|
|
|
self.fs = fs
|
|
|
+ self.mc = mc
|
|
|
|
|
|
self.debug_info = {}
|
|
|
non_iterable_list = []
|
|
|
@@ -192,6 +194,7 @@ class IterableESPnetDataset(IterableDataset):
|
|
|
array = torchaudio.transforms.Resample(orig_freq=audio_fs,
|
|
|
new_freq=model_fs)(array)
|
|
|
array = array.squeeze(0).numpy()
|
|
|
+
|
|
|
data[name] = array
|
|
|
|
|
|
if self.preprocess is not None:
|
|
|
@@ -238,11 +241,12 @@ class IterableESPnetDataset(IterableDataset):
|
|
|
model_fs = self.fs["model_fs"]
|
|
|
if audio_fs is not None and model_fs is not None:
|
|
|
array = torch.from_numpy(array)
|
|
|
- array = array.unsqueeze(0)
|
|
|
array = torchaudio.transforms.Resample(orig_freq=audio_fs,
|
|
|
new_freq=model_fs)(array)
|
|
|
- array = array.squeeze(0).numpy()
|
|
|
- data[name] = array
|
|
|
+ if self.mc:
|
|
|
+ data[name] = array.transpose(0, 1).numpy()
|
|
|
+ else:
|
|
|
+ data[name] = array[0].numpy()
|
|
|
|
|
|
if self.preprocess is not None:
|
|
|
data = self.preprocess(uid, data)
|
|
|
@@ -340,11 +344,15 @@ class IterableESPnetDataset(IterableDataset):
|
|
|
model_fs = self.fs["model_fs"]
|
|
|
if audio_fs is not None and model_fs is not None:
|
|
|
array = torch.from_numpy(array)
|
|
|
- array = array.unsqueeze(0)
|
|
|
array = torchaudio.transforms.Resample(orig_freq=audio_fs,
|
|
|
new_freq=model_fs)(array)
|
|
|
- array = array.squeeze(0).numpy()
|
|
|
- data[name] = array
|
|
|
+ if _type == "sound":
|
|
|
+ if self.mc:
|
|
|
+ data[name] = array.transpose(0, 1).numpy()
|
|
|
+ else:
|
|
|
+ data[name] = array[0].numpy()
|
|
|
+ else:
|
|
|
+ data[name] = array
|
|
|
if self.non_iterable_dataset is not None:
|
|
|
# 2.b. Load data from non-iterable dataset
|
|
|
_, from_non_iterable = self.non_iterable_dataset[uid]
|