| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126 |
- import argparse
- import json
- import os
- import numpy as np
- import torchaudio
- import torchaudio.compliance.kaldi as kaldi
- import yaml
- def get_parser():
- parser = argparse.ArgumentParser(
- description="computer global cmvn",
- formatter_class=argparse.ArgumentDefaultsHelpFormatter,
- )
- parser.add_argument(
- "--dim",
- default=80,
- type=int,
- help="feature dimension",
- )
- parser.add_argument(
- "--wav_path",
- default=False,
- required=True,
- type=str,
- help="the path of wav scps",
- )
- parser.add_argument(
- "--config_file",
- type=str,
- help="the config file for computing cmvn",
- )
- parser.add_argument(
- "--idx",
- default=1,
- required=True,
- type=int,
- help="index",
- )
- return parser
- def compute_fbank(wav_file,
- num_mel_bins=80,
- frame_length=25,
- frame_shift=10,
- dither=0.0,
- resample_rate=16000,
- speed=1.0,
- window_type="hamming"):
- waveform, sample_rate = torchaudio.load(wav_file)
- if resample_rate != sample_rate:
- waveform = torchaudio.transforms.Resample(orig_freq=sample_rate,
- new_freq=resample_rate)(waveform)
- if speed != 1.0:
- waveform, _ = torchaudio.sox_effects.apply_effects_tensor(
- waveform, resample_rate,
- [['speed', str(speed)], ['rate', str(resample_rate)]]
- )
- waveform = waveform * (1 << 15)
- mat = kaldi.fbank(waveform,
- num_mel_bins=num_mel_bins,
- frame_length=frame_length,
- frame_shift=frame_shift,
- dither=dither,
- energy_floor=0.0,
- window_type=window_type,
- sample_frequency=resample_rate)
- return mat.numpy()
- def main():
- parser = get_parser()
- args = parser.parse_args()
- wav_scp_file = os.path.join(args.wav_path, "wav.{}.scp".format(args.idx))
- cmvn_file = os.path.join(args.wav_path, "cmvn.{}.json".format(args.idx))
- mean_stats = np.zeros(args.dim)
- var_stats = np.zeros(args.dim)
- total_frames = 0
- # with ReadHelper('ark:{}'.format(ark_file)) as ark_reader:
- # for key, mat in ark_reader:
- # mean_stats += np.sum(mat, axis=0)
- # var_stats += np.sum(np.square(mat), axis=0)
- # total_frames += mat.shape[0]
- with open(args.config_file) as f:
- configs = yaml.safe_load(f)
- frontend_configs = configs.get("frontend_conf", {})
- num_mel_bins = frontend_configs.get("n_mels", 80)
- frame_length = frontend_configs.get("frame_length", 25)
- frame_shift = frontend_configs.get("frame_shift", 10)
- window_type = frontend_configs.get("window", "hamming")
- resample_rate = frontend_configs.get("fs", 16000)
- assert num_mel_bins == args.dim
- with open(wav_scp_file) as f:
- lines = f.readlines()
- for line in lines:
- _, wav_file = line.strip().split()
- fbank = compute_fbank(wav_file,
- num_mel_bins=args.dim,
- frame_length=frame_length,
- frame_shift=frame_shift,
- resample_rate=resample_rate,
- window_type=window_type)
- mean_stats += np.sum(fbank, axis=0)
- var_stats += np.sum(np.square(fbank), axis=0)
- total_frames += fbank.shape[0]
- cmvn_info = {
- 'mean_stats': list(mean_stats.tolist()),
- 'var_stats': list(var_stats.tolist()),
- 'total_frames': total_frames
- }
- with open(cmvn_file, 'w') as fout:
- fout.write(json.dumps(cmvn_info))
- if __name__ == '__main__':
- main()
|