游雁 1 年之前
父节点
当前提交
fa6f60fa76

+ 18 - 16
examples/aishell/llm_asr_nar/conf/template.yaml

@@ -24,11 +24,11 @@ llm_conf:
   init_param_path: "/nfs/maziyang.mzy/models/vicuna-7b-v1.5"
   freeze: true
 
-adaptor: linear
+adaptor: Linear
 adaptor_conf:
   downsample_rate: 1
   llm_dim: 4096
-  encoder_dim: 2048
+  encoder_dim: 512
 
 # frontend related
 frontend: WavFrontend
@@ -38,54 +38,56 @@ frontend_conf:
     n_mels: 80
     frame_length: 25
     frame_shift: 10
-    dither: 0.0
-    lfr_m: 1
-    lfr_n: 1
+    lfr_m: 7
+    lfr_n: 6
+    cmvn_file: "/root/.cache/modelscope/hub/iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/am.mvn"
 
-specaug: SpecAug
+specaug: SpecAugLFR
 specaug_conf:
-    apply_time_warp: true
+    apply_time_warp: false
     time_warp_window: 5
     time_warp_mode: bicubic
     apply_freq_mask: true
     freq_mask_width_range:
     - 0
     - 30
-    num_freq_mask: 2
+    lfr_rate: 6
+    num_freq_mask: 1
     apply_time_mask: true
     time_mask_width_range:
     - 0
-    - 40
-    num_time_mask: 2
+    - 12
+    num_time_mask: 1
 
 train_conf:
   accum_grad: 1
   grad_clip: 5
   max_epoch: 150
   keep_nbest_models: 10
-  log_interval: 50
+  log_interval: 10
 
-optim: adam
+optim: adamw
 optim_conf:
-   lr: 0.001
+   lr: 0.0001
    weight_decay: 0.000001
 scheduler: warmuplr
 scheduler_conf:
-   warmup_steps: 35000
+   warmup_steps: 1500
 
 dataset: AudioLLMDataset
 dataset_conf:
     index_ds: IndexDSJsonl
     batch_sampler: RankFullLocalShuffleBatchSampler
     batch_type: example # example or length
-    batch_size: 4 # if batch_type is example, batch_size is the numbers of samples; if length, batch_size is source_token_len+target_token_len;
+    batch_size: 8 # if batch_type is example, batch_size is the numbers of samples; if length, batch_size is source_token_len+target_token_len;
     max_token_length: 2048 # filter samples if source_token_len+target_token_len > max_token_length,
     buffer_size: 500
     shuffle: True
     num_workers: 4
+    preprocessor_text: TextPreprocessRemovePunctuation
 
 tokenizer: HuggingfaceTokenizer
 tokenizer_conf:
   unk_symbol: <unk>
-  init_param_path: null
+  init_param_path: "/nfs/maziyang.mzy/models/vicuna-7b-v1.5"
 

+ 4 - 2
funasr/auto/auto_model.py

@@ -157,8 +157,10 @@ class AutoModel:
             tokenizer_class = tables.tokenizer_classes.get(tokenizer)
             tokenizer = tokenizer_class(**kwargs["tokenizer_conf"])
             kwargs["tokenizer"] = tokenizer
-            kwargs["token_list"] = tokenizer.token_list
-            vocab_size = len(tokenizer.token_list)
+
+            kwargs["token_list"] = tokenizer.token_list if hasattr(tokenizer, "token_list") else None
+            kwargs["token_list"] = tokenizer.get_vocab() if hasattr(tokenizer, "get_vocab") else kwargs["token_list"]
+            vocab_size = len(kwargs["token_list"])
         else:
             vocab_size = -1
         

+ 3 - 1
funasr/bin/train.py

@@ -85,7 +85,9 @@ def main(**kwargs):
 
     # build model
     model_class = tables.model_classes.get(kwargs["model"])
-    model = model_class(**kwargs, **kwargs["model_conf"], vocab_size=len(tokenizer.token_list))
+    vocab_size = len(tokenizer.token_list) if hasattr(tokenizer, "token_list") else None
+    vocab_size = len(tokenizer.get_vocab()) if hasattr(tokenizer, "get_vocab") else vocab_size
+    model = model_class(**kwargs, **kwargs["model_conf"], vocab_size=vocab_size)
 
 
 

