Przeglądaj źródła

update data2vec pretrain

jmwang66 3 lat temu
rodzic
commit
ded881802c

+ 1 - 0
funasr/datasets/large_datasets/datapipes/batch.py

@@ -46,6 +46,7 @@ class MaxTokenBucketizerIterDataPipe(IterableDataset):
         batch = []
         bucket = []
         max_lengths = 0
+        min_lengths = 999999
         batch_lengths = 0
 
         if self.batch_mode == "clipping":

+ 4 - 3
funasr/datasets/large_datasets/dataset.py

@@ -158,9 +158,10 @@ def Dataset(data_list_file,
     filter_fn = partial(filter, **filter_conf)
     dataset = FilterIterDataPipe(dataset, fn=filter_fn)
 
-    vocab = {'vocab': dict, 'seg_dict': seg_dict}
-    tokenize_fn = partial(tokenize, **vocab)
-    dataset = MapperIterDataPipe(dataset, fn=tokenize_fn)
+    if "text" in data_names:
+        vocab = {'vocab': dict, 'seg_dict': seg_dict}
+        tokenize_fn = partial(tokenize, **vocab)
+        dataset = MapperIterDataPipe(dataset, fn=tokenize_fn)
 
     if shuffle:
         buffer_conf = conf.get('shuffle_conf', {})