aky15 2 年 前
コミット
77045e7bb7

+ 8 - 11
funasr/bin/asr_inference_rnnt.py

@@ -188,18 +188,15 @@ class Speech2Text:
         self.frontend = frontend
         self.window_size = self.chunk_size + self.right_context
         
-        self._ctx = self.asr_model.encoder.get_encoder_input_size(
-            self.window_size
-        )
+        if self.streaming:
+            self._ctx = self.asr_model.encoder.get_encoder_input_size(
+                self.window_size
+            )
        
-        #self.last_chunk_length = (
-        #    self.asr_model.encoder.embed.min_frame_length + self.right_context + 1
-        #) * self.hop_length
-
-        self.last_chunk_length = (
-            self.asr_model.encoder.embed.min_frame_length + self.right_context + 1
-        )
-        self.reset_inference_cache()
+            self.last_chunk_length = (
+                self.asr_model.encoder.embed.min_frame_length + self.right_context + 1
+            )
+            self.reset_inference_cache()
 
     def reset_inference_cache(self) -> None:
         """Reset Speech2Text parameters."""

+ 12 - 0
funasr/models/decoder/rnnt_decoder.py

@@ -33,6 +33,7 @@ class RNNTDecoder(torch.nn.Module):
         dropout_rate: float = 0.0,
         embed_dropout_rate: float = 0.0,
         embed_pad: int = 0,
+        use_embed_mask: bool = False,
     ) -> None:
         """Construct a RNNDecoder object."""
         super().__init__()
@@ -66,6 +67,15 @@ class RNNTDecoder(torch.nn.Module):
 
         self.device = next(self.parameters()).device
         self.score_cache = {}
