|
|
@@ -239,6 +239,7 @@ def build_asr_model(args):
|
|
|
vocab_size = len(token_list)
|
|
|
logging.info(f"Vocabulary size: {vocab_size}")
|
|
|
else:
|
|
|
+ token_list = None
|
|
|
vocab_size = None
|
|
|
|
|
|
# frontend
|
|
|
@@ -265,7 +266,10 @@ def build_asr_model(args):
|
|
|
# normalization layer
|
|
|
if args.normalize is not None:
|
|
|
normalize_class = normalize_choices.get_class(args.normalize)
|
|
|
- normalize = normalize_class(**args.normalize_conf)
|
|
|
+ if args.model == "mfcca":
|
|
|
+ normalize = normalize_class(stats_file=args.cmvn_file,**args.normalize_conf)
|
|
|
+ else:
|
|
|
+ normalize = normalize_class(**args.normalize_conf)
|
|
|
else:
|
|
|
normalize = None
|
|
|
|
|
|
@@ -300,7 +304,7 @@ def build_asr_model(args):
|
|
|
**args.model_conf,
|
|
|
)
|
|
|
elif args.model in ["paraformer", "paraformer_online", "paraformer_bert", "bicif_paraformer",
|
|
|
- "contextual_paraformer"]:
|
|
|
+ "contextual_paraformer", "neatcontextual_paraformer"]:
|
|
|
# predictor
|
|
|
predictor_class = predictor_choices.get_class(args.predictor)
|
|
|
predictor = predictor_class(**args.predictor_conf)
|