base.py 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990
  1. import json
  2. import logging
  3. import os
  4. from abc import ABC, abstractmethod
  5. from utils import load_file
  6. LOGGER = logging.getLogger('MINT')
  7. class Task(ABC):
  8. """Base class for a task instance."""
  9. task_name: str = 'base'
  10. in_context_example_dir = os.path.join(
  11. os.path.dirname(os.path.abspath(__file__)),
  12. 'in_context_examples',
  13. )
  14. def __init__(self, **kwargs) -> None:
  15. if 'loaded_history' in kwargs:
  16. self.loaded_history = kwargs['loaded_history']
  17. else:
  18. self.loaded_history = None
  19. # pre-load the in-context example
  20. task_dir = os.path.join(self.in_context_example_dir, self.task_name)
  21. self._in_context_example = {
  22. 'with_tool': load_file(os.path.join(task_dir, 'with_tool.txt')),
  23. }
  24. self.metadata = {}
  25. @property
  26. def task_id(self) -> str:
  27. """Return the task id."""
  28. assert hasattr(self, '_id'), 'Task does not have an id.'
  29. return self._id
  30. def in_context_example(
  31. self, use_tool: bool = True, with_feedback: bool = False
  32. ) -> str:
  33. """Return the in-context example for the task."""
  34. if use_tool and not with_feedback:
  35. return self._in_context_example['with_tool']
  36. else:
  37. raise NotImplementedError
  38. @property
  39. def prompt(self) -> str:
  40. """Return the task prompt."""
  41. assert hasattr(self, '_prompt'), 'Task does not have a prompt.'
  42. return self._prompt
  43. @property
  44. def reference(self) -> str:
  45. """Return the reference solution for the task."""
  46. assert hasattr(self, '_reference'), 'Task does not have a reference solution.'
  47. return self._reference
  48. @abstractmethod
  49. def extract_answer(self, solution: str) -> str | None:
  50. """Extract the answer from the given solution."""
  51. pass
  52. @abstractmethod
  53. def success(self, solution: str) -> bool:
  54. """This checks whether the given solution can complete the current task.
  55. Can be used to provide binary feedback.
  56. """
  57. answer = self.extract_answer(solution)
  58. return answer == self.reference
  59. @classmethod
  60. def load_tasks(cls, path: str) -> tuple[list['Task'], int]:
  61. """Load all the tasks from a given jsonl file."""
  62. assert path.endswith('.jsonl') or path.endswith('.json')
  63. with open(path, 'r') as f:
  64. tasks = [cls(**json.loads(line)) for line in f.readlines()]
  65. LOGGER.info(f'Loaded {len(tasks)} tasks from {path}')
  66. return tasks, len(tasks)
  67. def to_dict(self) -> dict:
  68. """Convert the task to a dictionary."""
  69. return {
  70. 'task_name': self.task_name,
  71. 'task_id': self.task_id,
  72. 'prompt': self.prompt,
  73. 'reference': self.reference,
  74. 'metadata': self.metadata,
  75. }