build_model_from_file.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128
  1. import argparse
  2. import logging
  3. import os
  4. from pathlib import Path
  5. from typing import Union
  6. import torch
  7. import yaml
  8. from typeguard import check_argument_types
  9. from funasr.build_utils.build_model import build_model
  10. from funasr.models.base_model import FunASRModel
  11. def build_model_from_file(
  12. config_file: Union[Path, str] = None,
  13. model_file: Union[Path, str] = None,
  14. cmvn_file: Union[Path, str] = None,
  15. device: str = "cpu",
  16. mode: str = "paraformer",
  17. ):
  18. """Build model from the files.
  19. This method is used for inference or fine-tuning.
  20. Args:
  21. config_file: The yaml file saved when training.
  22. model_file: The model file saved when training.
  23. device: Device type, "cpu", "cuda", or "cuda:N".
  24. """
  25. assert check_argument_types()
  26. if config_file is None:
  27. assert model_file is not None, (
  28. "The argument 'model_file' must be provided "
  29. "if the argument 'config_file' is not specified."
  30. )
  31. config_file = Path(model_file).parent / "config.yaml"
  32. else:
  33. config_file = Path(config_file)
  34. with config_file.open("r", encoding="utf-8") as f:
  35. args = yaml.safe_load(f)
  36. if cmvn_file is not None:
  37. args["cmvn_file"] = cmvn_file
  38. args = argparse.Namespace(**args)
  39. model = build_model(args)
  40. if not isinstance(model, FunASRModel):
  41. raise RuntimeError(
  42. f"model must inherit {FunASRModel.__name__}, but got {type(model)}"
  43. )
  44. model.to(device)
  45. model_dict = dict()
  46. model_name_pth = None
  47. if model_file is not None:
  48. logging.info("model_file is {}".format(model_file))
  49. if device == "cuda":
  50. device = f"cuda:{torch.cuda.current_device()}"
  51. model_dir = os.path.dirname(model_file)
  52. model_name = os.path.basename(model_file)
  53. if "model.ckpt-" in model_name or ".bin" in model_name:
  54. model_name_pth = os.path.join(model_dir, model_name.replace('.bin',
  55. '.pb')) if ".bin" in model_name else os.path.join(
  56. model_dir, "{}.pb".format(model_name))
  57. if os.path.exists(model_name_pth):
  58. logging.info("model_file is load from pth: {}".format(model_name_pth))
  59. model_dict = torch.load(model_name_pth, map_location=device)
  60. else:
  61. model_dict = convert_tf2torch(model, model_file, mode)
  62. model.load_state_dict(model_dict)
  63. else:
  64. model_dict = torch.load(model_file, map_location=device)
  65. model.load_state_dict(model_dict)
  66. if model_name_pth is not None and not os.path.exists(model_name_pth):
  67. torch.save(model_dict, model_name_pth)
  68. logging.info("model_file is saved to pth: {}".format(model_name_pth))
  69. return model, args
  70. def convert_tf2torch(
  71. model,
  72. ckpt,
  73. mode,
  74. ):
  75. assert mode == "paraformer" or mode == "uniasr"
  76. logging.info("start convert tf model to torch model")
  77. from funasr.modules.streaming_utils.load_fr_tf import load_tf_dict
  78. var_dict_tf = load_tf_dict(ckpt)
  79. var_dict_torch = model.state_dict()
  80. var_dict_torch_update = dict()
  81. if mode == "uniasr":
  82. # encoder
  83. var_dict_torch_update_local = model.encoder.convert_tf2torch(var_dict_tf, var_dict_torch)
  84. var_dict_torch_update.update(var_dict_torch_update_local)
  85. # predictor
  86. var_dict_torch_update_local = model.predictor.convert_tf2torch(var_dict_tf, var_dict_torch)
  87. var_dict_torch_update.update(var_dict_torch_update_local)
  88. # decoder
  89. var_dict_torch_update_local = model.decoder.convert_tf2torch(var_dict_tf, var_dict_torch)
  90. var_dict_torch_update.update(var_dict_torch_update_local)
  91. # encoder2
  92. var_dict_torch_update_local = model.encoder2.convert_tf2torch(var_dict_tf, var_dict_torch)
  93. var_dict_torch_update.update(var_dict_torch_update_local)
  94. # predictor2
  95. var_dict_torch_update_local = model.predictor2.convert_tf2torch(var_dict_tf, var_dict_torch)
  96. var_dict_torch_update.update(var_dict_torch_update_local)
  97. # decoder2
  98. var_dict_torch_update_local = model.decoder2.convert_tf2torch(var_dict_tf, var_dict_torch)
  99. var_dict_torch_update.update(var_dict_torch_update_local)
  100. # stride_conv
  101. var_dict_torch_update_local = model.stride_conv.convert_tf2torch(var_dict_tf, var_dict_torch)
  102. var_dict_torch_update.update(var_dict_torch_update_local)
  103. else:
  104. # encoder
  105. var_dict_torch_update_local = model.encoder.convert_tf2torch(var_dict_tf, var_dict_torch)
  106. var_dict_torch_update.update(var_dict_torch_update_local)
  107. # predictor
  108. var_dict_torch_update_local = model.predictor.convert_tf2torch(var_dict_tf, var_dict_torch)
  109. var_dict_torch_update.update(var_dict_torch_update_local)
  110. # decoder
  111. var_dict_torch_update_local = model.decoder.convert_tf2torch(var_dict_tf, var_dict_torch)
  112. var_dict_torch_update.update(var_dict_torch_update_local)
  113. # bias_encoder
  114. var_dict_torch_update_local = model.clas_convert_tf2torch(var_dict_tf, var_dict_torch)
  115. var_dict_torch_update.update(var_dict_torch_update_local)
  116. return var_dict_torch_update