speech_asr пре 2 година
родитељ
комит
d5a80d642a

+ 0 - 21
funasr/models/encoder/abs_encoder.py

@@ -1,21 +0,0 @@
-from abc import ABC
-from abc import abstractmethod
-from typing import Optional
-from typing import Tuple
-
-import torch
-
-
-class AbsEncoder(torch.nn.Module, ABC):
-    @abstractmethod
-    def output_size(self) -> int:
-        raise NotImplementedError
-
-    @abstractmethod
-    def forward(
-        self,
-        xs_pad: torch.Tensor,
-        ilens: torch.Tensor,
-        prev_states: torch.Tensor = None,
-    ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
-        raise NotImplementedError

+ 1 - 2
funasr/models/encoder/conformer_encoder.py

@@ -14,7 +14,6 @@ from torch import nn
 from typeguard import check_argument_types
 from typeguard import check_argument_types
 
 
 from funasr.models.ctc import CTC
 from funasr.models.ctc import CTC
-from funasr.models.encoder.abs_encoder import AbsEncoder
 from funasr.modules.attention import (
 from funasr.modules.attention import (
     MultiHeadedAttention,  # noqa: H301
     MultiHeadedAttention,  # noqa: H301
     RelPositionMultiHeadedAttention,  # noqa: H301
     RelPositionMultiHeadedAttention,  # noqa: H301
@@ -277,7 +276,7 @@ class EncoderLayer(nn.Module):
         return x, mask
         return x, mask
 
 
 
 
-class ConformerEncoder(AbsEncoder):
+class ConformerEncoder(torch.nn.Module):
     """Conformer encoder module.
     """Conformer encoder module.
 
 
     Args:
     Args:

+ 1 - 2
funasr/models/encoder/data2vec_encoder.py

@@ -12,7 +12,6 @@ import torch.nn as nn
 import torch.nn.functional as F
 import torch.nn.functional as F
 from typeguard import check_argument_types
 from typeguard import check_argument_types
 
 
-from funasr.models.encoder.abs_encoder import AbsEncoder
 from funasr.modules.data2vec.data_utils import compute_mask_indices
 from funasr.modules.data2vec.data_utils import compute_mask_indices
 from funasr.modules.data2vec.ema_module import EMAModule
 from funasr.modules.data2vec.ema_module import EMAModule
 from funasr.modules.data2vec.grad_multiply import GradMultiply
 from funasr.modules.data2vec.grad_multiply import GradMultiply
@@ -29,7 +28,7 @@ def get_annealed_rate(start, end, curr_step, total_steps):
     return end - r * pct_remaining
     return end - r * pct_remaining
 
 
 
 
-class Data2VecEncoder(AbsEncoder):
+class Data2VecEncoder(torch.nn.Module):
     def __init__(
     def __init__(
             self,
             self,
             # for ConvFeatureExtractionModel
             # for ConvFeatureExtractionModel

+ 1 - 3
funasr/models/encoder/mfcca_encoder.py

@@ -34,8 +34,6 @@ from funasr.modules.subsampling import Conv2dSubsampling6
 from funasr.modules.subsampling import Conv2dSubsampling8
 from funasr.modules.subsampling import Conv2dSubsampling8
 from funasr.modules.subsampling import TooShortUttError
 from funasr.modules.subsampling import TooShortUttError
 from funasr.modules.subsampling import check_short_utt
 from funasr.modules.subsampling import check_short_utt
-from funasr.models.encoder.abs_encoder import AbsEncoder
-import pdb
 import math
 import math
 
 
 class ConvolutionModule(nn.Module):
 class ConvolutionModule(nn.Module):
@@ -108,7 +106,7 @@ class ConvolutionModule(nn.Module):
 
 
 
 
 
 
-class MFCCAEncoder(AbsEncoder):
+class MFCCAEncoder(torch.nn.Module):
     """Conformer encoder module.
     """Conformer encoder module.
 
 
     Args:
     Args:

+ 1 - 2
funasr/models/encoder/resnet34_encoder.py

@@ -1,6 +1,5 @@
 import torch
 import torch
 from torch.nn import functional as F
 from torch.nn import functional as F
-from funasr.models.encoder.abs_encoder import AbsEncoder
 from typing import Tuple, Optional
 from typing import Tuple, Optional
 from funasr.models.pooling.statistic_pooling import statistic_pooling, windowed_statistic_pooling
 from funasr.models.pooling.statistic_pooling import statistic_pooling, windowed_statistic_pooling
 from collections import OrderedDict
 from collections import OrderedDict
@@ -76,7 +75,7 @@ class BasicBlock(torch.nn.Module):
         return xs_pad, ilens
         return xs_pad, ilens
 
 
 
 
-class ResNet34(AbsEncoder):
+class ResNet34(torch.nn.Module):
     def __init__(
     def __init__(
             self,
             self,
             input_size,
             input_size,

+ 1 - 2
funasr/models/encoder/rnn_encoder.py

@@ -9,10 +9,9 @@ from typeguard import check_argument_types
 from funasr.modules.nets_utils import make_pad_mask
 from funasr.modules.nets_utils import make_pad_mask
 from funasr.modules.rnn.encoders import RNN
 from funasr.modules.rnn.encoders import RNN
 from funasr.modules.rnn.encoders import RNNP
 from funasr.modules.rnn.encoders import RNNP
-from funasr.models.encoder.abs_encoder import AbsEncoder
 
 
 
 
-class RNNEncoder(AbsEncoder):
+class RNNEncoder(torch.nn.Module):
     """RNNEncoder class.
     """RNNEncoder class.
 
 
     Args:
     Args:

+ 3 - 4
funasr/models/encoder/sanm_encoder.py

@@ -26,7 +26,6 @@ from funasr.modules.subsampling import Conv2dSubsampling8
 from funasr.modules.subsampling import TooShortUttError
 from funasr.modules.subsampling import TooShortUttError
 from funasr.modules.subsampling import check_short_utt
 from funasr.modules.subsampling import check_short_utt
 from funasr.models.ctc import CTC
 from funasr.models.ctc import CTC
-from funasr.models.encoder.abs_encoder import AbsEncoder
 from funasr.modules.mask import subsequent_mask, vad_mask
 from funasr.modules.mask import subsequent_mask, vad_mask
 
 
 class EncoderLayerSANM(nn.Module):
 class EncoderLayerSANM(nn.Module):
@@ -115,7 +114,7 @@ class EncoderLayerSANM(nn.Module):
 
 
         return x, mask, cache, mask_shfit_chunk, mask_att_chunk_encoder
         return x, mask, cache, mask_shfit_chunk, mask_att_chunk_encoder
 
 
-class SANMEncoder(AbsEncoder):
+class SANMEncoder(torch.nn.Module):
     """
     """
     author: Speech Lab, Alibaba Group, China
     author: Speech Lab, Alibaba Group, China
     San-m: Memory equipped self-attention for end-to-end speech recognition
     San-m: Memory equipped self-attention for end-to-end speech recognition
@@ -547,7 +546,7 @@ class SANMEncoder(AbsEncoder):
         return var_dict_torch_update
         return var_dict_torch_update
 
 
 
 
-class SANMEncoderChunkOpt(AbsEncoder):
+class SANMEncoderChunkOpt(torch.nn.Module):
     """
     """
     author: Speech Lab, Alibaba Group, China
     author: Speech Lab, Alibaba Group, China
     SCAMA: Streaming chunk-aware multihead attention for online end-to-end speech recognition
     SCAMA: Streaming chunk-aware multihead attention for online end-to-end speech recognition
@@ -960,7 +959,7 @@ class SANMEncoderChunkOpt(AbsEncoder):
         return var_dict_torch_update
         return var_dict_torch_update
 
 
 
 
-class SANMVadEncoder(AbsEncoder):
+class SANMVadEncoder(torch.nn.Module):
     """
     """
     author: Speech Lab, Alibaba Group, China
     author: Speech Lab, Alibaba Group, China
 
 

+ 1 - 2
funasr/models/encoder/transformer_encoder.py

@@ -13,7 +13,6 @@ from typeguard import check_argument_types
 import logging
 import logging
 
 
 from funasr.models.ctc import CTC
 from funasr.models.ctc import CTC
-from funasr.models.encoder.abs_encoder import AbsEncoder
 from funasr.modules.attention import MultiHeadedAttention
 from funasr.modules.attention import MultiHeadedAttention
 from funasr.modules.embedding import PositionalEncoding
 from funasr.modules.embedding import PositionalEncoding
 from funasr.modules.layer_norm import LayerNorm
 from funasr.modules.layer_norm import LayerNorm
@@ -144,7 +143,7 @@ class EncoderLayer(nn.Module):
         return x, mask
         return x, mask
 
 
 
 
-class TransformerEncoder(AbsEncoder):
+class TransformerEncoder(torch.nn.Module):
     """Transformer encoder module.
     """Transformer encoder module.
 
 
     Args:
     Args:

+ 0 - 55
funasr/train/abs_espnet_model.py

@@ -1,55 +0,0 @@
-# Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved.
-#  Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)
-
-from abc import ABC
-from abc import abstractmethod
-from typing import Dict
-from typing import Tuple
-
-import torch
-
-
-class AbsESPnetModel(torch.nn.Module, ABC):
-    """The common abstract class among each tasks
-
-    "ESPnetModel" is referred to a class which inherits torch.nn.Module,
-    and makes the dnn-models forward as its member field,
-    a.k.a delegate pattern,
-    and defines "loss", "stats", and "weight" for the task.
-
-    If you intend to implement new task in ESPNet,
-    the model must inherit this class.
-    In other words, the "mediator" objects between
-    our training system and the your task class are
-    just only these three values, loss, stats, and weight.
-
-    Example:
-        >>> from funasr.tasks.abs_task import AbsTask
-        >>> class YourESPnetModel(AbsESPnetModel):
-        ...     def forward(self, input, input_lengths):
-        ...         ...
-        ...         return loss, stats, weight
-        >>> class YourTask(AbsTask):
-        ...     @classmethod
-        ...     def build_model(cls, args: argparse.Namespace) -> YourESPnetModel:
-    """
-
-    def __init__(self):
-        super().__init__()
-        self.num_updates = 0
-
-    @abstractmethod
-    def forward(
-        self, **batch: torch.Tensor
-    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
-        raise NotImplementedError
-
-    @abstractmethod
-    def collect_feats(self, **batch: torch.Tensor) -> Dict[str, torch.Tensor]:
-        raise NotImplementedError
-
-    def set_num_updates(self, num_updates):
-        self.num_updates = num_updates
-
-    def get_num_updates(self):
-        return self.num_updates