| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702 |
- #!/usr/bin/env python3
- # -*- coding: utf-8 -*-
- from enum import Enum
- import re, sys, unicodedata
- import codecs
- import argparse
- from tqdm import tqdm
- import os
- import pdb
- remove_tag = False
- spacelist = [" ", "\t", "\r", "\n"]
- puncts = [
- "!",
- ",",
- "?",
- "、",
- "。",
- "!",
- ",",
- ";",
- "?",
- ":",
- "「",
- "」",
- "︰",
- "『",
- "』",
- "《",
- "》",
- ]
- class Code(Enum):
- match = 1
- substitution = 2
- insertion = 3
- deletion = 4
- class WordError(object):
- def __init__(self):
- self.errors = {
- Code.substitution: 0,
- Code.insertion: 0,
- Code.deletion: 0,
- }
- self.ref_words = 0
- def get_wer(self):
- assert self.ref_words != 0
- errors = (
- self.errors[Code.substitution]
- + self.errors[Code.insertion]
- + self.errors[Code.deletion]
- )
- return 100.0 * errors / self.ref_words
- def get_result_string(self):
- return (
- f"error_rate={self.get_wer():.4f}, "
- f"ref_words={self.ref_words}, "
- f"subs={self.errors[Code.substitution]}, "
- f"ins={self.errors[Code.insertion]}, "
- f"dels={self.errors[Code.deletion]}"
- )
- def characterize(string):
- res = []
- i = 0
- while i < len(string):
- char = string[i]
- if char in puncts:
- i += 1
- continue
- cat1 = unicodedata.category(char)
- # https://unicodebook.readthedocs.io/unicode.html#unicode-categories
- if cat1 == "Zs" or cat1 == "Cn" or char in spacelist: # space or not assigned
- i += 1
- continue
- if cat1 == "Lo": # letter-other
- res.append(char)
- i += 1
- else:
- # some input looks like: <unk><noise>, we want to separate it to two words.
- sep = " "
- if char == "<":
- sep = ">"
- j = i + 1
- while j < len(string):
- c = string[j]
- if ord(c) >= 128 or (c in spacelist) or (c == sep):
- break
- j += 1
- if j < len(string) and string[j] == ">":
- j += 1
- res.append(string[i:j])
- i = j
- return res
- def stripoff_tags(x):
- if not x:
- return ""
- chars = []
- i = 0
- T = len(x)
- while i < T:
- if x[i] == "<":
- while i < T and x[i] != ">":
- i += 1
- i += 1
- else:
- chars.append(x[i])
- i += 1
- return "".join(chars)
- def normalize(sentence, ignore_words, cs, split=None):
- """sentence, ignore_words are both in unicode"""
- new_sentence = []
- for token in sentence:
- x = token
- if not cs:
- x = x.upper()
- if x in ignore_words:
- continue
- if remove_tag:
- x = stripoff_tags(x)
- if not x:
- continue
- if split and x in split:
- new_sentence += split[x]
- else:
- new_sentence.append(x)
- return new_sentence
- class Calculator:
- def __init__(self):
- self.data = {}
- self.space = []
- self.cost = {}
- self.cost["cor"] = 0
- self.cost["sub"] = 1
- self.cost["del"] = 1
- self.cost["ins"] = 1
- def calculate(self, lab, rec):
- # Initialization
- lab.insert(0, "")
- rec.insert(0, "")
- while len(self.space) < len(lab):
- self.space.append([])
- for row in self.space:
- for element in row:
- element["dist"] = 0
- element["error"] = "non"
- while len(row) < len(rec):
- row.append({"dist": 0, "error": "non"})
- for i in range(len(lab)):
- self.space[i][0]["dist"] = i
- self.space[i][0]["error"] = "del"
- for j in range(len(rec)):
- self.space[0][j]["dist"] = j
- self.space[0][j]["error"] = "ins"
- self.space[0][0]["error"] = "non"
- for token in lab:
- if token not in self.data and len(token) > 0:
- self.data[token] = {"all": 0, "cor": 0, "sub": 0, "ins": 0, "del": 0}
- for token in rec:
- if token not in self.data and len(token) > 0:
- self.data[token] = {"all": 0, "cor": 0, "sub": 0, "ins": 0, "del": 0}
- # Computing edit distance
- for i, lab_token in enumerate(lab):
- for j, rec_token in enumerate(rec):
- if i == 0 or j == 0:
- continue
- min_dist = sys.maxsize
- min_error = "none"
- dist = self.space[i - 1][j]["dist"] + self.cost["del"]
- error = "del"
- if dist < min_dist:
- min_dist = dist
- min_error = error
- dist = self.space[i][j - 1]["dist"] + self.cost["ins"]
- error = "ins"
- if dist < min_dist:
- min_dist = dist
- min_error = error
- if lab_token == rec_token.replace("<BIAS>", ""):
- dist = self.space[i - 1][j - 1]["dist"] + self.cost["cor"]
- error = "cor"
- else:
- dist = self.space[i - 1][j - 1]["dist"] + self.cost["sub"]
- error = "sub"
- if dist < min_dist:
- min_dist = dist
- min_error = error
- self.space[i][j]["dist"] = min_dist
- self.space[i][j]["error"] = min_error
- # Tracing back
- result = {
- "lab": [],
- "rec": [],
- "code": [],
- "all": 0,
- "cor": 0,
- "sub": 0,
- "ins": 0,
- "del": 0,
- }
- i = len(lab) - 1
- j = len(rec) - 1
- while True:
- if self.space[i][j]["error"] == "cor": # correct
- if len(lab[i]) > 0:
- self.data[lab[i]]["all"] = self.data[lab[i]]["all"] + 1
- self.data[lab[i]]["cor"] = self.data[lab[i]]["cor"] + 1
- result["all"] = result["all"] + 1
- result["cor"] = result["cor"] + 1
- result["lab"].insert(0, lab[i])
- result["rec"].insert(0, rec[j])
- result["code"].insert(0, Code.match)
- i = i - 1
- j = j - 1
- elif self.space[i][j]["error"] == "sub": # substitution
- if len(lab[i]) > 0:
- self.data[lab[i]]["all"] = self.data[lab[i]]["all"] + 1
- self.data[lab[i]]["sub"] = self.data[lab[i]]["sub"] + 1
- result["all"] = result["all"] + 1
- result["sub"] = result["sub"] + 1
- result["lab"].insert(0, lab[i])
- result["rec"].insert(0, rec[j])
- result["code"].insert(0, Code.substitution)
- i = i - 1
- j = j - 1
- elif self.space[i][j]["error"] == "del": # deletion
- if len(lab[i]) > 0:
- self.data[lab[i]]["all"] = self.data[lab[i]]["all"] + 1
- self.data[lab[i]]["del"] = self.data[lab[i]]["del"] + 1
- result["all"] = result["all"] + 1
- result["del"] = result["del"] + 1
- result["lab"].insert(0, lab[i])
- result["rec"].insert(0, "")
- result["code"].insert(0, Code.deletion)
- i = i - 1
- elif self.space[i][j]["error"] == "ins": # insertion
- if len(rec[j]) > 0:
- self.data[rec[j]]["ins"] = self.data[rec[j]]["ins"] + 1
- result["ins"] = result["ins"] + 1
- result["lab"].insert(0, "")
- result["rec"].insert(0, rec[j])
- result["code"].insert(0, Code.insertion)
- j = j - 1
- elif self.space[i][j]["error"] == "non": # starting point
- break
- else: # shouldn't reach here
- print(
- "this should not happen , i = {i} , j = {j} , error = {error}".format(
- i=i, j=j, error=self.space[i][j]["error"]
- )
- )
- return result
- def overall(self):
- result = {"all": 0, "cor": 0, "sub": 0, "ins": 0, "del": 0}
- for token in self.data:
- result["all"] = result["all"] + self.data[token]["all"]
- result["cor"] = result["cor"] + self.data[token]["cor"]
- result["sub"] = result["sub"] + self.data[token]["sub"]
- result["ins"] = result["ins"] + self.data[token]["ins"]
- result["del"] = result["del"] + self.data[token]["del"]
- return result
- def cluster(self, data):
- result = {"all": 0, "cor": 0, "sub": 0, "ins": 0, "del": 0}
- for token in data:
- if token in self.data:
- result["all"] = result["all"] + self.data[token]["all"]
- result["cor"] = result["cor"] + self.data[token]["cor"]
- result["sub"] = result["sub"] + self.data[token]["sub"]
- result["ins"] = result["ins"] + self.data[token]["ins"]
- result["del"] = result["del"] + self.data[token]["del"]
- return result
- def keys(self):
- return list(self.data.keys())
- def width(string):
- return sum(1 + (unicodedata.east_asian_width(c) in "AFW") for c in string)
- def default_cluster(word):
- unicode_names = [unicodedata.name(char) for char in word]
- for i in reversed(range(len(unicode_names))):
- if unicode_names[i].startswith("DIGIT"): # 1
- unicode_names[i] = "Number" # 'DIGIT'
- elif unicode_names[i].startswith("CJK UNIFIED IDEOGRAPH") or unicode_names[
- i
- ].startswith("CJK COMPATIBILITY IDEOGRAPH"):
- # 明 / 郎
- unicode_names[i] = "Mandarin" # 'CJK IDEOGRAPH'
- elif unicode_names[i].startswith("LATIN CAPITAL LETTER") or unicode_names[
- i
- ].startswith("LATIN SMALL LETTER"):
- # A / a
- unicode_names[i] = "English" # 'LATIN LETTER'
- elif unicode_names[i].startswith("HIRAGANA LETTER"): # は こ め
- unicode_names[i] = "Japanese" # 'GANA LETTER'
- elif (
- unicode_names[i].startswith("AMPERSAND")
- or unicode_names[i].startswith("APOSTROPHE")
- or unicode_names[i].startswith("COMMERCIAL AT")
- or unicode_names[i].startswith("DEGREE CELSIUS")
- or unicode_names[i].startswith("EQUALS SIGN")
- or unicode_names[i].startswith("FULL STOP")
- or unicode_names[i].startswith("HYPHEN-MINUS")
- or unicode_names[i].startswith("LOW LINE")
- or unicode_names[i].startswith("NUMBER SIGN")
- or unicode_names[i].startswith("PLUS SIGN")
- or unicode_names[i].startswith("SEMICOLON")
- ):
- # & / ' / @ / ℃ / = / . / - / _ / # / + / ;
- del unicode_names[i]
- else:
- return "Other"
- if len(unicode_names) == 0:
- return "Other"
- if len(unicode_names) == 1:
- return unicode_names[0]
- for i in range(len(unicode_names) - 1):
- if unicode_names[i] != unicode_names[i + 1]:
- return "Other"
- return unicode_names[0]
- def get_args():
- parser = argparse.ArgumentParser(description="wer cal")
- parser.add_argument("--ref", type=str, help="Text input path")
- parser.add_argument("--ref_ocr", type=str, help="Text input path")
- parser.add_argument("--rec_name", type=str, action="append", default=[])
- parser.add_argument("--rec_file", type=str, action="append", default=[])
- parser.add_argument("--verbose", type=int, default=1, help="show")
- parser.add_argument("--char", type=bool, default=True, help="show")
- args = parser.parse_args()
- return args
- def main(args):
- cluster_file = ""
- ignore_words = set()
- tochar = args.char
- verbose = args.verbose
- padding_symbol = " "
- case_sensitive = False
- max_words_per_line = sys.maxsize
- split = None
- if not case_sensitive:
- ig = set([w.upper() for w in ignore_words])
- ignore_words = ig
- default_clusters = {}
- default_words = {}
- ref_file = args.ref
- ref_ocr = args.ref_ocr
- rec_files = args.rec_file
- rec_names = args.rec_name
- assert len(rec_files) == len(rec_names)
- # load ocr
- ref_ocr_dict = {}
- with codecs.open(ref_ocr, "r", "utf-8") as fh:
- for line in fh:
- if "$" in line:
- line = line.replace("$", " ")
- if tochar:
- array = characterize(line)
- else:
- array = line.strip().split()
- if len(array) == 0:
- continue
- fid = array[0]
- ref_ocr_dict[fid] = normalize(array[1:], ignore_words, case_sensitive, split)
- if split and not case_sensitive:
- newsplit = dict()
- for w in split:
- words = split[w]
- for i in range(len(words)):
- words[i] = words[i].upper()
- newsplit[w.upper()] = words
- split = newsplit
- rec_sets = {}
- calculators_dict = dict()
- ub_wer_dict = dict()
- hotwords_related_dict = dict() # 记录recall相关的内容
- for i, hyp_file in enumerate(rec_files):
- rec_sets[rec_names[i]] = dict()
- with codecs.open(hyp_file, "r", "utf-8") as fh:
- for line in fh:
- if tochar:
- array = characterize(line)
- else:
- array = line.strip().split()
- if len(array) == 0:
- continue
- fid = array[0]
- rec_sets[rec_names[i]][fid] = normalize(array[1:], ignore_words, case_sensitive, split)
- calculators_dict[rec_names[i]] = Calculator()
- ub_wer_dict[rec_names[i]] = {"u_wer": WordError(), "b_wer": WordError(), "wer": WordError()}
- hotwords_related_dict[rec_names[i]] = {'tp': 0, 'tn': 0, 'fp': 0, 'fn': 0}
- # tp: 热词在label里,同时在rec里
- # tn: 热词不在label里,同时不在rec里
- # fp: 热词不在label里,但是在rec里
- # fn: 热词在label里,但是不在rec里
- # record wrong label but in ocr
- wrong_rec_but_in_ocr_dict = {}
- for rec_name in rec_names:
- wrong_rec_but_in_ocr_dict[rec_name] = 0
- _file_total_len = 0
- with os.popen("cat {} | wc -l".format(ref_file)) as pipe:
- _file_total_len = int(pipe.read().strip())
- # compute error rate on the interaction of reference file and hyp file
- for line in tqdm(open(ref_file, 'r', encoding='utf-8'), total=_file_total_len):
- if tochar:
- array = characterize(line)
- else:
- array = line.rstrip('\n').split()
- if len(array) == 0: continue
- fid = array[0]
- lab = normalize(array[1:], ignore_words, case_sensitive, split)
- if verbose:
- print('\nutt: %s' % fid)
- ocr_text = ref_ocr_dict[fid]
- ocr_set = set(ocr_text)
- print('ocr: {}'.format(" ".join(ocr_text)))
- list_match = [] # 指label里面在ocr里面的内容
- list_not_mathch = []
- tmp_error = 0
- tmp_match = 0
- for index in range(len(lab)):
- # text_list.append(uttlist[index+1])
- if lab[index] not in ocr_set:
- tmp_error += 1
- list_not_mathch.append(lab[index])
- else:
- tmp_match += 1
- list_match.append(lab[index])
- print('label in ocr: {}'.format(" ".join(list_match)))
- # for each reco file
- base_wrong_ocr_wer = None
- ocr_wrong_ocr_wer = None
- for rec_name in rec_names:
- rec_set = rec_sets[rec_name]
- if fid not in rec_set:
- continue
- rec = rec_set[fid]
- # print(rec)
- for word in rec + lab:
- if word not in default_words:
- default_cluster_name = default_cluster(word)
- if default_cluster_name not in default_clusters:
- default_clusters[default_cluster_name] = {}
- if word not in default_clusters[default_cluster_name]:
- default_clusters[default_cluster_name][word] = 1
- default_words[word] = default_cluster_name
- result = calculators_dict[rec_name].calculate(lab.copy(), rec.copy())
- if verbose:
- if result['all'] != 0:
- wer = float(result['ins'] + result['sub'] + result['del']) * 100.0 / result['all']
- else:
- wer = 0.0
- print('WER(%s): %4.2f %%' % (rec_name, wer), end=' ')
- print('N=%d C=%d S=%d D=%d I=%d' %
- (result['all'], result['cor'], result['sub'], result['del'], result['ins']))
- # print(result['rec'])
- wrong_rec_but_in_ocr = []
- for idx in range(len(result['lab'])):
- if result['lab'][idx] != "":
- if result['lab'][idx] != result['rec'][idx].replace("<BIAS>", ""):
- if result['lab'][idx] in list_match:
- wrong_rec_but_in_ocr.append(result['lab'][idx])
- wrong_rec_but_in_ocr_dict[rec_name] += 1
- print('wrong_rec_but_in_ocr: {}'.format(" ".join(wrong_rec_but_in_ocr)))
- if rec_name == "base":
- base_wrong_ocr_wer = len(wrong_rec_but_in_ocr)
- if "ocr" in rec_name or "hot" in rec_name:
- ocr_wrong_ocr_wer = len(wrong_rec_but_in_ocr)
- if ocr_wrong_ocr_wer < base_wrong_ocr_wer:
- print("{} {} helps, {} -> {}".format(fid, rec_name, base_wrong_ocr_wer, ocr_wrong_ocr_wer))
- elif ocr_wrong_ocr_wer > base_wrong_ocr_wer:
- print("{} {} hurts, {} -> {}".format(fid, rec_name, base_wrong_ocr_wer, ocr_wrong_ocr_wer))
- # recall = 0
- # false_alarm = 0
- # for idx in range(len(result['lab'])):
- # if "<BIAS>" in result['rec'][idx]:
- # if result['rec'][idx].replace("<BIAS>", "") in list_match:
- # recall += 1
- # else:
- # false_alarm += 1
- # print("bias hotwords recall: {}, fa: {}, list_match {}, recall: {:.2f}, fa: {:.2f}".format(
- # recall, false_alarm, len(list_match), recall / len(list_match) if len(list_match) != 0 else 0, false_alarm / len(list_match) if len(list_match) != 0 else 0
- # ))
- # tp: 热词在label里,同时在rec里
- # tn: 热词不在label里,同时不在rec里
- # fp: 热词不在label里,但是在rec里
- # fn: 热词在label里,但是不在rec里
- _rec_list = [word.replace("<BIAS>", "") for word in rec]
- _label_list = [word for word in lab]
- _tp = _tn = _fp = _fn = 0
- hot_true_list = [hotword for hotword in ocr_text if hotword in _label_list]
- hot_bad_list = [hotword for hotword in ocr_text if hotword not in _label_list]
- for badhotword in hot_bad_list:
- count = len([word for word in _rec_list if word == badhotword])
- # print(f"bad {badhotword} count: {count}")
- # for word in _rec_list:
- # if badhotword == word:
- # count += 1
- if count == 0:
- hotwords_related_dict[rec_name]['tn'] += 1
- _tn += 1
- # fp: 0
- else:
- hotwords_related_dict[rec_name]['fp'] += count
- _fp += count
- # tn: 0
- # if badhotword in _rec_list:
- # hotwords_related_dict[rec_name]['fp'] += 1
- # else:
- # hotwords_related_dict[rec_name]['tn'] += 1
- for hotword in hot_true_list:
- true_count = len([word for word in _label_list if hotword == word])
- rec_count = len([word for word in _rec_list if hotword == word])
- # print(f"good {hotword} true_count: {true_count}, rec_count: {rec_count}")
- if rec_count == true_count:
- hotwords_related_dict[rec_name]['tp'] += true_count
- _tp += true_count
- elif rec_count > true_count:
- hotwords_related_dict[rec_name]['tp'] += true_count
- # fp: 不在label里,但是在rec里
- hotwords_related_dict[rec_name]['fp'] += rec_count - true_count
- _tp += true_count
- _fp += rec_count - true_count
- else:
- hotwords_related_dict[rec_name]['tp'] += rec_count
- # fn: 热词在label里,但是不在rec里
- hotwords_related_dict[rec_name]['fn'] += true_count - rec_count
- _tp += rec_count
- _fn += true_count - rec_count
- print("hotword: tp: {}, tn: {}, fp: {}, fn: {}, all: {}, recall: {:.2f}%".format(
- _tp, _tn, _fp, _fn, sum([_tp, _tn, _fp, _fn]), _tp / (_tp + _fn) * 100 if (_tp + _fn) != 0 else 0
- ))
- # if hotword in _rec_list:
- # hotwords_related_dict[rec_name]['tp'] += 1
- # else:
- # hotwords_related_dict[rec_name]['fn'] += 1
- # 计算uwer, bwer, wer
- for code, rec_word, lab_word in zip(result["code"], result["rec"], result["lab"]):
- if code == Code.match:
- ub_wer_dict[rec_name]["wer"].ref_words += 1
- if lab_word in hot_true_list:
- # tmp_ref.append(ref_tokens[ref_idx])
- ub_wer_dict[rec_name]["b_wer"].ref_words += 1
- else:
- ub_wer_dict[rec_name]["u_wer"].ref_words += 1
- elif code == Code.substitution:
- ub_wer_dict[rec_name]["wer"].ref_words += 1
- ub_wer_dict[rec_name]["wer"].errors[Code.substitution] += 1
- if lab_word in hot_true_list:
- # tmp_ref.append(ref_tokens[ref_idx])
- ub_wer_dict[rec_name]["b_wer"].ref_words += 1
- ub_wer_dict[rec_name]["b_wer"].errors[Code.substitution] += 1
- else:
- ub_wer_dict[rec_name]["u_wer"].ref_words += 1
- ub_wer_dict[rec_name]["u_wer"].errors[Code.substitution] += 1
- elif code == Code.deletion:
- ub_wer_dict[rec_name]["wer"].ref_words += 1
- ub_wer_dict[rec_name]["wer"].errors[Code.deletion] += 1
- if lab_word in hot_true_list:
- # tmp_ref.append(ref_tokens[ref_idx])
- ub_wer_dict[rec_name]["b_wer"].ref_words += 1
- ub_wer_dict[rec_name]["b_wer"].errors[Code.deletion] += 1
- else:
- ub_wer_dict[rec_name]["u_wer"].ref_words += 1
- ub_wer_dict[rec_name]["u_wer"].errors[Code.deletion] += 1
- elif code == Code.insertion:
- ub_wer_dict[rec_name]["wer"].errors[Code.insertion] += 1
- if rec_word in hot_true_list:
- ub_wer_dict[rec_name]["b_wer"].errors[Code.insertion] += 1
- else:
- ub_wer_dict[rec_name]["u_wer"].errors[Code.insertion] += 1
- space = {}
- space['lab'] = []
- space['rec'] = []
- for idx in range(len(result['lab'])):
- len_lab = width(result['lab'][idx])
- len_rec = width(result['rec'][idx])
- length = max(len_lab, len_rec)
- space['lab'].append(length - len_lab)
- space['rec'].append(length - len_rec)
- upper_lab = len(result['lab'])
- upper_rec = len(result['rec'])
- lab1, rec1 = 0, 0
- while lab1 < upper_lab or rec1 < upper_rec:
- if verbose > 1:
- print('lab(%s):' % fid.encode('utf-8'), end=' ')
- else:
- print('lab:', end=' ')
- lab2 = min(upper_lab, lab1 + max_words_per_line)
- for idx in range(lab1, lab2):
- token = result['lab'][idx]
- print('{token}'.format(token=token), end='')
- for n in range(space['lab'][idx]):
- print(padding_symbol, end='')
- print(' ', end='')
- print()
- if verbose > 1:
- print('rec(%s):' % fid.encode('utf-8'), end=' ')
- else:
- print('rec:', end=' ')
- rec2 = min(upper_rec, rec1 + max_words_per_line)
- for idx in range(rec1, rec2):
- token = result['rec'][idx]
- print('{token}'.format(token=token), end='')
- for n in range(space['rec'][idx]):
- print(padding_symbol, end='')
- print(' ', end='')
- print()
- # print('\n', end='\n')
- lab1 = lab2
- rec1 = rec2
- print('\n', end='\n')
- # break
- if verbose:
- print('===========================================================================')
- print()
- print(wrong_rec_but_in_ocr_dict)
- for rec_name in rec_names:
- result = calculators_dict[rec_name].overall()
- if result['all'] != 0:
- wer = float(result['ins'] + result['sub'] + result['del']) * 100.0 / result['all']
- else:
- wer = 0.0
- print('{} Overall -> {:4.2f} %'.format(rec_name, wer), end=' ')
- print('N=%d C=%d S=%d D=%d I=%d' %
- (result['all'], result['cor'], result['sub'], result['del'], result['ins']))
- print(f"WER: {ub_wer_dict[rec_name]['wer'].get_result_string()}")
- print(f"U-WER: {ub_wer_dict[rec_name]['u_wer'].get_result_string()}")
- print(f"B-WER: {ub_wer_dict[rec_name]['b_wer'].get_result_string()}")
- print('hotword: tp: {}, tn: {}, fp: {}, fn: {}, all: {}, recall: {:.2f}%'.format(
- hotwords_related_dict[rec_name]['tp'],
- hotwords_related_dict[rec_name]['tn'],
- hotwords_related_dict[rec_name]['fp'],
- hotwords_related_dict[rec_name]['fn'],
- sum([v for k, v in hotwords_related_dict[rec_name].items()]),
- hotwords_related_dict[rec_name]['tp'] / (
- hotwords_related_dict[rec_name]['tp'] + hotwords_related_dict[rec_name]['fn']
- ) * 100 if hotwords_related_dict[rec_name]['tp'] + hotwords_related_dict[rec_name]['fn'] != 0 else 0
- ))
- # tp: 热词在label里,同时在rec里
- # tn: 热词不在label里,同时不在rec里
- # fp: 热词不在label里,但是在rec里
- # fn: 热词在label里,但是不在rec里
- if not verbose:
- print()
- print()
- if __name__ == "__main__":
- args = get_args()
-
- # print("")
- print(args)
- main(args)
|