소스 검색

rename ArkDataloader

aky15 2 년 전
부모
커밋
35da78d63a
1개의 변경된 파일4개의 추가작업 그리고 19개의 파일을 삭제
  1. 4 19
      funasr/tasks/abs_task.py

+ 4 - 19
funasr/tasks/abs_task.py

@@ -1376,25 +1376,10 @@ class AbsTask(ABC):
 
             # 7. Build iterator factories
             if args.dataset_type == "large":
-                from funasr.datasets.large_datasets.build_dataloader import ArkDataLoader
-                train_iter_factory = ArkDataLoader(args.train_data_file, args.token_list, args.dataset_conf,
-                                                   frontend_conf=args.frontend_conf if hasattr(args,
-                                                                                               "frontend_conf") else None,
-                                                   seg_dict_file=args.seg_dict_file if hasattr(args,
-                                                                                               "seg_dict_file") else None,
-                                                   punc_dict_file=args.punc_list if hasattr(args,
-                                                                                            "punc_list") else None,
-                                                   bpemodel_file=args.bpemodel if hasattr(args, "bpemodel") else None,
-                                                   mode="train")
-                valid_iter_factory = ArkDataLoader(args.valid_data_file, args.token_list, args.dataset_conf,
-                                                   frontend_conf=args.frontend_conf if hasattr(args,
-                                                                                               "frontend_conf") else None,
-                                                   seg_dict_file=args.seg_dict_file if hasattr(args,
-                                                                                               "seg_dict_file") else None,
-                                                   punc_dict_file=args.punc_list if hasattr(args,
-                                                                                            "punc_list") else None,
-                                                   bpemodel_file=args.bpemodel if hasattr(args, "bpemodel") else None,
-                                                   mode="eval")
+                from funasr.datasets.large_datasets.build_dataloader import LargeDataLoader
+                train_iter_factory = LargeDataLoader(args, mode="train")
+                valid_iter_factory = LargeDataLoader(args, mode="eval")
+
             elif args.dataset_type == "small":
                 train_iter_factory = cls.build_iter_factory(
                     args=args,