unit_test.py 2.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697
  1. from funasr.bin.diar_inference_launch import inference_launch
  2. import os
  3. def test_fbank_cpu_infer():
  4. diar_config_path = "sond_fbank.yaml"
  5. diar_model_path = "sond.pb"
  6. output_dir = "./outputs"
  7. data_path_and_name_and_type = [
  8. ("data/unit_test/test_feats.scp", "speech", "kaldi_ark"),
  9. ("data/unit_test/test_profile.scp", "profile", "kaldi_ark"),
  10. ]
  11. pipeline = inference_launch(
  12. mode="sond",
  13. diar_train_config=diar_config_path,
  14. diar_model_file=diar_model_path,
  15. output_dir=output_dir,
  16. num_workers=0,
  17. log_level="INFO",
  18. )
  19. results = pipeline(data_path_and_name_and_type)
  20. print(results)
  21. def test_fbank_gpu_infer():
  22. diar_config_path = "sond_fbank.yaml"
  23. diar_model_path = "sond.pb"
  24. output_dir = "./outputs"
  25. data_path_and_name_and_type = [
  26. ("data/unit_test/test_feats.scp", "speech", "kaldi_ark"),
  27. ("data/unit_test/test_profile.scp", "profile", "kaldi_ark"),
  28. ]
  29. pipeline = inference_launch(
  30. mode="sond",
  31. diar_train_config=diar_config_path,
  32. diar_model_file=diar_model_path,
  33. output_dir=output_dir,
  34. ngpu=1,
  35. num_workers=1,
  36. log_level="INFO",
  37. )
  38. results = pipeline(data_path_and_name_and_type)
  39. print(results)
  40. def test_wav_gpu_infer():
  41. diar_config_path = "config.yaml"
  42. diar_model_path = "sond.pb"
  43. output_dir = "./outputs"
  44. data_path_and_name_and_type = [
  45. ("data/unit_test/test_wav.scp", "speech", "sound"),
  46. ("data/unit_test/test_profile.scp", "profile", "kaldi_ark"),
  47. ]
  48. pipeline = inference_launch(
  49. mode="sond",
  50. diar_train_config=diar_config_path,
  51. diar_model_file=diar_model_path,
  52. output_dir=output_dir,
  53. ngpu=1,
  54. num_workers=1,
  55. log_level="WARNING",
  56. )
  57. results = pipeline(data_path_and_name_and_type)
  58. print(results)
  59. def test_without_profile_gpu_infer():
  60. diar_config_path = "config.yaml"
  61. diar_model_path = "sond.pb"
  62. output_dir = "./outputs"
  63. raw_inputs = [[
  64. "data/unit_test/raw_inputs/record.wav",
  65. "data/unit_test/raw_inputs/spk1.wav",
  66. "data/unit_test/raw_inputs/spk2.wav",
  67. "data/unit_test/raw_inputs/spk3.wav",
  68. "data/unit_test/raw_inputs/spk4.wav"
  69. ]]
  70. pipeline = inference_launch(
  71. mode="sond_demo",
  72. diar_train_config=diar_config_path,
  73. diar_model_file=diar_model_path,
  74. output_dir=output_dir,
  75. ngpu=1,
  76. num_workers=1,
  77. log_level="WARNING",
  78. param_dict={},
  79. )
  80. results = pipeline(raw_inputs=raw_inputs)
  81. print(results)
  82. if __name__ == '__main__':
  83. os.environ["CUDA_VISIBLE_DEVICES"] = "7"
  84. test_fbank_cpu_infer()
  85. # test_fbank_gpu_infer()
  86. # test_wav_gpu_infer()
  87. # test_without_profile_gpu_infer()