padding.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778
  1. import numpy as np
  2. import torch
  3. from torch.nn.utils.rnn import pad_sequence
  4. def padding(data, float_pad_value=0.0, int_pad_value=-1):
  5. assert isinstance(data, list)
  6. assert "key" in data[0]
  7. assert "speech" in data[0] or "text" in data[0]
  8. keys = [x["key"] for x in data]
  9. batch = {}
  10. data_names = data[0].keys()
  11. for data_name in data_names:
  12. if data_name == "key" or data_name == "sampling_rate":
  13. continue
  14. else:
  15. if data_name != 'hotword_indxs':
  16. if data[0][data_name].dtype.kind == "i":
  17. pad_value = int_pad_value
  18. tensor_type = torch.int64
  19. else:
  20. pad_value = float_pad_value
  21. tensor_type = torch.float32
  22. tensor_list = [torch.tensor(np.copy(d[data_name]), dtype=tensor_type) for d in data]
  23. tensor_lengths = torch.tensor([len(d[data_name]) for d in data], dtype=torch.int32)
  24. tensor_pad = pad_sequence(tensor_list,
  25. batch_first=True,
  26. padding_value=pad_value)
  27. batch[data_name] = tensor_pad
  28. batch[data_name + "_lengths"] = tensor_lengths
  29. # DHA, EAHC NOT INCLUDED
  30. if "hotword_indxs" in batch:
  31. # if hotword indxs in batch
  32. # use it to slice hotwords out
  33. hotword_list = []
  34. hotword_lengths = []
  35. text = batch['text']
  36. text_lengths = batch['text_lengths']
  37. hotword_indxs = batch['hotword_indxs']
  38. num_hw = sum([int(i) for i in batch['hotword_indxs_lengths'] if i != 1]) // 2
  39. B, t1 = text.shape
  40. t1 += 1 # TODO: as parameter which is same as predictor_bias
  41. ideal_attn = torch.zeros(B, t1, num_hw+1)
  42. nth_hw = 0
  43. for b, (hotword_indx, one_text, length) in enumerate(zip(hotword_indxs, text, text_lengths)):
  44. ideal_attn[b][:,-1] = 1
  45. if hotword_indx[0] != -1:
  46. start, end = int(hotword_indx[0]), int(hotword_indx[1])
  47. hotword = one_text[start: end+1]
  48. hotword_list.append(hotword)
  49. hotword_lengths.append(end-start+1)
  50. ideal_attn[b][start:end+1, nth_hw] = 1
  51. ideal_attn[b][start:end+1, -1] = 0
  52. nth_hw += 1
  53. if len(hotword_indx) == 4 and hotword_indx[2] != -1:
  54. # the second hotword if exist
  55. start, end = int(hotword_indx[2]), int(hotword_indx[3])
  56. hotword_list.append(one_text[start: end+1])
  57. hotword_lengths.append(end-start+1)
  58. ideal_attn[b][start:end+1, nth_hw-1] = 1
  59. ideal_attn[b][start:end+1, -1] = 0
  60. nth_hw += 1
  61. hotword_list.append(torch.tensor([1]))
  62. hotword_lengths.append(1)
  63. hotword_pad = pad_sequence(hotword_list,
  64. batch_first=True,
  65. padding_value=0)
  66. batch["hotword_pad"] = hotword_pad
  67. batch["hotword_lengths"] = torch.tensor(hotword_lengths, dtype=torch.int32)
  68. batch['ideal_attn'] = ideal_attn
  69. del batch['hotword_indxs']
  70. del batch['hotword_indxs_lengths']
  71. return keys, batch