|
|
@@ -11,18 +11,6 @@ from funasr.build_utils.build_model import build_model
|
|
|
from funasr.models.base_model import FunASRModel
|
|
|
|
|
|
|
|
|
-def load_checkpoint(checkpoint_path, use_cuda=1):
|
|
|
- if use_cuda:
|
|
|
- checkpoint = torch.load(checkpoint_path)
|
|
|
- else:
|
|
|
- checkpoint = torch.load(
|
|
|
- checkpoint_path, map_location=lambda storage, loc: storage)
|
|
|
- return checkpoint
|
|
|
-
|
|
|
-def reload_ss_for_eval(model, checkpoint_path, use_cuda=False):
|
|
|
- checkpoint = load_checkpoint(checkpoint_path, use_cuda)
|
|
|
- model.load_state_dict(checkpoint['model'], strict=False)
|
|
|
-
|
|
|
def build_model_from_file(
|
|
|
config_file: Union[Path, str] = None,
|
|
|
model_file: Union[Path, str] = None,
|
|
|
@@ -82,9 +70,8 @@ def build_model_from_file(
|
|
|
model.load_state_dict(model_dict)
|
|
|
else:
|
|
|
model_dict = torch.load(model_file, map_location=device)
|
|
|
- if task_name == 'ss':
|
|
|
- reload_ss_for_eval(model, model_file, use_cuda=True)
|
|
|
- logging.info("model is loaded from path: {}".format(model_file))
|
|
|
+ if task_name == "ss":
|
|
|
+ model_dict = model_dict['model']
|
|
|
if task_name == "diar" and mode == "sond":
|
|
|
model_dict = fileter_model_dict(model_dict, model.state_dict())
|
|
|
if task_name == "vad":
|