Quellcode durchsuchen

add en sv model

志浩 vor 3 Jahren
Ursprung
Commit
777ae05adb

+ 39 - 0
egs_modelscope/speaker_verification/speech_xvector_sv-en-us-callhome-8k-spk6135-pytorch/infer.py

@@ -0,0 +1,39 @@
+from modelscope.pipelines import pipeline
+from modelscope.utils.constant import Tasks
+import numpy as np
+
+if __name__ == '__main__':
+    inference_sv_pipline = pipeline(
+        task=Tasks.speaker_verification,
+        model='damo/speech_xvector_sv-en-us-callhome-8k-spk6135-pytorch'
+    )
+
+    # extract speaker embedding
+    # for url use "spk_embedding" as key
+    rec_result = inference_sv_pipline(
+        audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/sv_example_enroll.wav')
+    enroll = rec_result["spk_embedding"]
+
+    # for local file use "spk_embedding" as key
+    rec_result = inference_sv_pipline(audio_in='sv_example_same.wav')["test1"]
+    same = rec_result["spk_embedding"]
+
+    import soundfile
+    wav = soundfile.read('sv_example_enroll.wav')[0]
+    # for raw inputs use "spk_embedding" as key
+    spk_embedding = inference_sv_pipline(audio_in=wav)["spk_embedding"]
+
+    rec_result = inference_sv_pipline(
+        audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/sv_example_different.wav')
+    different = rec_result["spk_embedding"]
+
+    # calculate cosine similarity for same speaker
+    sv_threshold = 0.9465
+    same_cos = np.sum(enroll * same) / (np.linalg.norm(enroll) * np.linalg.norm(same))
+    same_cos = max(same_cos - sv_threshold, 0.0) / (1.0 - sv_threshold) * 100.0
+    print("Similarity:", same_cos)
+
+    # calculate cosine similarity for different speaker
+    diff_cos = np.sum(enroll * different) / (np.linalg.norm(enroll) * np.linalg.norm(different))
+    diff_cos = max(diff_cos - sv_threshold, 0.0) / (1.0 - sv_threshold) * 100.0
+    print("Similarity:", diff_cos)

+ 21 - 0
egs_modelscope/speaker_verification/speech_xvector_sv-en-us-callhome-8k-spk6135-pytorch/infer_sv.py

@@ -0,0 +1,21 @@
+from modelscope.pipelines import pipeline
+from modelscope.utils.constant import Tasks
+
+if __name__ == '__main__':
+    inference_sv_pipline = pipeline(
+        task=Tasks.speaker_verification,
+        model='speech_xvector_sv-en-us-callhome-8k-spk6135-pytorch'
+    )
+
+    # the same speaker
+    rec_result = inference_sv_pipline(audio_in=(
+        'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/sv_example_enroll.wav',
+        'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/sv_example_same.wav'))
+    print("Similarity", rec_result["scores"])
+
+    # different speakers
+    rec_result = inference_sv_pipline(audio_in=(
+        'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/sv_example_enroll.wav',
+        'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/sv_example_different.wav'))
+
+    print("Similarity", rec_result["scores"])

+ 4 - 4
egs_modelscope/speaker_verification/speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch/infer.py

@@ -12,20 +12,20 @@ if __name__ == '__main__':
     # for url use "utt_id" as key
     rec_result = inference_sv_pipline(
         audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/sv_example_enroll.wav')
-    enroll = rec_result["utt_id"]
+    enroll = rec_result["spk_embedding"]
 
     # for local file use "utt_id" as key
     rec_result = inference_sv_pipline(audio_in='sv_example_same.wav')["test1"]
-    same = rec_result["test1"]
+    same = rec_result["spk_embedding"]
 
     import soundfile
     wav = soundfile.read('sv_example_enroll.wav')[0]
     # for raw inputs use "utt_id" as key
-    spk_embedding = inference_sv_pipline(audio_in=wav)["utt_id"]
+    spk_embedding = inference_sv_pipline(audio_in=wav)["spk_embedding"]
 
     rec_result = inference_sv_pipline(
         audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/sv_example_different.wav')
-    different = rec_result["utt_id"]
+    different = rec_result["spk_embedding"]
 
     # 对相同的说话人计算余弦相似度
     sv_threshold = 0.9465

+ 0 - 1
funasr/models/encoder/resnet34_encoder.py

@@ -387,7 +387,6 @@ class ResNet34_SP_L2Reg(AbsEncoder):
         return var_dict_torch_update
 
 
-
 class ResNet34Diar(ResNet34):
     def __init__(
             self,

+ 107 - 3
funasr/tasks/sv.py

@@ -1,14 +1,18 @@
 import argparse
 import logging
+import os
+from pathlib import Path
 from typing import Callable
 from typing import Collection
 from typing import Dict
 from typing import List
 from typing import Optional
 from typing import Tuple
+from typing import Union
 
 import numpy as np
 import torch
+import yaml
 from typeguard import check_argument_types
 from typeguard import check_return_type
 
@@ -21,7 +25,7 @@ from funasr.models.e2e_asr import ESPnetASRModel
 from funasr.models.decoder.abs_decoder import AbsDecoder
 from funasr.models.encoder.abs_encoder import AbsEncoder
 from funasr.models.encoder.rnn_encoder import RNNEncoder
-from funasr.models.encoder.resnet34_encoder import ResNet34
+from funasr.models.encoder.resnet34_encoder import ResNet34, ResNet34_SP_L2Reg
 from funasr.models.pooling.statistic_pooling import StatisticPooling
 from funasr.models.decoder.sv_decoder import DenseDecoder
 from funasr.models.e2e_sv import ESPnetSVModel
@@ -103,6 +107,7 @@ encoder_choices = ClassChoices(
     "encoder",
     classes=dict(
         resnet34=ResNet34,
+        resnet34_sp_l2reg=ResNet34_SP_L2Reg,
         rnn=RNNEncoder,
     ),
     type_check=AbsEncoder,
@@ -394,9 +399,16 @@ class SVTask(AbsTask):
 
         # 7. Pooling layer
         pooling_class = pooling_choices.get_class(args.pooling_type)
+        pooling_dim = (2, 3)
+        eps = 1e-12
+        if hasattr(args, "pooling_type_conf"):
+            if "pooling_dim" in args.pooling_type_conf:
+                pooling_dim = args.pooling_type_conf["pooling_dim"]
+            if "eps" in args.pooling_type_conf:
+                eps = args.pooling_type_conf["eps"]
         pooling_layer = pooling_class(
-            pooling_dim=(2, 3),
-            eps=1e-12,
+            pooling_dim=pooling_dim,
+            eps=eps,
         )
         if args.pooling_type == "statistic":
             encoder_output_size *= 2
@@ -435,3 +447,95 @@ class SVTask(AbsTask):
 
         assert check_return_type(model)
         return model
+
+    # ~~~~~~~~~ The methods below are mainly used for inference ~~~~~~~~~
+    @classmethod
+    def build_model_from_file(
+            cls,
+            config_file: Union[Path, str] = None,
+            model_file: Union[Path, str] = None,
+            cmvn_file: Union[Path, str] = None,
+            device: str = "cpu",
+    ):
+        """Build model from the files.
+
+        This method is used for inference or fine-tuning.
+
+        Args:
+            config_file: The yaml file saved when training.
+            model_file: The model file saved when training.
+            cmvn_file: The cmvn file for front-end
+            device: Device type, "cpu", "cuda", or "cuda:N".
+
+        """
+        assert check_argument_types()
+        if config_file is None:
+            assert model_file is not None, (
+                "The argument 'model_file' must be provided "
+                "if the argument 'config_file' is not specified."
+            )
+            config_file = Path(model_file).parent / "config.yaml"
+        else:
+            config_file = Path(config_file)
+
+        with config_file.open("r", encoding="utf-8") as f:
+            args = yaml.safe_load(f)
+        if cmvn_file is not None:
+            args["cmvn_file"] = cmvn_file
+        args = argparse.Namespace(**args)
+        model = cls.build_model(args)
+        if not isinstance(model, AbsESPnetModel):
+            raise RuntimeError(
+                f"model must inherit {AbsESPnetModel.__name__}, but got {type(model)}"
+            )
+        model.to(device)
+        model_dict = dict()
+        model_name_pth = None
+        if model_file is not None:
+            logging.info("model_file is {}".format(model_file))
+            if device == "cuda":
+                device = f"cuda:{torch.cuda.current_device()}"
+            model_dir = os.path.dirname(model_file)
+            model_name = os.path.basename(model_file)
+            if "model.ckpt-" in model_name or ".bin" in model_name:
+                if ".bin" in model_name:
+                    model_name_pth = os.path.join(model_dir, model_name.replace('.bin', '.pb'))
+                else:
+                    model_name_pth = os.path.join(model_dir, "{}.pth".format(model_name))
+                if os.path.exists(model_name_pth):
+                    logging.info("model_file is load from pth: {}".format(model_name_pth))
+                    model_dict = torch.load(model_name_pth, map_location=device)
+                else:
+                    model_dict = cls.convert_tf2torch(model, model_file)
+                model.load_state_dict(model_dict)
+            else:
+                model_dict = torch.load(model_file, map_location=device)
+        model.load_state_dict(model_dict)
+        if model_name_pth is not None and not os.path.exists(model_name_pth):
+            torch.save(model_dict, model_name_pth)
+            logging.info("model_file is saved to pth: {}".format(model_name_pth))
+
+        return model, args
+
+    @classmethod
+    def convert_tf2torch(
+            cls,
+            model,
+            ckpt,
+    ):
+        logging.info("start convert tf model to torch model")
+        from funasr.modules.streaming_utils.load_fr_tf import load_tf_dict
+        var_dict_tf = load_tf_dict(ckpt)
+        var_dict_torch = model.state_dict()
+        var_dict_torch_update = dict()
+        # speech encoder
+        var_dict_torch_update_local = model.encoder.convert_tf2torch(var_dict_tf, var_dict_torch)
+        var_dict_torch_update.update(var_dict_torch_update_local)
+        # pooling layer
+        var_dict_torch_update_local = model.pooling_layer.convert_tf2torch(var_dict_tf, var_dict_torch)
+        var_dict_torch_update.update(var_dict_torch_update_local)
+        # decoder
+        var_dict_torch_update_local = model.decoder.convert_tf2torch(var_dict_tf, var_dict_torch)
+        var_dict_torch_update.update(var_dict_torch_update_local)
+
+        return var_dict_torch_update