utils.py 1.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869
  1. import csv
  2. import os
  3. from typing import Union
  4. import inflect
  5. UNIT_1e01 = '十'
  6. UNIT_1e02 = '百'
  7. UNIT_1e03 = '千'
  8. UNIT_1e04 = '万'
  9. _inflect = inflect.engine()
  10. def num_to_word(x: Union[str, int]):
  11. """
  12. converts integer to spoken representation
  13. Args
  14. x: integer
  15. Returns: spoken representation
  16. """
  17. if isinstance(x, int):
  18. x = str(x)
  19. x = _inflect.number_to_words(str(x)).replace("-", " ").replace(",", "")
  20. return x
  21. def get_abs_path(rel_path):
  22. """
  23. Get absolute path
  24. Args:
  25. rel_path: relative path to this file
  26. Returns absolute path
  27. """
  28. return os.path.dirname(os.path.abspath(__file__)) + '/' + rel_path
  29. def load_labels(abs_path):
  30. """
  31. loads relative path file as dictionary
  32. Args:
  33. abs_path: absolute path
  34. Returns dictionary of mappings
  35. """
  36. label_tsv = open(abs_path, encoding="utf-8")
  37. labels = list(csv.reader(label_tsv, delimiter="\t"))
  38. return labels
  39. def augment_labels_with_punct_at_end(labels):
  40. """
  41. augments labels: if key ends on a punctuation that value does not have, add a new label
  42. where the value maintains the punctuation
  43. Args:
  44. labels : input labels
  45. Returns:
  46. additional labels
  47. """
  48. res = []
  49. for label in labels:
  50. if len(label) > 1:
  51. if label[0][-1] == "." and label[1][-1] != ".":
  52. res.append([label[0], label[1] + "."] + label[2:])
  53. return res