Forráskód Böngészése

update data2vec pretrain: add clipping

jmwang66 3 éve
szülő
commit
9befa9e508

+ 14 - 0
egs/aishell2/data2vec_pretrain/conf/train_pretrain_transformer.yaml

@@ -63,3 +63,17 @@ optim_conf:
 scheduler: tri_stage
 scheduler_conf:
     phase_ratio: [0.03,0.9,0.07]
+
+# for dataset
+dataset_conf:
+    batch_mode: clipping
+    data_names: speech,none
+    data_types: kaldi_ark,none
+    shuffle: true
+    shuffle_conf:
+        shuffle_size: 12800
+        sort_size: 12800
+    batch_conf:
+        batch_type: token
+        batch_size: 64000
+    num_workers: 8

+ 3 - 2
funasr/datasets/large_datasets/build_dataloader.py

@@ -35,15 +35,16 @@ def load_seg_dict(seg_dict_file):
 
 class ArkDataLoader(AbsIterFactory):
     def __init__(self, data_list, dict_file, dataset_conf, seg_dict_file=None, mode="train"):
-        symbol_table = read_symbol_table(dict_file)
+        symbol_table = read_symbol_table(dict_file) if dict_file is not None else None
         if seg_dict_file is not None:
             seg_dict = load_seg_dict(seg_dict_file)
         else:
             seg_dict = None
         self.dataset_conf = dataset_conf
         logging.info("dataloader config: {}".format(self.dataset_conf))
+        batch_mode = self.dataset_conf.get("batch_mode", "padding")
         self.dataset = Dataset(data_list, symbol_table, seg_dict,
-                               self.dataset_conf, mode=mode)
+                               self.dataset_conf, mode=mode, batch_mode=batch_mode)
 
     def build_iter(self, epoch, shuffle=True):
         self.dataset.set_epoch(epoch)

+ 118 - 54
funasr/datasets/large_datasets/datapipes/batch.py

@@ -24,7 +24,8 @@ class MaxTokenBucketizerIterDataPipe(IterableDataset):
             batch_size=8000,
             len_fn=_default_len_fn,
             buffer_size=10240,
-            sort_size=500
+            sort_size=500,
+            batch_mode="padding",
     ):
         assert batch_size > 0, "Batch size is required to be larger than 0!"
         assert buffer_size >= -1, "Buffer size is required to be larger than -1!"
@@ -35,6 +36,7 @@ class MaxTokenBucketizerIterDataPipe(IterableDataset):
         self.batch_size = batch_size
         self.buffer_size = buffer_size
         self.sort_size = sort_size
+        self.batch_mode = batch_mode
 
     def set_epoch(self, epoch):
         self.epoch = epoch
@@ -46,48 +48,8 @@ class MaxTokenBucketizerIterDataPipe(IterableDataset):
         max_lengths = 0
         batch_lengths = 0
 
-        if self.buffer_size == -1:
-            for d in self.datapipe:
-                if d[0] > self.batch_size:
-                    continue
-                buffer.append(d)
-            buffer.sort()
-            for sample in buffer:
-                length, _, token = sample
-                if length > max_lengths:
-                    max_lengths = length
-                batch_lengths = max_lengths * (len(batch) + 1)
-                if batch_lengths > self.batch_size:
-                    bucket.append(batch)
-                    batch = []
-                    max_lengths = length
-                batch.append(token)
-            random.shuffle(bucket)
-            if bucket:
-                for batch_sample in bucket:
-                    yield batch_sample
-            if batch:
-                yield batch
-
-        elif self.buffer_size == 0:
-            for d in self.datapipe:
-                if d[0] > self.batch_size:
-                    continue
-                length, _, token = d
-                if length > self.batch_size:
-                    continue
-                if length > max_lengths:
-                    max_lengths = length
-                batch_lengths = max_lengths * (len(batch) + 1)
-                if batch_lengths > self.batch_size:
-                    yield batch
-                    batch = []
-                    max_lengths = length
-                batch.append(token)
-            if batch:
-                yield batch
-
-        else:
+        if self.batch_mode == "clipping":
+            assert self.buffer_size > 0, "for clipping batch_mode, buffer_size must be > 1"
             for d in self.datapipe:
                 if d[0] > self.batch_size:
                     continue
