shixian.shi пре 2 година
родитељ
комит
e5528b586d
1 измењених фајлова са 7 додато и 11 уклоњено
  1. 7 11
      funasr/datasets/large_datasets/utils/padding.py

+ 7 - 11
funasr/datasets/large_datasets/utils/padding.py

@@ -32,7 +32,7 @@ def padding(data, float_pad_value=0.0, int_pad_value=-1):
             batch[data_name] = tensor_pad
             batch[data_name] = tensor_pad
             batch[data_name + "_lengths"] = tensor_lengths
             batch[data_name + "_lengths"] = tensor_lengths
 
 
-    # DHA, EAHC NOT INCLUDED
+    # SAC LABEL INCLUDE
     if "hotword_indxs" in batch:
     if "hotword_indxs" in batch:
         # if hotword indxs in batch
         # if hotword indxs in batch
         # use it to slice hotwords out
         # use it to slice hotwords out
@@ -41,28 +41,25 @@ def padding(data, float_pad_value=0.0, int_pad_value=-1):
         text = batch['text']
         text = batch['text']
         text_lengths = batch['text_lengths']
         text_lengths = batch['text_lengths']
         hotword_indxs = batch['hotword_indxs']
         hotword_indxs = batch['hotword_indxs']
-        num_hw = sum([int(i) for i in batch['hotword_indxs_lengths'] if i != 1]) // 2
-        B, t1 = text.shape
+        dha_pad = torch.ones_like(text) * -1
+        _, t1 = text.shape
         t1 += 1  # TODO: as parameter which is same as predictor_bias
         t1 += 1  # TODO: as parameter which is same as predictor_bias
-        ideal_attn = torch.zeros(B, t1, num_hw+1)
         nth_hw = 0
         nth_hw = 0
         for b, (hotword_indx, one_text, length) in enumerate(zip(hotword_indxs, text, text_lengths)):
         for b, (hotword_indx, one_text, length) in enumerate(zip(hotword_indxs, text, text_lengths)):
-            ideal_attn[b][:,-1] = 1
+            dha_pad[b][:length] = 8405
             if hotword_indx[0] != -1:
             if hotword_indx[0] != -1:
                 start, end = int(hotword_indx[0]), int(hotword_indx[1])
                 start, end = int(hotword_indx[0]), int(hotword_indx[1])
                 hotword = one_text[start: end+1]
                 hotword = one_text[start: end+1]
                 hotword_list.append(hotword)
                 hotword_list.append(hotword)
                 hotword_lengths.append(end-start+1)
                 hotword_lengths.append(end-start+1)
-                ideal_attn[b][start:end+1, nth_hw] = 1
-                ideal_attn[b][start:end+1, -1] = 0
+                dha_pad[b][start: end+1] = one_text[start: end+1]
                 nth_hw += 1
                 nth_hw += 1
                 if len(hotword_indx) == 4 and hotword_indx[2] != -1:
                 if len(hotword_indx) == 4 and hotword_indx[2] != -1:
                     # the second hotword if exist
                     # the second hotword if exist
                     start, end = int(hotword_indx[2]), int(hotword_indx[3])
                     start, end = int(hotword_indx[2]), int(hotword_indx[3])
                     hotword_list.append(one_text[start: end+1])
                     hotword_list.append(one_text[start: end+1])
                     hotword_lengths.append(end-start+1)
                     hotword_lengths.append(end-start+1)
-                    ideal_attn[b][start:end+1, nth_hw-1] = 1
-                    ideal_attn[b][start:end+1, -1] = 0
+                    dha_pad[b][start: end+1] = one_text[start: end+1]
                     nth_hw += 1
                     nth_hw += 1
         hotword_list.append(torch.tensor([1]))
         hotword_list.append(torch.tensor([1]))
         hotword_lengths.append(1)
         hotword_lengths.append(1)
@@ -71,8 +68,7 @@ def padding(data, float_pad_value=0.0, int_pad_value=-1):
                                 padding_value=0)
                                 padding_value=0)
         batch["hotword_pad"] = hotword_pad
         batch["hotword_pad"] = hotword_pad
         batch["hotword_lengths"] = torch.tensor(hotword_lengths, dtype=torch.int32)
         batch["hotword_lengths"] = torch.tensor(hotword_lengths, dtype=torch.int32)
-        batch['ideal_attn'] = ideal_attn
+        batch['dha_pad'] = dha_pad
         del batch['hotword_indxs']
         del batch['hotword_indxs']
         del batch['hotword_indxs_lengths']
         del batch['hotword_indxs_lengths']
-
     return keys, batch
     return keys, batch