Parcourir la source

bug fix (#667)

Co-authored-by: aky15 <ankeyu.aky@11.17.44.249>
aky15 il y a 2 ans
Parent
commit
cdf117b974

+ 2 - 2
egs/aishell/rnnt/conf/train_conformer_rnnt_unified.yaml

@@ -6,7 +6,7 @@ encoder_conf:
       unified_model_training: true
       default_chunk_size: 16
       jitter_range: 4
-      left_chunk_size: 0
+      left_chunk_size: 1
       embed_vgg_like: false
       subsampling_factor: 4
       linear_units: 2048
@@ -51,7 +51,7 @@ use_amp: true
 # optimization related
 accum_grad: 1
 grad_clip: 5
-max_epoch: 200
+max_epoch: 120
 val_scheduler_criterion:
     - valid
     - loss

+ 69 - 1
funasr/fileio/sound_scp.py

@@ -1,6 +1,6 @@
 import collections.abc
 from pathlib import Path
-from typing import Union
+from typing import List, Tuple, Union
 
 import random
 import numpy as np
@@ -13,6 +13,74 @@ import torchaudio
 
 from funasr.fileio.read_text import read_2column_text
 
+def soundfile_read(
+    wavs: Union[str, List[str]],
+    dtype=None,
+    always_2d: bool = False,
+    concat_axis: int = 1,
+    start: int = 0,
+    end: int = None,
+    return_subtype: bool = False,
+) -> Tuple[np.array, int]:
+    if isinstance(wavs, str):
+        wavs = [wavs]
+
+    arrays = []
+    subtypes = []
+    prev_rate = None
+    prev_wav = None
+    for wav in wavs:
+        with soundfile.SoundFile(wav) as f:
+            f.seek(start)
+            if end is not None:
+                frames = end - start
+            else:
+                frames = -1
+            if dtype == "float16":
+                array = f.read(
+                    frames,
+                    dtype="float32",
+                    always_2d=always_2d,
+                ).astype(dtype)
+            else:
+                array = f.read(frames, dtype=dtype, always_2d=always_2d)
+            rate = f.samplerate
+            subtype = f.subtype
+            subtypes.append(subtype)
+
+        if len(wavs) > 1 and array.ndim == 1 and concat_axis == 1:
+            # array: (Time, Channel)
+            array = array[:, None]
+
+        if prev_wav is not None:
+            if prev_rate != rate:
+                raise RuntimeError(
+                    f"'{prev_wav}' and '{wav}' have mismatched sampling rate: "
+                    f"{prev_rate} != {rate}"
+                )
+
+            dim1 = arrays[0].shape[1 - concat_axis]
+            dim2 = array.shape[1 - concat_axis]
+            if dim1 != dim2:
+                raise RuntimeError(
+                    "Shapes must match with "
+                    f"{1 - concat_axis} axis, but gut {dim1} and {dim2}"
+                )
+
+        prev_rate = rate
+        prev_wav = wav
+        arrays.append(array)
+
+    if len(arrays) == 1:
+        array = arrays[0]
+    else:
+        array = np.concatenate(arrays, axis=concat_axis)
+
+    if return_subtype:
+        return array, rate, subtypes
+    else:
+        return array, rate
+
 
 class SoundScpReader(collections.abc.Mapping):
     """Reader class for 'wav.scp'.

+ 50 - 5
funasr/models/encoder/conformer_encoder.py

@@ -1081,7 +1081,10 @@ class ConformerChunkEncoder(AbsEncoder):
         mask = make_source_mask(x_len).to(x.device)
 
         if self.unified_model_training:
-            chunk_size = self.default_chunk_size + torch.randint(-self.jitter_range, self.jitter_range+1, (1,)).item()
+            if self.training:
+                chunk_size = self.default_chunk_size + torch.randint(-self.jitter_range, self.jitter_range+1, (1,)).item()
+            else:
+                chunk_size = self.default_chunk_size
             x, mask = self.embed(x, mask, chunk_size)
             pos_enc = self.pos_enc(x)
             chunk_mask = make_chunk_mask(
@@ -1113,12 +1116,15 @@ class ConformerChunkEncoder(AbsEncoder):
 
         elif self.dynamic_chunk_training:
             max_len = x.size(1)
-            chunk_size = torch.randint(1, max_len, (1,)).item()
+            if self.training:
+                chunk_size = torch.randint(1, max_len, (1,)).item()
 
-            if chunk_size > (max_len * self.short_chunk_threshold):
-                chunk_size = max_len
+                if chunk_size > (max_len * self.short_chunk_threshold):
+                    chunk_size = max_len
+                else:
+                    chunk_size = (chunk_size % self.short_chunk_size) + 1
             else:
-                chunk_size = (chunk_size % self.short_chunk_size) + 1
+                chunk_size = self.default_chunk_size
 
             x, mask = self.embed(x, mask, chunk_size)
             pos_enc = self.pos_enc(x)
@@ -1147,6 +1153,45 @@ class ConformerChunkEncoder(AbsEncoder):
 
         return x, olens, None
 
+    def full_utt_forward(
+        self,
+        x: torch.Tensor,
+        x_len: torch.Tensor,
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        """Encode input sequences.
+        Args:
+            x: Encoder input features. (B, T_in, F)
+            x_len: Encoder input features lengths. (B,)
+        Returns:
+           x: Encoder outputs. (B, T_out, D_enc)
+           x_len: Encoder outputs lenghts. (B,)
+        """
+        short_status, limit_size = check_short_utt(
+            self.embed.subsampling_factor, x.size(1)
+        )
+
+        if short_status:
+            raise TooShortUttError(
+                f"has {x.size(1)} frames and is too short for subsampling "
+                + f"(it needs more than {limit_size} frames), return empty results",
+                x.size(1),
+                limit_size,
+            )
+
+        mask = make_source_mask(x_len).to(x.device)
+        x, mask = self.embed(x, mask, None)
+        pos_enc = self.pos_enc(x)
+        x_utt = self.encoders(
+            x,
+            pos_enc,
+            mask,
+            chunk_mask=None,
+        )
+
+        if self.time_reduction_factor > 1:
+            x_utt = x_utt[:,::self.time_reduction_factor,:]
+        return x_utt
+
     def simu_chunk_forward(
         self,
         x: torch.Tensor,

+ 11 - 10
funasr/modules/subsampling.py

@@ -427,6 +427,7 @@ class StreamingConvInput(torch.nn.Module):
         conv_size: Union[int, Tuple],
         subsampling_factor: int = 4,
         vgg_like: bool = True,
+        conv_kernel_size: int = 3,
         output_size: Optional[int] = None,
     ) -> None:
         """Construct a ConvInput object."""
@@ -436,14 +437,14 @@ class StreamingConvInput(torch.nn.Module):
                 conv_size1, conv_size2 = conv_size
 
                 self.conv = torch.nn.Sequential(
-                    torch.nn.Conv2d(1, conv_size1, 3, stride=1, padding=1),
+                    torch.nn.Conv2d(1, conv_size1, conv_kernel_size, stride=1, padding=(conv_kernel_size-1)//2),
                     torch.nn.ReLU(),
-                    torch.nn.Conv2d(conv_size1, conv_size1, 3, stride=1, padding=1),
+                    torch.nn.Conv2d(conv_size1, conv_size1, conv_kernel_size, stride=1, padding=(conv_kernel_size-1)//2),
                     torch.nn.ReLU(),
                     torch.nn.MaxPool2d((1, 2)),
-                    torch.nn.Conv2d(conv_size1, conv_size2, 3, stride=1, padding=1),
+                    torch.nn.Conv2d(conv_size1, conv_size2, conv_kernel_size, stride=1, padding=(conv_kernel_size-1)//2),
                     torch.nn.ReLU(),
-                    torch.nn.Conv2d(conv_size2, conv_size2, 3, stride=1, padding=1),
+                    torch.nn.Conv2d(conv_size2, conv_size2, conv_kernel_size, stride=1, padding=(conv_kernel_size-1)//2),
                     torch.nn.ReLU(),
                     torch.nn.MaxPool2d((1, 2)),
                 )
@@ -462,14 +463,14 @@ class StreamingConvInput(torch.nn.Module):
                 kernel_1 = int(subsampling_factor / 2)
 
                 self.conv = torch.nn.Sequential(
-                    torch.nn.Conv2d(1, conv_size1, 3, stride=1, padding=1),
+                    torch.nn.Conv2d(1, conv_size1, conv_kernel_size, stride=1, padding=(conv_kernel_size-1)//2),
                     torch.nn.ReLU(),
-                    torch.nn.Conv2d(conv_size1, conv_size1, 3, stride=1, padding=1),
+                    torch.nn.Conv2d(conv_size1, conv_size1, conv_kernel_size, stride=1, padding=(conv_kernel_size-1)//2),
                     torch.nn.ReLU(),
                     torch.nn.MaxPool2d((kernel_1, 2)),
-                    torch.nn.Conv2d(conv_size1, conv_size2, 3, stride=1, padding=1),
+                    torch.nn.Conv2d(conv_size1, conv_size2, conv_kernel_size, stride=1, padding=(conv_kernel_size-1)//2),
                     torch.nn.ReLU(),
-                    torch.nn.Conv2d(conv_size2, conv_size2, 3, stride=1, padding=1),
+                    torch.nn.Conv2d(conv_size2, conv_size2, conv_kernel_size, stride=1, padding=(conv_kernel_size-1)//2),
                     torch.nn.ReLU(),
                     torch.nn.MaxPool2d((2, 2)),
                 )
@@ -487,14 +488,14 @@ class StreamingConvInput(torch.nn.Module):
                 self.conv = torch.nn.Sequential(
                     torch.nn.Conv2d(1, conv_size, 3, [1,2], [1,0]),
                     torch.nn.ReLU(),
-                    torch.nn.Conv2d(conv_size, conv_size, 3, [1,2], [1,0]),
+                    torch.nn.Conv2d(conv_size, conv_size, conv_kernel_size, [1,2], [1,0]),
                     torch.nn.ReLU(),
                 )
 
                 output_proj = conv_size * (((input_size - 1) // 2 - 1) // 2)
 
                 self.subsampling_factor = subsampling_factor
-                self.kernel_2 = 3
+                self.kernel_2 = conv_kernel_size
                 self.stride_2 = 1
 
                 self.create_new_mask = self.create_new_conv2d_mask