speech_asr 2 سال پیش
والد
کامیت
eac9f111b5
5فایلهای تغییر یافته به همراه128 افزوده شده و 34 حذف شده
  1. 56 1
      funasr/utils/build_args.py
  2. 1 0
      funasr/utils/build_asr_model.py
  3. 34 0
      funasr/utils/build_lm_model.py
  4. 2 0
      funasr/utils/build_model.py
  5. 35 33
      funasr/utils/build_pretrain_model.py

+ 56 - 1
funasr/utils/build_args.py

@@ -79,7 +79,62 @@ def build_args(args):
             default=None,
             help="The file path of noise scp file.",
         )
-
+    elif args.task_name == "pretrain":
+        from funasr.utils.build_pretrain_model import class_choices_list
+        for class_choices in class_choices_list:
+            # Append --<name> and --<name>_conf.
+            # e.g. --encoder and --encoder_conf
+            class_choices.add_arguments(parser)
+        parser.add_argument(
+            "--init",
+            type=lambda x: str_or_none(x.lower()),
+            default=None,
+            help="The initialization method",
+            choices=[
+                "chainer",
+                "xavier_uniform",
+                "xavier_normal",
+                "kaiming_uniform",
+                "kaiming_normal",
+                None,
+            ],
+        )
+        parser.add_argument(
+            "--input_size",
+            type=int_or_none,
+            default=None,
+            help="The number of input dimension of the feature",
+        )
+        parser.add_argument(
+            "--feats_type",
+            type=str,
+            default='fbank',
+            help="feats type, e.g. fbank, wav, ark_wav(needed to be scale normalization)",
+        )
+        parser.add_argument(
+            "--noise_db_range",
+            type=str,
+            default="13_15",
+            help="The range of noise decibel level.",
+        )
+        parser.add_argument(
+            "--pred_masked_weight",
+            type=float,
+            default=1.0,
+            help="weight for predictive loss for masked frames",
+        )
+        parser.add_argument(
+            "--pred_nomask_weight",
+            type=float,
+            default=0.0,
+            help="weight for predictive loss for unmasked frames",
+        )
+        parser.add_argument(
+            "--loss_weights",
+            type=float,
+            default=0.0,
+            help="weights for additional loss terms (not first one)",
+        )
     else:
         raise NotImplementedError("Not supported task: {}".format(args.task_name))
 

+ 1 - 0
funasr/utils/build_asr_model.py

@@ -345,6 +345,7 @@ def build_asr_model(args):
     else:
         raise NotImplementedError("Not supported model: {}".format(args.model))
 
+    # initialize
     if args.init is not None:
         initialize(model, args.init)
 

+ 34 - 0
funasr/utils/build_lm_model.py

@@ -0,0 +1,34 @@
+from funasr.lm.abs_model import AbsLM
+from funasr.lm.seq_rnn_lm import SequentialRNNLM
+from funasr.lm.transformer_lm import TransformerLM
+from funasr.torch_utils.initialize import initialize
+from funasr.train.class_choices import ClassChoices
+
+lm_choices = ClassChoices(
+    "lm",
+    classes=dict(
+        seq_rnn=SequentialRNNLM,
+        transformer=TransformerLM,
+    ),
+    type_check=AbsLM,
+    default="seq_rnn",
+)
+
+class_choices_list = [
+    # --lm and --lm_conf
+    lm_choices
+]
+
+
+def build_pretrain_model(args):
+    # token_list
+    if args.token_list is not None:
+        with open(args.token_list) as f:
+            token_list = [line.rstrip() for line in f]
+        args.token_list = list(token_list)
+        vocab_size = len(token_list)
+        logging.info(f"Vocabulary size: {vocab_size}")
+    else:
+        vocab_size = None
+
+    return model

+ 2 - 0
funasr/utils/build_model.py

@@ -7,6 +7,8 @@ def build_model(args):
         model = build_asr_model(args)
     elif args.task_name == "pretrain":
         model = build_pretrain_model(args)
+    elif args.task_name == "lm":
+        model = build_lm_model(args)
     else:
         raise NotImplementedError("Not supported task: {}".format(args.task_name))
 

+ 35 - 33
funasr/utils/build_pretrain_model.py

@@ -57,39 +57,39 @@ class_choices_list = [
 
 
 def build_pretrain_model(args):
-    if args.model_name == "data2vec":
-        # frontend
-        if args.input_size is None:
-            frontend_class = frontend_choices.get_class(args.frontend)
-            frontend = frontend_class(**args.frontend_conf)
-            input_size = frontend.output_size()
-        else:
-            args.frontend = None
-            args.frontend_conf = {}
-            frontend = None
-            input_size = args.input_size
+    # frontend
+    if args.input_size is None:
+        frontend_class = frontend_choices.get_class(args.frontend)
+        frontend = frontend_class(**args.frontend_conf)
+        input_size = frontend.output_size()
+    else:
+        args.frontend = None
+        args.frontend_conf = {}
+        frontend = None
+        input_size = args.input_size
 
-        # data augmentation for spectrogram
-        if args.specaug is not None:
-            specaug_class = specaug_choices.get_class(args.specaug)
-            specaug = specaug_class(**args.specaug_conf)
-        else:
-            specaug = None
+    # data augmentation for spectrogram
+    if args.specaug is not None:
+        specaug_class = specaug_choices.get_class(args.specaug)
+        specaug = specaug_class(**args.specaug_conf)
+    else:
+        specaug = None
 
-        # normalization layer
-        if args.normalize is not None:
-            normalize_class = normalize_choices.get_class(args.normalize)
-            normalize = normalize_class(**args.normalize_conf)
-        else:
-            normalize = None
+    # normalization layer
+    if args.normalize is not None:
+        normalize_class = normalize_choices.get_class(args.normalize)
+        normalize = normalize_class(**args.normalize_conf)
+    else:
+        normalize = None
 
-        # encoder
-        encoder_class = encoder_choices.get_class(args.encoder)
-        encoder = encoder_class(
-            input_size=input_size,
-            **args.encoder_conf,
-        )
+    # encoder
+    encoder_class = encoder_choices.get_class(args.encoder)
+    encoder = encoder_class(
+        input_size=input_size,
+        **args.encoder_conf,
+    )
 
+    if args.model_name == "data2vec":
         model_class = model_choices.get_class("data2vec")
         model = model_class(
             frontend=frontend,
@@ -97,9 +97,11 @@ def build_pretrain_model(args):
             normalize=normalize,
             encoder=encoder,
         )
+    else:
+        raise NotImplementedError("Not supported model: {}".format(args.model))
 
-        # 7. Initialize
-        if args.init is not None:
-            initialize(model, args.init)
+    # initialize
+    if args.init is not None:
+        initialize(model, args.init)
 
-        return model
+    return model