Bläddra i källkod

(fix) confirmation mode bugfix for the EventStreamRuntime (#3695)

Mislav Balunovic 1 år sedan
förälder
incheckning
f979d612ec

+ 5 - 0
agenthub/codeact_agent/codeact_agent.py

@@ -17,6 +17,7 @@ from openhands.events.observation import (
     AgentDelegateObservation,
     AgentDelegateObservation,
     CmdOutputObservation,
     CmdOutputObservation,
     IPythonRunCellObservation,
     IPythonRunCellObservation,
+    UserRejectObservation,
 )
 )
 from openhands.events.observation.error import ErrorObservation
 from openhands.events.observation.error import ErrorObservation
 from openhands.events.observation.observation import Observation
 from openhands.events.observation.observation import Observation
@@ -153,6 +154,10 @@ class CodeActAgent(Agent):
             text = 'OBSERVATION:\n' + truncate_content(obs.content, max_message_chars)
             text = 'OBSERVATION:\n' + truncate_content(obs.content, max_message_chars)
             text += '\n[Error occurred in processing last action]'
             text += '\n[Error occurred in processing last action]'
             return Message(role='user', content=[TextContent(text=text)])
             return Message(role='user', content=[TextContent(text=text)])
+        elif isinstance(obs, UserRejectObservation):
+            text = 'OBSERVATION:\n' + truncate_content(obs.content, max_message_chars)
+            text += '\n[Last action has been rejected by the user]'
+            return Message(role='user', content=[TextContent(text=text)])
         else:
         else:
             # If an observation message is not returned, it will cause an error
             # If an observation message is not returned, it will cause an error
             # when the LLM tries to return the next message
             # when the LLM tries to return the next message

+ 1 - 1
openhands/events/observation/reject.py

@@ -6,7 +6,7 @@ from openhands.events.observation.observation import Observation
 
 
 @dataclass
 @dataclass
 class UserRejectObservation(Observation):
 class UserRejectObservation(Observation):
-    """This data class represents the result of a successful action."""
+    """This data class represents the result of a rejected action."""
 
 
     observation: str = ObservationType.USER_REJECTED
     observation: str = ObservationType.USER_REJECTED
 
 

+ 15 - 0
openhands/runtime/client/runtime.py

@@ -13,6 +13,7 @@ from openhands.core.config import AppConfig
 from openhands.core.logger import openhands_logger as logger
 from openhands.core.logger import openhands_logger as logger
 from openhands.events import EventStream
 from openhands.events import EventStream
 from openhands.events.action import (
 from openhands.events.action import (
+    ActionConfirmationStatus,
     BrowseInteractiveAction,
     BrowseInteractiveAction,
     BrowseURLAction,
     BrowseURLAction,
     CmdRunAction,
     CmdRunAction,
@@ -25,6 +26,7 @@ from openhands.events.observation import (
     ErrorObservation,
     ErrorObservation,
     NullObservation,
     NullObservation,
     Observation,
     Observation,
+    UserRejectObservation,
 )
 )
 from openhands.events.serialization import event_to_dict, observation_from_dict
 from openhands.events.serialization import event_to_dict, observation_from_dict
 from openhands.events.serialization.action import ACTION_TYPE_TO_CLASS
 from openhands.events.serialization.action import ACTION_TYPE_TO_CLASS
@@ -333,6 +335,12 @@ class EventStreamRuntime(Runtime):
         with self.action_semaphore:
         with self.action_semaphore:
             if not action.runnable:
             if not action.runnable:
                 return NullObservation('')
                 return NullObservation('')
+            if (
+                hasattr(action, 'is_confirmed')
+                and action.is_confirmed
+                == ActionConfirmationStatus.AWAITING_CONFIRMATION
+            ):
+                return NullObservation('')
             action_type = action.action  # type: ignore[attr-defined]
             action_type = action.action  # type: ignore[attr-defined]
             if action_type not in ACTION_TYPE_TO_CLASS:
             if action_type not in ACTION_TYPE_TO_CLASS:
                 return ErrorObservation(f'Action {action_type} does not exist.')
                 return ErrorObservation(f'Action {action_type} does not exist.')
@@ -340,6 +348,13 @@ class EventStreamRuntime(Runtime):
                 return ErrorObservation(
                 return ErrorObservation(
                     f'Action {action_type} is not supported in the current runtime.'
                     f'Action {action_type} is not supported in the current runtime.'
                 )
                 )
+            if (
+                hasattr(action, 'is_confirmed')
+                and action.is_confirmed == ActionConfirmationStatus.REJECTED
+            ):
+                return UserRejectObservation(
+                    'Action has been rejected by the user! Waiting for further user input.'
+                )
 
 
             logger.info('Awaiting session')
             logger.info('Awaiting session')
             self._wait_until_alive()
             self._wait_until_alive()