codegen.py 2.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283
  1. import logging
  2. from typing import Optional
  3. from utils import check_correctness
  4. from .base import Task
  5. LOGGER = logging.getLogger('MINT')
  6. class CodeGenTask(Task):
  7. """Generic code generation task instance."""
  8. def __init__(self, id: str, prompt: str, reference: str, **kwargs):
  9. super().__init__(**kwargs)
  10. self._id = id
  11. self._prompt = prompt
  12. self._reference = reference
  13. def success(self, solution: str) -> bool:
  14. """This checks whether the given solution can complete the current task.
  15. Can be used to provides binary feedback.
  16. """
  17. code_to_exec = self.extract_answer(solution)
  18. LOGGER.debug(f'CODE_TO_EXEC:\n{code_to_exec}')
  19. LOGGER.debug(f'TEST_CODE:\n{self._reference}')
  20. res = check_correctness(
  21. solution_code=code_to_exec, test_code=self._reference, timeout=10
  22. )
  23. return res['success']
  24. class MBPPTask(CodeGenTask):
  25. task_name = 'mbpp'
  26. @property
  27. def prompt(self) -> str:
  28. """Return the prompt for this task.
  29. MBPP prompt contains \"\"\" enclosed at both ends. Need to remove it.
  30. """
  31. return self._prompt.replace('"""', '').strip()
  32. def extract_answer(self, solution: str) -> Optional[str]:
  33. """Extract the answer from the given solution.
  34. Split off first block of code by scanning for class, def etc. on newlines.
  35. Modified from:
  36. https://github.com/bigcode-project/bigcode-evaluation-harness/blob/d61afde130005ecc65cf800ad8eca790a9bc2115/lm_eval/tasks/mbpp.py#L67
  37. """
  38. # STOP_WORDS = ["\nclass", "\nassert", '\n"""', "\nprint", "\nif", "\n<|/"]
  39. # return re.split("|".join(STOP_WORDS), solution)[0].rstrip()
  40. return solution
  41. class HumanEvalTask(CodeGenTask):
  42. task_name = 'humaneval'
  43. @property
  44. def prompt(self) -> str:
  45. """Return the prompt for this task.
  46. MBPP prompt contains \"\"\" enclosed at both ends. Need to remove it.
  47. """
  48. return 'Complete the following code:\n\n' + self._prompt
  49. def extract_answer(self, solution: str) -> Optional[str]:
  50. """Extract the answer from the given solution.
  51. Split off first block of code by scanning for class, def etc. on newlines.
  52. Modified from:
  53. https://github.com/bigcode-project/bigcode-evaluation-harness/blob/d61afde130005ecc65cf800ad8eca790a9bc2115/lm_eval/tasks/humaneval.py#L56
  54. """
  55. # STOP_WORDS = ["\nclass", "\ndef", "\n#", "\n@", "\nprint", "\nif"]
  56. # # Remove the last block of the code containing stop_words for HumanEval
  57. # string_list = re.split("(%s)" % "|".join(STOP_WORDS), solution)
  58. # # last string should be ""
  59. # return "".join(string_list[:-2])
  60. return solution