postprocess_utils.py 9.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import string
  3. import logging
  4. from typing import Any, List, Union
  5. def isChinese(ch: str):
  6. if '\u4e00' <= ch <= '\u9fff' or '\u0030' <= ch <= '\u0039' or ch == '@':
  7. return True
  8. return False
  9. def isAllChinese(word: Union[List[Any], str]):
  10. word_lists = []
  11. for i in word:
  12. cur = i.replace(' ', '')
  13. cur = cur.replace('</s>', '')
  14. cur = cur.replace('<s>', '')
  15. cur = cur.replace('<unk>', '')
  16. cur = cur.replace('<OOV>', '')
  17. word_lists.append(cur)
  18. if len(word_lists) == 0:
  19. return False
  20. for ch in word_lists:
  21. if isChinese(ch) is False:
  22. return False
  23. return True
  24. def isAllAlpha(word: Union[List[Any], str]):
  25. word_lists = []
  26. for i in word:
  27. cur = i.replace(' ', '')
  28. cur = cur.replace('</s>', '')
  29. cur = cur.replace('<s>', '')
  30. cur = cur.replace('<unk>', '')
  31. cur = cur.replace('<OOV>', '')
  32. word_lists.append(cur)
  33. if len(word_lists) == 0:
  34. return False
  35. for ch in word_lists:
  36. if ch.isalpha() is False and ch != "'":
  37. return False
  38. elif ch.isalpha() is True and isChinese(ch) is True:
  39. return False
  40. return True
  41. # def abbr_dispose(words: List[Any]) -> List[Any]:
  42. def abbr_dispose(words: List[Any], time_stamp: List[List] = None) -> List[Any]:
  43. words_size = len(words)
  44. word_lists = []
  45. abbr_begin = []
  46. abbr_end = []
  47. last_num = -1
  48. ts_lists = []
  49. ts_nums = []
  50. ts_index = 0
  51. for num in range(words_size):
  52. if num <= last_num:
  53. continue
  54. if len(words[num]) == 1 and words[num].encode('utf-8').isalpha():
  55. if num + 1 < words_size and words[
  56. num + 1] == ' ' and num + 2 < words_size and len(
  57. words[num +
  58. 2]) == 1 and words[num +
  59. 2].encode('utf-8').isalpha():
  60. # found the begin of abbr
  61. abbr_begin.append(num)
  62. num += 2
  63. abbr_end.append(num)
  64. # to find the end of abbr
  65. while True:
  66. num += 1
  67. if num < words_size and words[num] == ' ':
  68. num += 1
  69. if num < words_size and len(
  70. words[num]) == 1 and words[num].encode(
  71. 'utf-8').isalpha():
  72. abbr_end.pop()
  73. abbr_end.append(num)
  74. last_num = num
  75. else:
  76. break
  77. else:
  78. break
  79. for num in range(words_size):
  80. if words[num] == ' ':
  81. ts_nums.append(ts_index)
  82. else:
  83. ts_nums.append(ts_index)
  84. ts_index += 1
  85. last_num = -1
  86. for num in range(words_size):
  87. if num <= last_num:
  88. continue
  89. if num in abbr_begin:
  90. if time_stamp is not None:
  91. begin = time_stamp[ts_nums[num]][0]
  92. abbr_word = words[num].upper()
  93. num += 1
  94. while num < words_size:
  95. if num in abbr_end:
  96. abbr_word += words[num].upper()
  97. last_num = num
  98. break
  99. else:
  100. if words[num].encode('utf-8').isalpha():
  101. abbr_word += words[num].upper()
  102. num += 1
  103. word_lists.append(abbr_word)
  104. if time_stamp is not None:
  105. end = time_stamp[ts_nums[num]][1]
  106. ts_lists.append([begin, end])
  107. else:
  108. word_lists.append(words[num])
  109. if time_stamp is not None and words[num] != ' ':
  110. begin = time_stamp[ts_nums[num]][0]
  111. end = time_stamp[ts_nums[num]][1]
  112. ts_lists.append([begin, end])
  113. begin = end
  114. if time_stamp is not None:
  115. return word_lists, ts_lists
  116. else:
  117. return word_lists
  118. def sentence_postprocess(words: List[Any], time_stamp: List[List] = None):
  119. middle_lists = []
  120. word_lists = []
  121. word_item = ''
  122. ts_lists = []
  123. # wash words lists
  124. for i in words:
  125. word = ''
  126. if isinstance(i, str):
  127. word = i
  128. else:
  129. word = i.decode('utf-8')
  130. if word in ['<s>', '</s>', '<unk>', '<OOV>']:
  131. continue
  132. else:
  133. middle_lists.append(word)
  134. # all chinese characters
  135. if isAllChinese(middle_lists):
  136. for i, ch in enumerate(middle_lists):
  137. word_lists.append(ch.replace(' ', ''))
  138. if time_stamp is not None:
  139. ts_lists = time_stamp
  140. # all alpha characters
  141. elif isAllAlpha(middle_lists):
  142. ts_flag = True
  143. for i, ch in enumerate(middle_lists):
  144. if ts_flag and time_stamp is not None:
  145. begin = time_stamp[i][0]
  146. end = time_stamp[i][1]
  147. word = ''
  148. if '@@' in ch:
  149. word = ch.replace('@@', '')
  150. word_item += word
  151. if time_stamp is not None:
  152. ts_flag = False
  153. end = time_stamp[i][1]
  154. else:
  155. word_item += ch
  156. word_lists.append(word_item)
  157. word_lists.append(' ')
  158. word_item = ''
  159. if time_stamp is not None:
  160. ts_flag = True
  161. end = time_stamp[i][1]
  162. ts_lists.append([begin, end])
  163. begin = end
  164. # mix characters
  165. else:
  166. alpha_blank = False
  167. ts_flag = True
  168. begin = -1
  169. end = -1
  170. for i, ch in enumerate(middle_lists):
  171. if ts_flag and time_stamp is not None:
  172. begin = time_stamp[i][0]
  173. end = time_stamp[i][1]
  174. word = ''
  175. if isAllChinese(ch):
  176. if alpha_blank is True:
  177. word_lists.pop()
  178. word_lists.append(ch)
  179. alpha_blank = False
  180. if time_stamp is not None:
  181. ts_flag = True
  182. ts_lists.append([begin, end])
  183. begin = end
  184. elif '@@' in ch:
  185. word = ch.replace('@@', '')
  186. word_item += word
  187. alpha_blank = False
  188. if time_stamp is not None:
  189. ts_flag = False
  190. end = time_stamp[i][1]
  191. elif isAllAlpha(ch):
  192. word_item += ch
  193. word_lists.append(word_item)
  194. word_lists.append(' ')
  195. word_item = ''
  196. alpha_blank = True
  197. if time_stamp is not None:
  198. ts_flag = True
  199. end = time_stamp[i][1]
  200. ts_lists.append([begin, end])
  201. begin = end
  202. else:
  203. word_lists.append(ch)
  204. if time_stamp is not None:
  205. word_lists, ts_lists = abbr_dispose(word_lists, ts_lists)
  206. real_word_lists = []
  207. for ch in word_lists:
  208. if ch != ' ':
  209. real_word_lists.append(ch)
  210. sentence = ' '.join(real_word_lists).strip()
  211. return sentence, ts_lists, real_word_lists
  212. else:
  213. word_lists = abbr_dispose(word_lists)
  214. real_word_lists = []
  215. for ch in word_lists:
  216. if ch != ' ':
  217. real_word_lists.append(ch)
  218. sentence = ''.join(word_lists).strip()
  219. return sentence, real_word_lists
  220. def sentence_postprocess_sentencepiece(words):
  221. middle_lists = []
  222. word_lists = []
  223. word_item = ''
  224. # wash words lists
  225. for i in words:
  226. word = ''
  227. if isinstance(i, str):
  228. word = i
  229. else:
  230. word = i.decode('utf-8')
  231. if word in ['<s>', '</s>', '<unk>', '<OOV>']:
  232. continue
  233. else:
  234. middle_lists.append(word)
  235. # all alpha characters
  236. for i, ch in enumerate(middle_lists):
  237. word = ''
  238. if '\u2581' in ch and i == 0:
  239. word_item = ''
  240. word = ch.replace('\u2581', '')
  241. word_item += word
  242. elif '\u2581' in ch and i != 0:
  243. word_lists.append(word_item)
  244. word_lists.append(' ')
  245. word_item = ''
  246. word = ch.replace('\u2581', '')
  247. word_item += word
  248. else:
  249. word_item += ch
  250. if word_item is not None:
  251. word_lists.append(word_item)
  252. #word_lists = abbr_dispose(word_lists)
  253. real_word_lists = []
  254. for ch in word_lists:
  255. if ch != ' ':
  256. if ch == "i":
  257. ch = ch.replace("i", "I")
  258. elif ch == "i'm":
  259. ch = ch.replace("i'm", "I'm")
  260. elif ch == "i've":
  261. ch = ch.replace("i've", "I've")
  262. elif ch == "i'll":
  263. ch = ch.replace("i'll", "I'll")
  264. real_word_lists.append(ch)
  265. sentence = ''.join(word_lists)
  266. return sentence, real_word_lists