build_model_from_file.py 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191
  1. import argparse
  2. import logging
  3. import os
  4. from pathlib import Path
  5. from typing import Union
  6. import torch
  7. import yaml
  8. from funasr.build_utils.build_model import build_model
  9. from funasr.models.base_model import FunASRModel
  10. def build_model_from_file(
  11. config_file: Union[Path, str] = None,
  12. model_file: Union[Path, str] = None,
  13. cmvn_file: Union[Path, str] = None,
  14. device: str = "cpu",
  15. task_name: str = "asr",
  16. mode: str = "paraformer",
  17. ):
  18. """Build model from the files.
  19. This method is used for inference or fine-tuning.
  20. Args:
  21. config_file: The yaml file saved when training.
  22. model_file: The model file saved when training.
  23. device: Device type, "cpu", "cuda", or "cuda:N".
  24. """
  25. if config_file is None:
  26. assert model_file is not None, (
  27. "The argument 'model_file' must be provided "
  28. "if the argument 'config_file' is not specified."
  29. )
  30. config_file = Path(model_file).parent / "config.yaml"
  31. else:
  32. config_file = Path(config_file)
  33. with config_file.open("r", encoding="utf-8") as f:
  34. args = yaml.safe_load(f)
  35. if cmvn_file is not None:
  36. args["cmvn_file"] = cmvn_file
  37. args = argparse.Namespace(**args)
  38. args.task_name = task_name
  39. model = build_model(args)
  40. if not isinstance(model, FunASRModel):
  41. raise RuntimeError(
  42. f"model must inherit {FunASRModel.__name__}, but got {type(model)}"
  43. )
  44. model.to(device)
  45. model_dict = dict()
  46. model_name_pth = None
  47. if model_file is not None:
  48. logging.info("model_file is {}".format(model_file))
  49. if device == "cuda":
  50. device = f"cuda:{torch.cuda.current_device()}"
  51. model_dir = os.path.dirname(model_file)
  52. model_name = os.path.basename(model_file)
  53. if "model.ckpt-" in model_name or ".bin" in model_name:
  54. model_name_pth = os.path.join(model_dir, model_name.replace('.bin',
  55. '.pb')) if ".bin" in model_name else os.path.join(
  56. model_dir, "{}.pb".format(model_name))
  57. if os.path.exists(model_name_pth):
  58. logging.info("model_file is load from pth: {}".format(model_name_pth))
  59. model_dict = torch.load(model_name_pth, map_location=device)
  60. else:
  61. model_dict = convert_tf2torch(model, model_file, mode)
  62. model.load_state_dict(model_dict)
  63. else:
  64. model_dict = torch.load(model_file, map_location=device)
  65. if task_name == "diar" and mode == "sond":
  66. model_dict = fileter_model_dict(model_dict, model.state_dict())
  67. if task_name == "vad":
  68. model.encoder.load_state_dict(model_dict)
  69. else:
  70. model.load_state_dict(model_dict)
  71. if model_name_pth is not None and not os.path.exists(model_name_pth):
  72. torch.save(model_dict, model_name_pth)
  73. logging.info("model_file is saved to pth: {}".format(model_name_pth))
  74. return model, args
  75. def convert_tf2torch(
  76. model,
  77. ckpt,
  78. mode,
  79. ):
  80. assert mode == "paraformer" or mode == "uniasr" or mode == "sond" or mode == "sv" or mode == "tp"
  81. logging.info("start convert tf model to torch model")
  82. from funasr.modules.streaming_utils.load_fr_tf import load_tf_dict
  83. var_dict_tf = load_tf_dict(ckpt)
  84. var_dict_torch = model.state_dict()
  85. var_dict_torch_update = dict()
  86. if mode == "uniasr":
  87. # encoder
  88. var_dict_torch_update_local = model.encoder.convert_tf2torch(var_dict_tf, var_dict_torch)
  89. var_dict_torch_update.update(var_dict_torch_update_local)
  90. # predictor
  91. var_dict_torch_update_local = model.predictor.convert_tf2torch(var_dict_tf, var_dict_torch)
  92. var_dict_torch_update.update(var_dict_torch_update_local)
  93. # decoder
  94. var_dict_torch_update_local = model.decoder.convert_tf2torch(var_dict_tf, var_dict_torch)
  95. var_dict_torch_update.update(var_dict_torch_update_local)
  96. # encoder2
  97. var_dict_torch_update_local = model.encoder2.convert_tf2torch(var_dict_tf, var_dict_torch)
  98. var_dict_torch_update.update(var_dict_torch_update_local)
  99. # predictor2
  100. var_dict_torch_update_local = model.predictor2.convert_tf2torch(var_dict_tf, var_dict_torch)
  101. var_dict_torch_update.update(var_dict_torch_update_local)
  102. # decoder2
  103. var_dict_torch_update_local = model.decoder2.convert_tf2torch(var_dict_tf, var_dict_torch)
  104. var_dict_torch_update.update(var_dict_torch_update_local)
  105. # stride_conv
  106. var_dict_torch_update_local = model.stride_conv.convert_tf2torch(var_dict_tf, var_dict_torch)
  107. var_dict_torch_update.update(var_dict_torch_update_local)
  108. elif mode == "paraformer":
  109. # encoder
  110. var_dict_torch_update_local = model.encoder.convert_tf2torch(var_dict_tf, var_dict_torch)
  111. var_dict_torch_update.update(var_dict_torch_update_local)
  112. # predictor
  113. var_dict_torch_update_local = model.predictor.convert_tf2torch(var_dict_tf, var_dict_torch)
  114. var_dict_torch_update.update(var_dict_torch_update_local)
  115. # decoder
  116. var_dict_torch_update_local = model.decoder.convert_tf2torch(var_dict_tf, var_dict_torch)
  117. var_dict_torch_update.update(var_dict_torch_update_local)
  118. # bias_encoder
  119. var_dict_torch_update_local = model.clas_convert_tf2torch(var_dict_tf, var_dict_torch)
  120. var_dict_torch_update.update(var_dict_torch_update_local)
  121. elif "mode" == "sond":
  122. if model.encoder is not None:
  123. var_dict_torch_update_local = model.encoder.convert_tf2torch(var_dict_tf, var_dict_torch)
  124. var_dict_torch_update.update(var_dict_torch_update_local)
  125. # speaker encoder
  126. if model.speaker_encoder is not None:
  127. var_dict_torch_update_local = model.speaker_encoder.convert_tf2torch(var_dict_tf, var_dict_torch)
  128. var_dict_torch_update.update(var_dict_torch_update_local)
  129. # cd scorer
  130. if model.cd_scorer is not None:
  131. var_dict_torch_update_local = model.cd_scorer.convert_tf2torch(var_dict_tf, var_dict_torch)
  132. var_dict_torch_update.update(var_dict_torch_update_local)
  133. # ci scorer
  134. if model.ci_scorer is not None:
  135. var_dict_torch_update_local = model.ci_scorer.convert_tf2torch(var_dict_tf, var_dict_torch)
  136. var_dict_torch_update.update(var_dict_torch_update_local)
  137. # decoder
  138. if model.decoder is not None:
  139. var_dict_torch_update_local = model.decoder.convert_tf2torch(var_dict_tf, var_dict_torch)
  140. var_dict_torch_update.update(var_dict_torch_update_local)
  141. elif "mode" == "sv":
  142. # speech encoder
  143. var_dict_torch_update_local = model.encoder.convert_tf2torch(var_dict_tf, var_dict_torch)
  144. var_dict_torch_update.update(var_dict_torch_update_local)
  145. # pooling layer
  146. var_dict_torch_update_local = model.pooling_layer.convert_tf2torch(var_dict_tf, var_dict_torch)
  147. var_dict_torch_update.update(var_dict_torch_update_local)
  148. # decoder
  149. var_dict_torch_update_local = model.decoder.convert_tf2torch(var_dict_tf, var_dict_torch)
  150. var_dict_torch_update.update(var_dict_torch_update_local)
  151. else:
  152. # encoder
  153. var_dict_torch_update_local = model.encoder.convert_tf2torch(var_dict_tf, var_dict_torch)
  154. var_dict_torch_update.update(var_dict_torch_update_local)
  155. # predictor
  156. var_dict_torch_update_local = model.predictor.convert_tf2torch(var_dict_tf, var_dict_torch)
  157. var_dict_torch_update.update(var_dict_torch_update_local)
  158. # decoder
  159. var_dict_torch_update_local = model.decoder.convert_tf2torch(var_dict_tf, var_dict_torch)
  160. var_dict_torch_update.update(var_dict_torch_update_local)
  161. # bias_encoder
  162. var_dict_torch_update_local = model.clas_convert_tf2torch(var_dict_tf, var_dict_torch)
  163. var_dict_torch_update.update(var_dict_torch_update_local)
  164. return var_dict_torch_update
  165. return var_dict_torch_update
  166. def fileter_model_dict(src_dict: dict, dest_dict: dict):
  167. from collections import OrderedDict
  168. new_dict = OrderedDict()
  169. for key, value in src_dict.items():
  170. if key in dest_dict:
  171. new_dict[key] = value
  172. else:
  173. logging.info("{} is no longer needed in this model.".format(key))
  174. for key, value in dest_dict.items():
  175. if key not in new_dict:
  176. logging.warning("{} is missed in checkpoint.".format(key))
  177. return new_dict