Kaynağa Gözat

fix(llm): fallback when model is out of function calling supported list (#4617)

Co-authored-by: openhands <openhands@all-hands.dev>
Xingyao Wang 1 yıl önce
ebeveyn
işleme
2587220b12

+ 10 - 11
openhands/agenthub/codeact_agent/codeact_agent.py

@@ -93,17 +93,16 @@ class CodeActAgent(Agent):
             if config.micro_agent_name
             else None
         )
-        if (
-            self.config.function_calling
-            and not self.llm.config.supports_function_calling
-        ):
+
+        self.function_calling_active = self.config.function_calling
+        if self.function_calling_active and not self.llm.is_function_calling_active():
             logger.warning(
                 f'Function calling not supported for model {self.llm.config.model}. '
                 'Disabling function calling.'
             )
-            self.config.function_calling = False
+            self.function_calling_active = False
 
-        if self.config.function_calling:
+        if self.function_calling_active:
             # Function calling mode
             self.tools = codeact_function_calling.get_tools(
                 codeact_enable_browsing_delegate=self.config.codeact_enable_browsing_delegate,
@@ -172,7 +171,7 @@ class CodeActAgent(Agent):
                 FileEditAction,
             ),
         ) or (isinstance(action, AgentFinishAction) and action.source == 'agent'):
-            if self.config.function_calling:
+            if self.function_calling_active:
                 tool_metadata = action.tool_call_metadata
                 assert tool_metadata is not None, (
                     'Tool call metadata should NOT be None when function calling is enabled. Action: '
@@ -286,7 +285,7 @@ class CodeActAgent(Agent):
             # when the LLM tries to return the next message
             raise ValueError(f'Unknown observation type: {type(obs)}')
 
-        if self.config.function_calling:
+        if self.function_calling_active:
             # Update the message as tool response properly
             if (tool_call_metadata := obs.tool_call_metadata) is not None:
                 tool_call_id_to_message[tool_call_metadata.tool_call_id] = Message(
@@ -334,7 +333,7 @@ class CodeActAgent(Agent):
         params: dict = {
             'messages': self.llm.format_messages_for_llm(messages),
         }
-        if self.config.function_calling:
+        if self.function_calling_active:
             params['tools'] = self.tools
         else:
             params['stop'] = [
@@ -345,7 +344,7 @@ class CodeActAgent(Agent):
             ]
         response = self.llm.completion(**params)
 
-        if self.config.function_calling:
+        if self.function_calling_active:
             actions = codeact_function_calling.response_to_actions(response)
             for action in actions:
                 self.pending_actions.append(action)
@@ -479,7 +478,7 @@ class CodeActAgent(Agent):
                     else:
                         break
 
-        if not self.config.function_calling:
+        if not self.function_calling_active:
             # The latest user message is important:
             # we want to remind the agent of the environment constraints
             latest_user_message = next(

+ 0 - 2
openhands/core/config/llm_config.py

@@ -42,7 +42,6 @@ class LLMConfig:
         log_completions: Whether to log LLM completions to the state.
         log_completions_folder: The folder to log LLM completions to. Required if log_completions is True.
         draft_editor: A more efficient LLM to use for file editing. Introduced in [PR 3985](https://github.com/All-Hands-AI/OpenHands/pull/3985).
-        supports_function_calling: Whether the model supports function calling.
     """
 
     model: str = 'claude-3-5-sonnet-20241022'
@@ -77,7 +76,6 @@ class LLMConfig:
     log_completions: bool = False
     log_completions_folder: str | None = None
     draft_editor: Optional['LLMConfig'] = None
-    supports_function_calling: bool = False
 
     def defaults_to_dict(self) -> dict:
         """Serialize fields to a dict for the frontend, including type hints, defaults, and whether it's optional."""

+ 21 - 6
openhands/llm/llm.py

@@ -53,6 +53,14 @@ CACHE_PROMPT_SUPPORTED_MODELS = [
     'claude-3-opus-20240229',
 ]
 
+# function calling supporting models
+FUNCTION_CALLING_SUPPORTED_MODELS = [
+    'claude-3-5-sonnet-20240620',
+    'claude-3-5-sonnet-20241022',
+    'gpt-4o',
+    'gpt-4o-mini',
+]
+
 
 class LLM(RetryMixin, DebugMixin):
     """The LLM class represents a Language Model instance.
@@ -163,11 +171,6 @@ class LLM(RetryMixin, DebugMixin):
                 ):
                     self.config.max_output_tokens = self.model_info['max_tokens']
 
-        self.config.supports_function_calling = (
-            self.model_info is not None
-            and self.model_info.get('supports_function_calling', False)
-        )
-
         self._completion = partial(
             litellm_completion,
             model=self.config.model,
@@ -186,7 +189,7 @@ class LLM(RetryMixin, DebugMixin):
             logger.debug('LLM: model has vision enabled')
         if self.is_caching_prompt_active():
             logger.debug('LLM: caching prompt enabled')
-        if self.config.supports_function_calling:
+        if self.is_function_calling_active():
             logger.debug('LLM: model supports function calling')
 
         completion_unwrapped = self._completion
@@ -327,6 +330,18 @@ class LLM(RetryMixin, DebugMixin):
             )
         )
 
+    def is_function_calling_active(self) -> bool:
+        # Check if model name is in supported list before checking model_info
+        model_name_supported = (
+            self.config.model in FUNCTION_CALLING_SUPPORTED_MODELS
+            or self.config.model.split('/')[-1] in FUNCTION_CALLING_SUPPORTED_MODELS
+            or any(m in self.config.model for m in FUNCTION_CALLING_SUPPORTED_MODELS)
+        )
+        return model_name_supported and (
+            self.model_info is not None
+            and self.model_info.get('supports_function_calling', False)
+        )
+
     def _post_completion(self, response: ModelResponse) -> None:
         """Post-process the completion response.
 

+ 3 - 0
tests/unit/test_prompt_caching.py

@@ -137,6 +137,9 @@ def test_get_messages_prompt_caching(codeact_agent, mock_event_stream):
 
 
 def test_get_messages_with_cmd_action(codeact_agent, mock_event_stream):
+    if codeact_agent.config.function_calling:
+        pytest.skip('Skipping this test for function calling')
+
     # Add a mix of actions and observations
     message_action_1 = MessageAction(
         "Let's list the contents of the current directory."