嘉渊 2 лет назад
Родитель
Сommit
4b30f336ee

+ 16 - 29
funasr/bin/diar_infer.py

@@ -1,41 +1,28 @@
-# -*- encoding: utf-8 -*-
 #!/usr/bin/env python3
 #!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
 # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
 # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
 #  MIT License  (https://opensource.org/licenses/MIT)
 #  MIT License  (https://opensource.org/licenses/MIT)
 
 
-import argparse
 import logging
 import logging
 import os
 import os
-import sys
+from collections import OrderedDict
 from pathlib import Path
 from pathlib import Path
 from typing import Any
 from typing import Any
-from typing import List
 from typing import Optional
 from typing import Optional
-from typing import Sequence
-from typing import Tuple
 from typing import Union
 from typing import Union
 
 
-from collections import OrderedDict
 import numpy as np
 import numpy as np
-import soundfile
 import torch
 import torch
+from scipy.ndimage import median_filter
 from torch.nn import functional as F
 from torch.nn import functional as F
 from typeguard import check_argument_types
 from typeguard import check_argument_types
-from typeguard import check_return_type
 
 
-from funasr.utils.cli_utils import get_commandline_args
+from funasr.models.frontend.wav_frontend import WavFrontendMel23
 from funasr.tasks.diar import DiarTask
 from funasr.tasks.diar import DiarTask
-from funasr.tasks.diar import EENDOLADiarTask
+from funasr.build_utils.build_model_from_file import build_model_from_file
 from funasr.torch_utils.device_funcs import to_device
 from funasr.torch_utils.device_funcs import to_device
-from funasr.torch_utils.set_all_random_seed import set_all_random_seed
-from funasr.utils import config_argparse
-from funasr.utils.types import str2bool
-from funasr.utils.types import str2triple_str
-from funasr.utils.types import str_or_none
-from scipy.ndimage import median_filter
 from funasr.utils.misc import statistic_model_parameters
 from funasr.utils.misc import statistic_model_parameters
-from funasr.datasets.iterable_dataset import load_bytes
-from funasr.models.frontend.wav_frontend import WavFrontendMel23
+
 
 
 class Speech2DiarizationEEND:
 class Speech2DiarizationEEND:
     """Speech2Diarlization class
     """Speech2Diarlization class
@@ -61,10 +48,12 @@ class Speech2DiarizationEEND:
         assert check_argument_types()
         assert check_argument_types()
 
 
         # 1. Build Diarization model
         # 1. Build Diarization model
-        diar_model, diar_train_args = EENDOLADiarTask.build_model_from_file(
+        diar_model, diar_train_args = build_model_from_file(
             config_file=diar_train_config,
             config_file=diar_train_config,
             model_file=diar_model_file,
             model_file=diar_model_file,
-            device=device
+            device=device,
+            task_name="diar",
+            mode="eend-ola",
         )
         )
         frontend = None
         frontend = None
         if diar_train_args.frontend is not None and diar_train_args.frontend_conf is not None:
         if diar_train_args.frontend is not None and diar_train_args.frontend_conf is not None:
@@ -177,10 +166,12 @@ class Speech2DiarizationSOND:
         assert check_argument_types()
         assert check_argument_types()
 
 
         # TODO: 1. Build Diarization model
         # TODO: 1. Build Diarization model
-        diar_model, diar_train_args = DiarTask.build_model_from_file(
+        diar_model, diar_train_args = build_model_from_file(
             config_file=diar_train_config,
             config_file=diar_train_config,
             model_file=diar_model_file,
             model_file=diar_model_file,
-            device=device
+            device=device,
+            task_name="diar",
+            mode="sond",
         )
         )
         logging.info("diar_model: {}".format(diar_model))
         logging.info("diar_model: {}".format(diar_model))
         logging.info("model parameter number: {}".format(statistic_model_parameters(diar_model)))
         logging.info("model parameter number: {}".format(statistic_model_parameters(diar_model)))
