| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193 |
- import argparse
- import logging
- import os
- from pathlib import Path
- from typing import Union
- import torch
- import yaml
- from typeguard import check_argument_types
- from funasr.build_utils.build_model import build_model
- from funasr.models.base_model import FunASRModel
- def build_model_from_file(
- config_file: Union[Path, str] = None,
- model_file: Union[Path, str] = None,
- cmvn_file: Union[Path, str] = None,
- device: str = "cpu",
- task_name: str = "asr",
- mode: str = "paraformer",
- ):
- """Build model from the files.
- This method is used for inference or fine-tuning.
- Args:
- config_file: The yaml file saved when training.
- model_file: The model file saved when training.
- device: Device type, "cpu", "cuda", or "cuda:N".
- """
- assert check_argument_types()
- if config_file is None:
- assert model_file is not None, (
- "The argument 'model_file' must be provided "
- "if the argument 'config_file' is not specified."
- )
- config_file = Path(model_file).parent / "config.yaml"
- else:
- config_file = Path(config_file)
- with config_file.open("r", encoding="utf-8") as f:
- args = yaml.safe_load(f)
- if cmvn_file is not None:
- args["cmvn_file"] = cmvn_file
- args = argparse.Namespace(**args)
- args.task_name = task_name
- model = build_model(args)
- if not isinstance(model, FunASRModel):
- raise RuntimeError(
- f"model must inherit {FunASRModel.__name__}, but got {type(model)}"
- )
- model.to(device)
- model_dict = dict()
- model_name_pth = None
- if model_file is not None:
- logging.info("model_file is {}".format(model_file))
- if device == "cuda":
- device = f"cuda:{torch.cuda.current_device()}"
- model_dir = os.path.dirname(model_file)
- model_name = os.path.basename(model_file)
- if "model.ckpt-" in model_name or ".bin" in model_name:
- model_name_pth = os.path.join(model_dir, model_name.replace('.bin',
- '.pb')) if ".bin" in model_name else os.path.join(
- model_dir, "{}.pb".format(model_name))
- if os.path.exists(model_name_pth):
- logging.info("model_file is load from pth: {}".format(model_name_pth))
- model_dict = torch.load(model_name_pth, map_location=device)
- else:
- model_dict = convert_tf2torch(model, model_file, mode)
- model.load_state_dict(model_dict)
- else:
- model_dict = torch.load(model_file, map_location=device)
- if task_name == "diar" and mode == "sond":
- model_dict = fileter_model_dict(model_dict, model.state_dict())
- if task_name == "vad":
- model.encoder.load_state_dict(model_dict)
- else:
- model.load_state_dict(model_dict)
- if model_name_pth is not None and not os.path.exists(model_name_pth):
- torch.save(model_dict, model_name_pth)
- logging.info("model_file is saved to pth: {}".format(model_name_pth))
- return model, args
- def convert_tf2torch(
- model,
- ckpt,
- mode,
- ):
- assert mode == "paraformer" or mode == "uniasr" or mode == "sond" or mode == "sv" or mode == "tp"
- logging.info("start convert tf model to torch model")
- from funasr.modules.streaming_utils.load_fr_tf import load_tf_dict
- var_dict_tf = load_tf_dict(ckpt)
- var_dict_torch = model.state_dict()
- var_dict_torch_update = dict()
- if mode == "uniasr":
- # encoder
- var_dict_torch_update_local = model.encoder.convert_tf2torch(var_dict_tf, var_dict_torch)
- var_dict_torch_update.update(var_dict_torch_update_local)
- # predictor
- var_dict_torch_update_local = model.predictor.convert_tf2torch(var_dict_tf, var_dict_torch)
- var_dict_torch_update.update(var_dict_torch_update_local)
- # decoder
- var_dict_torch_update_local = model.decoder.convert_tf2torch(var_dict_tf, var_dict_torch)
- var_dict_torch_update.update(var_dict_torch_update_local)
- # encoder2
- var_dict_torch_update_local = model.encoder2.convert_tf2torch(var_dict_tf, var_dict_torch)
- var_dict_torch_update.update(var_dict_torch_update_local)
- # predictor2
- var_dict_torch_update_local = model.predictor2.convert_tf2torch(var_dict_tf, var_dict_torch)
- var_dict_torch_update.update(var_dict_torch_update_local)
- # decoder2
- var_dict_torch_update_local = model.decoder2.convert_tf2torch(var_dict_tf, var_dict_torch)
- var_dict_torch_update.update(var_dict_torch_update_local)
- # stride_conv
- var_dict_torch_update_local = model.stride_conv.convert_tf2torch(var_dict_tf, var_dict_torch)
- var_dict_torch_update.update(var_dict_torch_update_local)
- elif mode == "paraformer":
- # encoder
- var_dict_torch_update_local = model.encoder.convert_tf2torch(var_dict_tf, var_dict_torch)
- var_dict_torch_update.update(var_dict_torch_update_local)
- # predictor
- var_dict_torch_update_local = model.predictor.convert_tf2torch(var_dict_tf, var_dict_torch)
- var_dict_torch_update.update(var_dict_torch_update_local)
- # decoder
- var_dict_torch_update_local = model.decoder.convert_tf2torch(var_dict_tf, var_dict_torch)
- var_dict_torch_update.update(var_dict_torch_update_local)
- # bias_encoder
- var_dict_torch_update_local = model.clas_convert_tf2torch(var_dict_tf, var_dict_torch)
- var_dict_torch_update.update(var_dict_torch_update_local)
- elif "mode" == "sond":
- if model.encoder is not None:
- var_dict_torch_update_local = model.encoder.convert_tf2torch(var_dict_tf, var_dict_torch)
- var_dict_torch_update.update(var_dict_torch_update_local)
- # speaker encoder
- if model.speaker_encoder is not None:
- var_dict_torch_update_local = model.speaker_encoder.convert_tf2torch(var_dict_tf, var_dict_torch)
- var_dict_torch_update.update(var_dict_torch_update_local)
- # cd scorer
- if model.cd_scorer is not None:
- var_dict_torch_update_local = model.cd_scorer.convert_tf2torch(var_dict_tf, var_dict_torch)
- var_dict_torch_update.update(var_dict_torch_update_local)
- # ci scorer
- if model.ci_scorer is not None:
- var_dict_torch_update_local = model.ci_scorer.convert_tf2torch(var_dict_tf, var_dict_torch)
- var_dict_torch_update.update(var_dict_torch_update_local)
- # decoder
- if model.decoder is not None:
- var_dict_torch_update_local = model.decoder.convert_tf2torch(var_dict_tf, var_dict_torch)
- var_dict_torch_update.update(var_dict_torch_update_local)
- elif "mode" == "sv":
- # speech encoder
- var_dict_torch_update_local = model.encoder.convert_tf2torch(var_dict_tf, var_dict_torch)
- var_dict_torch_update.update(var_dict_torch_update_local)
- # pooling layer
- var_dict_torch_update_local = model.pooling_layer.convert_tf2torch(var_dict_tf, var_dict_torch)
- var_dict_torch_update.update(var_dict_torch_update_local)
- # decoder
- var_dict_torch_update_local = model.decoder.convert_tf2torch(var_dict_tf, var_dict_torch)
- var_dict_torch_update.update(var_dict_torch_update_local)
- else:
- # encoder
- var_dict_torch_update_local = model.encoder.convert_tf2torch(var_dict_tf, var_dict_torch)
- var_dict_torch_update.update(var_dict_torch_update_local)
- # predictor
- var_dict_torch_update_local = model.predictor.convert_tf2torch(var_dict_tf, var_dict_torch)
- var_dict_torch_update.update(var_dict_torch_update_local)
- # decoder
- var_dict_torch_update_local = model.decoder.convert_tf2torch(var_dict_tf, var_dict_torch)
- var_dict_torch_update.update(var_dict_torch_update_local)
- # bias_encoder
- var_dict_torch_update_local = model.clas_convert_tf2torch(var_dict_tf, var_dict_torch)
- var_dict_torch_update.update(var_dict_torch_update_local)
- return var_dict_torch_update
- return var_dict_torch_update
- def fileter_model_dict(src_dict: dict, dest_dict: dict):
- from collections import OrderedDict
- new_dict = OrderedDict()
- for key, value in src_dict.items():
- if key in dest_dict:
- new_dict[key] = value
- else:
- logging.info("{} is no longer needed in this model.".format(key))
- for key, value in dest_dict.items():
- if key not in new_dict:
- logging.warning("{} is missed in checkpoint.".format(key))
- return new_dict
|