游雁 пре 3 година
родитељ
комит
8a788ad0d9

+ 1 - 0
funasr/export/export_model.py

@@ -193,6 +193,7 @@ class ModelExport:
             model, vad_infer_args = VADTask.build_model_from_file(
                 config, model_file, 'cpu'
             )
+            self.export_config["feats_dim"] = 400
         self._export(model, tag_name)
             
 

+ 2 - 2
funasr/export/models/e2e_vad.py

@@ -11,7 +11,7 @@ from funasr.export.models.encoder.fsmn_encoder import FSMN as FSMN_export
 class E2EVadModel(nn.Module):
     def __init__(self, model,
                 max_seq_len=512,
-                feats_dim=560,
+                feats_dim=400,
                 model_name='model',
                 **kwargs,):
         super(E2EVadModel, self).__init__()
@@ -31,7 +31,7 @@ class E2EVadModel(nn.Module):
                        in_cache3: torch.Tensor,
                        ):
 
-        scores, cache0, cache1, cache2, cache3 = self.encoder(feats,
+        scores, (cache0, cache1, cache2, cache3) = self.encoder(feats,
                                                               in_cache0,
                                                               in_cache1,
                                                               in_cache2,

+ 6 - 7
funasr/export/models/encoder/fsmn_encoder.py

@@ -149,8 +149,7 @@ fsmn_layers:            no. of sequential fsmn layers
 
 class FSMN(nn.Module):
     def __init__(
-            self,
-        model,
+            self, model,
     ):
         super(FSMN, self).__init__()
         
@@ -177,10 +176,10 @@ class FSMN(nn.Module):
         self.out_linear1 = model.out_linear1
         self.out_linear2 = model.out_linear2
         self.softmax = model.softmax
-
-        for i, d in enumerate(self.model.fsmn):
+        self.fsmn = model.fsmn
+        for i, d in enumerate(model.fsmn):
             if isinstance(d, BasicBlock):
-                self.model.fsmn[i] = BasicBlock_export(d)
+                self.fsmn[i] = BasicBlock_export(d)
 
     def fuse_modules(self):
         pass
@@ -202,7 +201,7 @@ class FSMN(nn.Module):
         x = self.relu(x)
         # x4 = self.fsmn(x3, in_cache)  # self.in_cache will update automatically in self.fsmn
         out_caches = list()
-        for i, d in enumerate(self.model.fsmn):
+        for i, d in enumerate(self.fsmn):
             in_cache = args[i]
             x, out_cache = d(x, in_cache)
             out_caches.append(out_cache)
@@ -210,7 +209,7 @@ class FSMN(nn.Module):
         x = self.out_linear2(x)
         x = self.softmax(x)
 
-        return x, *out_caches
+        return x, out_caches
 
 
 '''