嘉渊 2 vuotta sitten
vanhempi
sitoutus
7d6177b43f
2 muutettua tiedostoa jossa 149 lisäystä ja 30 poistoa
  1. 21 30
      funasr/bin/asr_infer.py
  2. 128 0
      funasr/build_utils/build_model_from_file.py

+ 21 - 30
funasr/bin/asr_infer.py

@@ -24,7 +24,7 @@ import torch
 from packaging.version import parse as V
 from typeguard import check_argument_types
 from typeguard import check_return_type
-
+from  funasr.build_utils.build_model_from_file import build_model_from_file
 from funasr.models.e2e_asr_contextual_paraformer import NeatContextualParaformer
 from funasr.models.e2e_asr_paraformer import BiCifParaformer, ContextualParaformer
 from funasr.models.frontend.wav_frontend import WavFrontend, WavFrontendOnline
@@ -35,9 +35,7 @@ from funasr.modules.beam_search.beam_search_transducer import BeamSearchTransduc
 from funasr.modules.beam_search.beam_search_transducer import Hypothesis as HypothesisTransducer
 from funasr.modules.scorers.ctc import CTCPrefixScorer
 from funasr.modules.scorers.length_bonus import LengthBonus
-from funasr.tasks.asr import ASRTask
-from funasr.tasks.asr import frontend_choices
-from funasr.tasks.lm import LMTask
+from funasr.build_utils.build_asr_model import frontend_choices
 from funasr.text.build_tokenizer import build_tokenizer
 from funasr.text.token_id_converter import TokenIDConverter
 from funasr.torch_utils.device_funcs import to_device
@@ -84,15 +82,14 @@ class Speech2Text:
 
         # 1. Build ASR model
         scorers = {}
-        asr_model, asr_train_args = ASRTask.build_model_from_file(
-            asr_train_config, asr_model_file, cmvn_file, device
+        asr_model, asr_train_args = build_model_from_file(
+            asr_train_config, asr_model_file, cmvn_file, device, mode="asr"
         )
         frontend = None
         if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None:
             if asr_train_args.frontend == 'wav_frontend':
                 frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf)
             else:
-                from funasr.tasks.asr import frontend_choices
                 frontend_class = frontend_choices.get_class(asr_train_args.frontend)
                 frontend = frontend_class(**asr_train_args.frontend_conf).eval()
 
@@ -112,7 +109,7 @@ class Speech2Text:
 
         # 2. Build Language model
         if lm_train_config is not None:
