Ver Fonte

Funasr1.0 (#1362)

* funasr1.0.5

* funasr1.0.5 audio samples input

* batch_type token

* batch_type token

* huggingface model zoo

* dataloader

* dataloader

* fbank input

* vad is_final=True bugfix
zhifu gao há 1 ano atrás
pai
commit
d92cd5ae03

+ 2 - 1
funasr/auto/auto_model.py

@@ -171,7 +171,7 @@ class AutoModel:
         # build model
         model_class = tables.model_classes.get(kwargs["model"])
         model = model_class(**kwargs, **kwargs["model_conf"], vocab_size=vocab_size)
-        model.eval()
+        
         model.to(device)
         
         # init_param
@@ -206,6 +206,7 @@ class AutoModel:
         kwargs = self.kwargs if kwargs is None else kwargs
         kwargs.update(cfg)
         model = self.model if model is None else model
+        model.eval()
 
         batch_size = kwargs.get("batch_size", 1)
         # if kwargs.get("device", "cpu") == "cpu":

+ 52 - 2
funasr/datasets/audio_datasets/index_ds.py

@@ -6,8 +6,8 @@ import torch.distributed as dist
 from funasr.register import tables
 
 
-@tables.register("index_ds_classes", "IndexDSJsonl")
-class IndexDSJsonl(torch.utils.data.Dataset):
+@tables.register("index_ds_classes", "IndexDSJsonlRankSplit")
+class IndexDSJsonlRankSplit(torch.utils.data.Dataset):
     
     def __init__(self, path):
         super().__init__()
@@ -66,3 +66,53 @@ class IndexDSJsonl(torch.utils.data.Dataset):
     def get_target_len(self, data_dict):
         
         return data_dict["target_len"] if "target_len" in data_dict else 0
+
+@tables.register("index_ds_classes", "IndexDSJsonl")
+@tables.register("index_ds_classes", "IndexDSJsonlRankFull")
+class IndexDSJsonlRankFull(torch.utils.data.Dataset):
+    
+    def __init__(self, path):
+        super().__init__()
+        
+        contents = []
+        with open(path, encoding='utf-8') as fin:
+            for line in fin:
+                data = json.loads(line.strip())
+                if "text" in data:  # for sft
+                    self.contents.append(data['text'])
+                if "source" in data:  # for speech lab pretrain
+                    prompt = data.get("prompt", "<ASR>")
+                    source = data["source"]
+                    target = data["target"]
+                    source_len = data.get("source_len", 1)
+                    target_len = data.get("target_len", 0)
+                    
+                    contents.append({"source": source,
+                                     "prompt": prompt,
+                                     "target": target,
+                                     "source_len": source_len,
+                                     "target_len": target_len,
+                                     }
+                                    )
+
+        self.contents = contents
+        
+        logging.info(
+            "total_num of samplers across ranks: {}".format(len(self.contents)))
+    
+    def __len__(self):
+        return len(self.contents)
+    
+    def __getitem__(self, index):
+        try:
+            data = self.contents[index]
+        except:
+            print(index)
+        return data
+    
+    def get_source_len(self, data_dict):
+        return data_dict.get("source_len", 1)
+    
+    def get_target_len(self, data_dict):
+        
+        return data_dict.get("target_len", 0)

+ 193 - 0
funasr/datasets/audio_datasets/samplers.py

@@ -1,5 +1,7 @@
 import torch
 import numpy as np
+import logging
+import torch.distributed as dist
 
 from funasr.register import tables
 
@@ -82,3 +84,194 @@ class BatchSampler(torch.utils.data.BatchSampler):
                     max_token = sample_len_cur_raw
                     num_sample = 1
 
+
+@tables.register("batch_sampler_classes", "BatchSampler")
+@tables.register("batch_sampler_classes", "RankFullLocalShuffleBatchSampler")
+class RankFullLocalShuffleBatchSampler(torch.utils.data.BatchSampler):
+    
+    def __init__(self, dataset,
+                 batch_type: str = "example",
+                 batch_size: int = 100,
+                 buffer_size: int = 30,
+                 drop_last: bool = True,
+                 shuffle: bool = True,
+                 is_training: bool = True,
+                 **kwargs):
+        
+        self.drop_last = drop_last
+        self.pre_idx = -1
+        self.dataset = dataset
+        self.total_samples = len(dataset)
+        self.batch_type = batch_type
+        self.batch_size = int(batch_size)
+        self.buffer_size = buffer_size
+        self.max_token_length = kwargs.get("max_token_length", 1500)
+        self.shuffle_idx = np.arange(self.total_samples)
+        self.shuffle = shuffle and is_training
+        self.length_scale_source = kwargs.get("length_scale_source", 1.0)
+        
+        try:
+            rank = dist.get_rank()
+            world_size = dist.get_world_size()
+        except:
+            rank = 0
+            world_size = 1
+        self.rank = rank
+        self.world_size = world_size
+        
+    def __len__(self):
+        return (self.total_samples - 1) // (self.batch_size * self.world_size) + 1
+    
+    def set_epoch(self, epoch):
+        np.random.seed(epoch)
+    
+    def __iter__(self):
+    
+        batch_size_total = self.batch_size * self.world_size
+        
+        if self.shuffle:
+            np.random.shuffle(self.shuffle_idx)
+        
+        batch = []
+        max_token = 0
+        num_sample = 0
+        
+        iter_num = (self.total_samples - 1) // self.buffer_size + 1
+        # print("iter_num: ", iter_num)
+        for iter in range(self.pre_idx + 1, iter_num):
+            # if iter == iter_num -1 and self.drop_last:
+            #     continue
+            datalen_with_index = []
+            for i in range(self.buffer_size):
+                idx = iter * self.buffer_size + i
+                if idx >= self.total_samples:
+                    continue
+                
+                idx_map = self.shuffle_idx[idx]
+                # prompt = self.dataset.indexed_dataset[idx_map]["prompt"]
+                
+                source_len = self.dataset.get_source_len(idx_map) / self.length_scale_source
+                target_len = self.dataset.get_target_len(idx_map) if self.batch_type == 'length' else 0.0
+                sample_len_cur = source_len + target_len
+                
+                datalen_with_index.append([idx, sample_len_cur])
+            
+            datalen_with_index_sort = sorted(datalen_with_index, key=lambda x: x[1])
+            for item in datalen_with_index_sort:
+                idx, sample_len_cur_raw = item
+                if sample_len_cur_raw > self.max_token_length:
+                    continue
+
+                max_token_cur = max(max_token, sample_len_cur_raw)
+                max_token_padding = 1 + num_sample
+                # if self.batch_type != 'example':
+                #     max_token_padding *= max_token_cur
+                if max_token_padding <= batch_size_total:
+                    batch.append(idx)
+                    max_token = max_token_cur
+                    num_sample += 1
+                else:
+                    batch_rank = batch[self.rank*self.batch_size: (self.rank+1)*self.batch_size]
+                    yield batch_rank
+                    batch = [idx]
+                    max_token = sample_len_cur_raw
+                    num_sample = 1
+
+
+@tables.register("batch_sampler_classes", "RankFullLocalShuffleDynamicBatchSampler")
+class RankFullLocalShuffleDynamicBatchSampler(torch.utils.data.BatchSampler):
+    
+    def __init__(self, dataset,
+                 batch_type: str = "example",
+                 batch_size: int = 100,
+                 buffer_size: int = 30,
+                 drop_last: bool = True,
+                 shuffle: bool = True,
+                 is_training: bool = True,
+                 **kwargs):
+        
+        self.drop_last = drop_last
+        self.pre_idx = -1
+        self.dataset = dataset
+        self.total_samples = len(dataset)
+        self.batch_type = batch_type
+        self.batch_size = int(batch_size)
+        self.buffer_size = buffer_size
+        self.max_token_length = kwargs.get("max_token_length", 1500)
+        self.shuffle_idx = np.arange(self.total_samples)
+        self.shuffle = shuffle and is_training
+        self.length_scale_source = kwargs.get("length_scale_source", 1.0)
+        
+        try:
+            rank = dist.get_rank()
+            world_size = dist.get_world_size()
+        except:
+            rank = 0
+            world_size = 1
+        self.rank = rank
+        self.world_size = world_size
+    
+    def __len__(self):
+        return (self.total_samples - 1) // (self.batch_size * self.world_size) + 1
+    
+    def set_epoch(self, epoch):
+        np.random.seed(epoch)
+    
+    def __iter__(self):
+        
+        batch_size_total = self.batch_size * self.world_size
+        if self.shuffle:
+            np.random.shuffle(self.shuffle_idx)
+        
+        batch_list_all_rank = []
+        batch_list_cur = []
+        max_token = 0
+        num_sample = 0
+        
+        iter_num = (self.total_samples - 1) // self.buffer_size + 1
+        # print("iter_num: ", iter_num)
+        for iter in range(self.pre_idx + 1, iter_num):
+            # if iter == iter_num - 1 and self.drop_last:
+            #     continue
+            datalen_with_index = []
+            for i in range(self.buffer_size):
+                idx = iter * self.buffer_size + i
+                if idx >= self.total_samples:
+                    continue
+                
+                idx_map = self.shuffle_idx[idx]
+                # prompt = self.dataset.indexed_dataset[idx_map]["prompt"]
+                
+                source_len = self.dataset.get_source_len(idx_map) / self.length_scale_source
+                target_len = self.dataset.get_target_len(idx_map) if self.batch_type == 'length' else 0.0
+                sample_len_cur = source_len + target_len
+                
+                datalen_with_index.append([idx, sample_len_cur])
+            
+            datalen_with_index_sort = sorted(datalen_with_index, key=lambda x: x[1])
+            for ii, item in enumerate(datalen_with_index_sort):
+                is_last_batch = iter == iter_num - 1 and ii == len(datalen_with_index_sort)
+                idx, sample_len_cur_raw = item
+                if sample_len_cur_raw > self.max_token_length:
+                    continue
+                
+                max_token_cur = max(max_token, sample_len_cur_raw)
+                max_token_padding = 1 + num_sample
+                
+                if self.batch_type != 'example':
+                    max_token_padding *= max_token_cur
+                if len(batch_list_all_rank) < self.world_size:
+                    
+                    if max_token_padding <= self.batch_size:
+                        batch_list_cur.append(idx)
+                        max_token = max_token_cur
+                        num_sample += 1
+                    else:
+                        batch_list_all_rank.append(batch_list_cur)
+                        batch_list_cur = []
+                else:
+                    batch_rank = batch_list_all_rank[self.rank]
+                    yield batch_rank
+                    batch_list_all_rank = [idx]
+                    max_token = sample_len_cur_raw
+                    num_sample = 1

+ 2 - 1
funasr/models/fsmn_vad_streaming/model.py

@@ -575,7 +575,8 @@ class FsmnVADStreaming(nn.Module):
 		
 		time1 = time.perf_counter()
 		is_streaming_input = kwargs.get("is_streaming_input", False) if chunk_size >= 15000 else kwargs.get("is_streaming_input", True)
-		cfg = {"is_final": kwargs.get("is_final", False), "is_streaming_input": is_streaming_input}
+		is_final = kwargs.get("is_final", False) if is_streaming_input else kwargs.get("is_final", True)
+		cfg = {"is_final": is_final, "is_streaming_input": is_streaming_input}
 		audio_sample_list = load_audio_text_image_video(data_in,
 		                                                fs=frontend.fs,
 		                                                audio_fs=kwargs.get("fs", 16000),

+ 1 - 1
funasr/models/paraformer/cif_predictor.py

@@ -186,7 +186,7 @@ class CifPredictorV2(torch.nn.Module):
         alphas = alphas.squeeze(-1)
         mask = mask.squeeze(-1)
         if target_label_length is not None:
-            target_length = target_label_length
+            target_length = target_label_length.squeeze(-1)
         elif target_label is not None:
             target_length = (target_label != ignore_id).float().sum(-1)
         else:

+ 2 - 0
funasr/models/paraformer/model.py

@@ -491,6 +491,8 @@ class Paraformer(torch.nn.Module):
         b, n, d = decoder_out.size()
         if isinstance(key[0], (list, tuple)):
             key = key[0]
+        if len(key) < b:
+            key = key*b
         for i in range(b):
             x = encoder_out[i, :encoder_out_lens[i], :]
             am_scores = decoder_out[i, :pre_token_length[i], :]

+ 18 - 0
funasr/train_utils/trainer.py

@@ -204,7 +204,25 @@ class Trainer:
             my_context = self.model.no_sync if batch_idx % accum_grad != 0 else nullcontext
             with my_context():
                 time2 = time.perf_counter()
+                print("before, GPU, memory: {:.1} MB, "
+                      "{:.1} MB, "
+                      "{:.1} MB, "
+                      "{:.1} MB".format(torch.cuda.memory_allocated()/1024/1024/1024,
+                                     torch.cuda.max_memory_allocated()/1024/1024/1024,
+                                     torch.cuda.memory_reserved()/1024/1024/1024,
+                                     torch.cuda.max_memory_reserved()/1024/1024/1024,
+                                     ))
+
                 retval = self.model(**batch)
+                torch.cuda.empty_cache()
+                print("after, GPU, memory: {:.1} MB, "
+                      "{:.1} MB, "
+                      "{:.1} MB, "
+                      "{:.1} MB".format(torch.cuda.memory_allocated()/1024/1024/1024,
+                                     torch.cuda.max_memory_allocated()/1024/1024/1024,
+                                     torch.cuda.memory_reserved()/1024/1024/1024,
+                                     torch.cuda.max_memory_reserved()/1024/1024/1024,
+                                     ))
                 time3 = time.perf_counter()
                 speed_stats["forward_time"] = f"{time3 - time2:0.3f}"
                 loss, stats, weight = retval