| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657 |
- import logging
- import kaldiio
- import numpy as np
- import torch
- from torch.utils.data import DataLoader
- from torch.utils.data import Dataset
- def custom_collate(batch):
- keys, speech, speaker_labels, orders = zip(*batch)
- speech = [torch.from_numpy(np.copy(sph)).to(torch.float32) for sph in speech]
- speaker_labels = [torch.from_numpy(np.copy(spk)).to(torch.float32) for spk in speaker_labels]
- orders = [torch.from_numpy(np.copy(o)).to(torch.int64) for o in orders]
- batch = dict(speech=speech,
- speaker_labels=speaker_labels,
- orders=orders)
- return keys, batch
- class EENDOLADataset(Dataset):
- def __init__(
- self,
- data_file,
- ):
- self.data_file = data_file
- with open(data_file) as f:
- lines = f.readlines()
- self.samples = [line.strip().split() for line in lines]
- logging.info("total samples: {}".format(len(self.samples)))
- def __len__(self):
- return len(self.samples)
- def __getitem__(self, idx):
- key, speech_path, speaker_label_path = self.samples[idx]
- speech = kaldiio.load_mat(speech_path)
- speaker_label = kaldiio.load_mat(speaker_label_path).reshape(speech.shape[0], -1)
- order = np.arange(speech.shape[0])
- np.random.shuffle(order)
- return key, speech, speaker_label, order
- class EENDOLADataLoader():
- def __init__(self, data_file, batch_size, shuffle=True, num_workers=8):
- dataset = EENDOLADataset(data_file)
- self.data_loader = DataLoader(dataset,
- batch_size=batch_size,
- collate_fn=custom_collate,
- shuffle=shuffle,
- num_workers=num_workers)
- def build_iter(self, epoch):
- return self.data_loader
|