|
|
@@ -14,14 +14,8 @@ import torch
|
|
|
from torch.nn import functional as F
|
|
|
from typeguard import check_argument_types
|
|
|
|
|
|
-from funasr.modules.nets_utils import to_device
|
|
|
from funasr.modules.nets_utils import make_pad_mask
|
|
|
-from funasr.models.decoder.abs_decoder import AbsDecoder
|
|
|
-from funasr.models.encoder.abs_encoder import AbsEncoder
|
|
|
-from funasr.models.frontend.abs_frontend import AbsFrontend
|
|
|
-from funasr.models.specaug.abs_specaug import AbsSpecAug
|
|
|
from funasr.models.base_model import FunASRModel
|
|
|
-from funasr.layers.abs_normalize import AbsNormalize
|
|
|
from funasr.torch_utils.device_funcs import force_gatherable
|
|
|
from funasr.losses.label_smoothing_loss import LabelSmoothingLoss, SequenceBinaryCrossEntropy
|
|
|
from funasr.utils.misc import int2vec
|
|
|
@@ -43,9 +37,9 @@ class DiarSondModel(FunASRModel):
|
|
|
def __init__(
|
|
|
self,
|
|
|
vocab_size: int,
|
|
|
- frontend: Optional[AbsFrontend],
|
|
|
- specaug: Optional[AbsSpecAug],
|
|
|
- normalize: Optional[AbsNormalize],
|
|
|
+ frontend: Optional[torch.nn.Module],
|
|
|
+ specaug: Optional[torch.nn.Module],
|
|
|
+ normalize: Optional[torch.nn.Module],
|
|
|
encoder: torch.nn.Module,
|
|
|
speaker_encoder: Optional[torch.nn.Module],
|
|
|
ci_scorer: torch.nn.Module,
|
|
|
@@ -348,7 +342,7 @@ class DiarSondModel(FunASRModel):
|
|
|
cd_simi = torch.reshape(cd_simi, [bb, self.max_spk_num, tt, 1])
|
|
|
cd_simi = cd_simi.squeeze(dim=3).permute([0, 2, 1])
|
|
|
|
|
|
- if isinstance(self.ci_scorer, AbsEncoder):
|
|
|
+ if isinstance(self.ci_scorer, torch.nn.Module):
|
|
|
ci_simi = self.ci_scorer(ge_in, ge_len)[0]
|
|
|
ci_simi = torch.reshape(ci_simi, [bb, self.max_spk_num, tt]).permute([0, 2, 1])
|
|
|
else:
|