|
|
@@ -11,6 +11,11 @@ import random
|
|
|
from typing import List, Dict
|
|
|
from copy import deepcopy
|
|
|
import json
|
|
|
+logging.basicConfig(
|
|
|
+ level="INFO",
|
|
|
+ format=f"[{os.uname()[1].split('.')[0]}]"
|
|
|
+ f" %(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
|
|
|
+)
|
|
|
|
|
|
|
|
|
class MyRunner(MultiProcessRunnerV3):
|
|
|
@@ -28,24 +33,20 @@ class MyRunner(MultiProcessRunnerV3):
|
|
|
parser.add_argument("--embedding_dim", type=int, default=None)
|
|
|
parser.add_argument("--average_emb_num", type=int, default=0)
|
|
|
parser.add_argument("--subset", type=int, default=0)
|
|
|
- parser.add_argument("--data_dict", type=str, default=None)
|
|
|
+ parser.add_argument("--data_json", type=str, default=None)
|
|
|
+ parser.add_argument("--seed", type=int, default=1234)
|
|
|
+ parser.add_argument("--log_interval", type=int, default=100)
|
|
|
args = parser.parse_args()
|
|
|
+ random.seed(args.seed)
|
|
|
+ np.random.seed(args.seed)
|
|
|
|
|
|
- if not os.path.exists(args.out_dir):
|
|
|
- os.makedirs(args.out_dir)
|
|
|
-
|
|
|
- args.chunk_size = int(args.chunk_size / args.frame_shift)
|
|
|
- args.chunk_shift = int(args.chunk_shift / args.frame_shift)
|
|
|
-
|
|
|
- if not os.path.exists(args.data_dict):
|
|
|
+ logging.info("Loading data...")
|
|
|
+ if not os.path.exists(args.data_json):
|
|
|
label_list = load_scp_as_list(args.label_scp)
|
|
|
wav_scp = load_scp_as_dict(args.wav_scp)
|
|
|
utt2spk = load_scp_as_dict(args.utt2spk)
|
|
|
utt2xvec = load_scp_as_dict(args.utt2xvec)
|
|
|
spk2meeting = load_scp_as_dict(args.spk2meeting)
|
|
|
- if args.embedding_dim is None:
|
|
|
- args.embedding_dim = kaldiio.load_mat(random.choice(utt2xvec)).shape[1]
|
|
|
- logging.info("Embedding dim is detected as {}.".format(args.embedding_dim))
|
|
|
|
|
|
meeting2spks = OrderedDict()
|
|
|
for spk, meeting in spk2meeting.items():
|
|
|
@@ -59,23 +60,37 @@ class MyRunner(MultiProcessRunnerV3):
|
|
|
spk2utts[spk] = []
|
|
|
spk2utts[spk].append(utt)
|
|
|
|
|
|
- os.makedirs(os.path.dirname(args.data_dict), exist_ok=True)
|
|
|
+ os.makedirs(os.path.dirname(args.data_json), exist_ok=True)
|
|
|
+ logging.info("Dump data...")
|
|
|
json.dump({
|
|
|
"label_list": label_list, "wav_scp": wav_scp, "utt2xvec": utt2xvec,
|
|
|
"spk2utts": spk2utts, "meeting2spks": meeting2spks
|
|
|
- }, open(args.data_dict, "wt", encoding="utf-8"), ensure_ascii=False, indent=4)
|
|
|
+ }, open(args.data_json, "wt", encoding="utf-8"), ensure_ascii=False, indent=4)
|
|
|
else:
|
|
|
- data_dict = json.load(open(args.data_dict, "rt", encoding="utf-8"))
|
|
|
+ data_dict = json.load(open(args.data_json, "rt", encoding="utf-8"))
|
|
|
label_list = data_dict["label_list"]
|
|
|
wav_scp = data_dict["wav_scp"]
|
|
|
utt2xvec = data_dict["utt2xvec"]
|
|
|
spk2utts = data_dict["spk2utts"]
|
|
|
meeting2spks = data_dict["meeting2spks"]
|
|
|
|
|
|
+ if not os.path.exists(args.out_dir):
|
|
|
+ os.makedirs(args.out_dir)
|
|
|
+
|
|
|
+ args.chunk_size = int(args.chunk_size / args.frame_shift)
|
|
|
+ args.chunk_shift = int(args.chunk_shift / args.frame_shift)
|
|
|
+
|
|
|
+ if args.embedding_dim is None:
|
|
|
+ args.embedding_dim = kaldiio.load_mat(next(iter(utt2xvec.values()))).shape[1]
|
|
|
+ logging.info("Embedding dim is detected as {}.".format(args.embedding_dim))
|
|
|
+
|
|
|
+ logging.info("Number utt: {}, Number speaker: {}, Number meetings: {}".format(
|
|
|
+ len(wav_scp), len(spk2utts), len(meeting2spks)
|
|
|
+ ))
|
|
|
return label_list, (wav_scp, utt2xvec, spk2utts, meeting2spks), args
|
|
|
|
|
|
def post(self, results_list, args):
|
|
|
- pass
|
|
|
+ logging.info("[main]: Got {} chunks.".format(sum(results_list)))
|
|
|
|
|
|
|
|
|
def simu_wav_chunk(spk, spk2utts, wav_scp, sample_length):
|
|
|
@@ -89,7 +104,7 @@ def simu_wav_chunk(spk, spk2utts, wav_scp, sample_length):
|
|
|
cur_length += len(wav)
|
|
|
concat_wav = np.concatenate(wav_list, axis=0)
|
|
|
start = random.randint(0, len(concat_wav) - sample_length)
|
|
|
- return concat_wav[start:]
|
|
|
+ return concat_wav[start: start+sample_length]
|
|
|
|
|
|
|
|
|
def calculate_embedding(spk, spk2utts, utt2xvec, embedding_dim, average_emb_num):
|
|
|
@@ -103,9 +118,9 @@ def calculate_embedding(spk, spk2utts, utt2xvec, embedding_dim, average_emb_num)
|
|
|
xvec_list = [kaldiio.load_mat(utt2xvec[utt]) for utt in utt_list]
|
|
|
else:
|
|
|
xvec_list = [kaldiio.load_mat(utt2xvec[utt]) for utt in random.sample(utt_list, average_emb_num)]
|
|
|
- # TODO: rerun the simulation
|
|
|
- xvec_list = [x / np.linalg.norm(x, axis=-1) for x in xvec_list]
|
|
|
- xvec = np.mean(np.concatenate(xvec_list, axis=0), axis=0)
|
|
|
+ xvec = np.concatenate(xvec_list, axis=0)
|
|
|
+ xvec = xvec / np.linalg.norm(xvec, axis=-1, keepdims=True)
|
|
|
+ xvec = np.mean(xvec, axis=0)
|
|
|
|
|
|
return xvec
|
|
|
|
|
|
@@ -124,7 +139,7 @@ def simu_chunk(
|
|
|
):
|
|
|
frame_length, max_spk_num = frame_label.shape
|
|
|
sample_length = sample_label.shape[0]
|
|
|
- positive_speaker_num = np.max(frame_label.sum(axis=1), axis=0)
|
|
|
+ positive_speaker_num = int(np.sum(frame_label.sum(axis=0) > 0))
|
|
|
pos_speaker_list = deepcopy(meeting2spks[random.choice(meeting_list)])
|
|
|
|
|
|
# get positive speakers
|
|
|
@@ -134,7 +149,7 @@ def simu_chunk(
|
|
|
while len(pos_speaker_list) < positive_speaker_num:
|
|
|
_spk = random.choice(all_speaker_list)
|
|
|
if _spk not in pos_speaker_list:
|
|
|
- pos_speaker_list.extend(_spk)
|
|
|
+ pos_speaker_list.append(_spk)
|
|
|
|
|
|
# get negative speakers
|
|
|
negative_speaker_num = random.randint(0, max_spk_num - positive_speaker_num)
|
|
|
@@ -142,12 +157,12 @@ def simu_chunk(
|
|
|
while len(neg_speaker_list) < negative_speaker_num:
|
|
|
_spk = random.choice(all_speaker_list)
|
|
|
if _spk not in pos_speaker_list and _spk not in neg_speaker_list:
|
|
|
- neg_speaker_list.extend(_spk)
|
|
|
+ neg_speaker_list.append(_spk)
|
|
|
neg_speaker_list.extend(["None"] * (max_spk_num - positive_speaker_num - negative_speaker_num))
|
|
|
|
|
|
random.shuffle(pos_speaker_list)
|
|
|
random.shuffle(neg_speaker_list)
|
|
|
- seperated_wav = np.zeros(frame_label.shape, dtype=np.float32)
|
|
|
+ seperated_wav = np.zeros(sample_label.shape, dtype=np.float32)
|
|
|
this_spk_list = []
|
|
|
for idx, frame_num in enumerate(frame_label.sum(axis=0)):
|
|
|
if frame_num > 0:
|
|
|
@@ -166,12 +181,13 @@ def simu_chunk(
|
|
|
shuffle_idx = list(range(max_spk_num))
|
|
|
random.shuffle(shuffle_idx)
|
|
|
this_spk_list = [this_spk_list[x] for x in shuffle_idx]
|
|
|
- seperated_wav = seperated_wav.transpose([0, 1])[shuffle_idx].transpose([0, 1])
|
|
|
- frame_label = frame_label.transpose([0, 1])[shuffle_idx].transpose([0, 1])
|
|
|
+ seperated_wav = seperated_wav.transpose()[shuffle_idx].transpose()
|
|
|
+ frame_label = frame_label.transpose()[shuffle_idx].transpose()
|
|
|
|
|
|
- # calculate profile and pse_label
|
|
|
+ # calculate profile
|
|
|
profile = [calculate_embedding(spk, spk2utts, utt2xvec, embedding_dim, average_emb_num)
|
|
|
for spk in this_spk_list]
|
|
|
+ profile = np.vstack(profile)
|
|
|
# pse_weights = 2 ** np.arange(max_spk_num)
|
|
|
# pse_label = np.sum(frame_label * pse_weights[np.newaxis, :], axis=1)
|
|
|
# pse_label = pse_label.astype(str).tolist()
|
|
|
@@ -181,11 +197,13 @@ def simu_chunk(
|
|
|
|
|
|
def process(task_args):
|
|
|
task_idx, task_list, (wav_scp, utt2xvec, spk2utts, meeting2spks), args = task_args
|
|
|
+ logging.info("{:02d}/{:02d}: Start simulation...".format(task_idx+1, args.nj))
|
|
|
+
|
|
|
out_path = os.path.join(args.out_dir, "wav_mix.{}".format(task_idx+1))
|
|
|
wav_mix_writer = kaldiio.WriteHelper('ark,scp:{}.ark,{}.scp'.format(out_path, out_path))
|
|
|
|
|
|
- out_path = os.path.join(args.out_dir, "wav_sep.{}".format(task_idx + 1))
|
|
|
- wav_sep_writer = kaldiio.WriteHelper('ark,scp:{}.ark,{}.scp'.format(out_path, out_path))
|
|
|
+ # out_path = os.path.join(args.out_dir, "wav_sep.{}".format(task_idx + 1))
|
|
|
+ # wav_sep_writer = kaldiio.WriteHelper('ark,scp:{}.ark,{}.scp'.format(out_path, out_path))
|
|
|
|
|
|
out_path = os.path.join(args.out_dir, "profile.{}".format(task_idx + 1))
|
|
|
profile_writer = kaldiio.WriteHelper('ark,scp:{}.ark,{}.scp'.format(out_path, out_path))
|
|
|
@@ -195,16 +213,23 @@ def process(task_args):
|
|
|
|
|
|
speaker_list, meeting_list = list(spk2utts.keys()), list(meeting2spks.keys())
|
|
|
|
|
|
- idx = 0
|
|
|
+ labels_list = []
|
|
|
+ total_chunks = 0
|
|
|
for org_mid, label_path in task_list:
|
|
|
- rand_shift = random.randint(0, int(args.chunk_shift / args.frame_shift))
|
|
|
whole_label = kaldiio.load_mat(label_path)
|
|
|
- whole_label = whole_label[rand_shift:]
|
|
|
- num_chunk = (whole_label.shape[0] - args.chunk_size) // args.chunk_shift + 1
|
|
|
+ # random offset to keep diversity
|
|
|
+ rand_shift = random.randint(0, args.chunk_shift)
|
|
|
+ num_chunk = (whole_label.shape[0] - rand_shift - args.chunk_size) // args.chunk_shift + 1
|
|
|
+ labels_list.append((org_mid, whole_label, rand_shift, num_chunk))
|
|
|
+ total_chunks += num_chunk
|
|
|
+
|
|
|
+ idx = 0
|
|
|
+ simu_chunk_count = 0
|
|
|
+ for org_mid, whole_label, rand_shift, num_chunk in labels_list:
|
|
|
for i in range(num_chunk):
|
|
|
idx = idx + 1
|
|
|
- st = int((i*args.chunk_shift) / args.frame_shift)
|
|
|
- ed = int((i*args.chunk_shift+args.chunk_size) / args.frame_shift)
|
|
|
+ st = i * args.chunk_shift + rand_shift
|
|
|
+ ed = i * args.chunk_shift + args.chunk_size + rand_shift
|
|
|
utt_id = "subset{}_part{}_{}_{:06d}_{:06d}".format(
|
|
|
args.subset + 1, task_idx + 1, org_mid, st, ed
|
|
|
)
|
|
|
@@ -215,15 +240,20 @@ def process(task_args):
|
|
|
speaker_list, meeting_list, args.embedding_dim, args.average_emb_num
|
|
|
)
|
|
|
wav_mix_writer(utt_id, mix_wav)
|
|
|
- wav_sep_writer(utt_id, seg_wav)
|
|
|
+ # wav_sep_writer(utt_id, seg_wav)
|
|
|
profile_writer(utt_id, profile)
|
|
|
label_writer(utt_id, frame_label)
|
|
|
|
|
|
+ simu_chunk_count += 1
|
|
|
+ if simu_chunk_count % args.log_interval == 0:
|
|
|
+ logging.info("{:02d}/{:02d}: Complete {}/{} simulation, {}.".format(
|
|
|
+ task_idx + 1, args.nj, simu_chunk_count, total_chunks, utt_id))
|
|
|
wav_mix_writer.close()
|
|
|
- wav_sep_writer.close()
|
|
|
+ # wav_sep_writer.close()
|
|
|
profile_writer.close()
|
|
|
label_writer.close()
|
|
|
- return None
|
|
|
+ logging.info("[{}/{}]: Simulate {} chunks.".format(task_idx+1, args.nj, simu_chunk_count))
|
|
|
+ return simu_chunk_count
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|