嘉渊 2 سال پیش
والد
کامیت
27fddb4982

+ 2 - 1
funasr/bin/sv_infer.py

@@ -50,7 +50,8 @@ class Speech2Xvector:
             model_file=sv_model_file,
             cmvn_file=None,
             device=device,
-            task_name="sv"
+            task_name="sv",
+            mode="sv",
         )
         logging.info("sv_model: {}".format(sv_model))
         logging.info("model parameter number: {}".format(statistic_model_parameters(sv_model)))

+ 11 - 0
funasr/build_utils/build_args.py

@@ -81,6 +81,17 @@ def build_args(args, parser, extra_task_params):
         for class_choices in class_choices_list:
             class_choices.add_arguments(task_parser)
 
+    elif args.task_name == "sv":
+        from funasr.build_utils.build_sv_model import class_choices_list
+        for class_choices in class_choices_list:
+            class_choices.add_arguments(task_parser)
+        task_parser.add_argument(
+            "--input_size",
+            type=int_or_none,
+            default=None,
+            help="The number of input dimension of the feature",
+        )
+
     else:
         raise NotImplementedError("Not supported task: {}".format(args.task_name))
 

+ 4 - 1
funasr/build_utils/build_model.py

@@ -1,9 +1,10 @@
 from funasr.build_utils.build_asr_model import build_asr_model
+from funasr.build_utils.build_diar_model import build_diar_model
 from funasr.build_utils.build_lm_model import build_lm_model
 from funasr.build_utils.build_pretrain_model import build_pretrain_model
 from funasr.build_utils.build_punc_model import build_punc_model
+from funasr.build_utils.build_sv_model import build_sv_model
 from funasr.build_utils.build_vad_model import build_vad_model
-from funasr.build_utils.build_diar_model import build_diar_model
 
 
 def build_model(args):
@@ -19,6 +20,8 @@ def build_model(args):
         model = build_vad_model(args)
     elif args.task_name == "diar":
         model = build_diar_model(args)
+    elif args.task_name == "sv":
+        model = build_sv_model(args)
     else:
         raise NotImplementedError("Not supported task: {}".format(args.task_name))
 

+ 16 - 3
funasr/build_utils/build_model_from_file.py

@@ -87,7 +87,7 @@ def convert_tf2torch(
         ckpt,
         mode,
 ):
-    assert mode == "paraformer" or mode == "uniasr" or mode == "sond"
+    assert mode == "paraformer" or mode == "uniasr" or mode == "sond" or mode == "sv"
     logging.info("start convert tf model to torch model")
     from funasr.modules.streaming_utils.load_fr_tf import load_tf_dict
     var_dict_tf = load_tf_dict(ckpt)