+ 7 - 6
funasr/datasets/llm_datasets/datasets.py

@@ -24,12 +24,12 @@ class AudioLLMDataset(torch.utils.data.Dataset):
         preprocessor_speech = kwargs.get("preprocessor_speech", None)
         if preprocessor_speech:
             preprocessor_speech_class = tables.preprocessor_classes.get(preprocessor_speech)
-            preprocessor_speech = preprocessor_speech_class(**kwargs.get("preprocessor_speech_conf"))
+            preprocessor_speech = preprocessor_speech_class(**kwargs.get("preprocessor_speech_conf", {}))
         self.preprocessor_speech = preprocessor_speech
         preprocessor_text = kwargs.get("preprocessor_text", None)
         if preprocessor_text:
             preprocessor_text_class = tables.preprocessor_classes.get(preprocessor_text)
-            preprocessor_text = preprocessor_text_class(**kwargs.get("preprocessor_text_conf"))
+            preprocessor_text = preprocessor_text_class(**kwargs.get("preprocessor_text_conf", {}))
         self.preprocessor_text = preprocessor_text
         
         self.frontend = frontend
@@ -43,6 +43,7 @@ class AudioLLMDataset(torch.utils.data.Dataset):
         self.prompt_pre = "USER: \nINSTRUCTION: {}\nINPUT: ".format(
             self.prompt)  # "USER: \nINSTRUCTION: {}\nnINPUT: {}\nASSISTANT: "
         self.prompt_af = ""
+        self.IGNORE_INDEX = kwargs.get("IGNORE_INDEX", -100)
     
     def get_source_len(self, index):
         item = self.index_ds[index]
@@ -64,7 +65,7 @@ class AudioLLMDataset(torch.utils.data.Dataset):
         if self.preprocessor_speech:
             data_src = self.preprocessor_speech(data_src, fs=self.fs)
         speech, speech_lengths = extract_fbank(data_src, data_type=self.data_type, frontend=self.frontend, is_final=True) # speech: [b, T, d]
-        speech = speech.sequeeze(0)
+        speech = speech.squeeze(0)
 
         target = item["target"]
         if self.preprocessor_text:
@@ -91,10 +92,10 @@ class AudioLLMDataset(torch.utils.data.Dataset):
         label_mask = labels_ids.ge(0)  # [False,False,True,True]
         labels_ids[~label_mask] = self.IGNORE_INDEX  # [-100,-100,input,eos]
         
-        audio_mask = [0] * prompt_pre_length + [1] * audio_length
-        torch.tensor(audio_mask, dtype=torch.float32)
+        audio_mask = [0] * prompt_pre_length + [1] * audio_length + [0]
+        audio_mask = torch.tensor(audio_mask, dtype=torch.float32)
         
-        ids = self.tokenizer.encode(target)
+        ids = self.tokenizer.encode(target) # token ids is different from labels_ids
         text = torch.tensor(ids, dtype=torch.int64)
         text_lengths = torch.tensor([len(ids)], dtype=torch.int32)
         

+ 13 - 27
funasr/datasets/llm_datasets/preprocessor.py

@@ -11,41 +11,27 @@ import torchaudio
 from torch import nn
 import random
 import re
+import string
 from funasr.tokenizer.cleaner import TextCleaner
 from funasr.register import tables
 
 
-@tables.register("preprocessor_classes", "SpeechPreprocessSpeedPerturb")
-class SpeechPreprocessSpeedPerturb(nn.Module):
-	def __init__(self, speed_perturb: list=None, **kwargs):
-		super().__init__()
-		self.speed_perturb = speed_perturb
-		
-	def forward(self, waveform, fs, **kwargs):
-		if self.speed_perturb is None:
-			return waveform
-		speed = random.choice(self.speed_perturb)
-		if speed != 1.0:
-			if not isinstance(waveform, torch.Tensor):
-				waveform = torch.tensor(waveform)
-			waveform, _ = torchaudio.sox_effects.apply_effects_tensor(
-				waveform.view(1, -1), fs, [['speed', str(speed)], ['rate', str(fs)]])
-			waveform = waveform.view(-1)
-			
-		return waveform
 
