postprocess_utils.py 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import string
  3. import logging
  4. from typing import Any, List, Union
  5. def isChinese(ch: str):
  6. if '\u4e00' <= ch <= '\u9fff' or '\u0030' <= ch <= '\u0039':
  7. return True
  8. return False
  9. def isAllChinese(word: Union[List[Any], str]):
  10. word_lists = []
  11. for i in word:
  12. cur = i.replace(' ', '')
  13. cur = cur.replace('</s>', '')
  14. cur = cur.replace('<s>', '')
  15. word_lists.append(cur)
  16. if len(word_lists) == 0:
  17. return False
  18. for ch in word_lists:
  19. if isChinese(ch) is False:
  20. return False
  21. return True
  22. def isAllAlpha(word: Union[List[Any], str]):
  23. word_lists = []
  24. for i in word:
  25. cur = i.replace(' ', '')
  26. cur = cur.replace('</s>', '')
  27. cur = cur.replace('<s>', '')
  28. word_lists.append(cur)
  29. if len(word_lists) == 0:
  30. return False
  31. for ch in word_lists:
  32. if ch.isalpha() is False and ch != "'":
  33. return False
  34. elif ch.isalpha() is True and isChinese(ch) is True:
  35. return False
  36. return True
  37. # def abbr_dispose(words: List[Any]) -> List[Any]:
  38. def abbr_dispose(words: List[Any], time_stamp: List[List] = None) -> List[Any]:
  39. words_size = len(words)
  40. word_lists = []
  41. abbr_begin = []
  42. abbr_end = []
  43. last_num = -1
  44. ts_lists = []
  45. ts_nums = []
  46. ts_index = 0
  47. for num in range(words_size):
  48. if num <= last_num:
  49. continue
  50. if len(words[num]) == 1 and words[num].encode('utf-8').isalpha():
  51. if num + 1 < words_size and words[
  52. num + 1] == ' ' and num + 2 < words_size and len(
  53. words[num +
  54. 2]) == 1 and words[num +
  55. 2].encode('utf-8').isalpha():
  56. # found the begin of abbr
  57. abbr_begin.append(num)
  58. num += 2
  59. abbr_end.append(num)
  60. # to find the end of abbr
  61. while True:
  62. num += 1
  63. if num < words_size and words[num] == ' ':
  64. num += 1
  65. if num < words_size and len(
  66. words[num]) == 1 and words[num].encode(
  67. 'utf-8').isalpha():
  68. abbr_end.pop()
  69. abbr_end.append(num)
  70. last_num = num
  71. else:
  72. break
  73. else:
  74. break
  75. for num in range(words_size):
  76. if words[num] == ' ':
  77. ts_nums.append(ts_index)
  78. else:
  79. ts_nums.append(ts_index)
  80. ts_index += 1
  81. last_num = -1
  82. for num in range(words_size):
  83. if num <= last_num:
  84. continue
  85. if num in abbr_begin:
  86. if time_stamp is not None:
  87. begin = time_stamp[ts_nums[num]][0]
  88. word_lists.append(words[num].upper())
  89. num += 1
  90. while num < words_size:
  91. if num in abbr_end:
  92. word_lists.append(words[num].upper())
  93. last_num = num
  94. break
  95. else:
  96. if words[num].encode('utf-8').isalpha():
  97. word_lists.append(words[num].upper())
  98. num += 1
  99. if time_stamp is not None:
  100. end = time_stamp[ts_nums[num]][1]
  101. ts_lists.append([begin, end])
  102. else:
  103. word_lists.append(words[num])
  104. if time_stamp is not None and words[num] != ' ':
  105. begin = time_stamp[ts_nums[num]][0]
  106. end = time_stamp[ts_nums[num]][1]
  107. ts_lists.append([begin, end])
  108. begin = end
  109. if time_stamp is not None:
  110. return word_lists, ts_lists
  111. else:
  112. return word_lists
  113. def sentence_postprocess(words: List[Any], time_stamp: List[List] = None):
  114. middle_lists = []
  115. word_lists = []
  116. word_item = ''
  117. ts_lists = []
  118. # wash words lists
  119. for i in words:
  120. word = ''
  121. if isinstance(i, str):
  122. word = i
  123. else:
  124. word = i.decode('utf-8')
  125. if word in ['<s>', '</s>', '<unk>']:
  126. continue
  127. else:
  128. middle_lists.append(word)
  129. # all chinese characters
  130. if isAllChinese(middle_lists):
  131. for i, ch in enumerate(middle_lists):
  132. word_lists.append(ch.replace(' ', ''))
  133. if time_stamp is not None:
  134. ts_lists = time_stamp
  135. # all alpha characters
  136. elif isAllAlpha(middle_lists):
  137. ts_flag = True
  138. for i, ch in enumerate(middle_lists):
  139. if ts_flag and time_stamp is not None:
  140. begin = time_stamp[i][0]
  141. end = time_stamp[i][1]
  142. word = ''
  143. if '@@' in ch:
  144. word = ch.replace('@@', '')
  145. word_item += word
  146. if time_stamp is not None:
  147. ts_flag = False
  148. end = time_stamp[i][1]
  149. else:
  150. word_item += ch
  151. word_lists.append(word_item)
  152. word_lists.append(' ')
  153. word_item = ''
  154. if time_stamp is not None:
  155. ts_flag = True
  156. end = time_stamp[i][1]
  157. ts_lists.append([begin, end])
  158. begin = end
  159. # mix characters
  160. else:
  161. alpha_blank = False
  162. ts_flag = True
  163. begin = -1
  164. end = -1
  165. for i, ch in enumerate(middle_lists):
  166. if ts_flag and time_stamp is not None:
  167. begin = time_stamp[i][0]
  168. end = time_stamp[i][1]
  169. word = ''
  170. if isAllChinese(ch):
  171. if alpha_blank is True:
  172. word_lists.pop()
  173. word_lists.append(ch)
  174. alpha_blank = False
  175. if time_stamp is not None:
  176. ts_flag = True
  177. ts_lists.append([begin, end])
  178. begin = end
  179. elif '@@' in ch:
  180. word = ch.replace('@@', '')
  181. word_item += word
  182. alpha_blank = False
  183. if time_stamp is not None:
  184. ts_flag = False
  185. end = time_stamp[i][1]
  186. elif isAllAlpha(ch):
  187. word_item += ch
  188. word_lists.append(word_item)
  189. word_lists.append(' ')
  190. word_item = ''
  191. alpha_blank = True
  192. if time_stamp is not None:
  193. ts_flag = True
  194. end = time_stamp[i][1]
  195. ts_lists.append([begin, end])
  196. begin = end
  197. else:
  198. raise ValueError('invalid character: {}'.format(ch))
  199. if time_stamp is not None:
  200. word_lists, ts_lists = abbr_dispose(word_lists, ts_lists)
  201. real_word_lists = []
  202. for ch in word_lists:
  203. if ch != ' ':
  204. real_word_lists.append(ch)
  205. sentence = ' '.join(real_word_lists).strip()
  206. return sentence, ts_lists, real_word_lists
  207. else:
  208. word_lists = abbr_dispose(word_lists)
  209. real_word_lists = []
  210. for ch in word_lists:
  211. if ch != ' ':
  212. real_word_lists.append(ch)
  213. sentence = ''.join(word_lists).strip()
  214. return sentence, real_word_lists