Bladeren bron

aishell example

游雁 1 jaar geleden
bovenliggende
commit
ff4306346e

+ 4 - 3
examples/aishell/paraformer/run.sh

@@ -50,6 +50,7 @@ inference_scp="wav.scp"
 
 if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then
     echo "stage -1: Data Download"
+    mkdir -p ${raw_data}
     local/download_and_untar.sh ${raw_data} ${data_url} data_aishell
     local/download_and_untar.sh ${raw_data} ${data_url} resource_aishell
 fi
@@ -76,9 +77,8 @@ fi
 
 if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
     echo "stage 1: Feature and CMVN Generation"
-#    utils/compute_cmvn.sh --fbankdir ${feats_dir}/data/${train_set} --cmd "$train_cmd" --nj $nj --feats_dim ${feats_dim} --config_file "$config" --scale 1.0
     python ../../../funasr/bin/compute_audio_cmvn.py \
-    --config-path "${workspace}" \
+    --config-path "${workspace}/conf" \
     --config-name "${config}" \
     ++train_data_set_list="${feats_dir}/data/${train_set}/audio_datasets.jsonl" \
     ++cmvn_file="${feats_dir}/data/${train_set}/cmvn.json" \
@@ -109,13 +109,14 @@ fi
 if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
   echo "stage 4: ASR Training"
 
+  mkdir -p ${exp_dir}/exp/${model_dir}
   log_file="${exp_dir}/exp/${model_dir}/train.log.txt"
   echo "log_file: ${log_file}"
   torchrun \
   --nnodes 1 \
   --nproc_per_node ${gpu_num} \
   ../../../funasr/bin/train.py \
-  --config-path "${workspace}" \
+  --config-path "${workspace}/conf" \
   --config-name "${config}" \
   ++train_data_set_list="${feats_dir}/data/${train_set}/audio_datasets.jsonl" \
   ++tokenizer_conf.token_list="${token_list}" \

+ 12 - 11
funasr/bin/compute_audio_cmvn.py

@@ -79,8 +79,8 @@ def main(**kwargs):
 
         fbank = batch["speech"].numpy()[0, :, :]
         if total_frames == 0:
-            mean_stats = fbank
-            var_stats = np.square(fbank)
+            mean_stats = np.sum(fbank, axis=0)
+            var_stats = np.sum(np.square(fbank), axis=0)
         else:
             mean_stats += np.sum(fbank, axis=0)
             var_stats += np.sum(np.square(fbank), axis=0)
@@ -93,6 +93,7 @@ def main(**kwargs):
         'total_frames': total_frames
     }
     cmvn_file = kwargs.get("cmvn_file", "cmvn.json")
+    # import pdb;pdb.set_trace()
     with open(cmvn_file, 'w') as fout:
         fout.write(json.dumps(cmvn_info))
         
@@ -110,14 +111,14 @@ def main(**kwargs):
         fout.write("</Nnet>" + '\n')
     
     
-
+    
+"""
+python funasr/bin/compute_audio_cmvn.py \
+--config-path "/Users/zhifu/funasr1.0/examples/aishell/paraformer/conf" \
+--config-name "train_asr_paraformer_conformer_12e_6d_2048_256.yaml" \
+++train_data_set_list="/Users/zhifu/funasr1.0/data/list/audio_datasets.jsonl" \
+++cmvn_file="/Users/zhifu/funasr1.0/data/list/cmvn.json" \
+++dataset_conf.num_workers=0
+"""
 if __name__ == "__main__":
     main_hydra()
-    """
-    python funasr/bin/compute_status.py \
-    --config-path "/Users/zhifu/funasr1.0/examples/aishell/conf" \
-    --config-name "train_asr_paraformer_conformer_12e_6d_2048_256.yaml" \
-    ++train_data_set_list="/Users/zhifu/funasr1.0/data/list/audio_datasets.jsonl" \
-    ++cmvn_file="/Users/zhifu/funasr1.0/data/list/cmvn.json" \
-    ++dataset_conf.num_workers=32
-    """

+ 2 - 3
funasr/bin/train.py

@@ -79,9 +79,8 @@ def main(**kwargs):
         frontend = frontend_class(**kwargs["frontend_conf"])
         kwargs["frontend"] = frontend
         kwargs["input_size"] = frontend.output_size()
-    
-    # import pdb;
-    # pdb.set_trace()
+
+
     # build model
     model_class = tables.model_classes.get(kwargs["model"])
     model = model_class(**kwargs, **kwargs["model_conf"], vocab_size=len(tokenizer.token_list))

+ 3 - 3
funasr/datasets/audio_datasets/datasets.py

@@ -22,12 +22,12 @@ class AudioDataset(torch.utils.data.Dataset):
         self.index_ds = index_ds_class(path, **kwargs)
         preprocessor_speech = kwargs.get("preprocessor_speech", None)
         if preprocessor_speech:
