utils.py 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263
  1. import contextlib
  2. import faulthandler
  3. import functools
  4. import io
  5. import multiprocessing
  6. import os
  7. import platform
  8. import signal
  9. import tempfile
  10. from typing import Any
  11. # use cache to avoid loading the same file multiple times
  12. # which can leads to too many open files error
  13. @functools.lru_cache(maxsize=128)
  14. def load_file(filepath: str) -> str:
  15. with open(filepath, 'r') as f:
  16. content = f.read()
  17. return content
  18. """Check the correctness of a program by running a test suite.
  19. Modified from: https://github.com/openai/human-eval/blob/master/human_eval/execution.py
  20. """
  21. def unsafe_execute(
  22. result: Any,
  23. solution_code: str,
  24. test_code: str,
  25. timeout: float = 10,
  26. ):
  27. with create_tempdir():
  28. # These system calls are needed when cleaning up tempdir.
  29. import os
  30. import shutil
  31. rmtree = shutil.rmtree
  32. rmdir = os.rmdir
  33. chdir = os.chdir
  34. # Disable functionalities that can make destructive changes to the test.
  35. reliability_guard()
  36. # Construct the check program and run it.
  37. check_program = solution_code + '\n' + test_code
  38. try:
  39. exec_globals = {}
  40. with swallow_io():
  41. with time_limit(timeout):
  42. # WARNING
  43. # This program exists to execute untrusted model-generated code. Although
  44. # it is highly unlikely that model-generated code will do something overtly
  45. # malicious in response to this test suite, model-generated code may act
  46. # destructively due to a lack of model capability or alignment.
  47. # Users are strongly encouraged to sandbox this evaluation suite so that it
  48. # does not perform destructive actions on their host or network. For more
  49. # information on how OpenAI sandboxes its code, see the accompanying paper.
  50. # Once you have read this disclaimer and taken appropriate precautions,
  51. # uncomment the following line and proceed at your own risk:
  52. exec(check_program, exec_globals)
  53. result.append('passed')
  54. except TimeoutException:
  55. result.append('timed out')
  56. except BaseException as e:
  57. result.append(f'failed: {e}')
  58. # Needed for cleaning up.
  59. shutil.rmtree = rmtree
  60. os.rmdir = rmdir
  61. os.chdir = chdir
  62. def check_correctness(
  63. solution_code: str,
  64. test_code: str,
  65. timeout: float = 10,
  66. completion_id: int | None = None,
  67. ) -> dict:
  68. """Evaluates the functional correctness of a completion by running the test
  69. suite provided in the problem.
  70. :param completion_id: an optional completion ID so we can match
  71. the results later even if execution finishes asynchronously.
  72. """
  73. manager = multiprocessing.Manager()
  74. result = manager.list()
  75. p = multiprocessing.Process(
  76. target=unsafe_execute, args=(result, solution_code, test_code, timeout)
  77. )
  78. p.start()
  79. p.join(timeout=timeout + 1)
  80. if p.is_alive():
  81. p.kill()
  82. if not result:
  83. result.append('timed out')
  84. return dict(
  85. success=result[0] == 'passed',
  86. result=result[0],
  87. completion_id=completion_id,
  88. )
  89. @contextlib.contextmanager
  90. def time_limit(seconds: float):
  91. def signal_handler(signum, frame):
  92. raise TimeoutException('Timed out!')
  93. signal.setitimer(signal.ITIMER_REAL, seconds)
  94. signal.signal(signal.SIGALRM, signal_handler)
  95. try:
  96. yield
  97. finally:
  98. signal.setitimer(signal.ITIMER_REAL, 0)
  99. @contextlib.contextmanager
  100. def swallow_io():
  101. stream = WriteOnlyStringIO()
  102. with contextlib.redirect_stdout(stream):
  103. with contextlib.redirect_stderr(stream):
  104. with redirect_stdin(stream):
  105. yield
  106. @contextlib.contextmanager
  107. def create_tempdir():
  108. # with tempfile.TemporaryDirectory() as dirname:
  109. # Manually do this to avoid too many open files error caused by TemporaryDirectory
  110. dirname = tempfile.mkdtemp()
  111. with chdir(dirname):
  112. yield dirname
  113. os.rmdir(dirname)
  114. class TimeoutException(Exception):
  115. pass
  116. class WriteOnlyStringIO(io.StringIO):
  117. """StringIO that throws an exception when it's read from"""
  118. def read(self, *args, **kwargs):
  119. raise IOError
  120. def readline(self, *args, **kwargs):
  121. raise IOError
  122. def readlines(self, *args, **kwargs):
  123. raise IOError
  124. def readable(self, *args, **kwargs):
  125. """Returns True if the IO object can be read."""
  126. return False
  127. class redirect_stdin(contextlib._RedirectStream): # type: ignore
  128. _stream = 'stdin'
  129. @contextlib.contextmanager
  130. def chdir(root):
  131. if root == '.':
  132. yield
  133. return
  134. cwd = os.getcwd()
  135. os.chdir(root)
  136. try:
  137. yield
  138. except BaseException as exc:
  139. raise exc
  140. finally:
  141. os.chdir(cwd)
  142. def reliability_guard(maximum_memory_bytes: int | None = None):
  143. """This disables various destructive functions and prevents the generated code
  144. from interfering with the test (e.g. fork bomb, killing other processes,
  145. removing filesystem files, etc.)
  146. Warning:
  147. This function is NOT a security sandbox. Untrusted code, including, model-
  148. generated code, should not be blindly executed outside of one. See the
  149. Codex paper for more information about OpenAI's code sandbox, and proceed
  150. with caution.
  151. """
  152. if maximum_memory_bytes is not None:
  153. import resource
  154. resource.setrlimit(
  155. resource.RLIMIT_AS, (maximum_memory_bytes, maximum_memory_bytes)
  156. )
  157. resource.setrlimit(
  158. resource.RLIMIT_DATA, (maximum_memory_bytes, maximum_memory_bytes)
  159. )
  160. if not platform.uname().system == 'Darwin':
  161. resource.setrlimit(
  162. resource.RLIMIT_STACK, (maximum_memory_bytes, maximum_memory_bytes)
  163. )
  164. faulthandler.disable()
  165. import builtins
  166. builtins.exit = None
  167. builtins.quit = None
  168. import os
  169. os.environ['OMP_NUM_THREADS'] = '1'
  170. os.kill = None
  171. os.system = None
  172. os.putenv = None
  173. os.remove = None
  174. os.removedirs = None
  175. os.rmdir = None
  176. os.fchdir = None
  177. os.setuid = None
  178. os.fork = None
  179. os.forkpty = None
  180. os.killpg = None
  181. os.rename = None
  182. os.renames = None
  183. os.truncate = None
  184. os.replace = None
  185. os.unlink = None
  186. os.fchmod = None
  187. os.fchown = None
  188. os.chmod = None
  189. os.chown = None
  190. os.chroot = None
  191. os.fchdir = None
  192. os.lchflags = None
  193. os.lchmod = None
  194. os.lchown = None
  195. os.getcwd = None
  196. os.chdir = None
  197. import shutil
  198. shutil.rmtree = None
  199. shutil.move = None
  200. shutil.chown = None
  201. import subprocess
  202. subprocess.Popen = None # type: ignore
  203. __builtins__['help'] = None
  204. import sys
  205. sys.modules['ipdb'] = None
  206. sys.modules['joblib'] = None
  207. sys.modules['resource'] = None
  208. sys.modules['psutil'] = None
  209. sys.modules['tkinter'] = None