浏览代码

Dev aky2 (#549)

* support resume model from pai

* add padding for streaming rnnt conv input

* fix large dataset training bug

* bug fix

---------

Co-authored-by: aky15 <ankeyu.aky@11.17.44.249>
aky15 2 年之前
父节点
当前提交
8c24f52fd2

+ 2 - 2
funasr/datasets/large_datasets/build_dataloader.py

@@ -73,8 +73,8 @@ class LargeDataLoader(AbsIterFactory):
             seg_dict = load_seg_dict(args.seg_dict_file)
         if hasattr(args, "punc_dict_file") and args.punc_dict_file is not None:
             punc_dict = read_symbol_table(args.punc_dict_file)
-        if hasattr(args, "bpemodel_file") and args.bpemodel_file is not None:
-            bpe_tokenizer = SentencepiecesTokenizer(args.bpemodel_file)
+        if hasattr(args, "bpemodel") and args.bpemodel is not None:
+            bpe_tokenizer = SentencepiecesTokenizer(args.bpemodel)
         self.dataset_conf = args.dataset_conf
         self.frontend_conf = args.frontend_conf
         logging.info("dataloader config: {}".format(self.dataset_conf))

+ 1 - 3
funasr/datasets/large_datasets/utils/tokenize.py

@@ -46,10 +46,8 @@ def tokenize(data,
     text = data["text"]
     token = []
     vad = -2
-
     if bpe_tokenizer is not None:
-        text = bpe_tokenizer.text2tokens(text)
-
+        text = bpe_tokenizer.text2tokens(" ".join(text))
     if seg_dict is not None:
         assert isinstance(seg_dict, dict)
         text = seg_tokenize(text, seg_dict)

+ 3 - 3
funasr/modules/subsampling.py

@@ -506,9 +506,9 @@ class StreamingConvInput(torch.nn.Module):
                 )
 
                 self.conv = torch.nn.Sequential(
-                    torch.nn.Conv2d(1, conv_size, 3, 2),
+                    torch.nn.Conv2d(1, conv_size, 3, 2, [1,0]),
                     torch.nn.ReLU(),
-                    torch.nn.Conv2d(conv_size, conv_size, kernel_2, stride_2),
+                    torch.nn.Conv2d(conv_size, conv_size, kernel_2, stride_2, [(kernel_2-1)//2, 0]),
                     torch.nn.ReLU(),
                 )
 
@@ -597,7 +597,7 @@ class StreamingConvInput(torch.nn.Module):
             mask: Mask of output sequences. (B, sub(T))
         """
         if self.subsampling_factor > 1:
-            return mask[:, :-2:2][:, : -(self.kernel_2 - 1) : self.stride_2]
+            return mask[:, ::2][:, ::self.stride_2]
         else:
             return mask