|
|
@@ -3,7 +3,7 @@ import logging
|
|
|
import torch
|
|
|
|
|
|
from funasr.layers.global_mvn import GlobalMVN
|
|
|
-from funasr.layers.label_aggregation import LabelAggregate
|
|
|
+from funasr.layers.label_aggregation import LabelAggregate, LabelAggregateMaxPooling
|
|
|
from funasr.layers.utterance_mvn import UtteranceMVN
|
|
|
from funasr.models.e2e_diar_eend_ola import DiarEENDOLAModel
|
|
|
from funasr.models.e2e_diar_sond import DiarSondModel
|
|
|
@@ -26,6 +26,8 @@ from funasr.models.frontend.wav_frontend import WavFrontendMel23
|
|
|
from funasr.models.frontend.windowing import SlidingWindow
|
|
|
from funasr.models.specaug.specaug import SpecAug
|
|
|
from funasr.models.specaug.specaug import SpecAugLFR
|
|
|
+from funasr.models.specaug.abs_profileaug import AbsProfileAug
|
|
|
+from funasr.models.specaug.profileaug import ProfileAug
|
|
|
from funasr.modules.eend_ola.encoder import EENDOLATransformerEncoder
|
|
|
from funasr.modules.eend_ola.encoder_decoder_attractor import EncoderDecoderAttractor
|
|
|
from funasr.torch_utils.initialize import initialize
|
|
|
@@ -52,6 +54,15 @@ specaug_choices = ClassChoices(
|
|
|
default=None,
|
|
|
optional=True,
|
|
|
)
|
|
|
+profileaug_choices = ClassChoices(
|
|
|
+ name="profileaug",
|
|
|
+ classes=dict(
|
|
|
+ profileaug=ProfileAug,
|
|
|
+ ),
|
|
|
+ type_check=AbsProfileAug,
|
|
|
+ default=None,
|
|
|
+ optional=True,
|
|
|
+)
|
|
|
normalize_choices = ClassChoices(
|
|
|
"normalize",
|
|
|
classes=dict(
|
|
|
@@ -64,7 +75,8 @@ normalize_choices = ClassChoices(
|
|
|
label_aggregator_choices = ClassChoices(
|
|
|
"label_aggregator",
|
|
|
classes=dict(
|
|
|
- label_aggregator=LabelAggregate
|
|
|
+ label_aggregator=LabelAggregate,
|
|
|
+ label_aggregator_max_pool=LabelAggregateMaxPooling,
|
|
|
),
|
|
|
default=None,
|
|
|
optional=True,
|
|
|
@@ -155,6 +167,8 @@ class_choices_list = [
|
|
|
frontend_choices,
|
|
|
# --specaug and --specaug_conf
|
|
|
specaug_choices,
|
|
|
+ # --profileaug and --profileaug_conf
|
|
|
+ profileaug_choices,
|
|
|
# --normalize and --normalize_conf
|
|
|
normalize_choices,
|
|
|
# --label_aggregator and --label_aggregator_conf
|
|
|
@@ -217,6 +231,13 @@ def build_diar_model(args):
|
|
|
else:
|
|
|
specaug = None
|
|
|
|
|
|
+ # Data augmentation for Profiles
|
|
|
+ if hasattr(args, "profileaug") and args.profileaug is not None:
|
|
|
+ profileaug_class = profileaug_choices.get_class(args.profileaug)
|
|
|
+ profileaug = profileaug_class(**args.profileaug_conf)
|
|
|
+ else:
|
|
|
+ profileaug = None
|
|
|
+
|
|
|
# normalization layer
|
|
|
if args.normalize is not None:
|
|
|
normalize_class = normalize_choices.get_class(args.normalize)
|
|
|
@@ -261,6 +282,7 @@ def build_diar_model(args):
|
|
|
vocab_size=vocab_size,
|
|
|
frontend=frontend,
|
|
|
specaug=specaug,
|
|
|
+ profileaug=profileaug,
|
|
|
normalize=normalize,
|
|
|
label_aggregator=label_aggregator,
|
|
|
encoder=encoder,
|