|
|
@@ -33,6 +33,8 @@ from funasr.utils.types import str2triple_str
|
|
|
from funasr.utils.types import str_or_none
|
|
|
from scipy.ndimage import median_filter
|
|
|
from funasr.utils.misc import statistic_model_parameters
|
|
|
+from funasr.datasets.iterable_dataset import load_bytes
|
|
|
+
|
|
|
|
|
|
class Speech2Diarization:
|
|
|
"""Speech2Xvector class
|
|
|
@@ -257,6 +259,9 @@ def inference_modelscope(
|
|
|
assert "sv_model_file" in param_dict, "sv_model_file must be provided in param_dict."
|
|
|
sv_train_config = param_dict["sv_train_config"]
|
|
|
sv_model_file = param_dict["sv_model_file"]
|
|
|
+ if "model_dir" in param_dict:
|
|
|
+ sv_train_config = os.path.join(param_dict["model_dir"], sv_train_config)
|
|
|
+ sv_model_file = os.path.join(param_dict["model_dir"], sv_model_file)
|
|
|
from funasr.bin.sv_inference import Speech2Xvector
|
|
|
speech2xvector_kwargs = dict(
|
|
|
sv_train_config=sv_train_config,
|
|
|
@@ -320,7 +325,9 @@ def inference_modelscope(
|
|
|
def prepare_dataset():
|
|
|
for idx, example in enumerate(raw_inputs):
|
|
|
# read waveform file
|
|
|
- example = [soundfile.read(x)[0] if isinstance(example[0], str) else x
|
|
|
+ example = [load_bytes(x) if isinstance(x, bytes) else x
|
|
|
+ for x in example]
|
|
|
+ example = [soundfile.read(x)[0] if isinstance(x, str) else x
|
|
|
for x in example]
|
|
|
# convert torch tensor to numpy array
|
|
|
example = [x.numpy() if isinstance(example[0], torch.Tensor) else x
|