|
|
@@ -13,8 +13,6 @@ from typing import Union
|
|
|
import numpy as np
|
|
|
import torch
|
|
|
import yaml
|
|
|
-from typeguard import check_argument_types
|
|
|
-from typeguard import check_return_type
|
|
|
|
|
|
from funasr.datasets.collate_fn import DiarCollateFn
|
|
|
from funasr.datasets.preprocessor import CommonPreprocessor
|
|
|
@@ -341,7 +339,6 @@ class DiarTask(AbsTask):
|
|
|
[Collection[Tuple[str, Dict[str, np.ndarray]]]],
|
|
|
Tuple[List[str], Dict[str, torch.Tensor]],
|
|
|
]:
|
|
|
- assert check_argument_types()
|
|
|
# NOTE(kamo): int value = 0 is reserved by CTC-blank symbol
|
|
|
return DiarCollateFn(float_pad_value=0.0, int_pad_value=-1)
|
|
|
|
|
|
@@ -349,7 +346,6 @@ class DiarTask(AbsTask):
|
|
|
def build_preprocess_fn(
|
|
|
cls, args: argparse.Namespace, train: bool
|
|
|
) -> Optional[Callable[[str, Dict[str, np.array]], Dict[str, np.ndarray]]]:
|
|
|
- assert check_argument_types()
|
|
|
if args.use_preprocessor:
|
|
|
retval = CommonPreprocessor(
|
|
|
train=train,
|
|
|
@@ -379,7 +375,6 @@ class DiarTask(AbsTask):
|
|
|
)
|
|
|
else:
|
|
|
retval = None
|
|
|
- assert check_return_type(retval)
|
|
|
return retval
|
|
|
|
|
|
@classmethod
|
|
|
@@ -398,7 +393,6 @@ class DiarTask(AbsTask):
|
|
|
cls, train: bool = True, inference: bool = False
|
|
|
) -> Tuple[str, ...]:
|
|
|
retval = ()
|
|
|
- assert check_return_type(retval)
|
|
|
return retval
|
|
|
|
|
|
@classmethod
|
|
|
@@ -438,7 +432,6 @@ class DiarTask(AbsTask):
|
|
|
|
|
|
@classmethod
|
|
|
def build_model(cls, args: argparse.Namespace):
|
|
|
- assert check_argument_types()
|
|
|
if isinstance(args.token_list, str):
|
|
|
with open(args.token_list, encoding="utf-8") as f:
|
|
|
token_list = [line.rstrip() for line in f]
|
|
|
@@ -546,7 +539,6 @@ class DiarTask(AbsTask):
|
|
|
initialize(model, args.init)
|
|
|
logging.info(f"Init model parameters with {args.init}.")
|
|
|
|
|
|
- assert check_return_type(model)
|
|
|
return model
|
|
|
|
|
|
# ~~~~~~~~~ The methods below are mainly used for inference ~~~~~~~~~
|
|
|
@@ -569,7 +561,6 @@ class DiarTask(AbsTask):
|
|
|
device: Device type, "cpu", "cuda", or "cuda:N".
|
|
|
|
|
|
"""
|
|
|
- assert check_argument_types()
|
|
|
if config_file is None:
|
|
|
assert model_file is not None, (
|
|
|
"The argument 'model_file' must be provided "
|