| 12345678910111213141516171819202122232425262728293031323334353637383940 |
- import numpy as np
- import torch
- from funasr.datasets.collate_fn import crop_to_max_size
- def clipping(data):
- assert isinstance(data, list)
- assert "key" in data[0]
- keys = [x["key"] for x in data]
- batch = {}
- data_names = data[0].keys()
- for data_name in data_names:
- if data_name == "key":
- continue
- else:
- if data[0][data_name].dtype.kind == "i":
- tensor_type = torch.int64
- else:
- tensor_type = torch.float32
- tensor_list = [torch.tensor(np.copy(d[data_name]), dtype=tensor_type) for d in data]
- tensor_lengths = torch.tensor([len(d[data_name]) for d in data], dtype=torch.int32)
- length_clip = min(tensor_lengths)
- tensor_clip = tensor_list[0].new_zeros(len(tensor_list), length_clip, tensor_list[0].shape[1])
- for i, (tensor, length) in enumerate(zip(tensor_list, tensor_lengths)):
- diff = length - length_clip
- assert diff >= 0
- if diff == 0:
- tensor_clip[i] = tensor
- else:
- tensor_clip[i] = crop_to_max_size(tensor, length_clip)
- batch[data_name] = tensor_clip
- batch[data_name + "_lengths"] = torch.tensor([tensor.shape[0] for tensor in tensor_clip], dtype=torch.long)
- return keys, batch
|