|
@@ -4,6 +4,7 @@ import soundfile
|
|
|
from functools import partial
|
|
from functools import partial
|
|
|
|
|
|
|
|
import torch
|
|
import torch
|
|
|
|
|
+import torchaudio
|
|
|
import torch.distributed as dist
|
|
import torch.distributed as dist
|
|
|
from kaldiio import ReadHelper
|
|
from kaldiio import ReadHelper
|
|
|
from torch.utils.data import IterableDataset
|
|
from torch.utils.data import IterableDataset
|
|
@@ -117,7 +118,8 @@ class AudioDataset(IterableDataset):
|
|
|
sample_dict["key"] = key
|
|
sample_dict["key"] = key
|
|
|
elif data_type == "sound":
|
|
elif data_type == "sound":
|
|
|
key, path = item.strip().split()
|
|
key, path = item.strip().split()
|
|
|
- mat, sampling_rate = soundfile.read(path)
|
|
|
|
|
|
|
+ waveform, sampling_rate = torchaudio.load(path)
|
|
|
|
|
+ mat = waveform[0]
|
|
|
sample_dict[data_name] = mat
|
|
sample_dict[data_name] = mat
|
|
|
sample_dict["sampling_rate"] = sampling_rate
|
|
sample_dict["sampling_rate"] = sampling_rate
|
|
|
if data_name == "speech":
|
|
if data_name == "speech":
|