compute_min_dcf.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159
  1. #!/usr/bin/env python3
  2. # Copyright 2018 David Snyder
  3. # Apache 2.0
  4. # This script computes the minimum detection cost function, which is a common
  5. # error metric used in speaker recognition. Compared to equal error-rate,
  6. # which assigns equal weight to false negatives and false positives, this
  7. # error-rate is usually used to assess performance in settings where achieving
  8. # a low false positive rate is more important than achieving a low false
  9. # negative rate. See the NIST 2016 Speaker Recognition Evaluation Plan at
  10. # https://www.nist.gov/sites/default/files/documents/2016/10/07/sre16_eval_plan_v1.3.pdf
  11. # for more details about the metric.
  12. from __future__ import print_function
  13. from operator import itemgetter
  14. import sys, argparse, os
  15. def GetArgs():
  16. parser = argparse.ArgumentParser(description="Compute the minimum "
  17. "detection cost function along with the threshold at which it occurs. "
  18. "Usage: sid/compute_min_dcf.py [options...] <scores-file> "
  19. "<trials-file> "
  20. "E.g., sid/compute_min_dcf.py --p-target 0.01 --c-miss 1 --c-fa 1 "
  21. "exp/scores/trials data/test/trials",
  22. formatter_class=argparse.ArgumentDefaultsHelpFormatter)
  23. parser.add_argument('--p-target', type=float, dest="p_target",
  24. default=0.01,
  25. help='The prior probability of the target speaker in a trial.')
  26. parser.add_argument('--c-miss', type=float, dest="c_miss", default=1,
  27. help='Cost of a missed detection. This is usually not changed.')
  28. parser.add_argument('--c-fa', type=float, dest="c_fa", default=1,
  29. help='Cost of a spurious detection. This is usually not changed.')
  30. parser.add_argument("scores_filename",
  31. help="Input scores file, with columns of the form "
  32. "<utt1> <utt2> <score>")
  33. parser.add_argument("trials_filename",
  34. help="Input trials file, with columns of the form "
  35. "<utt1> <utt2> <target/nontarget>")
  36. sys.stderr.write(' '.join(sys.argv) + "\n")
  37. args = parser.parse_args()
  38. args = CheckArgs(args)
  39. return args
  40. def CheckArgs(args):
  41. if args.c_fa <= 0:
  42. raise Exception("--c-fa must be greater than 0")
  43. if args.c_miss <= 0:
  44. raise Exception("--c-miss must be greater than 0")
  45. if args.p_target <= 0 or args.p_target >= 1:
  46. raise Exception("--p-target must be greater than 0 and less than 1")
  47. return args
  48. # Creates a list of false-negative rates, a list of false-positive rates
  49. # and a list of decision thresholds that give those error-rates.
  50. def ComputeErrorRates(scores, labels):
  51. # Sort the scores from smallest to largest, and also get the corresponding
  52. # indexes of the sorted scores. We will treat the sorted scores as the
  53. # thresholds at which the the error-rates are evaluated.
  54. sorted_indexes, thresholds = zip(*sorted(
  55. [(index, threshold) for index, threshold in enumerate(scores)],
  56. key=itemgetter(1)))
  57. labels = [labels[i] for i in sorted_indexes]
  58. fns = []
  59. tns = []
  60. # At the end of this loop, fns[i] is the number of errors made by
  61. # incorrectly rejecting scores less than thresholds[i]. And, tns[i]
  62. # is the total number of times that we have correctly rejected scores
  63. # less than thresholds[i].
  64. for i in range(0, len(labels)):
  65. if i == 0:
  66. fns.append(labels[i])
  67. tns.append(1 - labels[i])
  68. else:
  69. fns.append(fns[i-1] + labels[i])
  70. tns.append(tns[i-1] + 1 - labels[i])
  71. positives = sum(labels)
  72. negatives = len(labels) - positives
  73. # Now divide the false negatives by the total number of
  74. # positives to obtain the false negative rates across
  75. # all thresholds
  76. fnrs = [fn / float(positives) for fn in fns]
  77. # Divide the true negatives by the total number of
  78. # negatives to get the true negative rate. Subtract these
  79. # quantities from 1 to get the false positive rates.
  80. fprs = [1 - tn / float(negatives) for tn in tns]
  81. return fnrs, fprs, thresholds
  82. # Computes the minimum of the detection cost function. The comments refer to
  83. # equations in Section 3 of the NIST 2016 Speaker Recognition Evaluation Plan.
  84. def ComputeMinDcf(fnrs, fprs, thresholds, p_target, c_miss, c_fa):
  85. min_c_det = float("inf")
  86. min_c_det_threshold = thresholds[0]
  87. for i in range(0, len(fnrs)):
  88. # See Equation (2). it is a weighted sum of false negative
  89. # and false positive errors.
  90. c_det = c_miss * fnrs[i] * p_target + c_fa * fprs[i] * (1 - p_target)
  91. if c_det < min_c_det:
  92. min_c_det = c_det
  93. min_c_det_threshold = thresholds[i]
  94. # See Equations (3) and (4). Now we normalize the cost.
  95. c_def = min(c_miss * p_target, c_fa * (1 - p_target))
  96. min_dcf = min_c_det / c_def
  97. return min_dcf, min_c_det_threshold
  98. def compute_min_dcf(scores_filename, trials_filename, c_miss=1, c_fa=1, p_target=0.01):
  99. scores_file = open(scores_filename, 'r').readlines()
  100. trials_file = open(trials_filename, 'r').readlines()
  101. c_miss = c_miss
  102. c_fa = c_fa
  103. p_target = p_target
  104. scores = []
  105. labels = []
  106. trials = {}
  107. for line in trials_file:
  108. utt1, utt2, target = line.rstrip().split()
  109. trial = utt1 + " " + utt2
  110. trials[trial] = target
  111. for line in scores_file:
  112. utt1, utt2, score = line.rstrip().split()
  113. trial = utt1 + " " + utt2
  114. if trial in trials:
  115. scores.append(float(score))
  116. if trials[trial] == "target":
  117. labels.append(1)
  118. else:
  119. labels.append(0)
  120. else:
  121. raise Exception("Missing entry for " + utt1 + " and " + utt2
  122. + " " + scores_filename)
  123. fnrs, fprs, thresholds = ComputeErrorRates(scores, labels)
  124. mindcf, threshold = ComputeMinDcf(fnrs, fprs, thresholds, p_target,
  125. c_miss, c_fa)
  126. return mindcf, threshold
  127. def main():
  128. args = GetArgs()
  129. mindcf, threshold = compute_min_dcf(
  130. args.scores_filename, args.trials_filename,
  131. args.c_miss, args.c_fa, args.p_target
  132. )
  133. sys.stdout.write("minDCF is {0:.4f} at threshold {1:.4f} (p-target={2}, c-miss={3}, "
  134. "c-fa={4})\n".format(mindcf, threshold, args.p_target, args.c_miss, args.c_fa))
  135. if __name__ == "__main__":
  136. main()