compute_wer_details.py 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. from enum import Enum
  4. import re, sys, unicodedata
  5. import codecs
  6. import argparse
  7. from tqdm import tqdm
  8. import os
  9. import pdb
  10. remove_tag = False
  11. spacelist = [" ", "\t", "\r", "\n"]
  12. puncts = [
  13. "!",
  14. ",",
  15. "?",
  16. "、",
  17. "。",
  18. "!",
  19. ",",
  20. ";",
  21. "?",
  22. ":",
  23. "「",
  24. "」",
  25. "︰",
  26. "『",
  27. "』",
  28. "《",
  29. "》",
  30. ]
  31. class Code(Enum):
  32. match = 1
  33. substitution = 2
  34. insertion = 3
  35. deletion = 4
  36. class WordError(object):
  37. def __init__(self):
  38. self.errors = {
  39. Code.substitution: 0,
  40. Code.insertion: 0,
  41. Code.deletion: 0,
  42. }
  43. self.ref_words = 0
  44. def get_wer(self):
  45. assert self.ref_words != 0
  46. errors = (
  47. self.errors[Code.substitution]
  48. + self.errors[Code.insertion]
  49. + self.errors[Code.deletion]
  50. )
  51. return 100.0 * errors / self.ref_words
  52. def get_result_string(self):
  53. return (
  54. f"error_rate={self.get_wer():.4f}, "
  55. f"ref_words={self.ref_words}, "
  56. f"subs={self.errors[Code.substitution]}, "
  57. f"ins={self.errors[Code.insertion]}, "
  58. f"dels={self.errors[Code.deletion]}"
  59. )
  60. def characterize(string):
  61. res = []
  62. i = 0
  63. while i < len(string):
  64. char = string[i]
  65. if char in puncts:
  66. i += 1
  67. continue
  68. cat1 = unicodedata.category(char)
  69. # https://unicodebook.readthedocs.io/unicode.html#unicode-categories
  70. if cat1 == "Zs" or cat1 == "Cn" or char in spacelist: # space or not assigned
  71. i += 1
  72. continue
  73. if cat1 == "Lo": # letter-other
  74. res.append(char)
  75. i += 1
  76. else:
  77. # some input looks like: <unk><noise>, we want to separate it to two words.
  78. sep = " "
  79. if char == "<":
  80. sep = ">"
  81. j = i + 1
  82. while j < len(string):
  83. c = string[j]
  84. if ord(c) >= 128 or (c in spacelist) or (c == sep):
  85. break
  86. j += 1
  87. if j < len(string) and string[j] == ">":
  88. j += 1
  89. res.append(string[i:j])
  90. i = j
  91. return res
  92. def stripoff_tags(x):
  93. if not x:
  94. return ""
  95. chars = []
  96. i = 0
  97. T = len(x)
  98. while i < T:
  99. if x[i] == "<":
  100. while i < T and x[i] != ">":
  101. i += 1
  102. i += 1
  103. else:
  104. chars.append(x[i])
  105. i += 1
  106. return "".join(chars)
  107. def normalize(sentence, ignore_words, cs, split=None):
  108. """sentence, ignore_words are both in unicode"""
  109. new_sentence = []
  110. for token in sentence:
  111. x = token
  112. if not cs:
  113. x = x.upper()
  114. if x in ignore_words:
  115. continue
  116. if remove_tag:
  117. x = stripoff_tags(x)
  118. if not x:
  119. continue
  120. if split and x in split:
  121. new_sentence += split[x]
  122. else:
  123. new_sentence.append(x)
  124. return new_sentence
  125. class Calculator:
  126. def __init__(self):
  127. self.data = {}
  128. self.space = []
  129. self.cost = {}
  130. self.cost["cor"] = 0
  131. self.cost["sub"] = 1
  132. self.cost["del"] = 1
  133. self.cost["ins"] = 1
  134. def calculate(self, lab, rec):
  135. # Initialization
  136. lab.insert(0, "")
  137. rec.insert(0, "")
  138. while len(self.space) < len(lab):
  139. self.space.append([])
  140. for row in self.space:
  141. for element in row:
  142. element["dist"] = 0
  143. element["error"] = "non"
  144. while len(row) < len(rec):
  145. row.append({"dist": 0, "error": "non"})
  146. for i in range(len(lab)):
  147. self.space[i][0]["dist"] = i
  148. self.space[i][0]["error"] = "del"
  149. for j in range(len(rec)):
  150. self.space[0][j]["dist"] = j
  151. self.space[0][j]["error"] = "ins"
  152. self.space[0][0]["error"] = "non"
  153. for token in lab:
  154. if token not in self.data and len(token) > 0:
  155. self.data[token] = {"all": 0, "cor": 0, "sub": 0, "ins": 0, "del": 0}
  156. for token in rec:
  157. if token not in self.data and len(token) > 0:
  158. self.data[token] = {"all": 0, "cor": 0, "sub": 0, "ins": 0, "del": 0}
  159. # Computing edit distance
  160. for i, lab_token in enumerate(lab):
  161. for j, rec_token in enumerate(rec):
  162. if i == 0 or j == 0:
  163. continue
  164. min_dist = sys.maxsize
  165. min_error = "none"
  166. dist = self.space[i - 1][j]["dist"] + self.cost["del"]
  167. error = "del"
  168. if dist < min_dist:
  169. min_dist = dist
  170. min_error = error
  171. dist = self.space[i][j - 1]["dist"] + self.cost["ins"]
  172. error = "ins"
  173. if dist < min_dist:
  174. min_dist = dist
  175. min_error = error
  176. if lab_token == rec_token.replace("<BIAS>", ""):
  177. dist = self.space[i - 1][j - 1]["dist"] + self.cost["cor"]
  178. error = "cor"
  179. else:
  180. dist = self.space[i - 1][j - 1]["dist"] + self.cost["sub"]
  181. error = "sub"
  182. if dist < min_dist:
  183. min_dist = dist
  184. min_error = error
  185. self.space[i][j]["dist"] = min_dist
  186. self.space[i][j]["error"] = min_error
  187. # Tracing back
  188. result = {
  189. "lab": [],
  190. "rec": [],
  191. "code": [],
  192. "all": 0,
  193. "cor": 0,
  194. "sub": 0,
  195. "ins": 0,
  196. "del": 0,
  197. }
  198. i = len(lab) - 1
  199. j = len(rec) - 1
  200. while True:
  201. if self.space[i][j]["error"] == "cor": # correct
  202. if len(lab[i]) > 0:
  203. self.data[lab[i]]["all"] = self.data[lab[i]]["all"] + 1
  204. self.data[lab[i]]["cor"] = self.data[lab[i]]["cor"] + 1
  205. result["all"] = result["all"] + 1
  206. result["cor"] = result["cor"] + 1
  207. result["lab"].insert(0, lab[i])
  208. result["rec"].insert(0, rec[j])
  209. result["code"].insert(0, Code.match)
  210. i = i - 1
  211. j = j - 1
  212. elif self.space[i][j]["error"] == "sub": # substitution
  213. if len(lab[i]) > 0:
  214. self.data[lab[i]]["all"] = self.data[lab[i]]["all"] + 1
  215. self.data[lab[i]]["sub"] = self.data[lab[i]]["sub"] + 1
  216. result["all"] = result["all"] + 1
  217. result["sub"] = result["sub"] + 1
  218. result["lab"].insert(0, lab[i])
  219. result["rec"].insert(0, rec[j])
  220. result["code"].insert(0, Code.substitution)
  221. i = i - 1
  222. j = j - 1
  223. elif self.space[i][j]["error"] == "del": # deletion
  224. if len(lab[i]) > 0:
  225. self.data[lab[i]]["all"] = self.data[lab[i]]["all"] + 1
  226. self.data[lab[i]]["del"] = self.data[lab[i]]["del"] + 1
  227. result["all"] = result["all"] + 1
  228. result["del"] = result["del"] + 1
  229. result["lab"].insert(0, lab[i])
  230. result["rec"].insert(0, "")
  231. result["code"].insert(0, Code.deletion)
  232. i = i - 1
  233. elif self.space[i][j]["error"] == "ins": # insertion
  234. if len(rec[j]) > 0:
  235. self.data[rec[j]]["ins"] = self.data[rec[j]]["ins"] + 1
  236. result["ins"] = result["ins"] + 1
  237. result["lab"].insert(0, "")
  238. result["rec"].insert(0, rec[j])
  239. result["code"].insert(0, Code.insertion)
  240. j = j - 1
  241. elif self.space[i][j]["error"] == "non": # starting point
  242. break
  243. else: # shouldn't reach here
  244. print(
  245. "this should not happen , i = {i} , j = {j} , error = {error}".format(
  246. i=i, j=j, error=self.space[i][j]["error"]
  247. )
  248. )
  249. return result
  250. def overall(self):
  251. result = {"all": 0, "cor": 0, "sub": 0, "ins": 0, "del": 0}
  252. for token in self.data:
  253. result["all"] = result["all"] + self.data[token]["all"]
  254. result["cor"] = result["cor"] + self.data[token]["cor"]
  255. result["sub"] = result["sub"] + self.data[token]["sub"]
  256. result["ins"] = result["ins"] + self.data[token]["ins"]
  257. result["del"] = result["del"] + self.data[token]["del"]
  258. return result
  259. def cluster(self, data):
  260. result = {"all": 0, "cor": 0, "sub": 0, "ins": 0, "del": 0}
  261. for token in data:
  262. if token in self.data:
  263. result["all"] = result["all"] + self.data[token]["all"]
  264. result["cor"] = result["cor"] + self.data[token]["cor"]
  265. result["sub"] = result["sub"] + self.data[token]["sub"]
  266. result["ins"] = result["ins"] + self.data[token]["ins"]
  267. result["del"] = result["del"] + self.data[token]["del"]
  268. return result
  269. def keys(self):
  270. return list(self.data.keys())
  271. def width(string):
  272. return sum(1 + (unicodedata.east_asian_width(c) in "AFW") for c in string)
  273. def default_cluster(word):
  274. unicode_names = [unicodedata.name(char) for char in word]
  275. for i in reversed(range(len(unicode_names))):
  276. if unicode_names[i].startswith("DIGIT"): # 1
  277. unicode_names[i] = "Number" # 'DIGIT'
  278. elif unicode_names[i].startswith("CJK UNIFIED IDEOGRAPH") or unicode_names[
  279. i
  280. ].startswith("CJK COMPATIBILITY IDEOGRAPH"):
  281. # 明 / 郎
  282. unicode_names[i] = "Mandarin" # 'CJK IDEOGRAPH'
  283. elif unicode_names[i].startswith("LATIN CAPITAL LETTER") or unicode_names[
  284. i
  285. ].startswith("LATIN SMALL LETTER"):
  286. # A / a
  287. unicode_names[i] = "English" # 'LATIN LETTER'
  288. elif unicode_names[i].startswith("HIRAGANA LETTER"): # は こ め
  289. unicode_names[i] = "Japanese" # 'GANA LETTER'
  290. elif (
  291. unicode_names[i].startswith("AMPERSAND")
  292. or unicode_names[i].startswith("APOSTROPHE")
  293. or unicode_names[i].startswith("COMMERCIAL AT")
  294. or unicode_names[i].startswith("DEGREE CELSIUS")
  295. or unicode_names[i].startswith("EQUALS SIGN")
  296. or unicode_names[i].startswith("FULL STOP")
  297. or unicode_names[i].startswith("HYPHEN-MINUS")
  298. or unicode_names[i].startswith("LOW LINE")
  299. or unicode_names[i].startswith("NUMBER SIGN")
  300. or unicode_names[i].startswith("PLUS SIGN")
  301. or unicode_names[i].startswith("SEMICOLON")
  302. ):
  303. # & / ' / @ / ℃ / = / . / - / _ / # / + / ;
  304. del unicode_names[i]
  305. else:
  306. return "Other"
  307. if len(unicode_names) == 0:
  308. return "Other"
  309. if len(unicode_names) == 1:
  310. return unicode_names[0]
  311. for i in range(len(unicode_names) - 1):
  312. if unicode_names[i] != unicode_names[i + 1]:
  313. return "Other"
  314. return unicode_names[0]
  315. def get_args():
  316. parser = argparse.ArgumentParser(description="wer cal")
  317. parser.add_argument("--ref", type=str, help="Text input path")
  318. parser.add_argument("--ref_ocr", type=str, help="Text input path")
  319. parser.add_argument("--rec_name", type=str, action="append", default=[])
  320. parser.add_argument("--rec_file", type=str, action="append", default=[])
  321. parser.add_argument("--verbose", type=int, default=1, help="show")
  322. parser.add_argument("--char", type=bool, default=True, help="show")
  323. args = parser.parse_args()
  324. return args
  325. def main(args):
  326. cluster_file = ""
  327. ignore_words = set()
  328. tochar = args.char
  329. verbose = args.verbose
  330. padding_symbol = " "
  331. case_sensitive = False
  332. max_words_per_line = sys.maxsize
  333. split = None
  334. if not case_sensitive:
  335. ig = set([w.upper() for w in ignore_words])
  336. ignore_words = ig
  337. default_clusters = {}
  338. default_words = {}
  339. ref_file = args.ref
  340. ref_ocr = args.ref_ocr
  341. rec_files = args.rec_file
  342. rec_names = args.rec_name
  343. assert len(rec_files) == len(rec_names)
  344. # load ocr
  345. ref_ocr_dict = {}
  346. with codecs.open(ref_ocr, "r", "utf-8") as fh:
  347. for line in fh:
  348. if "$" in line:
  349. line = line.replace("$", " ")
  350. if tochar:
  351. array = characterize(line)
  352. else:
  353. array = line.strip().split()
  354. if len(array) == 0:
  355. continue
  356. fid = array[0]
  357. ref_ocr_dict[fid] = normalize(array[1:], ignore_words, case_sensitive, split)
  358. if split and not case_sensitive:
  359. newsplit = dict()
  360. for w in split:
  361. words = split[w]
  362. for i in range(len(words)):
  363. words[i] = words[i].upper()
  364. newsplit[w.upper()] = words
  365. split = newsplit
  366. rec_sets = {}
  367. calculators_dict = dict()
  368. ub_wer_dict = dict()
  369. hotwords_related_dict = dict() # 记录recall相关的内容
  370. for i, hyp_file in enumerate(rec_files):
  371. rec_sets[rec_names[i]] = dict()
  372. with codecs.open(hyp_file, "r", "utf-8") as fh:
  373. for line in fh:
  374. if tochar:
  375. array = characterize(line)
  376. else:
  377. array = line.strip().split()
  378. if len(array) == 0:
  379. continue
  380. fid = array[0]
  381. rec_sets[rec_names[i]][fid] = normalize(array[1:], ignore_words, case_sensitive, split)
  382. calculators_dict[rec_names[i]] = Calculator()
  383. ub_wer_dict[rec_names[i]] = {"u_wer": WordError(), "b_wer": WordError(), "wer": WordError()}
  384. hotwords_related_dict[rec_names[i]] = {'tp': 0, 'tn': 0, 'fp': 0, 'fn': 0}
  385. # tp: 热词在label里,同时在rec里
  386. # tn: 热词不在label里,同时不在rec里
  387. # fp: 热词不在label里,但是在rec里
  388. # fn: 热词在label里,但是不在rec里
  389. # record wrong label but in ocr
  390. wrong_rec_but_in_ocr_dict = {}
  391. for rec_name in rec_names:
  392. wrong_rec_but_in_ocr_dict[rec_name] = 0
  393. _file_total_len = 0
  394. with os.popen("cat {} | wc -l".format(ref_file)) as pipe:
  395. _file_total_len = int(pipe.read().strip())
  396. # compute error rate on the interaction of reference file and hyp file
  397. for line in tqdm(open(ref_file, 'r', encoding='utf-8'), total=_file_total_len):
  398. if tochar:
  399. array = characterize(line)
  400. else:
  401. array = line.rstrip('\n').split()
  402. if len(array) == 0: continue
  403. fid = array[0]
  404. lab = normalize(array[1:], ignore_words, case_sensitive, split)
  405. if verbose:
  406. print('\nutt: %s' % fid)
  407. ocr_text = ref_ocr_dict[fid]
  408. ocr_set = set(ocr_text)
  409. print('ocr: {}'.format(" ".join(ocr_text)))
  410. list_match = [] # 指label里面在ocr里面的内容
  411. list_not_mathch = []
  412. tmp_error = 0
  413. tmp_match = 0
  414. for index in range(len(lab)):
  415. # text_list.append(uttlist[index+1])
  416. if lab[index] not in ocr_set:
  417. tmp_error += 1
  418. list_not_mathch.append(lab[index])
  419. else:
  420. tmp_match += 1
  421. list_match.append(lab[index])
  422. print('label in ocr: {}'.format(" ".join(list_match)))
  423. # for each reco file
  424. base_wrong_ocr_wer = None
  425. ocr_wrong_ocr_wer = None
  426. for rec_name in rec_names:
  427. rec_set = rec_sets[rec_name]
  428. if fid not in rec_set:
  429. continue
  430. rec = rec_set[fid]
  431. # print(rec)
  432. for word in rec + lab:
  433. if word not in default_words:
  434. default_cluster_name = default_cluster(word)
  435. if default_cluster_name not in default_clusters:
  436. default_clusters[default_cluster_name] = {}
  437. if word not in default_clusters[default_cluster_name]:
  438. default_clusters[default_cluster_name][word] = 1
  439. default_words[word] = default_cluster_name
  440. result = calculators_dict[rec_name].calculate(lab.copy(), rec.copy())
  441. if verbose:
  442. if result['all'] != 0:
  443. wer = float(result['ins'] + result['sub'] + result['del']) * 100.0 / result['all']
  444. else:
  445. wer = 0.0
  446. print('WER(%s): %4.2f %%' % (rec_name, wer), end=' ')
  447. print('N=%d C=%d S=%d D=%d I=%d' %
  448. (result['all'], result['cor'], result['sub'], result['del'], result['ins']))
  449. # print(result['rec'])
  450. wrong_rec_but_in_ocr = []
  451. for idx in range(len(result['lab'])):
  452. if result['lab'][idx] != "":
  453. if result['lab'][idx] != result['rec'][idx].replace("<BIAS>", ""):
  454. if result['lab'][idx] in list_match:
  455. wrong_rec_but_in_ocr.append(result['lab'][idx])
  456. wrong_rec_but_in_ocr_dict[rec_name] += 1
  457. print('wrong_rec_but_in_ocr: {}'.format(" ".join(wrong_rec_but_in_ocr)))
  458. if rec_name == "base":
  459. base_wrong_ocr_wer = len(wrong_rec_but_in_ocr)
  460. if "ocr" in rec_name or "hot" in rec_name:
  461. ocr_wrong_ocr_wer = len(wrong_rec_but_in_ocr)
  462. if ocr_wrong_ocr_wer < base_wrong_ocr_wer:
  463. print("{} {} helps, {} -> {}".format(fid, rec_name, base_wrong_ocr_wer, ocr_wrong_ocr_wer))
  464. elif ocr_wrong_ocr_wer > base_wrong_ocr_wer:
  465. print("{} {} hurts, {} -> {}".format(fid, rec_name, base_wrong_ocr_wer, ocr_wrong_ocr_wer))
  466. # recall = 0
  467. # false_alarm = 0
  468. # for idx in range(len(result['lab'])):
  469. # if "<BIAS>" in result['rec'][idx]:
  470. # if result['rec'][idx].replace("<BIAS>", "") in list_match:
  471. # recall += 1
  472. # else:
  473. # false_alarm += 1
  474. # print("bias hotwords recall: {}, fa: {}, list_match {}, recall: {:.2f}, fa: {:.2f}".format(
  475. # recall, false_alarm, len(list_match), recall / len(list_match) if len(list_match) != 0 else 0, false_alarm / len(list_match) if len(list_match) != 0 else 0
  476. # ))
  477. # tp: 热词在label里,同时在rec里
  478. # tn: 热词不在label里,同时不在rec里
  479. # fp: 热词不在label里,但是在rec里
  480. # fn: 热词在label里,但是不在rec里
  481. _rec_list = [word.replace("<BIAS>", "") for word in rec]
  482. _label_list = [word for word in lab]
  483. _tp = _tn = _fp = _fn = 0
  484. hot_true_list = [hotword for hotword in ocr_text if hotword in _label_list]
  485. hot_bad_list = [hotword for hotword in ocr_text if hotword not in _label_list]
  486. for badhotword in hot_bad_list:
  487. count = len([word for word in _rec_list if word == badhotword])
  488. # print(f"bad {badhotword} count: {count}")
  489. # for word in _rec_list:
  490. # if badhotword == word:
  491. # count += 1
  492. if count == 0:
  493. hotwords_related_dict[rec_name]['tn'] += 1
  494. _tn += 1
  495. # fp: 0
  496. else:
  497. hotwords_related_dict[rec_name]['fp'] += count
  498. _fp += count
  499. # tn: 0
  500. # if badhotword in _rec_list:
  501. # hotwords_related_dict[rec_name]['fp'] += 1
  502. # else:
  503. # hotwords_related_dict[rec_name]['tn'] += 1
  504. for hotword in hot_true_list:
  505. true_count = len([word for word in _label_list if hotword == word])
  506. rec_count = len([word for word in _rec_list if hotword == word])
  507. # print(f"good {hotword} true_count: {true_count}, rec_count: {rec_count}")
  508. if rec_count == true_count:
  509. hotwords_related_dict[rec_name]['tp'] += true_count
  510. _tp += true_count
  511. elif rec_count > true_count:
  512. hotwords_related_dict[rec_name]['tp'] += true_count
  513. # fp: 不在label里,但是在rec里
  514. hotwords_related_dict[rec_name]['fp'] += rec_count - true_count
  515. _tp += true_count
  516. _fp += rec_count - true_count
  517. else:
  518. hotwords_related_dict[rec_name]['tp'] += rec_count
  519. # fn: 热词在label里,但是不在rec里
  520. hotwords_related_dict[rec_name]['fn'] += true_count - rec_count
  521. _tp += rec_count
  522. _fn += true_count - rec_count
  523. print("hotword: tp: {}, tn: {}, fp: {}, fn: {}, all: {}, recall: {:.2f}%".format(
  524. _tp, _tn, _fp, _fn, sum([_tp, _tn, _fp, _fn]), _tp / (_tp + _fn) * 100 if (_tp + _fn) != 0 else 0
  525. ))
  526. # if hotword in _rec_list:
  527. # hotwords_related_dict[rec_name]['tp'] += 1
  528. # else:
  529. # hotwords_related_dict[rec_name]['fn'] += 1
  530. # 计算uwer, bwer, wer
  531. for code, rec_word, lab_word in zip(result["code"], result["rec"], result["lab"]):
  532. if code == Code.match:
  533. ub_wer_dict[rec_name]["wer"].ref_words += 1
  534. if lab_word in hot_true_list:
  535. # tmp_ref.append(ref_tokens[ref_idx])
  536. ub_wer_dict[rec_name]["b_wer"].ref_words += 1
  537. else:
  538. ub_wer_dict[rec_name]["u_wer"].ref_words += 1
  539. elif code == Code.substitution:
  540. ub_wer_dict[rec_name]["wer"].ref_words += 1
  541. ub_wer_dict[rec_name]["wer"].errors[Code.substitution] += 1
  542. if lab_word in hot_true_list:
  543. # tmp_ref.append(ref_tokens[ref_idx])
  544. ub_wer_dict[rec_name]["b_wer"].ref_words += 1
  545. ub_wer_dict[rec_name]["b_wer"].errors[Code.substitution] += 1
  546. else:
  547. ub_wer_dict[rec_name]["u_wer"].ref_words += 1
  548. ub_wer_dict[rec_name]["u_wer"].errors[Code.substitution] += 1
  549. elif code == Code.deletion:
  550. ub_wer_dict[rec_name]["wer"].ref_words += 1
  551. ub_wer_dict[rec_name]["wer"].errors[Code.deletion] += 1
  552. if lab_word in hot_true_list:
  553. # tmp_ref.append(ref_tokens[ref_idx])
  554. ub_wer_dict[rec_name]["b_wer"].ref_words += 1
  555. ub_wer_dict[rec_name]["b_wer"].errors[Code.deletion] += 1
  556. else:
  557. ub_wer_dict[rec_name]["u_wer"].ref_words += 1
  558. ub_wer_dict[rec_name]["u_wer"].errors[Code.deletion] += 1
  559. elif code == Code.insertion:
  560. ub_wer_dict[rec_name]["wer"].errors[Code.insertion] += 1
  561. if rec_word in hot_true_list:
  562. ub_wer_dict[rec_name]["b_wer"].errors[Code.insertion] += 1
  563. else:
  564. ub_wer_dict[rec_name]["u_wer"].errors[Code.insertion] += 1
  565. space = {}
  566. space['lab'] = []
  567. space['rec'] = []
  568. for idx in range(len(result['lab'])):
  569. len_lab = width(result['lab'][idx])
  570. len_rec = width(result['rec'][idx])
  571. length = max(len_lab, len_rec)
  572. space['lab'].append(length - len_lab)
  573. space['rec'].append(length - len_rec)
  574. upper_lab = len(result['lab'])
  575. upper_rec = len(result['rec'])
  576. lab1, rec1 = 0, 0
  577. while lab1 < upper_lab or rec1 < upper_rec:
  578. if verbose > 1:
  579. print('lab(%s):' % fid.encode('utf-8'), end=' ')
  580. else:
  581. print('lab:', end=' ')
  582. lab2 = min(upper_lab, lab1 + max_words_per_line)
  583. for idx in range(lab1, lab2):
  584. token = result['lab'][idx]
  585. print('{token}'.format(token=token), end='')
  586. for n in range(space['lab'][idx]):
  587. print(padding_symbol, end='')
  588. print(' ', end='')
  589. print()
  590. if verbose > 1:
  591. print('rec(%s):' % fid.encode('utf-8'), end=' ')
  592. else:
  593. print('rec:', end=' ')
  594. rec2 = min(upper_rec, rec1 + max_words_per_line)
  595. for idx in range(rec1, rec2):
  596. token = result['rec'][idx]
  597. print('{token}'.format(token=token), end='')
  598. for n in range(space['rec'][idx]):
  599. print(padding_symbol, end='')
  600. print(' ', end='')
  601. print()
  602. # print('\n', end='\n')
  603. lab1 = lab2
  604. rec1 = rec2
  605. print('\n', end='\n')
  606. # break
  607. if verbose:
  608. print('===========================================================================')
  609. print()
  610. print(wrong_rec_but_in_ocr_dict)
  611. for rec_name in rec_names:
  612. result = calculators_dict[rec_name].overall()
  613. if result['all'] != 0:
  614. wer = float(result['ins'] + result['sub'] + result['del']) * 100.0 / result['all']
  615. else:
  616. wer = 0.0
  617. print('{} Overall -> {:4.2f} %'.format(rec_name, wer), end=' ')
  618. print('N=%d C=%d S=%d D=%d I=%d' %
  619. (result['all'], result['cor'], result['sub'], result['del'], result['ins']))
  620. print(f"WER: {ub_wer_dict[rec_name]['wer'].get_result_string()}")
  621. print(f"U-WER: {ub_wer_dict[rec_name]['u_wer'].get_result_string()}")
  622. print(f"B-WER: {ub_wer_dict[rec_name]['b_wer'].get_result_string()}")
  623. print('hotword: tp: {}, tn: {}, fp: {}, fn: {}, all: {}, recall: {:.2f}%'.format(
  624. hotwords_related_dict[rec_name]['tp'],
  625. hotwords_related_dict[rec_name]['tn'],
  626. hotwords_related_dict[rec_name]['fp'],
  627. hotwords_related_dict[rec_name]['fn'],
  628. sum([v for k, v in hotwords_related_dict[rec_name].items()]),
  629. hotwords_related_dict[rec_name]['tp'] / (
  630. hotwords_related_dict[rec_name]['tp'] + hotwords_related_dict[rec_name]['fn']
  631. ) * 100 if hotwords_related_dict[rec_name]['tp'] + hotwords_related_dict[rec_name]['fn'] != 0 else 0
  632. ))
  633. # tp: 热词在label里,同时在rec里
  634. # tn: 热词不在label里,同时不在rec里
  635. # fp: 热词不在label里,但是在rec里
  636. # fn: 热词在label里,但是不在rec里
  637. if not verbose:
  638. print()
  639. print()
  640. if __name__ == "__main__":
  641. args = get_args()
  642. # print("")
  643. print(args)
  644. main(args)