compute_eer.py 2.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859
  1. import numpy as np
  2. from sklearn.metrics import roc_curve
  3. import argparse
  4. def _compute_eer(label, pred, positive_label=1):
  5. """
  6. Python compute equal error rate (eer)
  7. ONLY tested on binary classification
  8. :param label: ground-truth label, should be a 1-d list or np.array, each element represents the ground-truth label of one sample
  9. :param pred: model prediction, should be a 1-d list or np.array, each element represents the model prediction of one sample
  10. :param positive_label: the class that is viewed as positive class when computing EER
  11. :return: equal error rate (EER)
  12. """
  13. # all fpr, tpr, fnr, fnr, threshold are lists (in the format of np.array)
  14. fpr, tpr, threshold = roc_curve(label, pred, pos_label=positive_label)
  15. fnr = 1 - tpr
  16. # the threshold of fnr == fpr
  17. eer_threshold = threshold[np.nanargmin(np.absolute((fnr - fpr)))]
  18. # theoretically eer from fpr and eer from fnr should be identical but they can be slightly differ in reality
  19. eer_1 = fpr[np.nanargmin(np.absolute((fnr - fpr)))]
  20. eer_2 = fnr[np.nanargmin(np.absolute((fnr - fpr)))]
  21. # return the mean of eer from fpr and from fnr
  22. eer = (eer_1 + eer_2) / 2
  23. return eer, eer_threshold
  24. def compute_eer(trials_path, scores_path):
  25. labels = []
  26. for one_line in open(trials_path, "r"):
  27. labels.append(one_line.strip().rsplit(" ", 1)[-1] == "target")
  28. labels = np.array(labels, dtype=int)
  29. scores = []
  30. for one_line in open(scores_path, "r"):
  31. scores.append(float(one_line.strip().rsplit(" ", 1)[-1]))
  32. scores = np.array(scores, dtype=float)
  33. eer, threshold = _compute_eer(labels, scores)
  34. return eer, threshold
  35. def main():
  36. parser = argparse.ArgumentParser()
  37. parser.add_argument("trials", help="trial list")
  38. parser.add_argument("scores", help="score file, normalized to [0, 1]")
  39. args = parser.parse_args()
  40. eer, threshold = compute_eer(args.trials, args.scores)
  41. print("EER is {:.4f} at threshold {:.4f}".format(eer * 100.0, threshold))
  42. if __name__ == '__main__':
  43. main()