speech_asr 3 yıl önce
ebeveyn
işleme
e27de5aa6b

+ 12 - 7
funasr/models/e2e_diar_eend_ola.py

@@ -11,7 +11,8 @@ import torch
 import torch.nn as  nn
 from typeguard import check_argument_types
 
-from funasr.modules.eend_ola.encoder import TransformerEncoder
+from funasr.models.frontend.wav_frontend import WavFrontendMel23
+from funasr.modules.eend_ola.encoder import EENDOLATransformerEncoder
 from funasr.modules.eend_ola.encoder_decoder_attractor import EncoderDecoderAttractor
 from funasr.modules.eend_ola.utils.power import generate_mapping_dict
 from funasr.torch_utils.device_funcs import force_gatherable
@@ -34,12 +35,13 @@ def pad_attractor(att, max_n_speakers):
 
 
 class DiarEENDOLAModel(AbsESPnetModel):
-    """CTC-attention hybrid Encoder-Decoder model"""
+    """EEND-OLA diarization model"""
 
     def __init__(
             self,
-            encoder: TransformerEncoder,
-            eda: EncoderDecoderAttractor,
+            frontend: WavFrontendMel23,
+            encoder: EENDOLATransformerEncoder,
+            encoder_decoder_attractor: EncoderDecoderAttractor,
             n_units: int = 256,
             max_n_speaker: int = 8,
             attractor_loss_weight: float = 1.0,
@@ -49,8 +51,9 @@ class DiarEENDOLAModel(AbsESPnetModel):
         assert check_argument_types()
 
         super().__init__()
+        self.frontend = frontend
         self.encoder = encoder
-        self.eda = eda
+        self.encoder_decoder_attractor = encoder_decoder_attractor
         self.attractor_loss_weight = attractor_loss_weight
         self.max_n_speaker = max_n_speaker
         if mapping_dict is None:
@@ -187,16 +190,18 @@ class DiarEENDOLAModel(AbsESPnetModel):
                             shuffle: bool = True,
                             threshold: float = 0.5,
                             **kwargs):
+        if self.frontend is not None:
+            speech = self.frontend(speech)
         speech = [s[:s_len] for s, s_len in zip(speech, speech_lengths)]
         emb = self.forward_encoder(speech, speech_lengths)
         if shuffle:
             orders = [np.arange(e.shape[0]) for e in emb]
             for order in orders:
                 np.random.shuffle(order)
