codegen.py 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081
  1. import logging
  2. from utils import check_correctness
  3. from .base import Task
  4. LOGGER = logging.getLogger('MINT')
  5. class CodeGenTask(Task):
  6. """Generic code generation task instance."""
  7. def __init__(self, id: str, prompt: str, reference: str, **kwargs):
  8. super().__init__(**kwargs)
  9. self._id = id
  10. self._prompt = prompt
  11. self._reference = reference
  12. def success(self, solution: str) -> bool:
  13. """This checks whether the given solution can complete the current task.
  14. Can be used to provides binary feedback.
  15. """
  16. code_to_exec = self.extract_answer(solution)
  17. LOGGER.debug(f'CODE_TO_EXEC:\n{code_to_exec}')
  18. LOGGER.debug(f'TEST_CODE:\n{self._reference}')
  19. res = check_correctness(
  20. solution_code=code_to_exec, test_code=self._reference, timeout=10
  21. )
  22. return res['success']
  23. class MBPPTask(CodeGenTask):
  24. task_name = 'mbpp'
  25. @property
  26. def prompt(self) -> str:
  27. """Return the prompt for this task.
  28. MBPP prompt contains \"\"\" enclosed at both ends. Need to remove it.
  29. """
  30. return self._prompt.replace('"""', '').strip()
  31. def extract_answer(self, solution: str) -> str | None:
  32. """Extract the answer from the given solution.
  33. Split off first block of code by scanning for class, def etc. on newlines.
  34. Modified from:
  35. https://github.com/bigcode-project/bigcode-evaluation-harness/blob/d61afde130005ecc65cf800ad8eca790a9bc2115/lm_eval/tasks/mbpp.py#L67
  36. """
  37. # STOP_WORDS = ["\nclass", "\nassert", '\n"""', "\nprint", "\nif", "\n<|/"]
  38. # return re.split("|".join(STOP_WORDS), solution)[0].rstrip()
  39. return solution
  40. class HumanEvalTask(CodeGenTask):
  41. task_name = 'humaneval'
  42. @property
  43. def prompt(self) -> str:
  44. """Return the prompt for this task.
  45. MBPP prompt contains \"\"\" enclosed at both ends. Need to remove it.
  46. """
  47. return 'Complete the following code:\n\n' + self._prompt
  48. def extract_answer(self, solution: str) -> str | None:
  49. """Extract the answer from the given solution.
  50. Split off first block of code by scanning for class, def etc. on newlines.
  51. Modified from:
  52. https://github.com/bigcode-project/bigcode-evaluation-harness/blob/d61afde130005ecc65cf800ad8eca790a9bc2115/lm_eval/tasks/humaneval.py#L56
  53. """
  54. # STOP_WORDS = ["\nclass", "\ndef", "\n#", "\n@", "\nprint", "\nif"]
  55. # # Remove the last block of the code containing stop_words for HumanEval
  56. # string_list = re.split("(%s)" % "|".join(STOP_WORDS), solution)
  57. # # last string should be ""
  58. # return "".join(string_list[:-2])
  59. return solution