|
|
@@ -32,7 +32,7 @@ def padding(data, float_pad_value=0.0, int_pad_value=-1):
|
|
|
batch[data_name] = tensor_pad
|
|
|
batch[data_name + "_lengths"] = tensor_lengths
|
|
|
|
|
|
- # DHA, EAHC NOT INCLUDED
|
|
|
+ # SAC LABEL INCLUDE
|
|
|
if "hotword_indxs" in batch:
|
|
|
# if hotword indxs in batch
|
|
|
# use it to slice hotwords out
|
|
|
@@ -41,28 +41,25 @@ def padding(data, float_pad_value=0.0, int_pad_value=-1):
|
|
|
text = batch['text']
|
|
|
text_lengths = batch['text_lengths']
|
|
|
hotword_indxs = batch['hotword_indxs']
|
|
|
- num_hw = sum([int(i) for i in batch['hotword_indxs_lengths'] if i != 1]) // 2
|
|
|
- B, t1 = text.shape
|
|
|
+ dha_pad = torch.ones_like(text) * -1
|
|
|
+ _, t1 = text.shape
|
|
|
t1 += 1 # TODO: as parameter which is same as predictor_bias
|
|
|
- ideal_attn = torch.zeros(B, t1, num_hw+1)
|
|
|
nth_hw = 0
|
|
|
for b, (hotword_indx, one_text, length) in enumerate(zip(hotword_indxs, text, text_lengths)):
|
|
|
- ideal_attn[b][:,-1] = 1
|
|
|
+ dha_pad[b][:length] = 8405
|
|
|
if hotword_indx[0] != -1:
|
|
|
start, end = int(hotword_indx[0]), int(hotword_indx[1])
|
|
|
hotword = one_text[start: end+1]
|
|
|
hotword_list.append(hotword)
|
|
|
hotword_lengths.append(end-start+1)
|
|
|
- ideal_attn[b][start:end+1, nth_hw] = 1
|
|
|
- ideal_attn[b][start:end+1, -1] = 0
|
|
|
+ dha_pad[b][start: end+1] = one_text[start: end+1]
|
|
|
nth_hw += 1
|
|
|
if len(hotword_indx) == 4 and hotword_indx[2] != -1:
|
|
|
# the second hotword if exist
|
|
|
start, end = int(hotword_indx[2]), int(hotword_indx[3])
|
|
|
hotword_list.append(one_text[start: end+1])
|
|
|
hotword_lengths.append(end-start+1)
|
|
|
- ideal_attn[b][start:end+1, nth_hw-1] = 1
|
|
|
- ideal_attn[b][start:end+1, -1] = 0
|
|
|
+ dha_pad[b][start: end+1] = one_text[start: end+1]
|
|
|
nth_hw += 1
|
|
|
hotword_list.append(torch.tensor([1]))
|
|
|
hotword_lengths.append(1)
|
|
|
@@ -71,8 +68,7 @@ def padding(data, float_pad_value=0.0, int_pad_value=-1):
|
|
|
padding_value=0)
|
|
|
batch["hotword_pad"] = hotword_pad
|
|
|
batch["hotword_lengths"] = torch.tensor(hotword_lengths, dtype=torch.int32)
|
|
|
- batch['ideal_attn'] = ideal_attn
|
|
|
+ batch['dha_pad'] = dha_pad
|
|
|
del batch['hotword_indxs']
|
|
|
del batch['hotword_indxs_lengths']
|
|
|
-
|
|
|
return keys, batch
|