|
|
@@ -1,78 +1,43 @@
|
|
|
import argparse
|
|
|
import logging
|
|
|
+import os
|
|
|
+from pathlib import Path
|
|
|
from typing import Callable
|
|
|
from typing import Collection
|
|
|
from typing import Dict
|
|
|
from typing import List
|
|
|
from typing import Optional
|
|
|
from typing import Tuple
|
|
|
-import os
|
|
|
-from pathlib import Path
|
|
|
-from typing import Tuple
|
|
|
from typing import Union
|
|
|
-import yaml
|
|
|
+
|
|
|
import numpy as np
|
|
|
import torch
|
|
|
+import yaml
|
|
|
from typeguard import check_argument_types
|
|
|
from typeguard import check_return_type
|
|
|
|
|
|
from funasr.datasets.collate_fn import CommonCollateFn
|
|
|
-from funasr.datasets.preprocessor import CommonPreprocessor
|
|
|
-from funasr.models.ctc import CTC
|
|
|
-from funasr.models.decoder.abs_decoder import AbsDecoder
|
|
|
-from funasr.models.decoder.rnn_decoder import RNNDecoder
|
|
|
-from funasr.models.decoder.transformer_decoder import (
|
|
|
- DynamicConvolution2DTransformerDecoder, # noqa: H301
|
|
|
-)
|
|
|
-from funasr.models.decoder.transformer_decoder import DynamicConvolutionTransformerDecoder
|
|
|
-from funasr.models.decoder.transformer_decoder import (
|
|
|
- LightweightConvolution2DTransformerDecoder, # noqa: H301
|
|
|
-)
|
|
|
-from funasr.models.decoder.transformer_decoder import (
|
|
|
- LightweightConvolutionTransformerDecoder, # noqa: H301
|
|
|
-)
|
|
|
-from funasr.models.decoder.transformer_decoder import TransformerDecoder
|
|
|
-from funasr.models.encoder.abs_encoder import AbsEncoder
|
|
|
-from funasr.models.encoder.conformer_encoder import ConformerEncoder
|
|
|
-from funasr.models.encoder.data2vec_encoder import Data2VecEncoder
|
|
|
-from funasr.models.encoder.rnn_encoder import RNNEncoder
|
|
|
-from funasr.models.encoder.transformer_encoder import TransformerEncoder
|
|
|
+from funasr.layers.abs_normalize import AbsNormalize
|
|
|
+from funasr.layers.global_mvn import GlobalMVN
|
|
|
+from funasr.layers.utterance_mvn import UtteranceMVN
|
|
|
+from funasr.models.e2e_vad import E2EVadModel
|
|
|
+from funasr.models.encoder.fsmn_encoder import FSMN
|
|
|
from funasr.models.frontend.abs_frontend import AbsFrontend
|
|
|
from funasr.models.frontend.default import DefaultFrontend
|
|
|
from funasr.models.frontend.fused import FusedFrontends
|
|
|
-from funasr.models.frontend.wav_frontend import WavFrontend, WavFrontendOnline
|
|
|
from funasr.models.frontend.s3prl import S3prlFrontend
|
|
|
+from funasr.models.frontend.wav_frontend import WavFrontend, WavFrontendOnline
|
|
|
from funasr.models.frontend.windowing import SlidingWindow
|
|
|
-from funasr.models.postencoder.abs_postencoder import AbsPostEncoder
|
|
|
-from funasr.models.postencoder.hugging_face_transformers_postencoder import (
|
|
|
- HuggingFaceTransformersPostEncoder, # noqa: H301
|
|
|
-)
|
|
|
-from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
|
|
|
-from funasr.models.preencoder.linear import LinearProjection
|
|
|
-from funasr.models.preencoder.sinc import LightweightSincConvs
|
|
|
from funasr.models.specaug.abs_specaug import AbsSpecAug
|
|
|
from funasr.models.specaug.specaug import SpecAug
|
|
|
-from funasr.layers.abs_normalize import AbsNormalize
|
|
|
-from funasr.layers.global_mvn import GlobalMVN
|
|
|
-from funasr.layers.utterance_mvn import UtteranceMVN
|
|
|
+from funasr.models.specaug.specaug import SpecAugLFR
|
|
|
from funasr.tasks.abs_task import AbsTask
|
|
|
-from funasr.text.phoneme_tokenizer import g2p_choices
|
|
|
-from funasr.train.abs_espnet_model import AbsESPnetModel
|
|
|
from funasr.train.class_choices import ClassChoices
|
|
|
from funasr.train.trainer import Trainer
|
|
|
-from funasr.utils.get_default_kwargs import get_default_kwargs
|
|
|
-from funasr.utils.nested_dict_action import NestedDictAction
|
|
|
from funasr.utils.types import float_or_none
|
|
|
from funasr.utils.types import int_or_none
|
|
|
-from funasr.utils.types import str2bool
|
|
|
from funasr.utils.types import str_or_none
|
|
|
|
|
|
-from funasr.models.specaug.specaug import SpecAugLFR
|
|
|
-from funasr.models.predictor.cif import CifPredictor, CifPredictorV2
|
|
|
-from funasr.modules.subsampling import Conv1dSubsampling
|
|
|
-from funasr.models.e2e_vad import E2EVadModel
|
|
|
-from funasr.models.encoder.fsmn_encoder import FSMN
|
|
|
-
|
|
|
frontend_choices = ClassChoices(
|
|
|
name="frontend",
|
|
|
classes=dict(
|
|
|
@@ -292,7 +257,7 @@ class VADTask(AbsTask):
|
|
|
model_class = model_choices.get_class(args.model)
|
|
|
except AttributeError:
|
|
|
model_class = model_choices.get_class("e2evad")
|
|
|
-
|
|
|
+
|
|
|
# 1. frontend
|
|
|
if args.input_size is None:
|
|
|
# Extract features in the model
|
|
|
@@ -308,7 +273,7 @@ class VADTask(AbsTask):
|
|
|
args.frontend_conf = {}
|
|
|
frontend = None
|
|
|
input_size = args.input_size
|
|
|
-
|
|
|
+
|
|
|
model = model_class(encoder=encoder, vad_post_args=args.vad_post_conf, frontend=frontend)
|
|
|
|
|
|
return model
|
|
|
@@ -344,7 +309,7 @@ class VADTask(AbsTask):
|
|
|
|
|
|
with config_file.open("r", encoding="utf-8") as f:
|
|
|
args = yaml.safe_load(f)
|
|
|
- #if cmvn_file is not None:
|
|
|
+ # if cmvn_file is not None:
|
|
|
args["cmvn_file"] = cmvn_file
|
|
|
args = argparse.Namespace(**args)
|
|
|
model = cls.build_model(args)
|