utils.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267
  1. import contextlib
  2. import faulthandler
  3. import functools
  4. import io
  5. import multiprocessing
  6. import os
  7. import platform
  8. import signal
  9. import tempfile
  10. from typing import Any, Dict, Optional
  11. # use cache to avoid loading the same file multiple times
  12. # which can leads to too many open files error
  13. @functools.lru_cache(maxsize=128)
  14. def load_file(filepath: str) -> str:
  15. with open(filepath, 'r') as f:
  16. content = f.read()
  17. return content
  18. """Check the correctness of a program by running a test suite.
  19. Modified from: https://github.com/openai/human-eval/blob/master/human_eval/execution.py
  20. """
  21. def unsafe_execute(
  22. result: Any,
  23. solution_code: str,
  24. test_code: str,
  25. timeout: float = 10,
  26. ):
  27. with create_tempdir():
  28. # These system calls are needed when cleaning up tempdir.
  29. import os
  30. import shutil
  31. rmtree = shutil.rmtree
  32. rmdir = os.rmdir
  33. chdir = os.chdir
  34. # Disable functionalities that can make destructive changes to the test.
  35. reliability_guard()
  36. # Construct the check program and run it.
  37. check_program = solution_code + '\n' + test_code
  38. try:
  39. exec_globals = {}
  40. with swallow_io():
  41. with time_limit(timeout):
  42. # WARNING
  43. # This program exists to execute untrusted model-generated code. Although
  44. # it is highly unlikely that model-generated code will do something overtly
  45. # malicious in response to this test suite, model-generated code may act
  46. # destructively due to a lack of model capability or alignment.
  47. # Users are strongly encouraged to sandbox this evaluation suite so that it
  48. # does not perform destructive actions on their host or network. For more
  49. # information on how OpenAI sandboxes its code, see the accompanying paper.
  50. # Once you have read this disclaimer and taken appropriate precautions,
  51. # uncomment the following line and proceed at your own risk:
  52. exec(check_program, exec_globals)
  53. result.append('passed')
  54. except TimeoutException:
  55. result.append('timed out')
  56. except BaseException as e:
  57. result.append(f'failed: {e}')
  58. # Needed for cleaning up.
  59. shutil.rmtree = rmtree
  60. os.rmdir = rmdir
  61. os.chdir = chdir
  62. def check_correctness(
  63. solution_code: str,
  64. test_code: str,
  65. timeout: float = 10,
  66. completion_id: Optional[int] = None,
  67. ) -> Dict:
  68. """
  69. Evaluates the functional correctness of a completion by running the test
  70. suite provided in the problem.
  71. :param completion_id: an optional completion ID so we can match
  72. the results later even if execution finishes asynchronously.
  73. """
  74. manager = multiprocessing.Manager()
  75. result = manager.list()
  76. p = multiprocessing.Process(
  77. target=unsafe_execute, args=(result, solution_code, test_code, timeout)
  78. )
  79. p.start()
  80. p.join(timeout=timeout + 1)
  81. if p.is_alive():
  82. p.kill()
  83. if not result:
  84. result.append('timed out')
  85. return dict(
  86. success=result[0] == 'passed',
  87. result=result[0],
  88. completion_id=completion_id,
  89. )
  90. @contextlib.contextmanager
  91. def time_limit(seconds: float):
  92. def signal_handler(signum, frame):
  93. raise TimeoutException('Timed out!')
  94. signal.setitimer(signal.ITIMER_REAL, seconds)
  95. signal.signal(signal.SIGALRM, signal_handler)
  96. try:
  97. yield
  98. finally:
  99. signal.setitimer(signal.ITIMER_REAL, 0)
  100. @contextlib.contextmanager
  101. def swallow_io():
  102. stream = WriteOnlyStringIO()
  103. with contextlib.redirect_stdout(stream):
  104. with contextlib.redirect_stderr(stream):
  105. with redirect_stdin(stream):
  106. yield
  107. @contextlib.contextmanager
  108. def create_tempdir():
  109. # with tempfile.TemporaryDirectory() as dirname:
  110. # Manually do this to avoid too many open files error caused by TemporaryDirectory
  111. dirname = tempfile.mkdtemp()
  112. with chdir(dirname):
  113. yield dirname
  114. os.rmdir(dirname)
  115. class TimeoutException(Exception):
  116. pass
  117. class WriteOnlyStringIO(io.StringIO):
  118. """StringIO that throws an exception when it's read from"""
  119. def read(self, *args, **kwargs):
  120. raise IOError
  121. def readline(self, *args, **kwargs):
  122. raise IOError
  123. def readlines(self, *args, **kwargs):
  124. raise IOError
  125. def readable(self, *args, **kwargs):
  126. """Returns True if the IO object can be read."""
  127. return False
  128. class redirect_stdin(contextlib._RedirectStream): # type: ignore
  129. _stream = 'stdin'
  130. @contextlib.contextmanager
  131. def chdir(root):
  132. if root == '.':
  133. yield
  134. return
  135. cwd = os.getcwd()
  136. os.chdir(root)
  137. try:
  138. yield
  139. except BaseException as exc:
  140. raise exc
  141. finally:
  142. os.chdir(cwd)
  143. def reliability_guard(maximum_memory_bytes: Optional[int] = None):
  144. """
  145. This disables various destructive functions and prevents the generated code
  146. from interfering with the test (e.g. fork bomb, killing other processes,
  147. removing filesystem files, etc.)
  148. WARNING
  149. This function is NOT a security sandbox. Untrusted code, including, model-
  150. generated code, should not be blindly executed outside of one. See the
  151. Codex paper for more information about OpenAI's code sandbox, and proceed
  152. with caution.
  153. """
  154. if maximum_memory_bytes is not None:
  155. import resource
  156. resource.setrlimit(
  157. resource.RLIMIT_AS, (maximum_memory_bytes, maximum_memory_bytes)
  158. )
  159. resource.setrlimit(
  160. resource.RLIMIT_DATA, (maximum_memory_bytes, maximum_memory_bytes)
  161. )
  162. if not platform.uname().system == 'Darwin':
  163. resource.setrlimit(
  164. resource.RLIMIT_STACK, (maximum_memory_bytes, maximum_memory_bytes)
  165. )
  166. faulthandler.disable()
  167. import builtins
  168. builtins.exit = None
  169. builtins.quit = None
  170. import os
  171. os.environ['OMP_NUM_THREADS'] = '1'
  172. os.kill = None
  173. os.system = None
  174. os.putenv = None
  175. os.remove = None
  176. os.removedirs = None
  177. os.rmdir = None
  178. os.fchdir = None
  179. os.setuid = None
  180. os.fork = None
  181. os.forkpty = None
  182. os.killpg = None
  183. os.rename = None
  184. os.renames = None
  185. os.truncate = None
  186. os.replace = None
  187. os.unlink = None
  188. os.fchmod = None
  189. os.fchown = None
  190. os.chmod = None
  191. os.chown = None
  192. os.chroot = None
  193. os.fchdir = None
  194. os.lchflags = None
  195. os.lchmod = None
  196. os.lchown = None
  197. os.getcwd = None
  198. os.chdir = None
  199. import shutil
  200. shutil.rmtree = None
  201. shutil.move = None
  202. shutil.chown = None
  203. import subprocess
  204. subprocess.Popen = None # type: ignore
  205. __builtins__['help'] = None
  206. import sys
  207. sys.modules['ipdb'] = None
  208. sys.modules['joblib'] = None
  209. sys.modules['resource'] = None
  210. sys.modules['psutil'] = None
  211. sys.modules['tkinter'] = None