Bläddra i källkod

Dev gzf funasr2 (#1123)

* setup jamo
zhifu gao 2 år sedan
förälder
incheckning
269554388c
4 ändrade filer med 63 tillägg och 32 borttagningar
  1. 1 1
      funasr/datasets/data_sampler.py
  2. 32 19
      funasr/datasets/dataloader_fn.py
  3. 29 11
      funasr/datasets/dataset_jsonl.py
  4. 1 1
      setup.py

+ 1 - 1
funasr/datasets/data_sampler.py

@@ -4,7 +4,7 @@ import numpy as np
 
 class BatchSampler(torch.utils.data.BatchSampler):
 	
-	def __init__(self, dataset, batch_size_type: str="example", batch_size: int=14, sort_size: int=30, drop_last: bool=False, shuffle: bool=True, **kwargs):
+	def __init__(self, dataset, batch_size_type: str="example", batch_size: int=100, sort_size: int=30, drop_last: bool=False, shuffle: bool=True, **kwargs):
 		
 		self.drop_last = drop_last
 		self.pre_idx = -1

+ 32 - 19
funasr/datasets/dataloader_fn.py

@@ -1,4 +1,4 @@
-
+import time
 import torch
 from funasr.datasets.dataset_jsonl import AudioDataset
 from funasr.datasets.data_sampler import BatchSampler
@@ -8,7 +8,7 @@ from funasr.tokenizer.token_id_converter import TokenIDConverter
 collate_fn = None
 # collate_fn = collate_fn,
 
-jsonl = "/Users/zhifu/funasr_github/test_local/all_task_debug_len.jsonl"
+jsonl = "/Users/zhifu/funasr_github/test_local/aishell2_dev_ios/asr_task_debug_len.jsonl"
 
 frontend = WavFrontend()
 token_type = 'char'
@@ -26,7 +26,7 @@ tokenizer = build_tokenizer(
     non_linguistic_symbols=non_linguistic_symbols,
     g2p_type=g2p_type,
 )
-token_list = ""
+token_list = "/Users/zhifu/.cache/modelscope/hub/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/tokens.txt"
 unk_symbol = "<unk>"
 
 token_id_converter = TokenIDConverter(
@@ -34,20 +34,33 @@ token_id_converter = TokenIDConverter(
     unk_symbol=unk_symbol,
 )
 
-dataset = AudioDataset(jsonl, frontend=frontend, tokenizer=tokenizer)
+dataset = AudioDataset(jsonl, frontend=frontend, tokenizer=tokenizer, token_id_converter=token_id_converter)
 batch_sampler = BatchSampler(dataset)
-dataloader_tr = torch.utils.data.DataLoader(dataset,
-                           collate_fn=dataset.collator,
-                           batch_sampler=batch_sampler,
-                           shuffle=False,
-                           num_workers=0,
-                           pin_memory=True)
-
-print(len(dataset))
-for i in range(3):
-    print(i)
-    for data in dataloader_tr:
-        print(len(data), data)
-# data_iter = iter(dataloader_tr)
-# data = next(data_iter)
-pass
+
+
+def collator(samples: list = None):
+    return samples
+
+if __name__ == "__main__":
+    
+    dataloader_tr = torch.utils.data.DataLoader(dataset,
+                                                collate_fn=dataset.collator,
+                                                batch_sampler=batch_sampler,
+                                                shuffle=False,
+                                                num_workers=8,
+                                                pin_memory=True)
+    
+    print(len(dataset))
+    for i in range(3):
+        print(i)
+        beg = time.time()
+        for j, data in enumerate(dataloader_tr):
+            end = time.time()
+            time_cost = end - beg
+            beg = end
+            print(j, time_cost)
+    # data_iter = iter(dataloader_tr)
+    # data = next(data_iter)
+    pass
+
+    

+ 29 - 11
funasr/datasets/dataset_jsonl.py

@@ -4,8 +4,8 @@ import torch.distributed as dist
 import numpy as np
 import kaldiio
 import librosa
-
-
+import torchaudio
+import time
 
 def load_audio(audio_path: str, fs: int=16000):
 	audio = None
@@ -17,12 +17,19 @@ def load_audio(audio_path: str, fs: int=16000):
 		if ".ark:" in audio_path:
 			audio = kaldiio.load_mat(audio_path)
 		else:
-			audio, fs = librosa.load(audio_path, sr=fs)
+			# audio, fs = librosa.load(audio_path, sr=fs)
+			audio, fs = torchaudio.load(audio_path)
+			audio = audio[0, :]
 	return audio
 
 def extract_features(data, date_type: str="sound", frontend=None):
 	if date_type == "sound":
-		feat, feats_lens = frontend(data, len(data))
+
+		if isinstance(data, np.ndarray):
+			data = torch.from_numpy(data).to(torch.float32)
+		data_len = torch.tensor([data.shape[0]]).to(torch.int32)
+		feat, feats_lens = frontend(data[None, :], data_len)
+
 		feat = feat[0, :, :]
 	else:
 		feat, feats_lens = torch.from_numpy(data).to(torch.float32), torch.tensor([data.shape[0]]).to(torch.int32)
@@ -74,13 +81,16 @@ class IndexedDatasetJsonl(torch.utils.data.Dataset):
 
 
 class AudioDataset(torch.utils.data.Dataset):
-	def __init__(self, path, frontend=None, tokenizer=None):
+	def __init__(self, path, frontend=None, tokenizer=None, token_id_converter=None):
+
 		super().__init__()
 		self.indexed_dataset = IndexedDatasetJsonl(path)
 		self.frontend = frontend.forward
 		self.fs = 16000 if frontend is None else frontend.fs
 		self.data_type = "sound"
 		self.tokenizer = tokenizer
+		self.token_id_converter = token_id_converter
+
 		self.int_pad_value = -1
 		self.float_pad_value = 0.0
 
@@ -92,13 +102,17 @@ class AudioDataset(torch.utils.data.Dataset):
 	
 	def __getitem__(self, index):
 		item = self.indexed_dataset[index]
+		# return item
+
 		source = item["source"]
 		data_src = load_audio(source, fs=self.fs)
 		speech, speech_lengths = extract_features(data_src, self.data_type, self.frontend)
 		target = item["target"]
-		text = self.tokenizer.encode(target)
-		text_lengths = len(text)
-		text, text_lengths = torch.tensor(text, dtype=torch.int64), torch.tensor([text_lengths], dtype=torch.int32)
+		text = self.tokenizer.text2tokens(target)
+		ids = self.token_id_converter.tokens2ids(text)
+		ids_lengths = len(ids)
+		text, text_lengths = torch.tensor(ids, dtype=torch.int64), torch.tensor([ids_lengths], dtype=torch.int32)
+
 		return {"speech": speech,
 		        "speech_lengths": speech_lengths,
 		        "text": text,
@@ -108,17 +122,21 @@ class AudioDataset(torch.utils.data.Dataset):
 	
 	def collator(self, samples: list=None):
 		
+		# return samples
+		
 		outputs = {}
 		for sample in samples:
 			for key in sample.keys():
 				if key not in outputs:
 					outputs[key] = []
 				outputs[key].append(sample[key])
-		
+
 		for key, data_list in outputs.items():
-			if data_list[0].dtype.kind == "i":
+			if data_list[0].dtype == torch.int64:
+
 				pad_value = self.int_pad_value
 			else:
 				pad_value = self.float_pad_value
 			outputs[key] = torch.nn.utils.rnn.pad_sequence(data_list, batch_first=True, padding_value=pad_value)
-		return samples
+		return outputs
+

+ 1 - 1
setup.py

@@ -14,7 +14,7 @@ requirements = {
         "humanfriendly",
         "scipy>=1.4.1",
         "librosa",
-        # "jamo",  # For kss
+        "jamo",  # For kss
         "PyYAML>=5.1.2",
         # "soundfile>=0.12.1",
         # "h5py>=3.1.0",