padding.py 1.2 KB

1234567891011121314151617181920212223242526272829303132333435
  1. import numpy as np
  2. import torch
  3. from torch.nn.utils.rnn import pad_sequence
  4. def padding(data, float_pad_value=0.0, int_pad_value=-1):
  5. assert isinstance(data, list)
  6. assert "key" in data[0]
  7. assert "speech" in data[0]
  8. assert "text" in data[0]
  9. keys = [x["key"] for x in data]
  10. batch = {}
  11. data_names = data[0].keys()
  12. for data_name in data_names:
  13. if data_name == "key" or data_name =="sampling_rate":
  14. continue
  15. else:
  16. if data[0][data_name].dtype.kind == "i":
  17. pad_value = int_pad_value
  18. tensor_type = torch.int64
  19. else:
  20. pad_value = float_pad_value
  21. tensor_type = torch.float32
  22. tensor_list = [torch.tensor(np.copy(d[data_name]), dtype=tensor_type) for d in data]
  23. tensor_lengths = torch.tensor([len(d[data_name]) for d in data], dtype=torch.int32)
  24. tensor_pad = pad_sequence(tensor_list,
  25. batch_first=True,
  26. padding_value=pad_value)
  27. batch[data_name] = tensor_pad
  28. batch[data_name + "_lengths"] = tensor_lengths
  29. return keys, batch