datatypes.py 2.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586
  1. import enum
  2. from typing import Any
  3. class TaskState:
  4. def __init__(
  5. self,
  6. finished: bool = False,
  7. success: bool = False,
  8. agent_action_count: dict = None,
  9. terminate_reason: str = None,
  10. latest_output: dict[str, Any] = None,
  11. ):
  12. self.finished = finished
  13. self.success = success
  14. self.agent_action_count: dict[str, int] = (
  15. agent_action_count
  16. if agent_action_count
  17. else {
  18. 'propose_solution': 0,
  19. 'use_tool': 0,
  20. 'invalid_action': 0,
  21. }
  22. )
  23. self.terminate_reason = terminate_reason
  24. self.latest_output = latest_output
  25. def to_dict(self) -> dict[str, Any]:
  26. return {
  27. 'finished': self.finished,
  28. 'success': self.success,
  29. 'agent_action_count': self.agent_action_count,
  30. 'terminate_reason': self.terminate_reason,
  31. 'latest_output': self.latest_output,
  32. }
  33. class ParseError(Exception):
  34. pass
  35. class FeedbackType(enum.Enum):
  36. FEEDBACK_WITH_GT = 'feedback_with_gt'
  37. FEEDBACK_WO_GT = 'feedback_wo_gt'
  38. NO_FEEDBACK = 'no_feedback'
  39. class StepOutput:
  40. def __init__(
  41. self,
  42. observation: str = None,
  43. success: bool = False,
  44. extra: dict[str, Any] = None,
  45. turn_info: tuple[int, int] = None,
  46. ):
  47. self.observation: str = observation
  48. self.success: bool = success
  49. self.extra: dict[str, Any] = extra
  50. self.turn_info = turn_info
  51. def __repr__(self) -> str:
  52. return self.observation
  53. def to_str(self) -> str:
  54. output = 'Observation:\n'
  55. if self.observation is not None:
  56. output += self.observation + '\n'
  57. else:
  58. if not self.success:
  59. output += 'Your answer is wrong.\n'
  60. if self.turn_info is not None:
  61. n_steps_left, n_propose_solution_left = self.turn_info
  62. output += 'You have {} steps left and {} chances to propose solution left.\n'.format(
  63. n_steps_left, n_propose_solution_left
  64. )
  65. if n_steps_left <= 1:
  66. output += 'You should take the last step to propose a solution.\n'
  67. return output
  68. def to_dict(self) -> dict[str, Any]:
  69. return {
  70. 'observation': self.observation,
  71. 'success': self.success,
  72. }