postprocess_utils.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import string
  3. from typing import Any, List, Union
  4. def isChinese(ch: str):
  5. if '\u4e00' <= ch <= '\u9fff':
  6. return True
  7. return False
  8. def isAllChinese(word: Union[List[Any], str]):
  9. word_lists = []
  10. for i in word:
  11. cur = i.replace(' ', '')
  12. cur = cur.replace('</s>', '')
  13. cur = cur.replace('<s>', '')
  14. word_lists.append(cur)
  15. if len(word_lists) == 0:
  16. return False
  17. for ch in word_lists:
  18. if isChinese(ch) is False:
  19. return False
  20. return True
  21. def isAllAlpha(word: Union[List[Any], str]):
  22. word_lists = []
  23. for i in word:
  24. cur = i.replace(' ', '')
  25. cur = cur.replace('</s>', '')
  26. cur = cur.replace('<s>', '')
  27. word_lists.append(cur)
  28. if len(word_lists) == 0:
  29. return False
  30. for ch in word_lists:
  31. if ch.isalpha() is False and ch != "'":
  32. return False
  33. elif ch.isalpha() is True and isChinese(ch) is True:
  34. return False
  35. return True
  36. def abbr_dispose(words: List[Any]) -> List[Any]:
  37. words_size = len(words)
  38. word_lists = []
  39. abbr_begin = []
  40. abbr_end = []
  41. last_num = -1
  42. for num in range(words_size):
  43. if num <= last_num:
  44. continue
  45. if len(words[num]) == 1 and words[num].encode('utf-8').isalpha():
  46. if num + 1 < words_size and words[
  47. num + 1] == ' ' and num + 2 < words_size and len(
  48. words[num +
  49. 2]) == 1 and words[num +
  50. 2].encode('utf-8').isalpha():
  51. # found the begin of abbr
  52. abbr_begin.append(num)
  53. num += 2
  54. abbr_end.append(num)
  55. # to find the end of abbr
  56. while True:
  57. num += 1
  58. if num < words_size and words[num] == ' ':
  59. num += 1
  60. if num < words_size and len(
  61. words[num]) == 1 and words[num].encode(
  62. 'utf-8').isalpha():
  63. abbr_end.pop()
  64. abbr_end.append(num)
  65. last_num = num
  66. else:
  67. break
  68. else:
  69. break
  70. last_num = -1
  71. for num in range(words_size):
  72. if num <= last_num:
  73. continue
  74. if num in abbr_begin:
  75. word_lists.append(words[num].upper())
  76. num += 1
  77. while num < words_size:
  78. if num in abbr_end:
  79. word_lists.append(words[num].upper())
  80. last_num = num
  81. break
  82. else:
  83. if words[num].encode('utf-8').isalpha():
  84. word_lists.append(words[num].upper())
  85. num += 1
  86. else:
  87. word_lists.append(words[num])
  88. return word_lists
  89. def sentence_postprocess(words: List[Any]):
  90. middle_lists = []
  91. word_lists = []
  92. word_item = ''
  93. # wash words lists
  94. for i in words:
  95. word = ''
  96. if isinstance(i, str):
  97. word = i
  98. else:
  99. word = i.decode('utf-8')
  100. if word in ['<s>', '</s>', '<unk>']:
  101. continue
  102. else:
  103. middle_lists.append(word)
  104. # all chinese characters
  105. if isAllChinese(middle_lists):
  106. for ch in middle_lists:
  107. word_lists.append(ch.replace(' ', ''))
  108. # all alpha characters
  109. elif isAllAlpha(middle_lists):
  110. for ch in middle_lists:
  111. word = ''
  112. if '@@' in ch:
  113. word = ch.replace('@@', '')
  114. word_item += word
  115. else:
  116. word_item += ch
  117. word_lists.append(word_item)
  118. word_lists.append(' ')
  119. word_item = ''
  120. # mix characters
  121. else:
  122. alpha_blank = False
  123. for ch in middle_lists:
  124. word = ''
  125. if isAllChinese(ch):
  126. if alpha_blank is True:
  127. word_lists.pop()
  128. word_lists.append(ch)
  129. alpha_blank = False
  130. elif '@@' in ch:
  131. word = ch.replace('@@', '')
  132. word_item += word
  133. alpha_blank = False
  134. elif isAllAlpha(ch):
  135. word_item += ch
  136. word_lists.append(word_item)
  137. word_lists.append(' ')
  138. word_item = ''
  139. alpha_blank = True
  140. else:
  141. raise ValueError('invalid character: {}'.format(ch))
  142. word_lists = abbr_dispose(word_lists)
  143. sentence = ''.join(word_lists).strip()
  144. return sentence