+
+        self.use_embed_mask = use_embed_mask
+        if self.use_embed_mask:
+            self._embed_mask = SpecAug(
+                time_mask_width_range=3,
+                num_time_mask=4,
+                apply_freq_mask=False,
+                apply_time_warp=False
+            )
     
     def forward(
         self,
@@ -88,6 +98,8 @@ class RNNTDecoder(torch.nn.Module):
             states = self.init_state(labels.size(0))
 
         dec_embed = self.dropout_embed(self.embed(labels))
+        if self.use_embed_mask and self.training:
+            dec_embed = self._embed_mask(dec_embed, label_lens)[0]
         dec_out, states = self.rnn_forward(dec_embed, states)
         return dec_out
 

+ 4 - 4
funasr/models/e2e_asr_transducer.py

@@ -12,7 +12,7 @@ from funasr.models.frontend.abs_frontend import AbsFrontend
 from funasr.models.specaug.abs_specaug import AbsSpecAug
 from funasr.models.decoder.rnnt_decoder import RNNTDecoder
 from funasr.models.decoder.abs_decoder import AbsDecoder as AbsAttDecoder
-from funasr.models.encoder.conformer_encoder import ConformerChunkEncoder as Encoder
+from funasr.models.encoder.abs_encoder import AbsEncoder
 from funasr.models.joint_net.joint_network import JointNetwork
 from funasr.modules.nets_utils import get_transducer_task_io
 from funasr.layers.abs_normalize import AbsNormalize
@@ -62,7 +62,7 @@ class TransducerModel(AbsESPnetModel):
         frontend: Optional[AbsFrontend],
         specaug: Optional[AbsSpecAug],
         normalize: Optional[AbsNormalize],
-        encoder: Encoder,
+        encoder: AbsEncoder,
         decoder: RNNTDecoder,
         joint_network: JointNetwork,
         att_decoder: Optional[AbsAttDecoder] = None,
@@ -286,7 +286,7 @@ class TransducerModel(AbsESPnetModel):
                 feats, feats_lengths = self.normalize(feats, feats_lengths)
 
         # 4. Forward encoder
-        encoder_out, encoder_out_lens = self.encoder(feats, feats_lengths)
+        encoder_out, encoder_out_lens, _ = self.encoder(feats, feats_lengths)
 
         assert encoder_out.size(0) == speech.size(0), (
             encoder_out.size(),
@@ -515,7 +515,7 @@ class UnifiedTransducerModel(AbsESPnetModel):
         frontend: Optional[AbsFrontend],
         specaug: Optional[AbsSpecAug],
         normalize: Optional[AbsNormalize],
-        encoder: Encoder,
+        encoder: AbsEncoder,
         decoder: RNNTDecoder,
         joint_network: JointNetwork,
         att_decoder: Optional[AbsAttDecoder] = None,

+ 2 - 2
funasr/models/encoder/conformer_encoder.py

@@ -307,7 +307,7 @@ class ChunkEncoderLayer(torch.nn.Module):
         feed_forward: torch.nn.Module,
         feed_forward_macaron: torch.nn.Module,
         conv_mod: torch.nn.Module,
-        norm_class: torch.nn.Module = torch.nn.LayerNorm,
+        norm_class: torch.nn.Module = LayerNorm,
         norm_args: Dict = {},
         dropout_rate: float = 0.0,
     ) -> None:
@@ -1145,7 +1145,7 @@ class ConformerChunkEncoder(AbsEncoder):
             x = x[:,::self.time_reduction_factor,:]
             olens = torch.floor_divide(olens-1, self.time_reduction_factor) + 1
 
-        return x, olens
+        return x, olens, None
 
     def simu_chunk_forward(
         self,

+ 30 - 5
funasr/modules/nets_utils.py

@@ -485,14 +485,39 @@ def rename_state_dict(
         new_k = k.replace(old_prefix, new_prefix)
         state_dict[new_k] = v
 
-
 class Swish(torch.nn.Module):
-    """Construct an Swish object."""
+    """Swish activation definition.
+
+    Swish(x) = (beta * x) * sigmoid(x)
+                 where beta = 1 defines standard Swish activation.
+
+    References:
+        https://arxiv.org/abs/2108.12943 / https://arxiv.org/abs/1710.05941v1.
+        E-swish variant: https://arxiv.org/abs/1801.07145.
 
-    def forward(self, x):
-        """Return Swich activation function."""
-        return x * torch.sigmoid(x)
+    Args:
+        beta: Beta parameter for E-Swish.
+                (beta >= 1. If beta < 1, use standard Swish).
+        use_builtin: Whether to use PyTorch function if available.
+
+    """
+
+    def __init__(self, beta: float = 1.0, use_builtin: bool = False) -> None:
+        super().__init__()
+
+        self.beta = beta
+
+        if beta > 1:
+            self.swish = lambda x: (self.beta * x) * torch.sigmoid(x)
+        else:
+            if use_builtin:
+                self.swish = torch.nn.SiLU()
+            else:
+                self.swish = lambda x: x * torch.sigmoid(x)
 
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        """Forward computation."""
+        return self.swish(x)
 
 def get_activation(act):
     """Return activation function."""

+ 2 - 2
funasr/modules/repeat.py

@@ -7,7 +7,7 @@
 """Repeat the same layer definition."""
 
 from typing import Dict, List, Optional
-
+from funasr.modules.layer_norm import LayerNorm
 import torch
 
 
@@ -48,7 +48,7 @@ class MultiBlocks(torch.nn.Module):
         self,
         block_list: List[torch.nn.Module],
         output_size: int,
-        norm_class: torch.nn.Module = torch.nn.LayerNorm,
+        norm_class: torch.nn.Module = LayerNorm,
     ) -> None:
         """Construct a MultiBlocks object."""
         super().__init__()

+ 1 - 1
funasr/tasks/asr.py

@@ -1682,7 +1682,7 @@ class ASRTransducerTask(AbsTask):
 
         # 7. Build model
 
-        if encoder.unified_model_training:
+        if hasattr(encoder, 'unified_model_training') and encoder.unified_model_training:
             model = UnifiedTransducerModel(
                 vocab_size=vocab_size,
                 token_list=token_list,