datadir_writer.py 2.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778
  1. from pathlib import Path
  2. from typing import Union
  3. import warnings
  4. from typeguard import check_argument_types
  5. from typeguard import check_return_type
  6. class DatadirWriter:
  7. """Writer class to create kaldi like data directory.
  8. Examples:
  9. >>> with DatadirWriter("output") as writer:
  10. ... # output/sub.txt is created here
  11. ... subwriter = writer["sub.txt"]
  12. ... # Write "uttidA some/where/a.wav"
  13. ... subwriter["uttidA"] = "some/where/a.wav"
  14. ... subwriter["uttidB"] = "some/where/b.wav"
  15. """
  16. def __init__(self, p: Union[Path, str]):
  17. assert check_argument_types()
  18. self.path = Path(p)
  19. self.chilidren = {}
  20. self.fd = None
  21. self.has_children = False
  22. self.keys = set()
  23. def __enter__(self):
  24. return self
  25. def __getitem__(self, key: str) -> "DatadirWriter":
  26. assert check_argument_types()
  27. if self.fd is not None:
  28. raise RuntimeError("This writer points out a file")
  29. if key not in self.chilidren:
  30. w = DatadirWriter((self.path / key))
  31. self.chilidren[key] = w
  32. self.has_children = True
  33. retval = self.chilidren[key]
  34. assert check_return_type(retval)
  35. return retval
  36. def __setitem__(self, key: str, value: str):
  37. assert check_argument_types()
  38. if self.has_children:
  39. raise RuntimeError("This writer points out a directory")
  40. if key in self.keys:
  41. warnings.warn(f"Duplicated: {key}")
  42. if self.fd is None:
  43. self.path.parent.mkdir(parents=True, exist_ok=True)
  44. self.fd = self.path.open("w", encoding="utf-8")
  45. self.keys.add(key)
  46. self.fd.write(f"{key} {value}\n")
  47. self.fd.flush()
  48. def __exit__(self, exc_type, exc_val, exc_tb):
  49. self.close()
  50. def close(self):
  51. if self.has_children:
  52. prev_child = None
  53. for child in self.chilidren.values():
  54. child.close()
  55. if prev_child is not None and prev_child.keys != child.keys:
  56. warnings.warn(
  57. f"Ids are mismatching between "
  58. f"{prev_child.path} and {child.path}"
  59. )
  60. prev_child = child
  61. elif self.fd is not None:
  62. self.fd.close()