pack_funcs.py 9.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302
  1. from datetime import datetime
  2. from io import BytesIO
  3. from io import TextIOWrapper
  4. import os
  5. from pathlib import Path
  6. import sys
  7. import tarfile
  8. from typing import Dict
  9. from typing import Iterable
  10. from typing import Optional
  11. from typing import Union
  12. import zipfile
  13. import yaml
  14. class Archiver:
  15. def __init__(self, file, mode="r"):
  16. if Path(file).suffix == ".tar":
  17. self.type = "tar"
  18. elif Path(file).suffix == ".tgz" or Path(file).suffixes == [".tar", ".gz"]:
  19. self.type = "tar"
  20. if mode == "w":
  21. mode = "w:gz"
  22. elif Path(file).suffix == ".tbz2" or Path(file).suffixes == [".tar", ".bz2"]:
  23. self.type = "tar"
  24. if mode == "w":
  25. mode = "w:bz2"
  26. elif Path(file).suffix == ".txz" or Path(file).suffixes == [".tar", ".xz"]:
  27. self.type = "tar"
  28. if mode == "w":
  29. mode = "w:xz"
  30. elif Path(file).suffix == ".zip":
  31. self.type = "zip"
  32. else:
  33. raise ValueError(f"Cannot detect archive format: type={file}")
  34. if self.type == "tar":
  35. self.fopen = tarfile.open(file, mode=mode)
  36. elif self.type == "zip":
  37. self.fopen = zipfile.ZipFile(file, mode=mode)
  38. else:
  39. raise ValueError(f"Not supported: type={type}")
  40. def __enter__(self):
  41. return self
  42. def __exit__(self, exc_type, exc_val, exc_tb):
  43. self.fopen.close()
  44. def close(self):
  45. self.fopen.close()
  46. def __iter__(self):
  47. if self.type == "tar":
  48. return iter(self.fopen)
  49. elif self.type == "zip":
  50. return iter(self.fopen.infolist())
  51. else:
  52. raise ValueError(f"Not supported: type={self.type}")
  53. def add(self, filename, arcname=None, recursive: bool = True):
  54. if arcname is not None:
  55. print(f"adding: {arcname}")
  56. else:
  57. print(f"adding: {filename}")
  58. if recursive and Path(filename).is_dir():
  59. for f in Path(filename).glob("**/*"):
  60. if f.is_dir():
  61. continue
  62. if arcname is not None:
  63. _arcname = Path(arcname) / f
  64. else:
  65. _arcname = None
  66. self.add(f, _arcname)
  67. return
  68. if self.type == "tar":
  69. return self.fopen.add(filename, arcname)
  70. elif self.type == "zip":
  71. return self.fopen.write(filename, arcname)
  72. else:
  73. raise ValueError(f"Not supported: type={self.type}")
  74. def addfile(self, info, fileobj):
  75. print(f"adding: {self.get_name_from_info(info)}")
  76. if self.type == "tar":
  77. return self.fopen.addfile(info, fileobj)
  78. elif self.type == "zip":
  79. return self.fopen.writestr(info, fileobj.read())
  80. else:
  81. raise ValueError(f"Not supported: type={self.type}")
  82. def generate_info(self, name, size) -> Union[tarfile.TarInfo, zipfile.ZipInfo]:
  83. """Generate TarInfo using system information"""
  84. if self.type == "tar":
  85. tarinfo = tarfile.TarInfo(str(name))
  86. if os.name == "posix":
  87. tarinfo.gid = os.getgid()
  88. tarinfo.uid = os.getuid()
  89. tarinfo.mtime = datetime.now().timestamp()
  90. tarinfo.size = size
  91. # Keep mode as default
  92. return tarinfo
  93. elif self.type == "zip":
  94. zipinfo = zipfile.ZipInfo(str(name), datetime.now().timetuple()[:6])
  95. zipinfo.file_size = size
  96. return zipinfo
  97. else:
  98. raise ValueError(f"Not supported: type={self.type}")
  99. def get_name_from_info(self, info):
  100. if self.type == "tar":
  101. assert isinstance(info, tarfile.TarInfo), type(info)
  102. return info.name
  103. elif self.type == "zip":
  104. assert isinstance(info, zipfile.ZipInfo), type(info)
  105. return info.filename
  106. else:
  107. raise ValueError(f"Not supported: type={self.type}")
  108. def extract(self, info, path=None):
  109. if self.type == "tar":
  110. return self.fopen.extract(info, path)
  111. elif self.type == "zip":
  112. return self.fopen.extract(info, path)
  113. else:
  114. raise ValueError(f"Not supported: type={self.type}")
  115. def extractfile(self, info, mode="r"):
  116. if self.type == "tar":
  117. f = self.fopen.extractfile(info)
  118. if mode == "r":
  119. return TextIOWrapper(f)
  120. else:
  121. return f
  122. elif self.type == "zip":
  123. if mode == "rb":
  124. mode = "r"
  125. return self.fopen.open(info, mode)
  126. else:
  127. raise ValueError(f"Not supported: type={self.type}")
  128. def find_path_and_change_it_recursive(value, src: str, tgt: str):
  129. if isinstance(value, dict):
  130. return {
  131. k: find_path_and_change_it_recursive(v, src, tgt) for k, v in value.items()
  132. }
  133. elif isinstance(value, (list, tuple)):
  134. return [find_path_and_change_it_recursive(v, src, tgt) for v in value]
  135. elif isinstance(value, str) and Path(value) == Path(src):
  136. return tgt
  137. else:
  138. return value
  139. def get_dict_from_cache(meta: Union[Path, str]) -> Optional[Dict[str, str]]:
  140. meta = Path(meta)
  141. outpath = meta.parent.parent
  142. if not meta.exists():
  143. return None
  144. with meta.open("r", encoding="utf-8") as f:
  145. d = yaml.safe_load(f)
  146. assert isinstance(d, dict), type(d)
  147. yaml_files = d["yaml_files"]
  148. files = d["files"]
  149. assert isinstance(yaml_files, dict), type(yaml_files)
  150. assert isinstance(files, dict), type(files)
  151. retval = {}
  152. for key, value in list(yaml_files.items()) + list(files.items()):
  153. if not (outpath / value).exists():
  154. return None
  155. retval[key] = str(outpath / value)
  156. return retval
  157. def unpack(
  158. input_archive: Union[Path, str],
  159. outpath: Union[Path, str],
  160. use_cache: bool = True,
  161. ) -> Dict[str, str]:
  162. """Scan all files in the archive file and return as a dict of files.
  163. Examples:
  164. tarfile:
  165. model.pb
  166. some1.file
  167. some2.file
  168. >>> unpack("tarfile", "out")
  169. {'asr_model_file': 'out/model.pb'}
  170. """
  171. input_archive = Path(input_archive)
  172. outpath = Path(outpath)
  173. with Archiver(input_archive) as archive:
  174. for info in archive:
  175. if Path(archive.get_name_from_info(info)).name == "meta.yaml":
  176. if (
  177. use_cache
  178. and (outpath / Path(archive.get_name_from_info(info))).exists()
  179. ):
  180. retval = get_dict_from_cache(
  181. outpath / Path(archive.get_name_from_info(info))
  182. )
  183. if retval is not None:
  184. return retval
  185. d = yaml.safe_load(archive.extractfile(info))
  186. assert isinstance(d, dict), type(d)
  187. yaml_files = d["yaml_files"]
  188. files = d["files"]
  189. assert isinstance(yaml_files, dict), type(yaml_files)
  190. assert isinstance(files, dict), type(files)
  191. break
  192. else:
  193. raise RuntimeError("Format error: not found meta.yaml")
  194. for info in archive:
  195. fname = archive.get_name_from_info(info)
  196. outname = outpath / fname
  197. outname.parent.mkdir(parents=True, exist_ok=True)
  198. if fname in set(yaml_files.values()):
  199. d = yaml.safe_load(archive.extractfile(info))
  200. # Rewrite yaml
  201. for info2 in archive:
  202. name = archive.get_name_from_info(info2)
  203. d = find_path_and_change_it_recursive(d, name, str(outpath / name))
  204. with outname.open("w", encoding="utf-8") as f:
  205. yaml.safe_dump(d, f)
  206. else:
  207. archive.extract(info, path=outpath)
  208. retval = {}
  209. for key, value in list(yaml_files.items()) + list(files.items()):
  210. retval[key] = str(outpath / value)
  211. return retval
  212. def _to_relative_or_resolve(f):
  213. # Resolve to avoid symbolic link
  214. p = Path(f).resolve()
  215. try:
  216. # Change to relative if it can
  217. p = p.relative_to(Path(".").resolve())
  218. except ValueError:
  219. pass
  220. return str(p)
  221. def pack(
  222. files: Dict[str, Union[str, Path]],
  223. yaml_files: Dict[str, Union[str, Path]],
  224. outpath: Union[str, Path],
  225. option: Iterable[Union[str, Path]] = (),
  226. ):
  227. for v in list(files.values()) + list(yaml_files.values()) + list(option):
  228. if not Path(v).exists():
  229. raise FileNotFoundError(f"No such file or directory: {v}")
  230. files = {k: _to_relative_or_resolve(v) for k, v in files.items()}
  231. yaml_files = {k: _to_relative_or_resolve(v) for k, v in yaml_files.items()}
  232. option = [_to_relative_or_resolve(v) for v in option]
  233. meta_objs = dict(
  234. files=files,
  235. yaml_files=yaml_files,
  236. timestamp=datetime.now().timestamp(),
  237. python=sys.version,
  238. )
  239. try:
  240. import torch
  241. meta_objs.update(torch=str(torch.__version__))
  242. except ImportError:
  243. pass
  244. try:
  245. import espnet
  246. meta_objs.update(espnet=espnet.__version__)
  247. except ImportError:
  248. pass
  249. Path(outpath).parent.mkdir(parents=True, exist_ok=True)
  250. with Archiver(outpath, mode="w") as archive:
  251. # Write packed/meta.yaml
  252. fileobj = BytesIO(yaml.safe_dump(meta_objs).encode())
  253. info = archive.generate_info("meta.yaml", fileobj.getbuffer().nbytes)
  254. archive.addfile(info, fileobj=fileobj)
  255. for f in list(yaml_files.values()) + list(files.values()) + list(option):
  256. archive.add(f)
  257. print(f"Generate: {outpath}")