eend_ola_dataloader.py 1.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657
  1. import logging
  2. import kaldiio
  3. import numpy as np
  4. import torch
  5. from torch.utils.data import DataLoader
  6. from torch.utils.data import Dataset
  7. def custom_collate(batch):
  8. keys, speech, speaker_labels, orders = zip(*batch)
  9. speech = [torch.from_numpy(np.copy(sph)).to(torch.float32) for sph in speech]
  10. speaker_labels = [torch.from_numpy(np.copy(spk)).to(torch.float32) for spk in speaker_labels]
  11. orders = [torch.from_numpy(np.copy(o)).to(torch.int64) for o in orders]
  12. batch = dict(speech=speech,
  13. speaker_labels=speaker_labels,
  14. orders=orders)
  15. return keys, batch
  16. class EENDOLADataset(Dataset):
  17. def __init__(
  18. self,
  19. data_file,
  20. ):
  21. self.data_file = data_file
  22. with open(data_file) as f:
  23. lines = f.readlines()
  24. self.samples = [line.strip().split() for line in lines]
  25. logging.info("total samples: {}".format(len(self.samples)))
  26. def __len__(self):
  27. return len(self.samples)
  28. def __getitem__(self, idx):
  29. key, speech_path, speaker_label_path = self.samples[idx]
  30. speech = kaldiio.load_mat(speech_path)
  31. speaker_label = kaldiio.load_mat(speaker_label_path).reshape(speech.shape[0], -1)
  32. order = np.arange(speech.shape[0])
  33. np.random.shuffle(order)
  34. return key, speech, speaker_label, order
  35. class EENDOLADataLoader():
  36. def __init__(self, data_file, batch_size, shuffle=True, num_workers=8):
  37. dataset = EENDOLADataset(data_file)
  38. self.data_loader = DataLoader(dataset,
  39. batch_size=batch_size,
  40. collate_fn=custom_collate,
  41. shuffle=shuffle,
  42. num_workers=num_workers)
  43. def build_iter(self, epoch):
  44. return self.data_loader