test_cer.py 1.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748
  1. import os
  2. import time
  3. import sys
  4. import librosa
  5. from funasr.utils.types import str2bool
  6. import argparse
  7. parser = argparse.ArgumentParser()
  8. parser.add_argument('--model_dir', type=str, required=True)
  9. parser.add_argument('--backend', type=str, default='onnx', help='["onnx", "torch"]')
  10. parser.add_argument('--wav_file', type=str, default=None, help='amp fallback number')
  11. parser.add_argument('--quantize', type=str2bool, default=False, help='quantized model')
  12. parser.add_argument('--intra_op_num_threads', type=int, default=1, help='intra_op_num_threads for onnx')
  13. parser.add_argument('--output_dir', type=str, default=None, help='amp fallback number')
  14. args = parser.parse_args()
  15. from funasr.runtime.python.libtorch.funasr_torch import Paraformer
  16. if args.backend == "onnx":
  17. from funasr.runtime.python.onnxruntime.funasr_onnx import Paraformer
  18. model = Paraformer(args.model_dir, batch_size=1, quantize=args.quantize, intra_op_num_threads=args.intra_op_num_threads)
  19. wav_file_f = open(args.wav_file, 'r')
  20. wav_files = wav_file_f.readlines()
  21. output_dir = args.output_dir
  22. if not os.path.exists(output_dir):
  23. os.makedirs(output_dir)
  24. if os.name == 'nt': # Windows
  25. newline = '\r\n'
  26. else: # Linux Mac
  27. newline = '\n'
  28. text_f = open(os.path.join(output_dir, "text"), "w", newline=newline)
  29. token_f = open(os.path.join(output_dir, "token"), "w", newline=newline)
  30. for i, wav_path_i in enumerate(wav_files):
  31. wav_name, wav_path = wav_path_i.strip().split()
  32. result = model(wav_path)
  33. text_i = "{} {}\n".format(wav_name, result[0]['preds'][0])
  34. token_i = "{} {}\n".format(wav_name, result[0]['preds'][1])
  35. text_f.write(text_i)
  36. text_f.flush()
  37. token_f.write(token_i)
  38. token_f.flush()
  39. text_f.close()
  40. token_f.close()