-            preprocessor_speech_class = tables.preprocessor_speech_classes.get(preprocessor_speech)
+            preprocessor_speech_class = tables.preprocessor_classes.get(preprocessor_speech)
             preprocessor_speech = preprocessor_speech_class(**kwargs.get("preprocessor_speech_conf"))
         self.preprocessor_speech = preprocessor_speech
         preprocessor_text = kwargs.get("preprocessor_text", None)
         if preprocessor_text:
-            preprocessor_text_class = tables.preprocessor_text_classes.get(preprocessor_text)
+            preprocessor_text_class = tables.preprocessor_classes.get(preprocessor_text)
             preprocessor_text = preprocessor_text_class(**kwargs.get("preprocessor_text_conf"))
         self.preprocessor_text = preprocessor_text
         
@@ -57,7 +57,7 @@ class AudioDataset(torch.utils.data.Dataset):
         source = item["source"]
         data_src = load_audio_text_image_video(source, fs=self.fs)
         if self.preprocessor_speech:
-            data_src = self.preprocessor_speech(data_src)
+            data_src = self.preprocessor_speech(data_src, fs=self.fs)
         speech, speech_lengths = extract_fbank(data_src, data_type=self.data_type, frontend=self.frontend, is_final=True) # speech: [b, T, d]
 
         target = item["target"]

+ 83 - 0
funasr/datasets/audio_datasets/preprocessor.py

@@ -0,0 +1,83 @@
+import os
+import json
+import torch
+import logging
+import concurrent.futures
+import librosa
+import torch.distributed as dist
+from typing import Collection
+import torch
+import torchaudio
+from torch import nn
+import random
+import re
+from funasr.tokenizer.cleaner import TextCleaner
+from funasr.register import tables
+
+
+@tables.register("preprocessor_classes", "SpeechPreprocessSpeedPerturb")
+class SpeechPreprocessSpeedPerturb(nn.Module):
+	def __init__(self, speed_perturb: list=None, **kwargs):
+		super().__init__()
+		self.speed_perturb = speed_perturb
+		
+	def forward(self, waveform, fs, **kwargs):
+		if self.speed_perturb is None:
+			return waveform
+		speed = random.choice(self.speed_perturb)
+		if speed != 1.0:
+			waveform, _ = torchaudio.sox_effects.apply_effects_tensor(
+				torch.tensor(waveform).view(1, -1), fs, [['speed', str(speed)], ['rate', str(fs)]])
+			waveform = waveform.view(-1)
+			
+		return waveform
+
+
+@tables.register("preprocessor_classes", "TextPreprocessSegDict")
+class TextPreprocessSegDict(nn.Module):
+	def __init__(self, seg_dict: str = None,
+	             text_cleaner: Collection[str] = None,
+	             split_with_space: bool = False,
+	             **kwargs):
+		super().__init__()
+		
+		self.seg_dict = None
+		if seg_dict is not None:
+			self.seg_dict = {}
+			with open(seg_dict, "r", encoding="utf8") as f:
+				lines = f.readlines()
+			for line in lines:
+				s = line.strip().split()
+				key = s[0]
+				value = s[1:]
+				self.seg_dict[key] = " ".join(value)
+		self.text_cleaner = TextCleaner(text_cleaner)
+		self.split_with_space = split_with_space
+	
+	def forward(self, text, **kwargs):
+		if self.seg_dict is not None:
+			text = self.text_cleaner(text)
+			if self.split_with_space:
+				tokens = text.strip().split(" ")
+				if self.seg_dict is not None:
+					text = seg_tokenize(tokens, self.seg_dict)
+
+		return text
+
+def seg_tokenize(txt, seg_dict):
+	pattern = re.compile(r'^[\u4E00-\u9FA50-9]+$')
+	out_txt = ""
+	for word in txt:
+		word = word.lower()
+		if word in seg_dict:
+			out_txt += seg_dict[word] + " "
+		else:
+			if pattern.match(word):
+				for char in word:
+					if char in seg_dict:
+						out_txt += seg_dict[char] + " "
+					else:
+						out_txt += "<unk>" + " "
+			else:
+				out_txt += "<unk>" + " "
+	return out_txt.strip().split()

+ 1 - 0
funasr/frontends/wav_frontend.py

@@ -32,6 +32,7 @@ def load_cmvn(cmvn_file):
                 rescale_line = line_item[3:(len(line_item) - 1)]
                 vars_list = list(rescale_line)
                 continue
+    import pdb;pdb.set_trace()
     means = np.array(means_list).astype(np.float32)
     vars = np.array(vars_list).astype(np.float32)
     cmvn = np.array([means, vars])