aky15 3 лет назад
Родитель
Сommit
fc95956258

+ 8 - 7
funasr/models_transducer/encoder/blocks/conv_input.py

@@ -146,30 +146,31 @@ class ConvInput(torch.nn.Module):
         if mask is not None:
             mask = self.create_new_mask(mask)
             olens = max(mask.eq(0).sum(1))
-        
-        b, t_input, f = x.size()
+
+        b, t, f = x.size()
         x = x.unsqueeze(1) # (b. 1. t. f)
+
         if chunk_size is not None:
             max_input_length = int(
-                chunk_size * self.subsampling_factor * (math.ceil(float(t_input) / (chunk_size * self.subsampling_factor) ))
+                chunk_size * self.subsampling_factor * (math.ceil(float(t) / (chunk_size * self.subsampling_factor) ))
             )
             x = map(lambda inputs: pad_to_len(inputs, max_input_length, 1), x)
             x = list(x)
             x = torch.stack(x, dim=0)
             N_chunks = max_input_length // ( chunk_size * self.subsampling_factor)
             x = x.view(b * N_chunks, 1, chunk_size * self.subsampling_factor, f)
+
         x = self.conv(x)
 
-        _, c, t, f = x.size()
-        
+        _, c, _, f = x.size()
         if chunk_size is not None:
             x = x.transpose(1, 2).contiguous().view(b, -1, c * f)[:,:olens,:]
         else:
-            x = x.transpose(1, 2).contiguous().view(b, t, c * f)
+            x = x.transpose(1, 2).contiguous().view(b, -1, c * f)
 
         if self.output is not None:
             x = self.output(x)
-        
+
         return x, mask[:,:olens][:,:x.size(1)]
 
     def create_new_vgg_mask(self, mask: torch.Tensor) -> torch.Tensor:

+ 7 - 5
funasr/models_transducer/encoder/encoder.py

@@ -134,14 +134,11 @@ class Encoder(torch.nn.Module):
             )
 
         mask = make_source_mask(x_len)
-        if self.unified_model_training:
-            x, mask = self.embed(x, mask, self.default_chunk_size)
-        else:
-            x, mask = self.embed(x, mask)
-        pos_enc = self.pos_enc(x)
 
         if self.unified_model_training:
             chunk_size = self.default_chunk_size + torch.randint(-self.jitter_range, self.jitter_range+1, (1,)).item()
+            x, mask = self.embed(x, mask, chunk_size)
+            pos_enc = self.pos_enc(x)
             chunk_mask = make_chunk_mask(
                 x.size(1),
                 chunk_size,
@@ -178,6 +175,9 @@ class Encoder(torch.nn.Module):
             else:
                 chunk_size = (chunk_size % self.short_chunk_size) + 1
 
+            x, mask = self.embed(x, mask, chunk_size)
+            pos_enc = self.pos_enc(x)
+
             chunk_mask = make_chunk_mask(
                 x.size(1),
                 chunk_size,
@@ -185,6 +185,8 @@ class Encoder(torch.nn.Module):
                 device=x.device,
             )
         else:
+            x, mask = self.embed(x, mask, None)
+            pos_enc = self.pos_enc(x)
             chunk_mask = None
         x = self.encoders(
             x,

+ 1 - 2
funasr/models_transducer/espnet_transducer_model_unified.py

@@ -455,8 +455,7 @@ class ESPnetASRUnifiedTransducerModel(AbsESPnetModel):
                 gather=True,
         )
 
-        #if not self.training and (self.report_cer or self.report_wer):
-        if self.report_cer or self.report_wer:
+        if not self.training and (self.report_cer or self.report_wer):
             if self.error_calculator is None:
                 self.error_calculator = ErrorCalculator(
                     self.decoder,