@@ -248,7 +239,7 @@ class Speech2DiarizationSOND:
         ut = logits_idx.shape[1] * self.diar_model.encoder.time_ds_ratio
         ut = logits_idx.shape[1] * self.diar_model.encoder.time_ds_ratio
         logits_idx = F.upsample(
         logits_idx = F.upsample(
             logits_idx.unsqueeze(1).float(),
             logits_idx.unsqueeze(1).float(),
-            size=(ut, ),
+            size=(ut,),
             mode="nearest",
             mode="nearest",
         ).squeeze(1).long()
         ).squeeze(1).long()
         logits_idx = logits_idx[0].tolist()
         logits_idx = logits_idx[0].tolist()
@@ -268,7 +259,7 @@ class Speech2DiarizationSOND:
             if spk not in results:
             if spk not in results:
                 results[spk] = []
                 results[spk] = []
             if dur > self.dur_threshold:
             if dur > self.dur_threshold:
-                results[spk].append((st, st+dur))
+                results[spk].append((st, st + dur))
 
 
         # sort segments in start time ascending
         # sort segments in start time ascending
         for spk in results:
         for spk in results:
@@ -344,7 +335,3 @@ class Speech2DiarizationSOND:
             kwargs.update(**d.download_and_unpack(model_tag))
             kwargs.update(**d.download_and_unpack(model_tag))
 
 
         return Speech2DiarizationSOND(**kwargs)
         return Speech2DiarizationSOND(**kwargs)
-
-
-
-

+ 23 - 44
funasr/bin/diar_inference_launch.py

@@ -1,5 +1,5 @@
+# !/usr/bin/env python3
 # -*- encoding: utf-8 -*-
 # -*- encoding: utf-8 -*-
-#!/usr/bin/env python3
 # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
 # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
 #  MIT License  (https://opensource.org/licenses/MIT)
 #  MIT License  (https://opensource.org/licenses/MIT)
 
 
@@ -8,47 +8,28 @@ import argparse
 import logging
 import logging
 import os
 import os
 import sys
 import sys
-from typing import Union, Dict, Any
-
-from funasr.utils import config_argparse
-from funasr.utils.cli_utils import get_commandline_args
-from funasr.utils.types import str2bool
-from funasr.utils.types import str2triple_str
-from funasr.utils.types import str_or_none
-
-import argparse
-import logging
-import os
-import sys
-from pathlib import Path
-from typing import Any
 from typing import List
 from typing import List
 from typing import Optional
 from typing import Optional
 from typing import Sequence
 from typing import Sequence
 from typing import Tuple
 from typing import Tuple
 from typing import Union
 from typing import Union
 
 
-from collections import OrderedDict
 import numpy as np
 import numpy as np
 import soundfile
 import soundfile
 import torch
 import torch
-from torch.nn import functional as F
-from typeguard import check_argument_types
-from typeguard import check_return_type
 from scipy.signal import medfilt
 from scipy.signal import medfilt
-from funasr.utils.cli_utils import get_commandline_args
-from funasr.tasks.diar import DiarTask
-from funasr.tasks.diar import EENDOLADiarTask
-from funasr.torch_utils.device_funcs import to_device
+from typeguard import check_argument_types
+
+from funasr.bin.diar_infer import Speech2DiarizationSOND, Speech2DiarizationEEND
+from funasr.datasets.iterable_dataset import load_bytes
+from funasr.build_utils.build_streaming_iterator import build_streaming_iterator
 from funasr.torch_utils.set_all_random_seed import set_all_random_seed
 from funasr.torch_utils.set_all_random_seed import set_all_random_seed
 from funasr.utils import config_argparse
 from funasr.utils import config_argparse
+from funasr.utils.cli_utils import get_commandline_args
 from funasr.utils.types import str2bool
 from funasr.utils.types import str2bool
 from funasr.utils.types import str2triple_str
 from funasr.utils.types import str2triple_str
 from funasr.utils.types import str_or_none
 from funasr.utils.types import str_or_none
