|
@@ -3,14 +3,15 @@ import logging
|
|
|
from funasr.models.target_delay_transformer import TargetDelayTransformer
|
|
from funasr.models.target_delay_transformer import TargetDelayTransformer
|
|
|
from funasr.models.vad_realtime_transformer import VadRealtimeTransformer
|
|
from funasr.models.vad_realtime_transformer import VadRealtimeTransformer
|
|
|
from funasr.torch_utils.initialize import initialize
|
|
from funasr.torch_utils.initialize import initialize
|
|
|
-from funasr.train.abs_model import AbsPunctuation
|
|
|
|
|
from funasr.train.abs_model import PunctuationModel
|
|
from funasr.train.abs_model import PunctuationModel
|
|
|
from funasr.train.class_choices import ClassChoices
|
|
from funasr.train.class_choices import ClassChoices
|
|
|
|
|
|
|
|
punc_choices = ClassChoices(
|
|
punc_choices = ClassChoices(
|
|
|
"punctuation",
|
|
"punctuation",
|
|
|
- classes=dict(target_delay=TargetDelayTransformer, vad_realtime=VadRealtimeTransformer),
|
|
|
|
|
- type_check=AbsPunctuation,
|
|
|
|
|
|
|
+ classes=dict(
|
|
|
|
|
+ target_delay=TargetDelayTransformer,
|
|
|
|
|
+ vad_realtime=VadRealtimeTransformer
|
|
|
|
|
+ ),
|
|
|
default="target_delay",
|
|
default="target_delay",
|
|
|
)
|
|
)
|
|
|
model_choices = ClassChoices(
|
|
model_choices = ClassChoices(
|