|
|
@@ -1,4 +1,7 @@
|
|
|
+import logging
|
|
|
+
|
|
|
from funasr.lm.abs_model import AbsLM
|
|
|
+from funasr.lm.abs_model import LanguageModel
|
|
|
from funasr.lm.seq_rnn_lm import SequentialRNNLM
|
|
|
from funasr.lm.transformer_lm import TransformerLM
|
|
|
from funasr.torch_utils.initialize import initialize
|
|
|
@@ -13,10 +16,19 @@ lm_choices = ClassChoices(
|
|
|
type_check=AbsLM,
|
|
|
default="seq_rnn",
|
|
|
)
|
|
|
+model_choices = ClassChoices(
|
|
|
+ "model",
|
|
|
+ classes=dict(
|
|
|
+ lm=LanguageModel,
|
|
|
+ ),
|
|
|
+ default="lm",
|
|
|
+)
|
|
|
|
|
|
class_choices_list = [
|
|
|
# --lm and --lm_conf
|
|
|
- lm_choices
|
|
|
+ lm_choices,
|
|
|
+ # --model and --model_conf
|
|
|
+ model_choices
|
|
|
]
|
|
|
|
|
|
|
|
|
@@ -31,4 +43,15 @@ def build_lm_model(args):
|
|
|
else:
|
|
|
vocab_size = None
|
|
|
|
|
|
+ # lm
|
|
|
+ lm_class = lm_choices.get_class(args.lm)
|
|
|
+ lm = lm_class(vocab_size=vocab_size, **args.lm_conf)
|
|
|
+
|
|
|
+ model_class = model_choices.get_class(args.model)
|
|
|
+ model = model_class(lm=lm, vocab_size=vocab_size, **args.model_conf)
|
|
|
+
|
|
|
+ # initialize
|
|
|
+ if args.init is not None:
|
|
|
+ initialize(model, args.init)
|
|
|
+
|
|
|
return model
|