clipping.py 1.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940
  1. import numpy as np
  2. import torch
  3. from funasr.datasets.collate_fn import crop_to_max_size
  4. def clipping(data):
  5. assert isinstance(data, list)
  6. assert "key" in data[0]
  7. keys = [x["key"] for x in data]
  8. batch = {}
  9. data_names = data[0].keys()
  10. for data_name in data_names:
  11. if data_name == "key":
  12. continue
  13. else:
  14. if data[0][data_name].dtype.kind == "i":
  15. tensor_type = torch.int64
  16. else:
  17. tensor_type = torch.float32
  18. tensor_list = [torch.tensor(np.copy(d[data_name]), dtype=tensor_type) for d in data]
  19. tensor_lengths = torch.tensor([len(d[data_name]) for d in data], dtype=torch.int32)
  20. length_clip = min(tensor_lengths)
  21. tensor_clip = tensor_list[0].new_zeros(len(tensor_list), length_clip, tensor_list[0].shape[1])
  22. for i, (tensor, length) in enumerate(zip(tensor_list, tensor_lengths)):
  23. diff = length - length_clip
  24. assert diff >= 0
  25. if diff == 0:
  26. tensor_clip[i] = tensor
  27. else:
  28. tensor_clip[i] = crop_to_max_size(tensor, length_clip)
  29. batch[data_name] = tensor_clip
  30. batch[data_name + "_lengths"] = torch.tensor([tensor.shape[0] for tensor in tensor_clip], dtype=torch.long)
  31. return keys, batch