|
|
@@ -13,16 +13,16 @@ def padding(data, float_pad_value=0.0, int_pad_value=-1):
|
|
|
batch = {}
|
|
|
data_names = data[0].keys()
|
|
|
for data_name in data_names:
|
|
|
- if data_name == "key" or data_name == "sampling_rate" or data_name == 'hotword_indxs':
|
|
|
- batch[data_name] = data[0][data_name]
|
|
|
+ if data_name == "key" or data_name == "sampling_rate":
|
|
|
continue
|
|
|
else:
|
|
|
- if data[0][data_name].dtype.kind == "i":
|
|
|
- pad_value = int_pad_value
|
|
|
- tensor_type = torch.int64
|
|
|
- else:
|
|
|
- pad_value = float_pad_value
|
|
|
- tensor_type = torch.float32
|
|
|
+ if data_name != 'hotword_indxs':
|
|
|
+ if data[0][data_name].dtype.kind == "i":
|
|
|
+ pad_value = int_pad_value
|
|
|
+ tensor_type = torch.int64
|
|
|
+ else:
|
|
|
+ pad_value = float_pad_value
|
|
|
+ tensor_type = torch.float32
|
|
|
|
|
|
tensor_list = [torch.tensor(np.copy(d[data_name]), dtype=tensor_type) for d in data]
|
|
|
tensor_lengths = torch.tensor([len(d[data_name]) for d in data], dtype=torch.int32)
|