compute_audio_cmvn.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124
  1. import os
  2. import json
  3. import numpy as np
  4. import torch
  5. import hydra
  6. import logging
  7. from omegaconf import DictConfig, OmegaConf
  8. from funasr.register import tables
  9. from funasr.download.download_from_hub import download_model
  10. from funasr.train_utils.set_all_random_seed import set_all_random_seed
  11. @hydra.main(config_name=None, version_base=None)
  12. def main_hydra(kwargs: DictConfig):
  13. if kwargs.get("debug", False):
  14. import pdb; pdb.set_trace()
  15. assert "model" in kwargs
  16. if "model_conf" not in kwargs:
  17. logging.info("download models from model hub: {}".format(kwargs.get("model_hub", "ms")))
  18. kwargs = download_model(is_training=kwargs.get("is_training", True), **kwargs)
  19. main(**kwargs)
  20. def main(**kwargs):
  21. print(kwargs)
  22. # set random seed
  23. tables.print()
  24. set_all_random_seed(kwargs.get("seed", 0))
  25. torch.backends.cudnn.enabled = kwargs.get("cudnn_enabled", torch.backends.cudnn.enabled)
  26. torch.backends.cudnn.benchmark = kwargs.get("cudnn_benchmark", torch.backends.cudnn.benchmark)
  27. torch.backends.cudnn.deterministic = kwargs.get("cudnn_deterministic", True)
  28. tokenizer = kwargs.get("tokenizer", None)
  29. # build frontend if frontend is none None
  30. frontend = kwargs.get("frontend", None)
  31. if frontend is not None:
  32. frontend_class = tables.frontend_classes.get(frontend)
  33. frontend = frontend_class(**kwargs["frontend_conf"])
  34. kwargs["frontend"] = frontend
  35. kwargs["input_size"] = frontend.output_size()
  36. # dataset
  37. dataset_class = tables.dataset_classes.get(kwargs.get("dataset", "AudioDataset"))
  38. dataset_train = dataset_class(kwargs.get("train_data_set_list"), frontend=frontend, tokenizer=None, is_training=False, **kwargs.get("dataset_conf"))
  39. # dataloader
  40. batch_sampler = kwargs["dataset_conf"].get("batch_sampler", "DynamicBatchLocalShuffleSampler")
  41. batch_sampler_train = None
  42. if batch_sampler is not None:
  43. batch_sampler_class = tables.batch_sampler_classes.get(batch_sampler)
  44. dataset_conf = kwargs.get("dataset_conf")
  45. dataset_conf["batch_type"] = "example"
  46. dataset_conf["batch_size"] = 1
  47. batch_sampler_train = batch_sampler_class(dataset_train, is_training=False, **dataset_conf)
  48. dataloader_train = torch.utils.data.DataLoader(dataset_train,
  49. collate_fn=dataset_train.collator,
  50. batch_sampler=batch_sampler_train,
  51. num_workers=int(kwargs.get("dataset_conf").get("num_workers", 4)),
  52. pin_memory=True)
  53. iter_stop = int(kwargs.get("scale", 1.0)*len(dataloader_train))
  54. total_frames = 0
  55. for batch_idx, batch in enumerate(dataloader_train):
  56. if batch_idx >= iter_stop:
  57. break
  58. fbank = batch["speech"].numpy()[0, :, :]
  59. if total_frames == 0:
  60. mean_stats = np.sum(fbank, axis=0)
  61. var_stats = np.sum(np.square(fbank), axis=0)
  62. else:
  63. mean_stats += np.sum(fbank, axis=0)
  64. var_stats += np.sum(np.square(fbank), axis=0)
  65. total_frames += fbank.shape[0]
  66. cmvn_info = {
  67. 'mean_stats': list(mean_stats.tolist()),
  68. 'var_stats': list(var_stats.tolist()),
  69. 'total_frames': total_frames
  70. }
  71. cmvn_file = kwargs.get("cmvn_file", "cmvn.json")
  72. # import pdb;pdb.set_trace()
  73. with open(cmvn_file, 'w') as fout:
  74. fout.write(json.dumps(cmvn_info))
  75. mean = -1.0 * mean_stats / total_frames
  76. var = 1.0 / np.sqrt(var_stats / total_frames - mean * mean)
  77. dims = mean.shape[0]
  78. am_mvn = os.path.dirname(cmvn_file) + "/am.mvn"
  79. with open(am_mvn, 'w') as fout:
  80. fout.write("<Nnet>" + "\n" + "<Splice> " + str(dims) + " " + str(dims) + '\n' + "[ 0 ]" + "\n" + "<AddShift> " + str(dims) + " " + str(dims) + "\n")
  81. mean_str = str(list(mean)).replace(',', '').replace('[', '[ ').replace(']', ' ]')
  82. fout.write("<LearnRateCoef> 0 " + mean_str + '\n')
  83. fout.write("<Rescale> " + str(dims) + " " + str(dims) + '\n')
  84. var_str = str(list(var)).replace(',', '').replace('[', '[ ').replace(']', ' ]')
  85. fout.write("<LearnRateCoef> 0 " + var_str + '\n')
  86. fout.write("</Nnet>" + '\n')
  87. """
  88. python funasr/bin/compute_audio_cmvn.py \
  89. --config-path "/Users/zhifu/funasr1.0/examples/aishell/paraformer/conf" \
  90. --config-name "train_asr_paraformer_conformer_12e_6d_2048_256.yaml" \
  91. ++train_data_set_list="/Users/zhifu/funasr1.0/data/list/audio_datasets.jsonl" \
  92. ++cmvn_file="/Users/zhifu/funasr1.0/data/list/cmvn.json" \
  93. ++dataset_conf.num_workers=0
  94. """
  95. if __name__ == "__main__":
  96. main_hydra()