utils.py 1.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960
  1. import numpy as np
  2. def _levenshtein_distance(ref, hyp):
  3. """Levenshtein distance is a string metric for measuring the difference
  4. between two sequences. Informally, the levenshtein disctance is defined as
  5. the minimum number of single-character edits (substitutions, insertions or
  6. deletions) required to change one word into the other. We can naturally
  7. extend the edits to word level when calculate levenshtein disctance for
  8. two sentences.
  9. """
  10. m = len(ref)
  11. n = len(hyp)
  12. # special case
  13. if ref == hyp:
  14. return 0
  15. if m == 0:
  16. return n
  17. if n == 0:
  18. return m
  19. if m < n:
  20. ref, hyp = hyp, ref
  21. m, n = n, m
  22. # use O(min(m, n)) space
  23. distance = np.zeros((2, n + 1), dtype=np.int32)
  24. # initialize distance matrix
  25. for j in range(n + 1):
  26. distance[0][j] = j
  27. # calculate levenshtein distance
  28. for i in range(1, m + 1):
  29. prev_row_idx = (i - 1) % 2
  30. cur_row_idx = i % 2
  31. distance[cur_row_idx][0] = i
  32. for j in range(1, n + 1):
  33. if ref[i - 1] == hyp[j - 1]:
  34. distance[cur_row_idx][j] = distance[prev_row_idx][j - 1]
  35. else:
  36. s_num = distance[prev_row_idx][j - 1] + 1
  37. i_num = distance[cur_row_idx][j - 1] + 1
  38. d_num = distance[prev_row_idx][j] + 1
  39. distance[cur_row_idx][j] = min(s_num, i_num, d_num)
  40. return distance[m % 2][n]
  41. def cal_cer(references, predictions):
  42. errors = 0
  43. lengths = 0
  44. for ref, pred in zip(references, predictions):
  45. cur_ref = list(ref)
  46. cur_hyp = list(pred)
  47. cur_error = _levenshtein_distance(cur_ref, cur_hyp)
  48. errors += cur_error
  49. lengths += len(cur_ref)
  50. return float(errors) / lengths