task.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121
  1. import json
  2. import logging
  3. import os
  4. from abc import ABC, abstractmethod
  5. from typing import List, Optional, Tuple
  6. from utils import load_file
  7. LOGGER = logging.getLogger('MINT')
  8. class Task(ABC):
  9. """Base class for a task instance."""
  10. task_name: str = 'base'
  11. in_context_example_dir = os.path.join(
  12. os.path.dirname(os.path.abspath(__file__)),
  13. 'in_context_examples',
  14. )
  15. def __init__(self, **kwargs) -> None:
  16. if 'loaded_history' in kwargs:
  17. self.loaded_history = kwargs['loaded_history']
  18. else:
  19. self.loaded_history = None
  20. # pre-load the in-context example
  21. task_dir = os.path.join(self.in_context_example_dir, self.task_name)
  22. self._in_context_example = {
  23. 'with_tool': load_file(os.path.join(task_dir, 'with_tool.txt')),
  24. }
  25. self.metadata = {}
  26. @property
  27. def task_id(self) -> str:
  28. """Return the task id."""
  29. assert hasattr(self, '_id'), 'Task does not have an id.'
  30. return self._id
  31. def in_context_example(
  32. self, use_tool: bool = True, with_feedback: bool = False
  33. ) -> str:
  34. """Return the in-context example for the task."""
  35. if use_tool and not with_feedback:
  36. return self._in_context_example['with_tool']
  37. else:
  38. raise NotImplementedError
  39. @property
  40. def prompt(self) -> str:
  41. """Return the task prompt."""
  42. assert hasattr(self, '_prompt'), 'Task does not have a prompt.'
  43. return self._prompt
  44. @property
  45. def reference(self) -> str:
  46. """Return the reference solution for the task."""
  47. assert hasattr(self, '_reference'), 'Task does not have a reference solution.'
  48. return self._reference
  49. @abstractmethod
  50. def extract_answer(self, solution: str) -> Optional[str]:
  51. """Extract the answer from the given solution."""
  52. pass
  53. @abstractmethod
  54. def success(self, solution: str) -> bool:
  55. """This checks whether the given solution can complete the current task.
  56. Can be used to provide binary feedback.
  57. """
  58. answer = self.extract_answer(solution)
  59. return answer == self.reference
  60. @classmethod
  61. def load_tasks(cls, path: str) -> Tuple[List['Task'], int]:
  62. """Load all the tasks from a given jsonl file."""
  63. assert path.endswith('.jsonl') or path.endswith('.json')
  64. with open(path, 'r') as f:
  65. tasks = [cls(**json.loads(line)) for line in f.readlines()]
  66. LOGGER.info(f'Loaded {len(tasks)} tasks from {path}')
  67. return tasks, len(tasks)
  68. def to_dict(self) -> dict:
  69. """Convert the task to a dictionary."""
  70. return {
  71. 'task_name': self.task_name,
  72. 'task_id': self.task_id,
  73. 'prompt': self.prompt,
  74. 'reference': self.reference,
  75. 'metadata': self.metadata,
  76. }
  77. class ReasoningTask(Task):
  78. task_name = 'reasoning'
  79. def __init__(self, id: str, prompt: str, reference: str, **kwargs):
  80. super().__init__(**kwargs)
  81. self._id = id
  82. self._prompt = prompt.strip()
  83. self._reference = str(reference).strip().lower()
  84. def extract_answer(self, solution: str) -> Optional[str]:
  85. """Extract the answer from the given solution."""
  86. return solution.lower().strip()
  87. def compare_w_digits(self, reference: str, answer: str) -> bool:
  88. """Compare the reference and answer with digits."""
  89. # if reference can and answer can both be converted to floats by float()
  90. try:
  91. float(reference)
  92. float(answer)
  93. return abs(float(reference) - float(answer)) <= 0.05 * abs(float(reference))
  94. except ValueError:
  95. return reference in answer
  96. except Exception:
  97. raise ValueError(f'Cannot compare {reference} and {answer}')
  98. def success(self, solution: str) -> bool:
  99. answer = self.extract_answer(solution)
  100. return self.compare_w_digits(self._reference, answer)