@@ -128,7 +128,7 @@ def convert_tf2torch(
         # bias_encoder
         var_dict_torch_update_local = model.clas_convert_tf2torch(var_dict_tf, var_dict_torch)
         var_dict_torch_update.update(var_dict_torch_update_local)
-    else:
+    elif "mode" == "sond":
         if model.encoder is not None:
             var_dict_torch_update_local = model.encoder.convert_tf2torch(var_dict_tf, var_dict_torch)
             var_dict_torch_update.update(var_dict_torch_update_local)
@@ -148,9 +148,22 @@ def convert_tf2torch(
         if model.decoder is not None:
             var_dict_torch_update_local = model.decoder.convert_tf2torch(var_dict_tf, var_dict_torch)
             var_dict_torch_update.update(var_dict_torch_update_local)
+    else:
+        # speech encoder
+        var_dict_torch_update_local = model.encoder.convert_tf2torch(var_dict_tf, var_dict_torch)
+        var_dict_torch_update.update(var_dict_torch_update_local)
+        # pooling layer
+        var_dict_torch_update_local = model.pooling_layer.convert_tf2torch(var_dict_tf, var_dict_torch)
+        var_dict_torch_update.update(var_dict_torch_update_local)
+        # decoder
+        var_dict_torch_update_local = model.decoder.convert_tf2torch(var_dict_tf, var_dict_torch)
+        var_dict_torch_update.update(var_dict_torch_update_local)
+
+        return var_dict_torch_update
 
     return var_dict_torch_update
 
+
 def fileter_model_dict(src_dict: dict, dest_dict: dict):
     from collections import OrderedDict
     new_dict = OrderedDict()
@@ -162,4 +175,4 @@ def fileter_model_dict(src_dict: dict, dest_dict: dict):
     for key, value in dest_dict.items():
         if key not in new_dict:
             logging.warning("{} is missed in checkpoint.".format(key))
-    return new_dict
+    return new_dict

+ 258 - 0
funasr/build_utils/build_sv_model.py

@@ -0,0 +1,258 @@
+import logging
+
+import torch
+from typeguard import check_return_type
+
+from funasr.layers.abs_normalize import AbsNormalize
+from funasr.layers.global_mvn import GlobalMVN
+from funasr.layers.utterance_mvn import UtteranceMVN
+from funasr.models.base_model import FunASRModel
+from funasr.models.decoder.abs_decoder import AbsDecoder
+from funasr.models.decoder.sv_decoder import DenseDecoder
+from funasr.models.e2e_sv import ESPnetSVModel
+from funasr.models.encoder.abs_encoder import AbsEncoder
+from funasr.models.encoder.resnet34_encoder import ResNet34, ResNet34_SP_L2Reg
+from funasr.models.encoder.rnn_encoder import RNNEncoder
+from funasr.models.frontend.abs_frontend import AbsFrontend
+from funasr.models.frontend.default import DefaultFrontend
+from funasr.models.frontend.fused import FusedFrontends
+from funasr.models.frontend.s3prl import S3prlFrontend
+from funasr.models.frontend.wav_frontend import WavFrontend
+from funasr.models.frontend.windowing import SlidingWindow
+from funasr.models.pooling.statistic_pooling import StatisticPooling
+from funasr.models.postencoder.abs_postencoder import AbsPostEncoder
+from funasr.models.postencoder.hugging_face_transformers_postencoder import (
+    HuggingFaceTransformersPostEncoder,  # noqa: H301
+)
+from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
+from funasr.models.preencoder.linear import LinearProjection
+from funasr.models.preencoder.sinc import LightweightSincConvs
+from funasr.models.specaug.abs_specaug import AbsSpecAug
+from funasr.models.specaug.specaug import SpecAug
+from funasr.torch_utils.initialize import initialize
+from funasr.train.class_choices import ClassChoices
+
+frontend_choices = ClassChoices(
+    name="frontend",
+    classes=dict(
+        default=DefaultFrontend,
+        sliding_window=SlidingWindow,
+        s3prl=S3prlFrontend,
+        fused=FusedFrontends,
+        wav_frontend=WavFrontend,
+    ),
+    type_check=AbsFrontend,
+    default="default",
+)
+specaug_choices = ClassChoices(
+    name="specaug",
+    classes=dict(
+        specaug=SpecAug,
+    ),
+    type_check=AbsSpecAug,
+    default=None,
+    optional=True,
+)
+normalize_choices = ClassChoices(
+    "normalize",
+    classes=dict(
+        global_mvn=GlobalMVN,
+        utterance_mvn=UtteranceMVN,
+    ),
+    type_check=AbsNormalize,
+    default=None,
+    optional=True,
+)
+model_choices = ClassChoices(
+    "model",
+    classes=dict(
+        espnet=ESPnetSVModel,
+    ),
+    type_check=FunASRModel,
+    default="espnet",
+)
+preencoder_choices = ClassChoices(
+    name="preencoder",
+    classes=dict(
+        sinc=LightweightSincConvs,
+        linear=LinearProjection,
+    ),
+    type_check=AbsPreEncoder,
+    default=None,
+    optional=True,
+)
+encoder_choices = ClassChoices(
+    "encoder",
+    classes=dict(
+        resnet34=ResNet34,
+        resnet34_sp_l2reg=ResNet34_SP_L2Reg,
+        rnn=RNNEncoder,
+    ),
+    type_check=AbsEncoder,
+    default="resnet34",
+)
+postencoder_choices = ClassChoices(
+    name="postencoder",
+    classes=dict(
+        hugging_face_transformers=HuggingFaceTransformersPostEncoder,
+    ),
+    type_check=AbsPostEncoder,
+    default=None,
+    optional=True,
+)
+pooling_choices = ClassChoices(
+    name="pooling_type",
+    classes=dict(
+        statistic=StatisticPooling,
+    ),
+    type_check=torch.nn.Module,
+    default="statistic",
+)
+decoder_choices = ClassChoices(
+    "decoder",
+    classes=dict(
+        dense=DenseDecoder,
+    ),
+    type_check=AbsDecoder,
+    default="dense",
+)
+
+class_choices_list = [
+    # --frontend and --frontend_conf
+    frontend_choices,
+    # --specaug and --specaug_conf
+    specaug_choices,
+    # --normalize and --normalize_conf
+    normalize_choices,
+    # --model and --model_conf
+    model_choices,
+    # --preencoder and --preencoder_conf
+    preencoder_choices,
+    # --encoder and --encoder_conf
+    encoder_choices,
+    # --postencoder and --postencoder_conf
+    postencoder_choices,
+    # --pooling and --pooling_conf
+    pooling_choices,
+    # --decoder and --decoder_conf
+    decoder_choices,
+]
+
+
+def build_sv_model(args):
+    # token_list
+    if isinstance(args.token_list, str):
+        with open(args.token_list, encoding="utf-8") as f:
+            token_list = [line.rstrip() for line in f]
+
+        # Overwriting token_list to keep it as "portable".
+        args.token_list = list(token_list)
+    elif isinstance(args.token_list, (tuple, list)):
+        token_list = list(args.token_list)
+    else:
+        raise RuntimeError("token_list must be str or list")
+    vocab_size = len(token_list)
+    logging.info(f"Speaker number: {vocab_size}")
+
+    # 1. frontend
+    if args.input_size is None:
+        # Extract features in the model
+        frontend_class = frontend_choices.get_class(args.frontend)
+        frontend = frontend_class(**args.frontend_conf)
+        input_size = frontend.output_size()
+    else:
+        # Give features from data-loader
+        args.frontend = None
+        args.frontend_conf = {}
+        frontend = None
+        input_size = args.input_size
+
+    # 2. Data augmentation for spectrogram
+    if args.specaug is not None:
+        specaug_class = specaug_choices.get_class(args.specaug)
+        specaug = specaug_class(**args.specaug_conf)
+    else:
+        specaug = None
+
+    # 3. Normalization layer
+    if args.normalize is not None:
+        normalize_class = normalize_choices.get_class(args.normalize)
+        normalize = normalize_class(**args.normalize_conf)
+    else:
+        normalize = None
+
+    # 4. Pre-encoder input block
+    # NOTE(kan-bayashi): Use getattr to keep the compatibility
+    if getattr(args, "preencoder", None) is not None:
+        preencoder_class = preencoder_choices.get_class(args.preencoder)
+        preencoder = preencoder_class(**args.preencoder_conf)
+        input_size = preencoder.output_size()
+    else:
+        preencoder = None
+
+    # 5. Encoder
+    encoder_class = encoder_choices.get_class(args.encoder)
+    encoder = encoder_class(input_size=input_size, **args.encoder_conf)
+
+    # 6. Post-encoder block
+    # NOTE(kan-bayashi): Use getattr to keep the compatibility
+    encoder_output_size = encoder.output_size()
+    if getattr(args, "postencoder", None) is not None:
+        postencoder_class = postencoder_choices.get_class(args.postencoder)
+        postencoder = postencoder_class(
+            input_size=encoder_output_size, **args.postencoder_conf
+        )
+        encoder_output_size = postencoder.output_size()
+    else:
+        postencoder = None
+
+    # 7. Pooling layer
+    pooling_class = pooling_choices.get_class(args.pooling_type)
+    pooling_dim = (2, 3)
+    eps = 1e-12
+    if hasattr(args, "pooling_type_conf"):
+        if "pooling_dim" in args.pooling_type_conf:
+            pooling_dim = args.pooling_type_conf["pooling_dim"]
+        if "eps" in args.pooling_type_conf:
+            eps = args.pooling_type_conf["eps"]
+    pooling_layer = pooling_class(
+        pooling_dim=pooling_dim,
+        eps=eps,
+    )
+    if args.pooling_type == "statistic":
+        encoder_output_size *= 2
+
+    # 8. Decoder
+    decoder_class = decoder_choices.get_class(args.decoder)
+    decoder = decoder_class(
+        vocab_size=vocab_size,
+        encoder_output_size=encoder_output_size,
+        **args.decoder_conf,
+    )
+
+    # 7. Build model
+    try:
+        model_class = model_choices.get_class(args.model)
+    except AttributeError:
+        model_class = model_choices.get_class("espnet")
+    model = model_class(
+        vocab_size=vocab_size,
+        token_list=token_list,
+        frontend=frontend,
+        specaug=specaug,
+        normalize=normalize,
+        preencoder=preencoder,
+        encoder=encoder,
+        postencoder=postencoder,
+        pooling_layer=pooling_layer,
+        decoder=decoder,
+        **args.model_conf,
+    )
+
+    # FIXME(kamo): Should be done in model?
+    # 8. Initialize
+    if args.init is not None:
+        initialize(model, args.init)
+
+    assert check_return_type(model)
+    return model