utils.py 1.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263
  1. import csv
  2. import os
  3. from typing import Union
  4. import inflect
  5. _inflect = inflect.engine()
  6. def num_to_word(x: Union[str, int]):
  7. """
  8. converts integer to spoken representation
  9. Args
  10. x: integer
  11. Returns: spoken representation
  12. """
  13. if isinstance(x, int):
  14. x = str(x)
  15. x = _inflect.number_to_words(str(x)).replace("-", " ").replace(",", "")
  16. return x
  17. def get_abs_path(rel_path):
  18. """
  19. Get absolute path
  20. Args:
  21. rel_path: relative path to this file
  22. Returns absolute path
  23. """
  24. return os.path.dirname(os.path.abspath(__file__)) + '/' + rel_path
  25. def load_labels(abs_path):
  26. """
  27. loads relative path file as dictionary
  28. Args:
  29. abs_path: absolute path
  30. Returns dictionary of mappings
  31. """
  32. label_tsv = open(abs_path, encoding="utf-8")
  33. labels = list(csv.reader(label_tsv, delimiter="\t"))
  34. return labels
  35. def augment_labels_with_punct_at_end(labels):
  36. """
  37. augments labels: if key ends on a punctuation that value does not have, add a new label
  38. where the value maintains the punctuation
  39. Args:
  40. labels : input labels
  41. Returns:
  42. additional labels
  43. """
  44. res = []
  45. for label in labels:
  46. if len(label) > 1:
  47. if label[0][-1] == "." and label[1][-1] != ".":
  48. res.append([label[0], label[1] + "."] + label[2:])
  49. return res