datatypes.py 2.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182
  1. import enum
  2. from typing import Any, Dict, Tuple
  3. class TaskState:
  4. def __init__(
  5. self,
  6. finished: bool = False,
  7. success: bool = False,
  8. agent_action_count: dict = None,
  9. terminate_reason: str = None,
  10. latest_output: Dict[str, Any] = None,
  11. ):
  12. self.finished = finished
  13. self.success = success
  14. self.agent_action_count: Dict[str, int] = agent_action_count or {
  15. 'propose_solution': 0,
  16. 'use_tool': 0,
  17. 'invalid_action': 0,
  18. }
  19. self.terminate_reason = terminate_reason
  20. self.latest_output = latest_output
  21. def to_dict(self) -> Dict[str, Any]:
  22. return {
  23. 'finished': self.finished,
  24. 'success': self.success,
  25. 'agent_action_count': self.agent_action_count,
  26. 'terminate_reason': self.terminate_reason,
  27. 'latest_output': self.latest_output,
  28. }
  29. class ParseError(Exception):
  30. pass
  31. class FeedbackType(enum.Enum):
  32. FEEDBACK_WITH_GT = 'feedback_with_gt'
  33. FEEDBACK_WO_GT = 'feedback_wo_gt'
  34. NO_FEEDBACK = 'no_feedback'
  35. class StepOutput:
  36. def __init__(
  37. self,
  38. observation: str = None,
  39. success: bool = False,
  40. extra: Dict[str, Any] = None,
  41. turn_info: Tuple[int, int] = None,
  42. ):
  43. self.observation: str = observation
  44. self.success: bool = success
  45. self.extra: Dict[str, Any] = extra
  46. self.turn_info = turn_info
  47. def __repr__(self) -> str:
  48. return self.observation
  49. def to_str(self) -> str:
  50. output = 'Observation:\n'
  51. if self.observation is not None:
  52. output += self.observation + '\n'
  53. else:
  54. if not self.success:
  55. output += 'Your answer is wrong.\n'
  56. if self.turn_info is not None:
  57. n_steps_left, n_propose_solution_left = self.turn_info
  58. output += 'You have {} steps left and {} chances to propose solution left.\n'.format(
  59. n_steps_left, n_propose_solution_left
  60. )
  61. if n_steps_left <= 1:
  62. output += 'You should take the last step to propose a solution.\n'
  63. return output
  64. def to_dict(self) -> Dict[str, Any]:
  65. return {
  66. 'observation': self.observation,
  67. 'success': self.success,
  68. }