-            lm, lm_train_args = LMTask.build_model_from_file(
+            lm, lm_train_args = build_model_from_file(
                 lm_train_config, lm_file, None, device
             )
             scorers["lm"] = lm.lm
@@ -295,9 +292,8 @@ class Speech2TextParaformer:
 
         # 1. Build ASR model
         scorers = {}
-        from funasr.tasks.asr import ASRTaskParaformer as ASRTask
-        asr_model, asr_train_args = ASRTask.build_model_from_file(
-            asr_train_config, asr_model_file, cmvn_file, device
+        asr_model, asr_train_args = build_model_from_file(
+            asr_train_config, asr_model_file, cmvn_file, device, mode="paraformer"
         )
         frontend = None
         if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None:
@@ -319,7 +315,7 @@ class Speech2TextParaformer:
 
         # 2. Build Language model
         if lm_train_config is not None:
-            lm, lm_train_args = LMTask.build_model_from_file(
+            lm, lm_train_args = build_model_from_file(
                 lm_train_config, lm_file, device
             )
             scorers["lm"] = lm.lm
@@ -616,9 +612,8 @@ class Speech2TextParaformerOnline:
 
         # 1. Build ASR model
         scorers = {}
-        from funasr.tasks.asr import ASRTaskParaformer as ASRTask
-        asr_model, asr_train_args = ASRTask.build_model_from_file(
-            asr_train_config, asr_model_file, cmvn_file, device
+        asr_model, asr_train_args = build_model_from_file(
+            asr_train_config, asr_model_file, cmvn_file, device, mode="paraformer"
         )
         frontend = None
         if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None:
@@ -640,7 +635,7 @@ class Speech2TextParaformerOnline:
 
         # 2. Build Language model
         if lm_train_config is not None:
-            lm, lm_train_args = LMTask.build_model_from_file(
+            lm, lm_train_args = build_model_from_file(
                 lm_train_config, lm_file, device
             )
             scorers["lm"] = lm.lm
@@ -873,9 +868,8 @@ class Speech2TextUniASR:
 
         # 1. Build ASR model
         scorers = {}
-        from funasr.tasks.asr import ASRTaskUniASR as ASRTask
-        asr_model, asr_train_args = ASRTask.build_model_from_file(
-            asr_train_config, asr_model_file, cmvn_file, device
+        asr_model, asr_train_args = build_model_from_file(
+            asr_train_config, asr_model_file, cmvn_file, device, mode="uniasr"
         )
         frontend = None
         if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None:
@@ -901,8 +895,8 @@ class Speech2TextUniASR:
 
         # 2. Build Language model
         if lm_train_config is not None:
-            lm, lm_train_args = LMTask.build_model_from_file(
-                lm_train_config, lm_file, device
+            lm, lm_train_args = build_model_from_file(
+                lm_train_config, lm_file, device, "lm"
             )
             scorers["lm"] = lm.lm
 
@@ -1104,9 +1098,8 @@ class Speech2TextMFCCA:
         assert check_argument_types()
 
         # 1. Build ASR model
-        from funasr.tasks.asr import ASRTaskMFCCA as ASRTask
         scorers = {}
-        asr_model, asr_train_args = ASRTask.build_model_from_file(
+        asr_model, asr_train_args = build_model_from_file(
             asr_train_config, asr_model_file, cmvn_file, device
         )
 
@@ -1126,7 +1119,7 @@ class Speech2TextMFCCA:
 
         # 2. Build Language model
         if lm_train_config is not None:
-            lm, lm_train_args = LMTask.build_model_from_file(
+            lm, lm_train_args = build_model_from_file(
                 lm_train_config, lm_file, device
             )
             lm.to(device)
@@ -1315,8 +1308,7 @@ class Speech2TextTransducer:
         super().__init__()
 
         assert check_argument_types()
-        from funasr.tasks.asr import ASRTransducerTask
-        asr_model, asr_train_args = ASRTransducerTask.build_model_from_file(
+        asr_model, asr_train_args = build_model_from_file(
             asr_train_config, asr_model_file, cmvn_file, device
         )
 
@@ -1350,7 +1342,7 @@ class Speech2TextTransducer:
             asr_model.to(dtype=getattr(torch, dtype)).eval()
 
         if lm_train_config is not None:
-            lm, lm_train_args = LMTask.build_model_from_file(
+            lm, lm_train_args = build_model_from_file(
                 lm_train_config, lm_file, device
             )
             lm_scorer = lm.lm
@@ -1638,9 +1630,8 @@ class Speech2TextSAASR:
         assert check_argument_types()
 
         # 1. Build ASR model
-        from funasr.tasks.sa_asr import ASRTask
         scorers = {}
-        asr_model, asr_train_args = ASRTask.build_model_from_file(
+        asr_model, asr_train_args = build_model_from_file(
             asr_train_config, asr_model_file, cmvn_file, device
         )
         frontend = None
@@ -1667,7 +1658,7 @@ class Speech2TextSAASR:
 
         # 2. Build Language model
         if lm_train_config is not None:
-            lm, lm_train_args = LMTask.build_model_from_file(
+            lm, lm_train_args = build_model_from_file(
                 lm_train_config, lm_file, None, device
             )
             scorers["lm"] = lm.lm

+ 128 - 0
funasr/build_utils/build_model_from_file.py

@@ -0,0 +1,128 @@
+import argparse
+import logging
+import os
+from pathlib import Path
+from typing import Union
+
+import torch
+import yaml
+from typeguard import check_argument_types
+
+from funasr.build_utils.build_model import build_model
+from funasr.models.base_model import FunASRModel
+
+
+def build_model_from_file(
+        config_file: Union[Path, str] = None,
+        model_file: Union[Path, str] = None,
+        cmvn_file: Union[Path, str] = None,
+        device: str = "cpu",
+        mode: str = "paraformer",
+):
+    """Build model from the files.
+
+    This method is used for inference or fine-tuning.
+
+    Args:
+        config_file: The yaml file saved when training.
+        model_file: The model file saved when training.
+        device: Device type, "cpu", "cuda", or "cuda:N".
+
+    """
+    assert check_argument_types()
+    if config_file is None:
+        assert model_file is not None, (
+            "The argument 'model_file' must be provided "
+            "if the argument 'config_file' is not specified."
+        )
+        config_file = Path(model_file).parent / "config.yaml"
+    else:
+        config_file = Path(config_file)
+
+    with config_file.open("r", encoding="utf-8") as f:
+        args = yaml.safe_load(f)
+    if cmvn_file is not None:
+        args["cmvn_file"] = cmvn_file
+    args = argparse.Namespace(**args)
+    model = build_model(args)
+    if not isinstance(model, FunASRModel):
+        raise RuntimeError(
+            f"model must inherit {FunASRModel.__name__}, but got {type(model)}"
+        )
+    model.to(device)
+    model_dict = dict()
+    model_name_pth = None
+    if model_file is not None:
+        logging.info("model_file is {}".format(model_file))
+        if device == "cuda":
+            device = f"cuda:{torch.cuda.current_device()}"
+        model_dir = os.path.dirname(model_file)
+        model_name = os.path.basename(model_file)
+        if "model.ckpt-" in model_name or ".bin" in model_name:
+            model_name_pth = os.path.join(model_dir, model_name.replace('.bin',
+                                                                        '.pb')) if ".bin" in model_name else os.path.join(
+                model_dir, "{}.pb".format(model_name))
+            if os.path.exists(model_name_pth):
+                logging.info("model_file is load from pth: {}".format(model_name_pth))
+                model_dict = torch.load(model_name_pth, map_location=device)
+            else:
+                model_dict = convert_tf2torch(model, model_file, mode)
+            model.load_state_dict(model_dict)
+        else:
+            model_dict = torch.load(model_file, map_location=device)
+    model.load_state_dict(model_dict)
+    if model_name_pth is not None and not os.path.exists(model_name_pth):
+        torch.save(model_dict, model_name_pth)
+        logging.info("model_file is saved to pth: {}".format(model_name_pth))
+
+    return model, args
+
+
+def convert_tf2torch(
+        model,
+        ckpt,
+        mode,
+):
+    assert mode == "paraformer" or mode == "uniasr"
+    logging.info("start convert tf model to torch model")
+    from funasr.modules.streaming_utils.load_fr_tf import load_tf_dict
+    var_dict_tf = load_tf_dict(ckpt)
+    var_dict_torch = model.state_dict()
+    var_dict_torch_update = dict()
+    if mode == "uniasr":
+        # encoder
+        var_dict_torch_update_local = model.encoder.convert_tf2torch(var_dict_tf, var_dict_torch)
+        var_dict_torch_update.update(var_dict_torch_update_local)
+        # predictor
+        var_dict_torch_update_local = model.predictor.convert_tf2torch(var_dict_tf, var_dict_torch)
+        var_dict_torch_update.update(var_dict_torch_update_local)
+        # decoder
+        var_dict_torch_update_local = model.decoder.convert_tf2torch(var_dict_tf, var_dict_torch)
+        var_dict_torch_update.update(var_dict_torch_update_local)
+        # encoder2
+        var_dict_torch_update_local = model.encoder2.convert_tf2torch(var_dict_tf, var_dict_torch)
+        var_dict_torch_update.update(var_dict_torch_update_local)
+        # predictor2
+        var_dict_torch_update_local = model.predictor2.convert_tf2torch(var_dict_tf, var_dict_torch)
+        var_dict_torch_update.update(var_dict_torch_update_local)
+        # decoder2
+        var_dict_torch_update_local = model.decoder2.convert_tf2torch(var_dict_tf, var_dict_torch)
+        var_dict_torch_update.update(var_dict_torch_update_local)
+        # stride_conv
+        var_dict_torch_update_local = model.stride_conv.convert_tf2torch(var_dict_tf, var_dict_torch)
+        var_dict_torch_update.update(var_dict_torch_update_local)
+    else:
+        # encoder
+        var_dict_torch_update_local = model.encoder.convert_tf2torch(var_dict_tf, var_dict_torch)
+        var_dict_torch_update.update(var_dict_torch_update_local)
+        # predictor
+        var_dict_torch_update_local = model.predictor.convert_tf2torch(var_dict_tf, var_dict_torch)
+        var_dict_torch_update.update(var_dict_torch_update_local)
+        # decoder
+        var_dict_torch_update_local = model.decoder.convert_tf2torch(var_dict_tf, var_dict_torch)
+        var_dict_torch_update.update(var_dict_torch_update_local)
+        # bias_encoder
+        var_dict_torch_update_local = model.clas_convert_tf2torch(var_dict_tf, var_dict_torch)
+        var_dict_torch_update.update(var_dict_torch_update_local)
+
+    return var_dict_torch_update