-
-@tables.register("preprocessor_classes", "TextPreprocessSegDict")
+@tables.register("preprocessor_classes", "TextPreprocessRemovePunctuation")
 class TextPreprocessSegDict(nn.Module):
-	def __init__(self, seg_dict: str = None,
-	             text_cleaner: Collection[str] = None,
-	             split_with_space: bool = False,
+	def __init__(self,
 	             **kwargs):
 		super().__init__()
 		
-		self.text_cleaner = TextCleaner(text_cleaner)
 	
 	def forward(self, text, **kwargs):
-		text = self.text_cleaner(text)
-		
-		return text
+		# 定义英文标点符号
+		en_punct = string.punctuation
+		# 定义中文标点符号(部分常用的)
+		cn_punct = '。?!,、;:“”‘’()《》【】…—~·'
+		# 合并英文和中文标点符号
+		all_punct = en_punct + cn_punct
+		# 创建正则表达式模式,匹配任何在all_punct中的字符
+		punct_pattern = re.compile('[{}]'.format(re.escape(all_punct)))
+		# 使用正则表达式的sub方法替换掉这些字符
+		return punct_pattern.sub('', text)

+ 0 - 96
funasr/datasets/llm_datasets/scp2jsonl.py

@@ -1,96 +0,0 @@
-import os
-import json
-import torch
-import logging
-import hydra
-from omegaconf import DictConfig, OmegaConf
-import concurrent.futures
-import librosa
-import torch.distributed as dist
-
-
-
-def gen_jsonl_from_wav_text_list(path, data_type_list=("source", "target"), jsonl_file_out:str=None, **kwargs):
-    try:
-        rank = dist.get_rank()
-        world_size = dist.get_world_size()
-    except:
-        rank = 0
-        world_size = 1
-
-    cpu_cores = os.cpu_count() or 1
-    print(f"convert wav.scp text to jsonl, ncpu: {cpu_cores}")
-    if rank == 0:
-        json_dict = {}
-        for data_type, data_file in zip(data_type_list, path):
-            json_dict[data_type] = {}
-            with open(data_file, "r") as f:
-                
-                data_file_lists = f.readlines()
-                lines_for_each_th = (len(data_file_lists)-1)//cpu_cores + 1
-                task_num = cpu_cores if len(data_file_lists) > cpu_cores else 1
-                with concurrent.futures.ThreadPoolExecutor(max_workers=cpu_cores) as executor:
-
-                    futures = [executor.submit(parse_context_length, data_file_lists[i*lines_for_each_th:(i+1)*lines_for_each_th], data_type) for i in range(task_num)]
-    
-                    for future in concurrent.futures.as_completed(futures):
-                        
-                        json_dict[data_type].update(future.result())
-            # print(json_dict)
-        
-        with open(jsonl_file_out, "w") as f:
-            for key in json_dict[data_type_list[0]].keys():
-                jsonl_line = {"key": key}
-                for data_file in data_type_list:
-                    jsonl_line.update(json_dict[data_file][key])
-                jsonl_line = json.dumps(jsonl_line, ensure_ascii=False)
-                f.write(jsonl_line+"\n")
-                f.flush()
-                
-    else:
-        pass
-        
-    if world_size > 1:
-        dist.barrier()
-    
-    
-def parse_context_length(data_list: list, data_type: str):
-    
-    res = {}
-    for i, line in enumerate(data_list):
-        key, line = line.strip().split(maxsplit=1)
-        line = line.strip()
-        if os.path.exists(line):
-            waveform, _ = librosa.load(line, sr=16000)
-            sample_num = len(waveform)
-            context_len = int(sample_num//16000*1000/10)
-        else:
-            context_len = len(line.split()) if " " in line else len(line)
-        res[key] = {data_type: line, f"{data_type}_len": context_len}
-    return res
-    
-
-@hydra.main(config_name=None, version_base=None)
-def main_hydra(cfg: DictConfig):
- 
-    kwargs = OmegaConf.to_container(cfg, resolve=True)
-
-    scp_file_list = kwargs.get("scp_file_list", ("/Users/zhifu/funasr1.0/test_local/wav.scp", "/Users/zhifu/funasr1.0/test_local/text.txt"))
-    if isinstance(scp_file_list, str):
-        scp_file_list = eval(scp_file_list)
-    data_type_list = kwargs.get("data_type_list", ("source", "target"))
-    jsonl_file_out = kwargs.get("jsonl_file_out", "/Users/zhifu/funasr1.0/test_local/audio_datasets.jsonl")
-    gen_jsonl_from_wav_text_list(scp_file_list, data_type_list=data_type_list, jsonl_file_out=jsonl_file_out)
-    
-
-"""
-python -m funasr.datasets.audio_datasets.scp2jsonl \
-++scp_file_list='["/Users/zhifu/funasr1.0/test_local/wav.scp", "/Users/zhifu/funasr1.0/test_local/text.txt"]' \
-++data_type_list='["source", "target"]' \
-++jsonl_file_out=/Users/zhifu/funasr1.0/test_local/audio_datasets.jsonl
-"""
-
-if __name__ == "__main__":
-    main_hydra()
-
-    

+ 1 - 3
funasr/metrics/compute_acc.py

@@ -35,8 +35,6 @@ def compute_accuracy(pad_outputs, pad_targets, ignore_label):
 
     """
     mask = pad_targets != ignore_label
-    numerator = torch.sum(
-        pad_outputs.masked_select(mask) == pad_targets.masked_select(mask)
-    )
+    numerator = torch.sum(pad_outputs.masked_select(mask) == pad_targets.masked_select(mask))
     denominator = torch.sum(mask)
     return numerator.float() / denominator.float() #(FIX:MZY):return torch.Tensor type

+ 13 - 12
funasr/models/llm_asr/model.py

@@ -73,7 +73,7 @@ class LLMASRNAR(nn.Module):
         hub = encoder_conf.get("hub", None)
         if hub == "funasr":
             from funasr import AutoModel
-            init_param_path = encoder_conf.get("hub", "iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch")
+            init_param_path = encoder_conf.get("init_param_path", "iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch")
             model = AutoModel(model=init_param_path, model_revision="v2.0.4")
             # frontend = model.kwargs.get("frontend")
             model.model.decoder = None
@@ -179,6 +179,7 @@ class LLMASRNAR(nn.Module):
 
         if input_ids is not None:
             input_ids[input_ids == -1] = 0
+            input_ids[input_ids == -100] = 0
             if hasattr(self.llm.model, "embed_tokens"):
                 inputs_embeds = self.llm.model.embed_tokens(input_ids)
             elif hasattr(self.llm.model.model, "embed_tokens"):
@@ -190,7 +191,7 @@ class LLMASRNAR(nn.Module):
                 batch_size, token_num, dims = inputs_embeds.shape
                 _, l, _ = encoder_out.shape
                 encoder_outs_pad = F.pad(encoder_out, (0, 0, token_num-l-1, 1, 0, 0), value=0.0)
-                inputs_embeds = encoder_outs_pad * audio_mask[:, :, None] + inputs_embeds * (~audio_mask[:, :, None])
+                inputs_embeds = encoder_outs_pad * audio_mask[:, :, None] + inputs_embeds * (1.0-audio_mask[:, :, None])
                 inputs_embeds = F.pad(inputs_embeds[:, 1:, :], (0, 0, 0, 1, 0, 0), value=0.0)
 
         model_outputs = self.llm(inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels_ids)
@@ -198,11 +199,10 @@ class LLMASRNAR(nn.Module):
 
 
         stats = {}
-        if self.metric:
-            with torch.no_grad():
-                preds = torch.argmax(model_outputs.logits, -1)
-                acc_att = compute_accuracy(preds[:, :-1], labels_ids[:, 1:], ignore_label=-100)
-                stats["acc"] = acc_att
+        with torch.no_grad():
+            preds = torch.argmax(model_outputs.logits, -1)
+            acc_att = compute_accuracy(preds[:, :-1], labels_ids[:, 1:], ignore_label=-100)
+            stats["acc"] = acc_att
 
         stats["loss"] = torch.clone(loss.detach())
 
@@ -221,11 +221,12 @@ class LLMASRNAR(nn.Module):
 
         batch = {"speech": speech, "speech_lengths": speech_lengths}
         enc, enc_lens = self.audio_encoder.encode(**batch)
-        enc_mask = sequence_mask(enc_lens, enc.size(1), device=enc.device)[:, None, :]
-        pre_acoustic_embeds, pre_token_length, _, _ = self.audio_encoder.predictor(enc,
-                                                                           mask=enc_mask,
-                                                                           target_label_length=audio_token_lengths,
-                                                                           )
+        with autocast(False):
+            enc_mask = sequence_mask(enc_lens, enc.size(1), device=enc.device)[:, None, :]
+            pre_acoustic_embeds, pre_token_length, _, _ = self.audio_encoder.predictor(enc,
+                                                                               mask=enc_mask,
+                                                                               target_label_length=audio_token_lengths,
+                                                                               )
 
         return pre_acoustic_embeds, pre_token_length
 

+ 74 - 126
funasr/models/paraformer/cif_predictor.py

@@ -10,7 +10,7 @@ import numpy as np
 from funasr.register import tables
 from funasr.train_utils.device_funcs import to_device
 from funasr.models.transformer.utils.nets_utils import make_pad_mask
-
+from torch.cuda.amp import autocast
 
 @tables.register("predictor_classes", "CifPredictor")
 class CifPredictor(torch.nn.Module):
@@ -28,42 +28,44 @@ class CifPredictor(torch.nn.Module):
 
     def forward(self, hidden, target_label=None, mask=None, ignore_id=-1, mask_chunk_predictor=None,
                 target_label_length=None):
-        h = hidden
-        context = h.transpose(1, 2)
-        queries = self.pad(context)
-        memory = self.cif_conv1d(queries)
-        output = memory + context
-        output = self.dropout(output)
-        output = output.transpose(1, 2)
-        output = torch.relu(output)
-        output = self.cif_output(output)
-        alphas = torch.sigmoid(output)
-        alphas = torch.nn.functional.relu(alphas * self.smooth_factor - self.noise_threshold)
-        if mask is not None:
-            mask = mask.transpose(-1, -2).float()
-            alphas = alphas * mask
-        if mask_chunk_predictor is not None:
-            alphas = alphas * mask_chunk_predictor
-        alphas = alphas.squeeze(-1)
-        mask = mask.squeeze(-1)
-        if target_label_length is not None:
-            target_length = target_label_length
-        elif target_label is not None:
-            target_length = (target_label != ignore_id).float().sum(-1)
-        else:
-            target_length = None
-        token_num = alphas.sum(-1)
-        if target_length is not None:
-            alphas *= (target_length / token_num)[:, None].repeat(1, alphas.size(1))
-        elif self.tail_threshold > 0.0:
-            hidden, alphas, token_num = self.tail_process_fn(hidden, alphas, token_num, mask=mask)
-            
-        acoustic_embeds, cif_peak = cif(hidden, alphas, self.threshold)
-        
-        if target_length is None and self.tail_threshold > 0.0:
-            token_num_int = torch.max(token_num).type(torch.int32).item()
-            acoustic_embeds = acoustic_embeds[:, :token_num_int, :]
+    
+        with autocast(False):
+            h = hidden
+            context = h.transpose(1, 2)
+            queries = self.pad(context)
+            memory = self.cif_conv1d(queries)
+            output = memory + context
+            output = self.dropout(output)
+            output = output.transpose(1, 2)
+            output = torch.relu(output)
+            output = self.cif_output(output)
+            alphas = torch.sigmoid(output)
+            alphas = torch.nn.functional.relu(alphas * self.smooth_factor - self.noise_threshold)
+            if mask is not None:
+                mask = mask.transpose(-1, -2).float()
+                alphas = alphas * mask
+            if mask_chunk_predictor is not None:
+                alphas = alphas * mask_chunk_predictor
+            alphas = alphas.squeeze(-1)
+            mask = mask.squeeze(-1)
+            if target_label_length is not None:
+                target_length = target_label_length
+            elif target_label is not None:
+                target_length = (target_label != ignore_id).float().sum(-1)
+            else:
+                target_length = None
+            token_num = alphas.sum(-1)
+            if target_length is not None:
+                alphas *= (target_length / token_num)[:, None].repeat(1, alphas.size(1))
+            elif self.tail_threshold > 0.0:
+                hidden, alphas, token_num = self.tail_process_fn(hidden, alphas, token_num, mask=mask)
+                
+            acoustic_embeds, cif_peak = cif(hidden, alphas, self.threshold)
             
+            if target_length is None and self.tail_threshold > 0.0:
+                token_num_int = torch.max(token_num).type(torch.int32).item()
+                acoustic_embeds = acoustic_embeds[:, :token_num_int, :]
+                
         return acoustic_embeds, token_num, alphas, cif_peak
 
     def tail_process_fn(self, hidden, alphas, token_num=None, mask=None):
@@ -169,41 +171,43 @@ class CifPredictorV2(torch.nn.Module):
 
     def forward(self, hidden, target_label=None, mask=None, ignore_id=-1, mask_chunk_predictor=None,
                 target_label_length=None):
-        h = hidden
-        context = h.transpose(1, 2)
-        queries = self.pad(context)
-        output = torch.relu(self.cif_conv1d(queries))
-        output = output.transpose(1, 2)
-
-        output = self.cif_output(output)
-        alphas = torch.sigmoid(output)
-        alphas = torch.nn.functional.relu(alphas * self.smooth_factor - self.noise_threshold)
-        if mask is not None:
-            mask = mask.transpose(-1, -2).float()
-            alphas = alphas * mask
-        if mask_chunk_predictor is not None:
-            alphas = alphas * mask_chunk_predictor
-        alphas = alphas.squeeze(-1)
-        mask = mask.squeeze(-1)
-        if target_label_length is not None:
-            target_length = target_label_length.squeeze(-1)
-        elif target_label is not None:
-            target_length = (target_label != ignore_id).float().sum(-1)
-        else:
-            target_length = None
-        token_num = alphas.sum(-1)
-        if target_length is not None:
-            alphas *= (target_length / token_num)[:, None].repeat(1, alphas.size(1))
-        elif self.tail_threshold > 0.0:
-            if self.tail_mask:
-                hidden, alphas, token_num = self.tail_process_fn(hidden, alphas, token_num, mask=mask)
+        
+        with autocast(False):
+            h = hidden
+            context = h.transpose(1, 2)
+            queries = self.pad(context)
+            output = torch.relu(self.cif_conv1d(queries))
+            output = output.transpose(1, 2)
+    
+            output = self.cif_output(output)
+            alphas = torch.sigmoid(output)
+            alphas = torch.nn.functional.relu(alphas * self.smooth_factor - self.noise_threshold)
+            if mask is not None:
+                mask = mask.transpose(-1, -2).float()
+                alphas = alphas * mask
+            if mask_chunk_predictor is not None:
+                alphas = alphas * mask_chunk_predictor
+            alphas = alphas.squeeze(-1)
+            mask = mask.squeeze(-1)
+            if target_label_length is not None:
+                target_length = target_label_length.squeeze(-1)
+            elif target_label is not None:
+                target_length = (target_label != ignore_id).float().sum(-1)
             else:
-                hidden, alphas, token_num = self.tail_process_fn(hidden, alphas, token_num, mask=None)
-
-        acoustic_embeds, cif_peak = cif(hidden, alphas, self.threshold)
-        if target_length is None and self.tail_threshold > 0.0:
-            token_num_int = torch.max(token_num).type(torch.int32).item()
-            acoustic_embeds = acoustic_embeds[:, :token_num_int, :]
+                target_length = None
+            token_num = alphas.sum(-1)
+            if target_length is not None:
+                alphas *= (target_length / token_num)[:, None].repeat(1, alphas.size(1))
+            elif self.tail_threshold > 0.0:
+                if self.tail_mask:
+                    hidden, alphas, token_num = self.tail_process_fn(hidden, alphas, token_num, mask=mask)
+                else:
+                    hidden, alphas, token_num = self.tail_process_fn(hidden, alphas, token_num, mask=None)
+    
+            acoustic_embeds, cif_peak = cif(hidden, alphas, self.threshold)
+            if target_length is None and self.tail_threshold > 0.0:
+                token_num_int = torch.max(token_num).type(torch.int32).item()
+                acoustic_embeds = acoustic_embeds[:, :token_num_int, :]
 
         return acoustic_embeds, token_num, alphas, cif_peak
 
@@ -371,62 +375,6 @@ class CifPredictorV2(torch.nn.Module):
         predictor_alignments_length = predictor_alignments.sum(-1).type(encoder_sequence_length.dtype)
         return predictor_alignments.detach(), predictor_alignments_length.detach()
 
-    def gen_tf2torch_map_dict(self):
-    
-        tensor_name_prefix_torch = self.tf2torch_tensor_name_prefix_torch
-        tensor_name_prefix_tf = self.tf2torch_tensor_name_prefix_tf
-        map_dict_local = {
-            ## predictor
-            "{}.cif_conv1d.weight".format(tensor_name_prefix_torch):
-                {"name": "{}/conv1d/kernel".format(tensor_name_prefix_tf),
-                 "squeeze": None,
-                 "transpose": (2, 1, 0),
-                 },  # (256,256,3),(3,256,256)
-            "{}.cif_conv1d.bias".format(tensor_name_prefix_torch):
-                {"name": "{}/conv1d/bias".format(tensor_name_prefix_tf),
-                 "squeeze": None,
-                 "transpose": None,
-                 },  # (256,),(256,)
-            "{}.cif_output.weight".format(tensor_name_prefix_torch):
-                {"name": "{}/conv1d_1/kernel".format(tensor_name_prefix_tf),
-                 "squeeze": 0,
-                 "transpose": (1, 0),
-                 },  # (1,256),(1,256,1)
-            "{}.cif_output.bias".format(tensor_name_prefix_torch):
-                {"name": "{}/conv1d_1/bias".format(tensor_name_prefix_tf),
-                 "squeeze": None,
-                 "transpose": None,
-                 },  # (1,),(1,)
-        }
-        return map_dict_local
-
-    def convert_tf2torch(self,
-                         var_dict_tf,
-                         var_dict_torch,
-                         ):
-        map_dict = self.gen_tf2torch_map_dict()
-        var_dict_torch_update = dict()
-        for name in sorted(var_dict_torch.keys(), reverse=False):
-            names = name.split('.')
-            if names[0] == self.tf2torch_tensor_name_prefix_torch:
-                name_tf = map_dict[name]["name"]
-                data_tf = var_dict_tf[name_tf]
-                if map_dict[name]["squeeze"] is not None:
-                    data_tf = np.squeeze(data_tf, axis=map_dict[name]["squeeze"])
-                if map_dict[name]["transpose"] is not None:
-                    data_tf = np.transpose(data_tf, map_dict[name]["transpose"])
-                data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
-                assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
-                                                                                                var_dict_torch[
-                                                                                                    name].size(),
-                                                                                                data_tf.size())
-                var_dict_torch_update[name] = data_tf
-                logging.info(
-                    "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf,
-                                                                                  var_dict_tf[name_tf].shape))
-    
-        return var_dict_torch_update
-
 
 class mae_loss(torch.nn.Module):
 

+ 1 - 1
setup.py

@@ -40,11 +40,11 @@ requirements = {
         "umap_learn",
         "jaconv",
         "hydra-core>=1.3.2",
+        "tensorboardX",
     ],
     # train: The modules invoked when training only.
     "train": [
         "editdistance",
-        "tensorboardX",
     ],
     # all: The modules should be optionally installled due to some reason.
     #      Please consider moving them to "install" occasionally