speech_asr 3 лет назад
Родитель
Сommit
7c6ed3830a
1 измененных файлов с 16 добавлено и 2 удалено
  1. 16 2
      funasr/bin/eend_ola_inference.py

+ 16 - 2
funasr/bin/eend_ola_inference.py

@@ -17,6 +17,7 @@ from typing import Union
 import numpy as np
 import torch
 from typeguard import check_argument_types
+from scipy.signal import medfilt
 
 from funasr.models.frontend.wav_frontend import WavFrontendMel23
 from funasr.tasks.diar import EENDOLADiarTask
@@ -234,9 +235,22 @@ def inference_modelscope(
             # batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")}
 
             results = speech2diar(**batch)
+
+            # post process
+            a = medfilt(results[0], (11, 1))
+            rst = []
+            for spkid, frames in enumerate(a.T):
+                frames = np.pad(frames, (1, 1), 'constant')
+                changes, = np.where(np.diff(frames, axis=0) != 0)
+                fmt = "SPEAKER {:s} 1 {:7.2f} {:7.2f} <NA> <NA> {:s} <NA>"
+                for s, e in zip(changes[::2], changes[1::2]):
+                    st = s / 10.
+                    ed = e / 10.
+                    rst.append(fmt.format(keys[0], st, ed, "{}_{}".format(keys[0],str(spkid))))
+
             # Only supporting batch_size==1
-            key, value = keys[0], output_results_str(results, keys[0])
-            item = {"key": key, "value": value}
+            value = "\n".join(rst)
+            item = {"key": keys[0], "value": value}
             result_list.append(item)
             if output_path is not None:
                 output_writer.write(value)