@@ -100,13 +62,13 @@ class MaxTokenBucketizerIterDataPipe(IterableDataset):
                             bucket.sort()
                             for x in bucket:
                                 length, _, token = x
-                                if length > max_lengths:
-                                    max_lengths = length
-                                batch_lengths = max_lengths * (len(batch) + 1)
+                                if length < min_lengths:
+                                    min_lengths = length
+                                batch_lengths = min_lengths * (len(batch) + 1)
                                 if batch_lengths > self.batch_size:
                                     yield batch
                                     batch = []
-                                    max_lengths = length
+                                    min_lengths = length
                                 batch.append(token)
                             bucket = []
                     buffer = []
@@ -119,13 +81,13 @@ class MaxTokenBucketizerIterDataPipe(IterableDataset):
                         bucket.sort()
                         for x in bucket:
                             length, _, token = x
-                            if length > max_lengths:
-                                max_lengths = length
-                            batch_lengths = max_lengths * (len(batch) + 1)
+                            if length < min_lengths:
+                                min_lengths = length
+                            batch_lengths = min_lengths * (len(batch) + 1)
                             if batch_lengths > self.batch_size:
                                 yield batch
                                 batch = []
-                                max_lengths = length
+                                min_lengths = length
                             batch.append(token)
                         bucket = []
                 buffer = []
@@ -134,6 +96,50 @@ class MaxTokenBucketizerIterDataPipe(IterableDataset):
                 bucket.sort()
                 for x in bucket:
                     length, _, token = x
+                    if length < min_lengths:
+                        min_lengths = length
+                    batch_lengths = min_lengths * (len(batch) + 1)
+                    if batch_lengths > self.batch_size:
+                        yield batch
+                        batch = []
+                        min_lengths = length
+                    batch.append(token)
+                bucket = []
+
+            if batch:
+                yield batch
+
+        else:
+            if self.buffer_size == -1:
+                for d in self.datapipe:
+                    if d[0] > self.batch_size:
+                        continue
+                    buffer.append(d)
+                buffer.sort()
+                for sample in buffer:
+                    length, _, token = sample
+                    if length > max_lengths:
+                        max_lengths = length
+                    batch_lengths = max_lengths * (len(batch) + 1)
+                    if batch_lengths > self.batch_size:
+                        bucket.append(batch)
+                        batch = []
+                        max_lengths = length
+                    batch.append(token)
+                random.shuffle(bucket)
+                if bucket:
+                    for batch_sample in bucket:
+                        yield batch_sample
+                if batch:
+                    yield batch
+
+            elif self.buffer_size == 0:
+                for d in self.datapipe:
+                    if d[0] > self.batch_size:
+                        continue
+                    length, _, token = d
+                    if length > self.batch_size:
+                        continue
                     if length > max_lengths:
                         max_lengths = length
                     batch_lengths = max_lengths * (len(batch) + 1)
@@ -142,7 +148,65 @@ class MaxTokenBucketizerIterDataPipe(IterableDataset):
                         batch = []
                         max_lengths = length
                     batch.append(token)
-                bucket = []
+                if batch:
+                    yield batch
 
