error_rate_zh 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370
  1. #!/usr/bin/env python3
  2. # coding=utf8
  3. # Copyright 2021 Jiayu DU
  4. import sys
  5. import argparse
  6. import json
  7. import logging
  8. logging.basicConfig(stream=sys.stderr, level=logging.INFO, format='[%(levelname)s] %(message)s')
  9. DEBUG = None
  10. def GetEditType(ref_token, hyp_token):
  11. if ref_token == None and hyp_token != None:
  12. return 'I'
  13. elif ref_token != None and hyp_token == None:
  14. return 'D'
  15. elif ref_token == hyp_token:
  16. return 'C'
  17. elif ref_token != hyp_token:
  18. return 'S'
  19. else:
  20. raise RuntimeError
  21. class AlignmentArc:
  22. def __init__(self, src, dst, ref, hyp):
  23. self.src = src
  24. self.dst = dst
  25. self.ref = ref
  26. self.hyp = hyp
  27. self.edit_type = GetEditType(ref, hyp)
  28. def similarity_score_function(ref_token, hyp_token):
  29. return 0 if (ref_token == hyp_token) else -1.0
  30. def insertion_score_function(token):
  31. return -1.0
  32. def deletion_score_function(token):
  33. return -1.0
  34. def EditDistance(
  35. ref,
  36. hyp,
  37. similarity_score_function = similarity_score_function,
  38. insertion_score_function = insertion_score_function,
  39. deletion_score_function = deletion_score_function):
  40. assert(len(ref) != 0)
  41. class DPState:
  42. def __init__(self):
  43. self.score = -float('inf')
  44. # backpointer
  45. self.prev_r = None
  46. self.prev_h = None
  47. def print_search_grid(S, R, H, fstream):
  48. print(file=fstream)
  49. for r in range(R):
  50. for h in range(H):
  51. print(F'[{r},{h}]:{S[r][h].score:4.3f}:({S[r][h].prev_r},{S[r][h].prev_h}) ', end='', file=fstream)
  52. print(file=fstream)
  53. R = len(ref) + 1
  54. H = len(hyp) + 1
  55. # Construct DP search space, a (R x H) grid
  56. S = [ [] for r in range(R) ]
  57. for r in range(R):
  58. S[r] = [ DPState() for x in range(H) ]
  59. # initialize DP search grid origin, S(r = 0, h = 0)
  60. S[0][0].score = 0.0
  61. S[0][0].prev_r = None
  62. S[0][0].prev_h = None
  63. # initialize REF axis
  64. for r in range(1, R):
  65. S[r][0].score = S[r-1][0].score + deletion_score_function(ref[r-1])
  66. S[r][0].prev_r = r-1
  67. S[r][0].prev_h = 0
  68. # initialize HYP axis
  69. for h in range(1, H):
  70. S[0][h].score = S[0][h-1].score + insertion_score_function(hyp[h-1])
  71. S[0][h].prev_r = 0
  72. S[0][h].prev_h = h-1
  73. best_score = S[0][0].score
  74. best_state = (0, 0)
  75. for r in range(1, R):
  76. for h in range(1, H):
  77. sub_or_cor_score = similarity_score_function(ref[r-1], hyp[h-1])
  78. new_score = S[r-1][h-1].score + sub_or_cor_score
  79. if new_score >= S[r][h].score:
  80. S[r][h].score = new_score
  81. S[r][h].prev_r = r-1
  82. S[r][h].prev_h = h-1
  83. del_score = deletion_score_function(ref[r-1])
  84. new_score = S[r-1][h].score + del_score
  85. if new_score >= S[r][h].score:
  86. S[r][h].score = new_score
  87. S[r][h].prev_r = r - 1
  88. S[r][h].prev_h = h
  89. ins_score = insertion_score_function(hyp[h-1])
  90. new_score = S[r][h-1].score + ins_score
  91. if new_score >= S[r][h].score:
  92. S[r][h].score = new_score
  93. S[r][h].prev_r = r
  94. S[r][h].prev_h = h-1
  95. best_score = S[R-1][H-1].score
  96. best_state = (R-1, H-1)
  97. if DEBUG:
  98. print_search_grid(S, R, H, sys.stderr)
  99. # Backtracing best alignment path, i.e. a list of arcs
  100. # arc = (src, dst, ref, hyp, edit_type)
  101. # src/dst = (r, h), where r/h refers to search grid state-id along Ref/Hyp axis
  102. best_path = []
  103. r, h = best_state[0], best_state[1]
  104. prev_r, prev_h = S[r][h].prev_r, S[r][h].prev_h
  105. score = S[r][h].score
  106. # loop invariant:
  107. # 1. (prev_r, prev_h) -> (r, h) is a "forward arc" on best alignment path
  108. # 2. score is the value of point(r, h) on DP search grid
  109. while prev_r != None or prev_h != None:
  110. src = (prev_r, prev_h)
  111. dst = (r, h)
  112. if (r == prev_r + 1 and h == prev_h + 1): # Substitution or correct
  113. arc = AlignmentArc(src, dst, ref[prev_r], hyp[prev_h])
  114. elif (r == prev_r + 1 and h == prev_h): # Deletion
  115. arc = AlignmentArc(src, dst, ref[prev_r], None)
  116. elif (r == prev_r and h == prev_h + 1): # Insertion
  117. arc = AlignmentArc(src, dst, None, hyp[prev_h])
  118. else:
  119. raise RuntimeError
  120. best_path.append(arc)
  121. r, h = prev_r, prev_h
  122. prev_r, prev_h = S[r][h].prev_r, S[r][h].prev_h
  123. score = S[r][h].score
  124. best_path.reverse()
  125. return (best_path, best_score)
  126. def PrettyPrintAlignment(alignment, stream = sys.stderr):
  127. def get_token_str(token):
  128. if token == None:
  129. return "*"
  130. return token
  131. def is_double_width_char(ch):
  132. if (ch >= '\u4e00') and (ch <= '\u9fa5'): # codepoint ranges for Chinese chars
  133. return True
  134. # TODO: support other double-width-char language such as Japanese, Korean
  135. else:
  136. return False
  137. def display_width(token_str):
  138. m = 0
  139. for c in token_str:
  140. if is_double_width_char(c):
  141. m += 2
  142. else:
  143. m += 1
  144. return m
  145. R = ' REF : '
  146. H = ' HYP : '
  147. E = ' EDIT : '
  148. for arc in alignment:
  149. r = get_token_str(arc.ref)
  150. h = get_token_str(arc.hyp)
  151. e = arc.edit_type if arc.edit_type != 'C' else ''
  152. nr, nh, ne = display_width(r), display_width(h), display_width(e)
  153. n = max(nr, nh, ne) + 1
  154. R += r + ' ' * (n-nr)
  155. H += h + ' ' * (n-nh)
  156. E += e + ' ' * (n-ne)
  157. print(R, file=stream)
  158. print(H, file=stream)
  159. print(E, file=stream)
  160. def CountEdits(alignment):
  161. c, s, i, d = 0, 0, 0, 0
  162. for arc in alignment:
  163. if arc.edit_type == 'C':
  164. c += 1
  165. elif arc.edit_type == 'S':
  166. s += 1
  167. elif arc.edit_type == 'I':
  168. i += 1
  169. elif arc.edit_type == 'D':
  170. d += 1
  171. else:
  172. raise RuntimeError
  173. return (c, s, i, d)
  174. def ComputeTokenErrorRate(c, s, i, d):
  175. return 100.0 * (s + d + i) / (s + d + c)
  176. def ComputeSentenceErrorRate(num_err_utts, num_utts):
  177. assert(num_utts != 0)
  178. return 100.0 * num_err_utts / num_utts
  179. class EvaluationResult:
  180. def __init__(self):
  181. self.num_ref_utts = 0
  182. self.num_hyp_utts = 0
  183. self.num_eval_utts = 0 # seen in both ref & hyp
  184. self.num_hyp_without_ref = 0
  185. self.C = 0
  186. self.S = 0
  187. self.I = 0
  188. self.D = 0
  189. self.token_error_rate = 0.0
  190. self.num_utts_with_error = 0
  191. self.sentence_error_rate = 0.0
  192. def to_json(self):
  193. return json.dumps(self.__dict__)
  194. def to_kaldi(self):
  195. info = (
  196. F'%WER {self.token_error_rate:.2f} [ {self.S + self.D + self.I} / {self.C + self.S + self.D}, {self.I} ins, {self.D} del, {self.S} sub ]\n'
  197. F'%SER {self.sentence_error_rate:.2f} [ {self.num_utts_with_error} / {self.num_eval_utts} ]\n'
  198. )
  199. return info
  200. def to_sclite(self):
  201. return "TODO"
  202. def to_espnet(self):
  203. return "TODO"
  204. def to_summary(self):
  205. #return json.dumps(self.__dict__, indent=4)
  206. summary = (
  207. '==================== Overall Statistics ====================\n'
  208. F'num_ref_utts: {self.num_ref_utts}\n'
  209. F'num_hyp_utts: {self.num_hyp_utts}\n'
  210. F'num_hyp_without_ref: {self.num_hyp_without_ref}\n'
  211. F'num_eval_utts: {self.num_eval_utts}\n'
  212. F'sentence_error_rate: {self.sentence_error_rate:.2f}%\n'
  213. F'token_error_rate: {self.token_error_rate:.2f}%\n'
  214. F'token_stats:\n'
  215. F' - tokens:{self.C + self.S + self.D:>7}\n'
  216. F' - edits: {self.S + self.I + self.D:>7}\n'
  217. F' - cor: {self.C:>7}\n'
  218. F' - sub: {self.S:>7}\n'
  219. F' - ins: {self.I:>7}\n'
  220. F' - del: {self.D:>7}\n'
  221. '============================================================\n'
  222. )
  223. return summary
  224. class Utterance:
  225. def __init__(self, uid, text):
  226. self.uid = uid
  227. self.text = text
  228. def LoadUtterances(filepath, format):
  229. utts = {}
  230. if format == 'text': # utt_id word1 word2 ...
  231. with open(filepath, 'r', encoding='utf8') as f:
  232. for line in f:
  233. line = line.strip()
  234. if line:
  235. cols = line.split(maxsplit=1)
  236. assert(len(cols) == 2 or len(cols) == 1)
  237. uid = cols[0]
  238. text = cols[1] if len(cols) == 2 else ''
  239. if utts.get(uid) != None:
  240. raise RuntimeError(F'Found duplicated utterence id {uid}')
  241. utts[uid] = Utterance(uid, text)
  242. else:
  243. raise RuntimeError(F'Unsupported text format {format}')
  244. return utts
  245. def tokenize_text(text, tokenizer):
  246. if tokenizer == 'whitespace':
  247. return text.split()
  248. elif tokenizer == 'char':
  249. return [ ch for ch in ''.join(text.split()) ]
  250. else:
  251. raise RuntimeError(F'ERROR: Unsupported tokenizer {tokenizer}')
  252. if __name__ == '__main__':
  253. parser = argparse.ArgumentParser()
  254. # optional
  255. parser.add_argument('--tokenizer', choices=['whitespace', 'char'], default='whitespace', help='whitespace for WER, char for CER')
  256. parser.add_argument('--ref-format', choices=['text'], default='text', help='reference format, first col is utt_id, the rest is text')
  257. parser.add_argument('--hyp-format', choices=['text'], default='text', help='hypothesis format, first col is utt_id, the rest is text')
  258. # required
  259. parser.add_argument('--ref', type=str, required=True, help='input reference file')
  260. parser.add_argument('--hyp', type=str, required=True, help='input hypothesis file')
  261. parser.add_argument('result_file', type=str)
  262. args = parser.parse_args()
  263. logging.info(args)
  264. ref_utts = LoadUtterances(args.ref, args.ref_format)
  265. hyp_utts = LoadUtterances(args.hyp, args.hyp_format)
  266. r = EvaluationResult()
  267. # check valid utterances in hyp that have matched non-empty reference
  268. eval_utts = []
  269. r.num_hyp_without_ref = 0
  270. for uid in sorted(hyp_utts.keys()):
  271. if uid in ref_utts.keys(): # TODO: efficiency
  272. if ref_utts[uid].text.strip(): # non-empty reference
  273. eval_utts.append(uid)
  274. else:
  275. logging.warn(F'Found {uid} with empty reference, skipping...')
  276. else:
  277. logging.warn(F'Found {uid} without reference, skipping...')
  278. r.num_hyp_without_ref += 1
  279. r.num_hyp_utts = len(hyp_utts)
  280. r.num_ref_utts = len(ref_utts)
  281. r.num_eval_utts = len(eval_utts)
  282. with open(args.result_file, 'w+', encoding='utf8') as fo:
  283. for uid in eval_utts:
  284. ref = ref_utts[uid]
  285. hyp = hyp_utts[uid]
  286. alignment, score = EditDistance(
  287. tokenize_text(ref.text, args.tokenizer),
  288. tokenize_text(hyp.text, args.tokenizer)
  289. )
  290. c, s, i, d = CountEdits(alignment)
  291. utt_ter = ComputeTokenErrorRate(c, s, i, d)
  292. # utt-level evaluation result
  293. print(F'{{"uid":{uid}, "score":{score}, "ter":{utt_ter:.2f}, "cor":{c}, "sub":{s}, "ins":{i}, "del":{d}}}', file=fo)
  294. PrettyPrintAlignment(alignment, fo)
  295. r.C += c
  296. r.S += s
  297. r.I += i
  298. r.D += d
  299. if utt_ter > 0:
  300. r.num_utts_with_error += 1
  301. # corpus level evaluation result
  302. r.sentence_error_rate = ComputeSentenceErrorRate(r.num_utts_with_error, r.num_eval_utts)
  303. r.token_error_rate = ComputeTokenErrorRate(r.C, r.S, r.I, r.D)
  304. print(r.to_summary(), file=fo)
  305. print(r.to_json())
  306. print(r.to_kaldi())