|
|
@@ -1,14 +1,18 @@
|
|
|
import argparse
|
|
|
import logging
|
|
|
+import os
|
|
|
+from pathlib import Path
|
|
|
from typing import Callable
|
|
|
from typing import Collection
|
|
|
from typing import Dict
|
|
|
from typing import List
|
|
|
from typing import Optional
|
|
|
from typing import Tuple
|
|
|
+from typing import Union
|
|
|
|
|
|
import numpy as np
|
|
|
import torch
|
|
|
+import yaml
|
|
|
from typeguard import check_argument_types
|
|
|
from typeguard import check_return_type
|
|
|
|
|
|
@@ -21,7 +25,7 @@ from funasr.models.e2e_asr import ESPnetASRModel
|
|
|
from funasr.models.decoder.abs_decoder import AbsDecoder
|
|
|
from funasr.models.encoder.abs_encoder import AbsEncoder
|
|
|
from funasr.models.encoder.rnn_encoder import RNNEncoder
|
|
|
-from funasr.models.encoder.resnet34_encoder import ResNet34
|
|
|
+from funasr.models.encoder.resnet34_encoder import ResNet34, ResNet34_SP_L2Reg
|
|
|
from funasr.models.pooling.statistic_pooling import StatisticPooling
|
|
|
from funasr.models.decoder.sv_decoder import DenseDecoder
|
|
|
from funasr.models.e2e_sv import ESPnetSVModel
|
|
|
@@ -103,6 +107,7 @@ encoder_choices = ClassChoices(
|
|
|
"encoder",
|
|
|
classes=dict(
|
|
|
resnet34=ResNet34,
|
|
|
+ resnet34_sp_l2reg=ResNet34_SP_L2Reg,
|
|
|
rnn=RNNEncoder,
|
|
|
),
|
|
|
type_check=AbsEncoder,
|
|
|
@@ -394,9 +399,16 @@ class SVTask(AbsTask):
|
|
|
|
|
|
# 7. Pooling layer
|
|
|
pooling_class = pooling_choices.get_class(args.pooling_type)
|
|
|
+ pooling_dim = (2, 3)
|
|
|
+ eps = 1e-12
|
|
|
+ if hasattr(args, "pooling_type_conf"):
|
|
|
+ if "pooling_dim" in args.pooling_type_conf:
|
|
|
+ pooling_dim = args.pooling_type_conf["pooling_dim"]
|
|
|
+ if "eps" in args.pooling_type_conf:
|
|
|
+ eps = args.pooling_type_conf["eps"]
|
|
|
pooling_layer = pooling_class(
|
|
|
- pooling_dim=(2, 3),
|
|
|
- eps=1e-12,
|
|
|
+ pooling_dim=pooling_dim,
|
|
|
+ eps=eps,
|
|
|
)
|
|
|
if args.pooling_type == "statistic":
|
|
|
encoder_output_size *= 2
|
|
|
@@ -435,3 +447,95 @@ class SVTask(AbsTask):
|
|
|
|
|
|
assert check_return_type(model)
|
|
|
return model
|
|
|
+
|
|
|
+ # ~~~~~~~~~ The methods below are mainly used for inference ~~~~~~~~~
|
|
|
+ @classmethod
|
|
|
+ def build_model_from_file(
|
|
|
+ cls,
|
|
|
+ config_file: Union[Path, str] = None,
|
|
|
+ model_file: Union[Path, str] = None,
|
|
|
+ cmvn_file: Union[Path, str] = None,
|
|
|
+ device: str = "cpu",
|
|
|
+ ):
|
|
|
+ """Build model from the files.
|
|
|
+
|
|
|
+ This method is used for inference or fine-tuning.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ config_file: The yaml file saved when training.
|
|
|
+ model_file: The model file saved when training.
|
|
|
+ cmvn_file: The cmvn file for front-end
|
|
|
+ device: Device type, "cpu", "cuda", or "cuda:N".
|
|
|
+
|
|
|
+ """
|
|
|
+ assert check_argument_types()
|
|
|
+ if config_file is None:
|
|
|
+ assert model_file is not None, (
|
|
|
+ "The argument 'model_file' must be provided "
|
|
|
+ "if the argument 'config_file' is not specified."
|
|
|
+ )
|
|
|
+ config_file = Path(model_file).parent / "config.yaml"
|
|
|
+ else:
|
|
|
+ config_file = Path(config_file)
|
|
|
+
|
|
|
+ with config_file.open("r", encoding="utf-8") as f:
|
|
|
+ args = yaml.safe_load(f)
|
|
|
+ if cmvn_file is not None:
|
|
|
+ args["cmvn_file"] = cmvn_file
|
|
|
+ args = argparse.Namespace(**args)
|
|
|
+ model = cls.build_model(args)
|
|
|
+ if not isinstance(model, AbsESPnetModel):
|
|
|
+ raise RuntimeError(
|
|
|
+ f"model must inherit {AbsESPnetModel.__name__}, but got {type(model)}"
|
|
|
+ )
|
|
|
+ model.to(device)
|
|
|
+ model_dict = dict()
|
|
|
+ model_name_pth = None
|
|
|
+ if model_file is not None:
|
|
|
+ logging.info("model_file is {}".format(model_file))
|
|
|
+ if device == "cuda":
|
|
|
+ device = f"cuda:{torch.cuda.current_device()}"
|
|
|
+ model_dir = os.path.dirname(model_file)
|
|
|
+ model_name = os.path.basename(model_file)
|
|
|
+ if "model.ckpt-" in model_name or ".bin" in model_name:
|
|
|
+ if ".bin" in model_name:
|
|
|
+ model_name_pth = os.path.join(model_dir, model_name.replace('.bin', '.pb'))
|
|
|
+ else:
|
|
|
+ model_name_pth = os.path.join(model_dir, "{}.pth".format(model_name))
|
|
|
+ if os.path.exists(model_name_pth):
|
|
|
+ logging.info("model_file is load from pth: {}".format(model_name_pth))
|
|
|
+ model_dict = torch.load(model_name_pth, map_location=device)
|
|
|
+ else:
|
|
|
+ model_dict = cls.convert_tf2torch(model, model_file)
|
|
|
+ model.load_state_dict(model_dict)
|
|
|
+ else:
|
|
|
+ model_dict = torch.load(model_file, map_location=device)
|
|
|
+ model.load_state_dict(model_dict)
|
|
|
+ if model_name_pth is not None and not os.path.exists(model_name_pth):
|
|
|
+ torch.save(model_dict, model_name_pth)
|
|
|
+ logging.info("model_file is saved to pth: {}".format(model_name_pth))
|
|
|
+
|
|
|
+ return model, args
|
|
|
+
|
|
|
+ @classmethod
|
|
|
+ def convert_tf2torch(
|
|
|
+ cls,
|
|
|
+ model,
|
|
|
+ ckpt,
|
|
|
+ ):
|
|
|
+ logging.info("start convert tf model to torch model")
|
|
|
+ from funasr.modules.streaming_utils.load_fr_tf import load_tf_dict
|
|
|
+ var_dict_tf = load_tf_dict(ckpt)
|
|
|
+ var_dict_torch = model.state_dict()
|
|
|
+ var_dict_torch_update = dict()
|
|
|
+ # speech encoder
|
|
|
+ var_dict_torch_update_local = model.encoder.convert_tf2torch(var_dict_tf, var_dict_torch)
|
|
|
+ var_dict_torch_update.update(var_dict_torch_update_local)
|
|
|
+ # pooling layer
|
|
|
+ var_dict_torch_update_local = model.pooling_layer.convert_tf2torch(var_dict_tf, var_dict_torch)
|
|
|
+ var_dict_torch_update.update(var_dict_torch_update_local)
|
|
|
+ # decoder
|
|
|
+ var_dict_torch_update_local = model.decoder.convert_tf2torch(var_dict_tf, var_dict_torch)
|
|
|
+ var_dict_torch_update.update(var_dict_torch_update_local)
|
|
|
+
|
|
|
+ return var_dict_torch_update
|