|
|
@@ -16,8 +16,8 @@ from typing import Union
|
|
|
|
|
|
import numpy as np
|
|
|
import torch
|
|
|
-from typeguard import check_argument_types
|
|
|
from scipy.signal import medfilt
|
|
|
+from typeguard import check_argument_types
|
|
|
|
|
|
from funasr.models.frontend.wav_frontend import WavFrontendMel23
|
|
|
from funasr.tasks.diar import EENDOLADiarTask
|
|
|
@@ -28,6 +28,7 @@ from funasr.utils.types import str2bool
|
|
|
from funasr.utils.types import str2triple_str
|
|
|
from funasr.utils.types import str_or_none
|
|
|
|
|
|
+
|
|
|
class Speech2Diarization:
|
|
|
"""Speech2Diarlization class
|
|
|
|
|
|
@@ -237,7 +238,7 @@ def inference_modelscope(
|
|
|
results = speech2diar(**batch)
|
|
|
|
|
|
# post process
|
|
|
- a = results[0].cpu().numpy()
|
|
|
+ a = results[0][0].cpu().numpy()
|
|
|
a = medfilt(a, (11, 1))
|
|
|
rst = []
|
|
|
for spkid, frames in enumerate(a.T):
|
|
|
@@ -246,8 +247,8 @@ def inference_modelscope(
|
|
|
fmt = "SPEAKER {:s} 1 {:7.2f} {:7.2f} <NA> <NA> {:s} <NA>"
|
|
|
for s, e in zip(changes[::2], changes[1::2]):
|
|
|
st = s / 10.
|
|
|
- ed = e / 10.
|
|
|
- rst.append(fmt.format(keys[0], st, ed, "{}_{}".format(keys[0],str(spkid))))
|
|
|
+ dur = (e - s) / 10.
|
|
|
+ rst.append(fmt.format(keys[0], st, dur, "{}_{}".format(keys[0], str(spkid))))
|
|
|
|
|
|
# Only supporting batch_size==1
|
|
|
value = "\n".join(rst)
|