padding.py 3.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374
  1. import numpy as np
  2. import torch
  3. from torch.nn.utils.rnn import pad_sequence
  4. def padding(data, float_pad_value=0.0, int_pad_value=-1):
  5. assert isinstance(data, list)
  6. assert "key" in data[0]
  7. assert "speech" in data[0] or "text" in data[0]
  8. keys = [x["key"] for x in data]
  9. batch = {}
  10. data_names = data[0].keys()
  11. for data_name in data_names:
  12. if data_name == "key" or data_name == "sampling_rate":
  13. continue
  14. else:
  15. if data_name != 'hotword_indxs':
  16. if data[0][data_name].dtype.kind == "i":
  17. pad_value = int_pad_value
  18. tensor_type = torch.int64
  19. else:
  20. pad_value = float_pad_value
  21. tensor_type = torch.float32
  22. tensor_list = [torch.tensor(np.copy(d[data_name]), dtype=tensor_type) for d in data]
  23. tensor_lengths = torch.tensor([len(d[data_name]) for d in data], dtype=torch.int32)
  24. tensor_pad = pad_sequence(tensor_list,
  25. batch_first=True,
  26. padding_value=pad_value)
  27. batch[data_name] = tensor_pad
  28. batch[data_name + "_lengths"] = tensor_lengths
  29. # SAC LABEL INCLUDE
  30. if "hotword_indxs" in batch:
  31. # if hotword indxs in batch
  32. # use it to slice hotwords out
  33. hotword_list = []
  34. hotword_lengths = []
  35. text = batch['text']
  36. text_lengths = batch['text_lengths']
  37. hotword_indxs = batch['hotword_indxs']
  38. dha_pad = torch.ones_like(text) * -1
  39. _, t1 = text.shape
  40. t1 += 1 # TODO: as parameter which is same as predictor_bias
  41. nth_hw = 0
  42. for b, (hotword_indx, one_text, length) in enumerate(zip(hotword_indxs, text, text_lengths)):
  43. dha_pad[b][:length] = 8405
  44. if hotword_indx[0] != -1:
  45. start, end = int(hotword_indx[0]), int(hotword_indx[1])
  46. hotword = one_text[start: end+1]
  47. hotword_list.append(hotword)
  48. hotword_lengths.append(end-start+1)
  49. dha_pad[b][start: end+1] = one_text[start: end+1]
  50. nth_hw += 1
  51. if len(hotword_indx) == 4 and hotword_indx[2] != -1:
  52. # the second hotword if exist
  53. start, end = int(hotword_indx[2]), int(hotword_indx[3])
  54. hotword_list.append(one_text[start: end+1])
  55. hotword_lengths.append(end-start+1)
  56. dha_pad[b][start: end+1] = one_text[start: end+1]
  57. nth_hw += 1
  58. hotword_list.append(torch.tensor([1]))
  59. hotword_lengths.append(1)
  60. hotword_pad = pad_sequence(hotword_list,
  61. batch_first=True,
  62. padding_value=0)
  63. batch["hotword_pad"] = hotword_pad
  64. batch["hotword_lengths"] = torch.tensor(hotword_lengths, dtype=torch.int32)
  65. batch['dha_pad'] = dha_pad
  66. del batch['hotword_indxs']
  67. del batch['hotword_indxs_lengths']
  68. return keys, batch