-from scipy.ndimage import median_filter
-from funasr.utils.misc import statistic_model_parameters
-from funasr.datasets.iterable_dataset import load_bytes
-from funasr.bin.diar_infer import Speech2DiarizationSOND, Speech2DiarizationEEND
+
 
 
 def inference_sond(
 def inference_sond(
         diar_train_config: str,
         diar_train_config: str,
@@ -94,7 +75,8 @@ def inference_sond(
     set_all_random_seed(seed)
     set_all_random_seed(seed)
 
 
     # 2a. Build speech2xvec [Optional]
     # 2a. Build speech2xvec [Optional]
-    if mode == "sond_demo" and param_dict is not None and "extract_profile" in param_dict and param_dict["extract_profile"]:
+    if mode == "sond_demo" and param_dict is not None and "extract_profile" in param_dict and param_dict[
+        "extract_profile"]:
         assert "sv_train_config" in param_dict, "sv_train_config must be provided param_dict."
         assert "sv_train_config" in param_dict, "sv_train_config must be provided param_dict."
         assert "sv_model_file" in param_dict, "sv_model_file must be provided in param_dict."
         assert "sv_model_file" in param_dict, "sv_model_file must be provided in param_dict."
         sv_train_config = param_dict["sv_train_config"]
         sv_train_config = param_dict["sv_train_config"]
@@ -139,7 +121,7 @@ def inference_sond(
         rst = []
         rst = []
         mid = uttid.rsplit("-", 1)[0]
         mid = uttid.rsplit("-", 1)[0]
         for key in results:
         for key in results:
-            results[key] = [(x[0]/100, x[1]/100) for x in results[key]]
+            results[key] = [(x[0] / 100, x[1] / 100) for x in results[key]]
         if out_format == "vad":
         if out_format == "vad":
             for spk, segs in results.items():
             for spk, segs in results.items():
                 rst.append("{} {}".format(spk, segs))
                 rst.append("{} {}".format(spk, segs))
@@ -176,7 +158,7 @@ def inference_sond(
                         example = [x.numpy() if isinstance(example[0], torch.Tensor) else x
                         example = [x.numpy() if isinstance(example[0], torch.Tensor) else x
                                    for x in example]
                                    for x in example]
                         speech = example[0]
                         speech = example[0]
-                        logging.info("Extracting profiles for {} waveforms".format(len(example)-1))
+                        logging.info("Extracting profiles for {} waveforms".format(len(example) - 1))
                         profile = [speech2xvector.calculate_embedding(x) for x in example[1:]]
                         profile = [speech2xvector.calculate_embedding(x) for x in example[1:]]
                         profile = torch.cat(profile, dim=0)
                         profile = torch.cat(profile, dim=0)
                         yield ["test{}".format(idx)], {"speech": [speech], "profile": [profile]}
                         yield ["test{}".format(idx)], {"speech": [speech], "profile": [profile]}
@@ -186,16 +168,15 @@ def inference_sond(
                 raise TypeError("raw_inputs must be a list or tuple in [speech, profile1, profile2, ...] ")
                 raise TypeError("raw_inputs must be a list or tuple in [speech, profile1, profile2, ...] ")
         else:
         else:
             # 3. Build data-iterator
             # 3. Build data-iterator
-            loader = DiarTask.build_streaming_iterator(
-                data_path_and_name_and_type,
+            loader = build_streaming_iterator(
+                task_name="diar",
+                preprocess_args=None,
+                data_path_and_name_and_type=data_path_and_name_and_type,
                 dtype=dtype,
                 dtype=dtype,
                 batch_size=batch_size,
                 batch_size=batch_size,
                 key_file=key_file,
                 key_file=key_file,
                 num_workers=num_workers,
                 num_workers=num_workers,
-                preprocess_fn=None,
-                collate_fn=None,
-                allow_variable_data_keys=allow_variable_data_keys,
-                inference=True,
+                use_collate_fn=False,
             )
             )
 
 
         # 7. Start for-loop
         # 7. Start for-loop
@@ -235,6 +216,7 @@ def inference_sond(
 
 
     return _forward
     return _forward
 
 
+
 def inference_eend(
 def inference_eend(
         diar_train_config: str,
         diar_train_config: str,
         diar_model_file: str,
         diar_model_file: str,
@@ -306,16 +288,14 @@ def inference_eend(
             if isinstance(raw_inputs, torch.Tensor):
             if isinstance(raw_inputs, torch.Tensor):
                 raw_inputs = raw_inputs.numpy()
                 raw_inputs = raw_inputs.numpy()
             data_path_and_name_and_type = [raw_inputs[0], "speech", "sound"]
             data_path_and_name_and_type = [raw_inputs[0], "speech", "sound"]
-        loader = EENDOLADiarTask.build_streaming_iterator(
-            data_path_and_name_and_type,
+        loader = build_streaming_iterator(
+            task_name="diar",
+            preprocess_args=None,
+            data_path_and_name_and_type=data_path_and_name_and_type,
             dtype=dtype,
             dtype=dtype,
             batch_size=batch_size,
             batch_size=batch_size,
             key_file=key_file,
             key_file=key_file,
             num_workers=num_workers,
             num_workers=num_workers,
-            preprocess_fn=EENDOLADiarTask.build_preprocess_fn(speech2diar.diar_train_args, False),
-            collate_fn=EENDOLADiarTask.build_collate_fn(speech2diar.diar_train_args, False),
-            allow_variable_data_keys=allow_variable_data_keys,
-            inference=True,
         )
         )
 
 
         # 3. Start for-loop
         # 3. Start for-loop
@@ -362,8 +342,6 @@ def inference_eend(
     return _forward
     return _forward
 
 
 
 
-
-
 def inference_launch(mode, **kwargs):
 def inference_launch(mode, **kwargs):
     if mode == "sond":
     if mode == "sond":
         return inference_sond(mode=mode, **kwargs)
         return inference_sond(mode=mode, **kwargs)
@@ -386,6 +364,7 @@ def inference_launch(mode, **kwargs):
         logging.info("Unknown decoding mode: {}".format(mode))
         logging.info("Unknown decoding mode: {}".format(mode))
         return None
         return None
 
 
+
 def get_parser():
 def get_parser():
     parser = config_argparse.ArgumentParser(
     parser = config_argparse.ArgumentParser(
         description="Speaker Verification",
         description="Speaker Verification",

+ 37 - 2
funasr/build_utils/build_model_from_file.py

@@ -72,6 +72,8 @@ def build_model_from_file(
             model.load_state_dict(model_dict)
             model.load_state_dict(model_dict)
         else:
         else:
             model_dict = torch.load(model_file, map_location=device)
             model_dict = torch.load(model_file, map_location=device)
+    if task_name == "diar" and mode == "sond":
+        model_dict = fileter_model_dict(model_dict, model.state_dict())
     model.load_state_dict(model_dict)
     model.load_state_dict(model_dict)
     if model_name_pth is not None and not os.path.exists(model_name_pth):
     if model_name_pth is not None and not os.path.exists(model_name_pth):
         torch.save(model_dict, model_name_pth)
         torch.save(model_dict, model_name_pth)
@@ -85,7 +87,7 @@ def convert_tf2torch(
         ckpt,
         ckpt,
         mode,
         mode,
 ):
 ):
-    assert mode == "paraformer" or mode == "uniasr"
+    assert mode == "paraformer" or mode == "uniasr" or mode == "sond"
     logging.info("start convert tf model to torch model")
     logging.info("start convert tf model to torch model")
     from funasr.modules.streaming_utils.load_fr_tf import load_tf_dict
     from funasr.modules.streaming_utils.load_fr_tf import load_tf_dict
     var_dict_tf = load_tf_dict(ckpt)
     var_dict_tf = load_tf_dict(ckpt)
@@ -113,7 +115,7 @@ def convert_tf2torch(
         # stride_conv
         # stride_conv
         var_dict_torch_update_local = model.stride_conv.convert_tf2torch(var_dict_tf, var_dict_torch)
         var_dict_torch_update_local = model.stride_conv.convert_tf2torch(var_dict_tf, var_dict_torch)
         var_dict_torch_update.update(var_dict_torch_update_local)
         var_dict_torch_update.update(var_dict_torch_update_local)
-    else:
+    elif mode == "paraformer":
         # encoder
         # encoder
         var_dict_torch_update_local = model.encoder.convert_tf2torch(var_dict_tf, var_dict_torch)
         var_dict_torch_update_local = model.encoder.convert_tf2torch(var_dict_tf, var_dict_torch)
         var_dict_torch_update.update(var_dict_torch_update_local)
         var_dict_torch_update.update(var_dict_torch_update_local)
@@ -126,5 +128,38 @@ def convert_tf2torch(
         # bias_encoder
         # bias_encoder
         var_dict_torch_update_local = model.clas_convert_tf2torch(var_dict_tf, var_dict_torch)
         var_dict_torch_update_local = model.clas_convert_tf2torch(var_dict_tf, var_dict_torch)
         var_dict_torch_update.update(var_dict_torch_update_local)
         var_dict_torch_update.update(var_dict_torch_update_local)
+    else:
+        if model.encoder is not None:
+            var_dict_torch_update_local = model.encoder.convert_tf2torch(var_dict_tf, var_dict_torch)
+            var_dict_torch_update.update(var_dict_torch_update_local)
+        # speaker encoder
+        if model.speaker_encoder is not None:
+            var_dict_torch_update_local = model.speaker_encoder.convert_tf2torch(var_dict_tf, var_dict_torch)
+            var_dict_torch_update.update(var_dict_torch_update_local)
+        # cd scorer
+        if model.cd_scorer is not None:
+            var_dict_torch_update_local = model.cd_scorer.convert_tf2torch(var_dict_tf, var_dict_torch)
+            var_dict_torch_update.update(var_dict_torch_update_local)
+        # ci scorer
+        if model.ci_scorer is not None:
+            var_dict_torch_update_local = model.ci_scorer.convert_tf2torch(var_dict_tf, var_dict_torch)
+            var_dict_torch_update.update(var_dict_torch_update_local)
+        # decoder
+        if model.decoder is not None:
+            var_dict_torch_update_local = model.decoder.convert_tf2torch(var_dict_tf, var_dict_torch)
+            var_dict_torch_update.update(var_dict_torch_update_local)
 
 
     return var_dict_torch_update
     return var_dict_torch_update
+
+def fileter_model_dict(src_dict: dict, dest_dict: dict):
+    from collections import OrderedDict
+    new_dict = OrderedDict()
+    for key, value in src_dict.items():
+        if key in dest_dict:
+            new_dict[key] = value
+        else:
+            logging.info("{} is no longer needed in this model.".format(key))
+    for key, value in dest_dict.items():
+        if key not in new_dict:
+            logging.warning("{} is missed in checkpoint.".format(key))
+    return new_dict

+ 4 - 1
funasr/build_utils/build_streaming_iterator.py

@@ -17,6 +17,7 @@ def build_streaming_iterator(
         mc: bool = False,
         mc: bool = False,
         dtype: str = np.float32,
         dtype: str = np.float32,
         num_workers: int = 1,
         num_workers: int = 1,
+        use_collate_fn: bool = True,
         ngpu: int = 0,
         ngpu: int = 0,
         train: bool=False,
         train: bool=False,
 ) -> DataLoader:
 ) -> DataLoader:
@@ -30,7 +31,9 @@ def build_streaming_iterator(
         preprocess_fn = None
         preprocess_fn = None
 
 
     # collate
     # collate
-    if task_name in ["punc", "lm"]:
+    if not use_collate_fn:
+        collate_fn = None
+    elif task_name in ["punc", "lm"]:
         collate_fn = CommonCollateFn(int_pad_value=0)
         collate_fn = CommonCollateFn(int_pad_value=0)
     else:
     else:
         collate_fn = CommonCollateFn(float_pad_value=0.0, int_pad_value=-1)
         collate_fn = CommonCollateFn(float_pad_value=0.0, int_pad_value=-1)