|
|
@@ -17,6 +17,7 @@ from typing import Union
|
|
|
import numpy as np
|
|
|
import torch
|
|
|
from typeguard import check_argument_types
|
|
|
+from scipy.signal import medfilt
|
|
|
|
|
|
from funasr.models.frontend.wav_frontend import WavFrontendMel23
|
|
|
from funasr.tasks.diar import EENDOLADiarTask
|
|
|
@@ -234,9 +235,22 @@ def inference_modelscope(
|
|
|
# batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")}
|
|
|
|
|
|
results = speech2diar(**batch)
|
|
|
+
|
|
|
+ # post process
|
|
|
+ a = medfilt(results[0], (11, 1))
|
|
|
+ rst = []
|
|
|
+ for spkid, frames in enumerate(a.T):
|
|
|
+ frames = np.pad(frames, (1, 1), 'constant')
|
|
|
+ changes, = np.where(np.diff(frames, axis=0) != 0)
|
|
|
+ fmt = "SPEAKER {:s} 1 {:7.2f} {:7.2f} <NA> <NA> {:s} <NA>"
|
|
|
+ for s, e in zip(changes[::2], changes[1::2]):
|
|
|
+ st = s / 10.
|
|
|
+ ed = e / 10.
|
|
|
+ rst.append(fmt.format(keys[0], st, ed, "{}_{}".format(keys[0],str(spkid))))
|
|
|
+
|
|
|
# Only supporting batch_size==1
|
|
|
- key, value = keys[0], output_results_str(results, keys[0])
|
|
|
- item = {"key": key, "value": value}
|
|
|
+ value = "\n".join(rst)
|
|
|
+ item = {"key": keys[0], "value": value}
|
|
|
result_list.append(item)
|
|
|
if output_path is not None:
|
|
|
output_writer.write(value)
|