build_model_from_file.py 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193
  1. import argparse
  2. import logging
  3. import os
  4. from pathlib import Path
  5. from typing import Union
  6. import torch
  7. import yaml
  8. from typeguard import check_argument_types
  9. from funasr.build_utils.build_model import build_model
  10. from funasr.models.base_model import FunASRModel
  11. def build_model_from_file(
  12. config_file: Union[Path, str] = None,
  13. model_file: Union[Path, str] = None,
  14. cmvn_file: Union[Path, str] = None,
  15. device: str = "cpu",
  16. task_name: str = "asr",
  17. mode: str = "paraformer",
  18. ):
  19. """Build model from the files.
  20. This method is used for inference or fine-tuning.
  21. Args:
  22. config_file: The yaml file saved when training.
  23. model_file: The model file saved when training.
  24. device: Device type, "cpu", "cuda", or "cuda:N".
  25. """
  26. assert check_argument_types()
  27. if config_file is None:
  28. assert model_file is not None, (
  29. "The argument 'model_file' must be provided "
  30. "if the argument 'config_file' is not specified."
  31. )
  32. config_file = Path(model_file).parent / "config.yaml"
  33. else:
  34. config_file = Path(config_file)
  35. with config_file.open("r", encoding="utf-8") as f:
  36. args = yaml.safe_load(f)
  37. if cmvn_file is not None:
  38. args["cmvn_file"] = cmvn_file
  39. args = argparse.Namespace(**args)
  40. args.task_name = task_name
  41. model = build_model(args)
  42. if not isinstance(model, FunASRModel):
  43. raise RuntimeError(
  44. f"model must inherit {FunASRModel.__name__}, but got {type(model)}"
  45. )
  46. model.to(device)
  47. model_dict = dict()
  48. model_name_pth = None
  49. if model_file is not None:
  50. logging.info("model_file is {}".format(model_file))
  51. if device == "cuda":
  52. device = f"cuda:{torch.cuda.current_device()}"
  53. model_dir = os.path.dirname(model_file)
  54. model_name = os.path.basename(model_file)
  55. if "model.ckpt-" in model_name or ".bin" in model_name:
  56. model_name_pth = os.path.join(model_dir, model_name.replace('.bin',
  57. '.pb')) if ".bin" in model_name else os.path.join(
  58. model_dir, "{}.pb".format(model_name))
  59. if os.path.exists(model_name_pth):
  60. logging.info("model_file is load from pth: {}".format(model_name_pth))
  61. model_dict = torch.load(model_name_pth, map_location=device)
  62. else:
  63. model_dict = convert_tf2torch(model, model_file, mode)
  64. model.load_state_dict(model_dict)
  65. else:
  66. model_dict = torch.load(model_file, map_location=device)
  67. if task_name == "diar" and mode == "sond":
  68. model_dict = fileter_model_dict(model_dict, model.state_dict())
  69. if task_name == "vad":
  70. model.encoder.load_state_dict(model_dict)
  71. else:
  72. model.load_state_dict(model_dict)
  73. if model_name_pth is not None and not os.path.exists(model_name_pth):
  74. torch.save(model_dict, model_name_pth)
  75. logging.info("model_file is saved to pth: {}".format(model_name_pth))
  76. return model, args
  77. def convert_tf2torch(
  78. model,
  79. ckpt,
  80. mode,
  81. ):
  82. assert mode == "paraformer" or mode == "uniasr" or mode == "sond" or mode == "sv" or mode == "tp"
  83. logging.info("start convert tf model to torch model")
  84. from funasr.modules.streaming_utils.load_fr_tf import load_tf_dict
  85. var_dict_tf = load_tf_dict(ckpt)
  86. var_dict_torch = model.state_dict()
  87. var_dict_torch_update = dict()
  88. if mode == "uniasr":
  89. # encoder
  90. var_dict_torch_update_local = model.encoder.convert_tf2torch(var_dict_tf, var_dict_torch)
  91. var_dict_torch_update.update(var_dict_torch_update_local)
  92. # predictor
  93. var_dict_torch_update_local = model.predictor.convert_tf2torch(var_dict_tf, var_dict_torch)
  94. var_dict_torch_update.update(var_dict_torch_update_local)
  95. # decoder
  96. var_dict_torch_update_local = model.decoder.convert_tf2torch(var_dict_tf, var_dict_torch)
  97. var_dict_torch_update.update(var_dict_torch_update_local)
  98. # encoder2
  99. var_dict_torch_update_local = model.encoder2.convert_tf2torch(var_dict_tf, var_dict_torch)
  100. var_dict_torch_update.update(var_dict_torch_update_local)
  101. # predictor2
  102. var_dict_torch_update_local = model.predictor2.convert_tf2torch(var_dict_tf, var_dict_torch)
  103. var_dict_torch_update.update(var_dict_torch_update_local)
  104. # decoder2
  105. var_dict_torch_update_local = model.decoder2.convert_tf2torch(var_dict_tf, var_dict_torch)
  106. var_dict_torch_update.update(var_dict_torch_update_local)
  107. # stride_conv
  108. var_dict_torch_update_local = model.stride_conv.convert_tf2torch(var_dict_tf, var_dict_torch)
  109. var_dict_torch_update.update(var_dict_torch_update_local)
  110. elif mode == "paraformer":
  111. # encoder
  112. var_dict_torch_update_local = model.encoder.convert_tf2torch(var_dict_tf, var_dict_torch)
  113. var_dict_torch_update.update(var_dict_torch_update_local)
  114. # predictor
  115. var_dict_torch_update_local = model.predictor.convert_tf2torch(var_dict_tf, var_dict_torch)
  116. var_dict_torch_update.update(var_dict_torch_update_local)
  117. # decoder
  118. var_dict_torch_update_local = model.decoder.convert_tf2torch(var_dict_tf, var_dict_torch)
  119. var_dict_torch_update.update(var_dict_torch_update_local)
  120. # bias_encoder
  121. var_dict_torch_update_local = model.clas_convert_tf2torch(var_dict_tf, var_dict_torch)
  122. var_dict_torch_update.update(var_dict_torch_update_local)
  123. elif "mode" == "sond":
  124. if model.encoder is not None:
  125. var_dict_torch_update_local = model.encoder.convert_tf2torch(var_dict_tf, var_dict_torch)
  126. var_dict_torch_update.update(var_dict_torch_update_local)
  127. # speaker encoder
  128. if model.speaker_encoder is not None:
  129. var_dict_torch_update_local = model.speaker_encoder.convert_tf2torch(var_dict_tf, var_dict_torch)
  130. var_dict_torch_update.update(var_dict_torch_update_local)
  131. # cd scorer
  132. if model.cd_scorer is not None:
  133. var_dict_torch_update_local = model.cd_scorer.convert_tf2torch(var_dict_tf, var_dict_torch)
  134. var_dict_torch_update.update(var_dict_torch_update_local)
  135. # ci scorer
  136. if model.ci_scorer is not None:
  137. var_dict_torch_update_local = model.ci_scorer.convert_tf2torch(var_dict_tf, var_dict_torch)
  138. var_dict_torch_update.update(var_dict_torch_update_local)
  139. # decoder
  140. if model.decoder is not None:
  141. var_dict_torch_update_local = model.decoder.convert_tf2torch(var_dict_tf, var_dict_torch)
  142. var_dict_torch_update.update(var_dict_torch_update_local)
  143. elif "mode" == "sv":
  144. # speech encoder
  145. var_dict_torch_update_local = model.encoder.convert_tf2torch(var_dict_tf, var_dict_torch)
  146. var_dict_torch_update.update(var_dict_torch_update_local)
  147. # pooling layer
  148. var_dict_torch_update_local = model.pooling_layer.convert_tf2torch(var_dict_tf, var_dict_torch)
  149. var_dict_torch_update.update(var_dict_torch_update_local)
  150. # decoder
  151. var_dict_torch_update_local = model.decoder.convert_tf2torch(var_dict_tf, var_dict_torch)
  152. var_dict_torch_update.update(var_dict_torch_update_local)
  153. else:
  154. # encoder
  155. var_dict_torch_update_local = model.encoder.convert_tf2torch(var_dict_tf, var_dict_torch)
  156. var_dict_torch_update.update(var_dict_torch_update_local)
  157. # predictor
  158. var_dict_torch_update_local = model.predictor.convert_tf2torch(var_dict_tf, var_dict_torch)
  159. var_dict_torch_update.update(var_dict_torch_update_local)
  160. # decoder
  161. var_dict_torch_update_local = model.decoder.convert_tf2torch(var_dict_tf, var_dict_torch)
  162. var_dict_torch_update.update(var_dict_torch_update_local)
  163. # bias_encoder
  164. var_dict_torch_update_local = model.clas_convert_tf2torch(var_dict_tf, var_dict_torch)
  165. var_dict_torch_update.update(var_dict_torch_update_local)
  166. return var_dict_torch_update
  167. return var_dict_torch_update
  168. def fileter_model_dict(src_dict: dict, dest_dict: dict):
  169. from collections import OrderedDict
  170. new_dict = OrderedDict()
  171. for key, value in src_dict.items():
  172. if key in dest_dict:
  173. new_dict[key] = value
  174. else:
  175. logging.info("{} is no longer needed in this model.".format(key))
  176. for key, value in dest_dict.items():
  177. if key not in new_dict:
  178. logging.warning("{} is missed in checkpoint.".format(key))
  179. return new_dict