tokenize.py 2.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273
  1. #!/usr/bin/env python
  2. import re
  3. import numpy as np
  4. def forward_segment(text, seg_dict):
  5. word_list = []
  6. i = 0
  7. while i < len(text):
  8. longest_word = text[i]
  9. for j in range(i + 1, len(text) + 1):
  10. word = text[i:j]
  11. if word in seg_dict:
  12. if len(word) > len(longest_word):
  13. longest_word = word
  14. word_list.append(longest_word)
  15. i += len(longest_word)
  16. return word_list
  17. def seg_tokenize(txt, seg_dict):
  18. out_txt = ""
  19. for word in txt:
  20. if word in seg_dict:
  21. out_txt += seg_dict[word] + " "
  22. else:
  23. out_txt += "<unk>" + " "
  24. return out_txt.strip().split()
  25. def tokenize(data,
  26. vocab=None,
  27. seg_dict=None,
  28. punc_dict=None,
  29. bpe_tokenizer=None):
  30. assert "text" in data
  31. assert isinstance(vocab, dict)
  32. text = data["text"]
  33. token = []
  34. vad = -2
  35. if bpe_tokenizer is not None:
  36. text = bpe_tokenizer.text2tokens("".join(text))
  37. if seg_dict is not None:
  38. assert isinstance(seg_dict, dict)
  39. txt = forward_segment("".join(text).lower(), seg_dict)
  40. text = seg_tokenize(txt, seg_dict)
  41. length = len(text)
  42. for i in range(length):
  43. x = text[i]
  44. if i == length-1 and "punc" in data and text[i].startswith("vad:"):
  45. vad = x[-1][4:]
  46. if len(vad) == 0:
  47. vad = -1
  48. else:
  49. vad = int(vad)
  50. elif x in vocab:
  51. token.append(vocab[x])
  52. else:
  53. token.append(vocab['<unk>'])
  54. if "punc" in data and punc_dict is not None:
  55. punc_token = []
  56. for punc in data["punc"]:
  57. if punc in punc_dict:
  58. punc_token.append(punc_dict[punc])
  59. else:
  60. punc_token.append(punc_dict["_"])
  61. data["punc"] = np.array(punc_token)
  62. data["text"] = np.array(token)
  63. if vad is not -2:
  64. data["vad_indexes"]=np.array([vad], dtype=np.int64)
  65. return data