| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171 |
- from kaldiio import WriteHelper
- import argparse
- import numpy as np
- import json
- import torch
- import torchaudio
- import torchaudio.compliance.kaldi as kaldi
- def compute_fbank(wav_file,
- num_mel_bins=80,
- frame_length=25,
- frame_shift=10,
- dither=0.0,
- resample_rate=16000,
- speed=1.0,
- window_type="hamming"):
- waveform, sample_rate = torchaudio.load(wav_file)
- if resample_rate != sample_rate:
- waveform = torchaudio.transforms.Resample(orig_freq=sample_rate,
- new_freq=resample_rate)(waveform)
- if speed != 1.0:
- waveform, _ = torchaudio.sox_effects.apply_effects_tensor(
- waveform, resample_rate,
- [['speed', str(speed)], ['rate', str(resample_rate)]]
- )
- waveform = waveform * (1 << 15)
- mat = kaldi.fbank(waveform,
- num_mel_bins=num_mel_bins,
- frame_length=frame_length,
- frame_shift=frame_shift,
- dither=dither,
- energy_floor=0.0,
- window_type=window_type,
- sample_frequency=resample_rate)
- return mat.numpy()
- def get_parser():
- parser = argparse.ArgumentParser(
- description="computer features",
- formatter_class=argparse.ArgumentDefaultsHelpFormatter,
- )
- parser.add_argument(
- "--wav-lists",
- "-w",
- default=False,
- required=True,
- type=str,
- help="input wav lists",
- )
- parser.add_argument(
- "--text-files",
- "-t",
- default=False,
- required=True,
- type=str,
- help="input text files",
- )
- parser.add_argument(
- "--dims",
- "-d",
- default=80,
- type=int,
- help="feature dims",
- )
- parser.add_argument(
- "--max-lengths",
- "-m",
- default=1500,
- type=int,
- help="max frame numbers",
- )
- parser.add_argument(
- "--sample-frequency",
- "-s",
- default=16000,
- type=int,
- help="sample frequency",
- )
- parser.add_argument(
- "--speed-perturb",
- "-p",
- default="1.0",
- type=str,
- help="speed perturb",
- )
- parser.add_argument(
- "--ark-index",
- "-a",
- default=1,
- required=True,
- type=int,
- help="ark index",
- )
- parser.add_argument(
- "--output-dir",
- "-o",
- default=False,
- required=True,
- type=str,
- help="output dir",
- )
- parser.add_argument(
- "--window-type",
- default="hamming",
- required=False,
- type=str,
- help="window type"
- )
- return parser
- def main():
- parser = get_parser()
- args = parser.parse_args()
- ark_file = args.output_dir + "/ark/feats." + str(args.ark_index) + ".ark"
- scp_file = args.output_dir + "/ark/feats." + str(args.ark_index) + ".scp"
- text_file = args.output_dir + "/txt/text." + str(args.ark_index) + ".txt"
- feats_shape_file = args.output_dir + "/ark/len." + str(args.ark_index)
- text_shape_file = args.output_dir + "/txt/len." + str(args.ark_index)
- ark_writer = WriteHelper('ark,scp:{},{}'.format(ark_file, scp_file))
- text_writer = open(text_file, 'w')
- feats_shape_writer = open(feats_shape_file, 'w')
- text_shape_writer = open(text_shape_file, 'w')
- speed_perturb_list = args.speed_perturb.split(',')
-
- for speed in speed_perturb_list:
- with open(args.wav_lists, 'r', encoding='utf-8') as wavfile:
- with open(args.text_files, 'r', encoding='utf-8') as textfile:
- for wav, text in zip(wavfile, textfile):
- s_w = wav.strip().split()
- wav_id = s_w[0]
- wav_file = s_w[1]
- s_t = text.strip().split()
- text_id = s_t[0]
- txt = s_t[1:]
- fbank = compute_fbank(wav_file,
- num_mel_bins=args.dims,
- resample_rate=args.sample_frequency,
- speed=float(speed),
- window_type=args.window_type
- )
- feats_dims = fbank.shape[1]
- feats_lens = fbank.shape[0]
- if feats_lens >= args.max_lengths:
- continue
- txt_lens = len(txt)
- if speed == "1.0":
- wav_id_sp = wav_id
- else:
- wav_id_sp = wav_id + "_sp" + speed
- feats_shape_writer.write(wav_id_sp + " " + str(feats_lens) + "," + str(feats_dims) + '\n')
- text_shape_writer.write(wav_id_sp + " " + str(txt_lens) + '\n')
- text_writer.write(wav_id_sp + " " + " ".join(txt) + '\n')
- ark_writer(wav_id_sp, fbank)
-
- if __name__ == '__main__':
- main()
|