-            if batch:
-                yield batch
+            else:
+                for d in self.datapipe:
+                    if d[0] > self.batch_size:
+                        continue
+                    buffer.append(d)
+                    if len(buffer) == self.buffer_size:
+                        random.shuffle(buffer)
+                        for sample in buffer:
+                            bucket.append(sample)
+                            if len(bucket) == self.sort_size:
+                                bucket.sort()
+                                for x in bucket:
+                                    length, _, token = x
+                                    if length > max_lengths:
+                                        max_lengths = length
+                                    batch_lengths = max_lengths * (len(batch) + 1)
+                                    if batch_lengths > self.batch_size:
+                                        yield batch
+                                        batch = []
+                                        max_lengths = length
+                                    batch.append(token)
+                                bucket = []
+                        buffer = []
+
+                if buffer:
+                    random.shuffle(buffer)
+                    for sample in buffer:
+                        bucket.append(sample)
+                        if len(bucket) == self.sort_size:
+                            bucket.sort()
+                            for x in bucket:
+                                length, _, token = x
+                                if length > max_lengths:
+                                    max_lengths = length
+                                batch_lengths = max_lengths * (len(batch) + 1)
+                                if batch_lengths > self.batch_size:
+                                    yield batch
+                                    batch = []
+                                    max_lengths = length
+                                batch.append(token)
+                            bucket = []
+                    buffer = []
+
+                if bucket:
+                    bucket.sort()
+                    for x in bucket:
+                        length, _, token = x
+                        if length > max_lengths:
+                            max_lengths = length
+                        batch_lengths = max_lengths * (len(batch) + 1)
+                        if batch_lengths > self.batch_size:
+                            yield batch
+                            batch = []
+                            max_lengths = length
+                        batch.append(token)
+                    bucket = []
+
+                if batch:
+                    yield batch

+ 6 - 3
funasr/datasets/large_datasets/dataset.py

@@ -13,6 +13,7 @@ from funasr.datasets.large_datasets.datapipes.filter import FilterIterDataPipe
 from funasr.datasets.large_datasets.datapipes.map import MapperIterDataPipe
 from funasr.datasets.large_datasets.utils.filter import filter
 from funasr.datasets.large_datasets.utils.padding import padding
+from funasr.datasets.large_datasets.utils.clipping import clipping
 from funasr.datasets.large_datasets.utils.tokenize import tokenize
 
 
@@ -143,7 +144,8 @@ def Dataset(data_list_file,
             dict,
             seg_dict,
             conf,
-            mode="train"):
+            mode="train",
+            batch_mode="padding"):
     scp_lists = read_lists(data_list_file)
     shuffle = conf.get('shuffle', True)
     data_names = conf.get("data_names", "speech,text")
@@ -180,8 +182,9 @@ def Dataset(data_list_file,
                                              batch_size=batch_size,
                                              len_fn=len_fn,
                                              buffer_size=buffer_size,
-                                             sort_size=sort_size)
+                                             sort_size=sort_size,
+                                             batch_mode=batch_mode)
 
-    dataset = MapperIterDataPipe(dataset, fn=padding)
+    dataset = MapperIterDataPipe(dataset, fn=padding if batch_mode == "padding" else clipping)
 
     return dataset

+ 40 - 0
funasr/datasets/large_datasets/utils/clipping.py

@@ -0,0 +1,40 @@
+import numpy as np
+import torch
+
+from funasr.datasets.collate_fn import crop_to_max_size
+
+
+def clipping(data):
+    assert isinstance(data, list)
+    assert "key" in data[0]
+
+    keys = [x["key"] for x in data]
+
+    batch = {}
+    data_names = data[0].keys()
+    for data_name in data_names:
+        if data_name == "key":
+            continue
+        else:
+            if data[0][data_name].dtype.kind == "i":
+                tensor_type = torch.int64
+            else:
+                tensor_type = torch.float32
+
+            tensor_list = [torch.tensor(np.copy(d[data_name]), dtype=tensor_type) for d in data]
+            tensor_lengths = torch.tensor([len(d[data_name]) for d in data], dtype=torch.int32)
+
+            length_clip = min(tensor_lengths)
+            tensor_clip = tensor_list[0].new_zeros(len(tensor_list), length_clip, tensor_list[0].shape[1])
+            for i, (tensor, length) in enumerate(zip(tensor_list, tensor_lengths)):
+                diff = length - length_clip
+                assert diff >= 0
+                if diff == 0:
+                    tensor_clip[i] = tensor
+                else:
+                    tensor_clip[i] = crop_to_max_size(tensor, length_clip)
+
+            batch[data_name] = tensor_clip
+            batch[data_name + "_lengths"] = torch.tensor([tensor.shape[0] for tensor in tensor_clip], dtype=torch.long)
+
+    return keys, batch