-            attractors, probs = self.eda.estimate(
+            attractors, probs = self.encoder_decoder_attractor.estimate(
                 [e[torch.from_numpy(order).to(torch.long).to(speech[0].device)] for e, order in zip(emb, orders)])
         else:
-            attractors, probs = self.eda.estimate(emb)
+            attractors, probs = self.encoder_decoder_attractor.estimate(emb)
         attractors_active = []
         for p, att, e in zip(probs, attractors, emb):
             if n_speakers and n_speakers >= 0:

+ 11 - 5
funasr/modules/eend_ola/encoder.py

@@ -1,5 +1,5 @@
 import math
-import numpy as np
+
 import torch
 import torch.nn.functional as F
 from torch import nn
@@ -81,10 +81,16 @@ class PositionalEncoding(torch.nn.Module):
         return self.dropout(x)
 
 
-class TransformerEncoder(nn.Module):
-    def __init__(self, idim, n_layers, n_units,
-                 e_units=2048, h=8, dropout_rate=0.1, use_pos_emb=False):
-        super(TransformerEncoder, self).__init__()
+class EENDOLATransformerEncoder(nn.Module):
+    def __init__(self,
+                 idim: int,
+                 n_layers: int,
+                 n_units: int,
+                 e_units: int = 2048,
+                 h: int = 8,
+                 dropout_rate: float = 0.1,
+                 use_pos_emb: bool = False):
+        super(EENDOLATransformerEncoder, self).__init__()
         self.lnorm_in = nn.LayerNorm(n_units)
         self.n_layers = n_layers
         self.dropout = nn.Dropout(dropout_rate)

+ 312 - 15
funasr/tasks/diar.py

@@ -20,19 +20,18 @@ from funasr.datasets.collate_fn import CommonCollateFn
 from funasr.datasets.preprocessor import CommonPreprocessor
 from funasr.layers.abs_normalize import AbsNormalize
 from funasr.layers.global_mvn import GlobalMVN
-from funasr.layers.utterance_mvn import UtteranceMVN
 from funasr.layers.label_aggregation import LabelAggregate
-from funasr.models.ctc import CTC
-from funasr.models.encoder.resnet34_encoder import ResNet34Diar, ResNet34SpL2RegDiar
-from funasr.models.encoder.ecapa_tdnn_encoder import ECAPA_TDNN
-from funasr.models.encoder.opennmt_encoders.conv_encoder import ConvEncoder
-from funasr.models.encoder.opennmt_encoders.fsmn_encoder import FsmnEncoder
-from funasr.models.encoder.opennmt_encoders.self_attention_encoder import SelfAttentionEncoder
-from funasr.models.encoder.opennmt_encoders.ci_scorers import DotScorer, CosScorer
+from funasr.layers.utterance_mvn import UtteranceMVN
 from funasr.models.e2e_diar_sond import DiarSondModel
 from funasr.models.encoder.abs_encoder import AbsEncoder
 from funasr.models.encoder.conformer_encoder import ConformerEncoder
 from funasr.models.encoder.data2vec_encoder import Data2VecEncoder
+from funasr.models.encoder.ecapa_tdnn_encoder import ECAPA_TDNN
+from funasr.models.encoder.opennmt_encoders.ci_scorers import DotScorer, CosScorer
+from funasr.models.encoder.opennmt_encoders.conv_encoder import ConvEncoder
+from funasr.models.encoder.opennmt_encoders.fsmn_encoder import FsmnEncoder
+from funasr.models.encoder.opennmt_encoders.self_attention_encoder import SelfAttentionEncoder
+from funasr.models.encoder.resnet34_encoder import ResNet34Diar, ResNet34SpL2RegDiar
 from funasr.models.encoder.rnn_encoder import RNNEncoder
 from funasr.models.encoder.sanm_encoder import SANMEncoder, SANMEncoderChunkOpt
 from funasr.models.encoder.transformer_encoder import TransformerEncoder
@@ -41,17 +40,13 @@ from funasr.models.frontend.default import DefaultFrontend
 from funasr.models.frontend.fused import FusedFrontends
 from funasr.models.frontend.s3prl import S3prlFrontend
 from funasr.models.frontend.wav_frontend import WavFrontend
+from funasr.models.frontend.wav_frontend import WavFrontendMel23
 from funasr.models.frontend.windowing import SlidingWindow
-from funasr.models.postencoder.abs_postencoder import AbsPostEncoder
-from funasr.models.postencoder.hugging_face_transformers_postencoder import (
-    HuggingFaceTransformersPostEncoder,  # noqa: H301
-)
-from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
-from funasr.models.preencoder.linear import LinearProjection
-from funasr.models.preencoder.sinc import LightweightSincConvs
 from funasr.models.specaug.abs_specaug import AbsSpecAug
 from funasr.models.specaug.specaug import SpecAug
 from funasr.models.specaug.specaug import SpecAugLFR
+from funasr.modules.eend_ola.encoder import EENDOLATransformerEncoder
+from funasr.modules.eend_ola.encoder_decoder_attractor import EncoderDecoderAttractor
 from funasr.tasks.abs_task import AbsTask
 from funasr.torch_utils.initialize import initialize
 from funasr.train.abs_espnet_model import AbsESPnetModel
@@ -70,6 +65,7 @@ frontend_choices = ClassChoices(
         s3prl=S3prlFrontend,
         fused=FusedFrontends,
         wav_frontend=WavFrontend,
+        wav_frontend_mel23=WavFrontendMel23,
     ),
     type_check=AbsFrontend,
     default="default",
@@ -126,6 +122,7 @@ encoder_choices = ClassChoices(
         sanm_chunk_opt=SANMEncoderChunkOpt,
         data2vec_encoder=Data2VecEncoder,
         ecapa_tdnn=ECAPA_TDNN,
+        eend_ola_transformer=EENDOLATransformerEncoder,
     ),
     type_check=torch.nn.Module,
     default="resnet34",
@@ -177,6 +174,15 @@ decoder_choices = ClassChoices(
     type_check=torch.nn.Module,
     default="fsmn",
 )
+# encoder_decoder_attractor is used for EEND-OLA
+encoder_decoder_attractor_choices = ClassChoices(
+    "encoder_decoder_attractor",
+    classes=dict(
+        eda=EncoderDecoderAttractor,
+    ),
+    type_check=torch.nn.Module,
+    default="eda",
+)
 
 
 class DiarTask(AbsTask):
@@ -594,3 +600,294 @@ class DiarTask(AbsTask):
             var_dict_torch_update.update(var_dict_torch_update_local)
 
         return var_dict_torch_update
+
+
+class EENDOLADiarTask(AbsTask):
+    # If you need more than 1 optimizer, change this value
+    num_optimizers: int = 1
+
+    # Add variable objects configurations
+    class_choices_list = [
+        # --frontend and --frontend_conf
+        frontend_choices,
+        # --specaug and --specaug_conf
+        model_choices,
+        # --encoder and --encoder_conf
+        encoder_choices,
+        # --speaker_encoder and --speaker_encoder_conf
+        encoder_decoder_attractor_choices,
+    ]
+
+    # If you need to modify train() or eval() procedures, change Trainer class here
+    trainer = Trainer
+
+    @classmethod
+    def add_task_arguments(cls, parser: argparse.ArgumentParser):
+        group = parser.add_argument_group(description="Task related")
+
+        # NOTE(kamo): add_arguments(..., required=True) can't be used
+        # to provide --print_config mode. Instead of it, do as
+        # required = parser.get_default("required")
+        # required += ["token_list"]
+
+        group.add_argument(
+            "--token_list",
+            type=str_or_none,
+            default=None,
+            help="A text mapping int-id to token",
+        )
+        group.add_argument(
+            "--split_with_space",
+            type=str2bool,
+            default=True,
+            help="whether to split text using <space>",
+        )
+        group.add_argument(
+            "--seg_dict_file",
+            type=str,
+            default=None,
+            help="seg_dict_file for text processing",
+        )
+        group.add_argument(
+            "--init",
+            type=lambda x: str_or_none(x.lower()),
+            default=None,
+            help="The initialization method",
+            choices=[
+                "chainer",
+                "xavier_uniform",
+                "xavier_normal",
+                "kaiming_uniform",
+                "kaiming_normal",
+                None,
+            ],
+        )
+
+        group.add_argument(
+            "--input_size",
+            type=int_or_none,
+            default=None,
+            help="The number of input dimension of the feature",
+        )
+
+        group = parser.add_argument_group(description="Preprocess related")
+        group.add_argument(
+            "--use_preprocessor",
+            type=str2bool,
+            default=True,
+            help="Apply preprocessing to data or not",
+        )
+        group.add_argument(
+            "--token_type",
+            type=str,
+            default="char",
+            choices=["char"],
+            help="The text will be tokenized in the specified level token",
+        )
+        parser.add_argument(
+            "--speech_volume_normalize",
+            type=float_or_none,
+            default=None,
+            help="Scale the maximum amplitude to the given value.",
+        )
+        parser.add_argument(
+            "--rir_scp",
+            type=str_or_none,
+            default=None,
+            help="The file path of rir scp file.",
+        )
+        parser.add_argument(
+            "--rir_apply_prob",
+            type=float,
+            default=1.0,
+            help="THe probability for applying RIR convolution.",
+        )
+        parser.add_argument(
+            "--cmvn_file",
+            type=str_or_none,
+            default=None,
+            help="The file path of noise scp file.",
+        )
+        parser.add_argument(
+            "--noise_scp",
+            type=str_or_none,
+            default=None,
+            help="The file path of noise scp file.",
+        )
+        parser.add_argument(
+            "--noise_apply_prob",
+            type=float,
+            default=1.0,
+            help="The probability applying Noise adding.",
+        )
+        parser.add_argument(
+            "--noise_db_range",
+            type=str,
+            default="13_15",
+            help="The range of noise decibel level.",
+        )
+
+        for class_choices in cls.class_choices_list:
+            # Append --<name> and --<name>_conf.
+            # e.g. --encoder and --encoder_conf
+            class_choices.add_arguments(group)
+
+    @classmethod
+    def build_collate_fn(
+            cls, args: argparse.Namespace, train: bool
+    ) -> Callable[
+        [Collection[Tuple[str, Dict[str, np.ndarray]]]],
+        Tuple[List[str], Dict[str, torch.Tensor]],
+    ]:
+        assert check_argument_types()
+        # NOTE(kamo): int value = 0 is reserved by CTC-blank symbol
+        return CommonCollateFn(float_pad_value=0.0, int_pad_value=-1)
+
+    @classmethod
+    def build_preprocess_fn(
+            cls, args: argparse.Namespace, train: bool
+    ) -> Optional[Callable[[str, Dict[str, np.array]], Dict[str, np.ndarray]]]:
+        assert check_argument_types()
+        if args.use_preprocessor:
+            retval = CommonPreprocessor(
+                train=train,
+                token_type=args.token_type,
+                token_list=args.token_list,
+                bpemodel=None,
+                non_linguistic_symbols=None,
+                text_cleaner=None,
+                g2p_type=None,
+                split_with_space=args.split_with_space if hasattr(args, "split_with_space") else False,
+                seg_dict_file=args.seg_dict_file if hasattr(args, "seg_dict_file") else None,
+                # NOTE(kamo): Check attribute existence for backward compatibility
+                rir_scp=args.rir_scp if hasattr(args, "rir_scp") else None,
+                rir_apply_prob=args.rir_apply_prob
+                if hasattr(args, "rir_apply_prob")
+                else 1.0,
+                noise_scp=args.noise_scp if hasattr(args, "noise_scp") else None,
+                noise_apply_prob=args.noise_apply_prob
+                if hasattr(args, "noise_apply_prob")
+                else 1.0,
+                noise_db_range=args.noise_db_range
+                if hasattr(args, "noise_db_range")
+                else "13_15",
+                speech_volume_normalize=args.speech_volume_normalize
+                if hasattr(args, "rir_scp")
+                else None,
+            )
+        else:
+            retval = None
+        assert check_return_type(retval)
+        return retval
+
+    @classmethod
+    def required_data_names(
+            cls, train: bool = True, inference: bool = False
+    ) -> Tuple[str, ...]:
+        if not inference:
+            retval = ("speech", "profile", "binary_labels")
+        else:
+            # Recognition mode
+            retval = ("speech")
+        return retval
+
+    @classmethod
+    def optional_data_names(
+            cls, train: bool = True, inference: bool = False
+    ) -> Tuple[str, ...]:
+        retval = ()
+        assert check_return_type(retval)
+        return retval
+
+    @classmethod
+    def build_model(cls, args: argparse.Namespace):
+        assert check_argument_types()
+
+        # 1. frontend
+        if args.input_size is None or args.frontend == "wav_frontend_mel23":
+            # Extract features in the model
+            frontend_class = frontend_choices.get_class(args.frontend)
+            if args.frontend == 'wav_frontend':
+                frontend = frontend_class(cmvn_file=args.cmvn_file, **args.frontend_conf)
+            else:
+                frontend = frontend_class(**args.frontend_conf)
+            input_size = frontend.output_size()
+        else:
+            # Give features from data-loader
+            args.frontend = None
+            args.frontend_conf = {}
+            frontend = None
+            input_size = args.input_size
+
+        # 2. Encoder
+        encoder_class = encoder_choices.get_class(args.encoder)
+        encoder = encoder_class(input_size=input_size, **args.encoder_conf)
+
+        # 3. EncoderDecoderAttractor
+        encoder_decoder_attractor_class = encoder_decoder_attractor_choices.get_class(args.encoder_decoder_attractor)
+        encoder_decoder_attractor = encoder_decoder_attractor_class(**args.encoder_decoder_attractor_conf)
+
+        # 9. Build model
+        model_class = model_choices.get_class(args.model)
+        model = model_class(
+            frontend=frontend,
+            encoder=encoder,
+            encoder_decoder_attractor=encoder_decoder_attractor,
+            **args.model_conf,
+        )
+
+        # 10. Initialize
+        if args.init is not None:
+            initialize(model, args.init)
+
+        assert check_return_type(model)
+        return model
+
+    # ~~~~~~~~~ The methods below are mainly used for inference ~~~~~~~~~
+    @classmethod
+    def build_model_from_file(
+            cls,
+            config_file: Union[Path, str] = None,
+            model_file: Union[Path, str] = None,
+            cmvn_file: Union[Path, str] = None,
+            device: str = "cpu",
+    ):
+        """Build model from the files.
+
+        This method is used for inference or fine-tuning.
+
+        Args:
+            config_file: The yaml file saved when training.
+            model_file: The model file saved when training.
+            cmvn_file: The cmvn file for front-end
+            device: Device type, "cpu", "cuda", or "cuda:N".
+
+        """
+        assert check_argument_types()
+        if config_file is None:
+            assert model_file is not None, (
+                "The argument 'model_file' must be provided "
+                "if the argument 'config_file' is not specified."
+            )
+            config_file = Path(model_file).parent / "config.yaml"
+        else:
+            config_file = Path(config_file)
+
+        with config_file.open("r", encoding="utf-8") as f:
+            args = yaml.safe_load(f)
+        args = argparse.Namespace(**args)
+        model = cls.build_model(args)
+        if not isinstance(model, AbsESPnetModel):
+            raise RuntimeError(
+                f"model must inherit {AbsESPnetModel.__name__}, but got {type(model)}"
+            )
+        if model_file is not None:
+            if device == "cuda":
+                device = f"cuda:{torch.cuda.current_device()}"
+            checkpoint = torch.load(model_file, map_location=device)
+            if "state_dict" in checkpoint.keys():
+                model.load_state_dict(checkpoint["state_dict"])
+            else:
+                model.load_state_dict(checkpoint)
+        model.to(device)
+        return model, args