|
@@ -16,6 +16,8 @@ from funasr.frontends.utils.log_mel import LogMel
|
|
|
from funasr.frontends.utils.stft import Stft
|
|
from funasr.frontends.utils.stft import Stft
|
|
|
from funasr.frontends.utils.frontend import Frontend
|
|
from funasr.frontends.utils.frontend import Frontend
|
|
|
from funasr.models.transformer.utils.nets_utils import make_pad_mask
|
|
from funasr.models.transformer.utils.nets_utils import make_pad_mask
|
|
|
|
|
+from funasr.register import tables
|
|
|
|
|
+
|
|
|
|
|
|
|
|
@tables.register("frontend_classes", "DefaultFrontend")
|
|
@tables.register("frontend_classes", "DefaultFrontend")
|
|
|
class DefaultFrontend(nn.Module):
|
|
class DefaultFrontend(nn.Module):
|
|
@@ -40,6 +42,7 @@ class DefaultFrontend(nn.Module):
|
|
|
frontend_conf: Optional[dict] = None,
|
|
frontend_conf: Optional[dict] = None,
|
|
|
apply_stft: bool = True,
|
|
apply_stft: bool = True,
|
|
|
use_channel: int = None,
|
|
use_channel: int = None,
|
|
|
|
|
+ **kwargs,
|
|
|
):
|
|
):
|
|
|
super().__init__()
|
|
super().__init__()
|
|
|
if isinstance(fs, str):
|
|
if isinstance(fs, str):
|