|
|
@@ -0,0 +1,137 @@
|
|
|
+import soundfile
|
|
|
+import kaldiio
|
|
|
+from tqdm import tqdm
|
|
|
+import json
|
|
|
+import os
|
|
|
+from funasr.utils.misc import load_scp_as_list, load_scp_as_dict
|
|
|
+import numpy as np
|
|
|
+import argparse
|
|
|
+import random
|
|
|
+
|
|
|
+short_spk_list = []
|
|
|
+def calc_rand_ivc(spk, spk2utt, utt2ivc, utt2frames, total_len=3000):
|
|
|
+ all_utts = spk2utt[spk]
|
|
|
+ idx_list = list(range(len(all_utts)))
|
|
|
+ random.shuffle(idx_list)
|
|
|
+ count = 0
|
|
|
+ utt_list = []
|
|
|
+ for i in idx_list:
|
|
|
+ utt_id = all_utts[i]
|
|
|
+ utt_list.append(utt_id)
|
|
|
+ count += int(utt2frames[utt_id])
|
|
|
+ if count >= total_len:
|
|
|
+ break
|
|
|
+ if count < 300 and spk not in short_spk_list:
|
|
|
+ print("Speaker {} has only {} frames, but expect {} frames at least, use them all.".format(spk, count, 300))
|
|
|
+ short_spk_list.append(spk)
|
|
|
+
|
|
|
+ ivc_list = [kaldiio.load_mat(utt2ivc[utt])[np.newaxis, :] for utt in utt_list]
|
|
|
+ ivc = np.concatenate(ivc_list, axis=0)
|
|
|
+ ivc = np.mean(ivc, axis=0, keepdims=False)
|
|
|
+ return ivc
|
|
|
+
|
|
|
+
|
|
|
+def process(meeting_scp, labels_scp, spk2utt, utt2xvec, utt2frames, meeting2spk_list, args):
|
|
|
+ out_prefix = args.out
|
|
|
+
|
|
|
+ ivc_dim = 512
|
|
|
+ win_len, win_shift = 400, 160
|
|
|
+ label_weights = 2 ** np.array(list(range(args.n_spk)))
|
|
|
+ wav_writer = kaldiio.WriteHelper("ark,scp:{}_wav.ark,{}_wav.scp".format(out_prefix, out_prefix))
|
|
|
+ ivc_writer = kaldiio.WriteHelper("ark,scp:{}_profile.ark,{}_profile.scp".format(out_prefix, out_prefix))
|
|
|
+ label_writer = kaldiio.WriteHelper("ark,scp:{}_label.ark,{}_label.scp".format(out_prefix, out_prefix))
|
|
|
+
|
|
|
+
|
|
|
+ frames_list = []
|
|
|
+ chunk_size = int(args.chunk_size * args.sr)
|
|
|
+ chunk_shift = int(args.chunk_shift * args.sr)
|
|
|
+ for mid, meeting_wav_path in tqdm(meeting_scp, total=len(meeting_scp), ascii=True, disable=args.no_pbar):
|
|
|
+ meeting_wav, sr = soundfile.read(meeting_wav_path, dtype='float32')
|
|
|
+ num_chunk = (len(meeting_wav) - chunk_size) // chunk_shift + 1
|
|
|
+ meeting_labels = np.load(labels_scp[mid])
|
|
|
+ for i in range(num_chunk):
|
|
|
+ st, ed = i*chunk_shift, i*chunk_shift+chunk_size
|
|
|
+ seg_id = "{}-{:03d}-{:06d}-{:06d}".format(mid, i, int(st/args.sr*100), int(ed/args.sr*100))
|
|
|
+ wav_writer(seg_id, meeting_wav[st: ed])
|
|
|
+
|
|
|
+ xvec_list = []
|
|
|
+ for spk in meeting2spk_list[mid]:
|
|
|
+ spk_xvec = calc_rand_ivc(spk, spk2utt, utt2xvec, utt2frames, 1000)[np.newaxis, :]
|
|
|
+ xvec_list.append(spk_xvec)
|
|
|
+ for _ in range(args.n_spk - len(xvec_list)):
|
|
|
+ xvec_list.append(np.zeros((ivc_dim,), dtype=np.float32))
|
|
|
+ xvec = np.row_stack(xvec_list)
|
|
|
+ ivc_writer(seg_id, xvec)
|
|
|
+
|
|
|
+ wav_label = meeting_labels[st:ed, :]
|
|
|
+ frame_num = (ed-st) // win_shift
|
|
|
+ # wav_label = np.pad(wav_label, ((win_len/2, win_len/2), (0, 0)), "constant")
|
|
|
+ feat_label = np.zeros((frame_num, wav_label.shape[1]), dtype=int)
|
|
|
+ for i in range(frame_num):
|
|
|
+ frame_label = wav_label[i*win_shift: (i+1)*win_shift, :]
|
|
|
+ feat_label[i, :] = (np.sum(frame_label, axis=0) > 0).astype(int)
|
|
|
+ label_writer(seg_id, feat_label)
|
|
|
+
|
|
|
+ frames_list.append((mid, feat_label.shape[0]))
|
|
|
+ return frames_list
|
|
|
+
|
|
|
+
|
|
|
+def calc_spk_list(rttm_path):
|
|
|
+ spk_list = []
|
|
|
+ for one_line in open(rttm_path, "rt"):
|
|
|
+ parts = one_line.strip().split(" ")
|
|
|
+ mid, st, dur, spk = parts[1], float(parts[3]), float(parts[4]), int(parts[7])
|
|
|
+ spk = "{}_S{:03d}".format(mid, spk)
|
|
|
+ if spk not in spk_list:
|
|
|
+ spk_list.append(spk)
|
|
|
+
|
|
|
+ return spk_list
|
|
|
+
|
|
|
+
|
|
|
+def main():
|
|
|
+ parser = argparse.ArgumentParser()
|
|
|
+ parser.add_argument("--dir", required=True, type=str, default=None,
|
|
|
+ help="feats.scp")
|
|
|
+ parser.add_argument("--out", required=True, type=str, default=None,
|
|
|
+ help="The prefix of dumpped files.")
|
|
|
+ parser.add_argument("--n_spk", type=int, default=4)
|
|
|
+ parser.add_argument("--use_lfr", default=False, action="store_true")
|
|
|
+ parser.add_argument("--no_pbar", default=False, action="store_true")
|
|
|
+ parser.add_argument("--sr", type=int, default=16000)
|
|
|
+ parser.add_argument("--chunk_size", type=int, default=16)
|
|
|
+ parser.add_argument("--chunk_shift", type=int, default=4)
|
|
|
+ args = parser.parse_args()
|
|
|
+
|
|
|
+ if not os.path.exists(os.path.dirname(args.out)):
|
|
|
+ os.makedirs(os.path.dirname(args.out))
|
|
|
+
|
|
|
+ meetings_scp = load_scp_as_list(os.path.join(args.dir, "meetings_rmsil.scp"))
|
|
|
+ labels_scp = load_scp_as_dict(os.path.join(args.dir, "labels.scp"))
|
|
|
+ rttm_scp = load_scp_as_list(os.path.join(args.dir, "rttm.scp"))
|
|
|
+ utt2spk = load_scp_as_dict(os.path.join(args.dir, "utt2spk"))
|
|
|
+ utt2xvec = load_scp_as_dict(os.path.join(args.dir, "utt2xvec"))
|
|
|
+ utt2wav = load_scp_as_dict(os.path.join(args.dir, "wav.scp"))
|
|
|
+ utt2frames = {}
|
|
|
+ for uttid, wav_path in utt2wav.items():
|
|
|
+ wav, sr = soundfile.read(wav_path, dtype="int16")
|
|
|
+ utt2frames[uttid] = int(len(wav) / sr * 100)
|
|
|
+
|
|
|
+ meeting2spk_list = {}
|
|
|
+ for mid, rttm_path in rttm_scp:
|
|
|
+ meeting2spk_list[mid] = calc_spk_list(rttm_path)
|
|
|
+
|
|
|
+ spk2utt = {}
|
|
|
+ for utt, spk in utt2spk.items():
|
|
|
+ if utt in utt2xvec and utt in utt2frames and int(utt2frames[utt]) > 25:
|
|
|
+ if spk not in spk2utt:
|
|
|
+ spk2utt[spk] = []
|
|
|
+ spk2utt[spk].append(utt)
|
|
|
+
|
|
|
+ # random.shuffle(feat_scp)
|
|
|
+ meeting_lens = process(meetings_scp, labels_scp, spk2utt, utt2xvec, utt2frames, meeting2spk_list, args)
|
|
|
+ total_frames = sum([x[1] for x in meeting_lens])
|
|
|
+ print("Total chunks: {:6d}, total frames: {:10d}".format(len(meeting_lens), total_frames))
|
|
|
+
|
|
|
+
|
|
|
+if __name__ == '__main__':
|
|
|
+ main()
|