| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859 |
- import numpy as np
- from sklearn.metrics import roc_curve
- import argparse
- def _compute_eer(label, pred, positive_label=1):
- """
- Python compute equal error rate (eer)
- ONLY tested on binary classification
- :param label: ground-truth label, should be a 1-d list or np.array, each element represents the ground-truth label of one sample
- :param pred: model prediction, should be a 1-d list or np.array, each element represents the model prediction of one sample
- :param positive_label: the class that is viewed as positive class when computing EER
- :return: equal error rate (EER)
- """
- # all fpr, tpr, fnr, fnr, threshold are lists (in the format of np.array)
- fpr, tpr, threshold = roc_curve(label, pred, pos_label=positive_label)
- fnr = 1 - tpr
- # the threshold of fnr == fpr
- eer_threshold = threshold[np.nanargmin(np.absolute((fnr - fpr)))]
- # theoretically eer from fpr and eer from fnr should be identical but they can be slightly differ in reality
- eer_1 = fpr[np.nanargmin(np.absolute((fnr - fpr)))]
- eer_2 = fnr[np.nanargmin(np.absolute((fnr - fpr)))]
- # return the mean of eer from fpr and from fnr
- eer = (eer_1 + eer_2) / 2
- return eer, eer_threshold
- def compute_eer(trials_path, scores_path):
- labels = []
- for one_line in open(trials_path, "r"):
- labels.append(one_line.strip().rsplit(" ", 1)[-1] == "target")
- labels = np.array(labels, dtype=int)
- scores = []
- for one_line in open(scores_path, "r"):
- scores.append(float(one_line.strip().rsplit(" ", 1)[-1]))
- scores = np.array(scores, dtype=float)
- eer, threshold = _compute_eer(labels, scores)
- return eer, threshold
- def main():
- parser = argparse.ArgumentParser()
- parser.add_argument("trials", help="trial list")
- parser.add_argument("scores", help="score file, normalized to [0, 1]")
- args = parser.parse_args()
- eer, threshold = compute_eer(args.trials, args.scores)
- print("EER is {:.4f} at threshold {:.4f}".format(eer * 100.0, threshold))
- if __name__ == '__main__':
- main()
|