Просмотр исходного кода

Fix non-function calls messages (#5026)

Co-authored-by: Xingyao Wang <xingyao@all-hands.dev>
Engel Nyst 1 год назад
Родитель
Сommit
d08886f30e
3 измененных файлов с 49 добавлено и 32 удалено
  1. 31 17
      openhands/core/message.py
  2. 7 5
      openhands/llm/fn_call_converter.py
  3. 11 10
      openhands/llm/llm.py

+ 31 - 17
openhands/core/message.py

@@ -56,6 +56,7 @@ class Message(BaseModel):
     cache_enabled: bool = False
     vision_enabled: bool = False
     # function calling
+    function_calling_enabled: bool = False
     # - tool calls (from LLM)
     tool_calls: list[ChatCompletionMessageToolCall] | None = None
     # - tool execution result (to LLM)
@@ -72,22 +73,22 @@ class Message(BaseModel):
         # - into a single string: for providers that don't support list of content items (e.g. no vision, no tool calls)
         # - into a list of content items: the new APIs of providers with vision/prompt caching/tool calls
         # NOTE: remove this when litellm or providers support the new API
-        if (
-            self.cache_enabled
-            or self.vision_enabled
-            or self.tool_call_id is not None
-            or self.tool_calls is not None
-        ):
+        if self.cache_enabled or self.vision_enabled or self.function_calling_enabled:
             return self._list_serializer()
+        # some providers, like HF and Groq/llama, don't support a list here, but a single string
         return self._string_serializer()
 
-    def _string_serializer(self):
+    def _string_serializer(self) -> dict:
+        # convert content to a single string
         content = '\n'.join(
             item.text for item in self.content if isinstance(item, TextContent)
         )
-        return {'content': content, 'role': self.role}
+        message_dict: dict = {'content': content, 'role': self.role}
+
+        # add tool call keys if we have a tool call or response
+        return self._add_tool_call_keys(message_dict)
 
-    def _list_serializer(self):
+    def _list_serializer(self) -> dict:
         content: list[dict] = []
         role_tool_with_prompt_caching = False
         for item in self.content:
@@ -102,24 +103,37 @@ class Message(BaseModel):
             elif isinstance(item, ImageContent) and self.vision_enabled:
                 content.extend(d)
 
-        ret: dict = {'content': content, 'role': self.role}
+        message_dict: dict = {'content': content, 'role': self.role}
+
         # pop content if it's empty
         if not content or (
             len(content) == 1
             and content[0]['type'] == 'text'
             and content[0]['text'] == ''
         ):
-            ret.pop('content')
+            message_dict.pop('content')
 
         if role_tool_with_prompt_caching:
-            ret['cache_control'] = {'type': 'ephemeral'}
+            message_dict['cache_control'] = {'type': 'ephemeral'}
+
+        # add tool call keys if we have a tool call or response
+        return self._add_tool_call_keys(message_dict)
 
+    def _add_tool_call_keys(self, message_dict: dict) -> dict:
+        """Add tool call keys if we have a tool call or response.
+
+        NOTE: this is necessary for both native and non-native tool calling"""
+
+        # an assistant message calling a tool
+        if self.tool_calls is not None:
+            message_dict['tool_calls'] = self.tool_calls
+
+        # an observation message with tool response
         if self.tool_call_id is not None:
             assert (
                 self.name is not None
             ), 'name is required when tool_call_id is not None'
-            ret['tool_call_id'] = self.tool_call_id
-            ret['name'] = self.name
-        if self.tool_calls:
-            ret['tool_calls'] = self.tool_calls
-        return ret
+            message_dict['tool_call_id'] = self.tool_call_id
+            message_dict['name'] = self.name
+
+        return message_dict

+ 7 - 5
openhands/llm/fn_call_converter.py

@@ -320,9 +320,8 @@ def convert_fncall_messages_to_non_fncall_messages(
     converted_messages = []
     first_user_message_encountered = False
     for message in messages:
-        role, content = message['role'], message['content']
-        if content is None:
-            content = ''
+        role = message['role']
+        content = message.get('content', '')
 
         # 1. SYSTEM MESSAGES
         # append system prompt suffix to content
@@ -339,6 +338,7 @@ def convert_fncall_messages_to_non_fncall_messages(
                     f'Unexpected content type {type(content)}. Expected str or list. Content: {content}'
                 )
             converted_messages.append({'role': 'system', 'content': content})
+
         # 2. USER MESSAGES (no change)
         elif role == 'user':
             # Add in-context learning example for the first user message
@@ -447,10 +447,12 @@ def convert_fncall_messages_to_non_fncall_messages(
                         f'Unexpected content type {type(content)}. Expected str or list. Content: {content}'
                     )
             converted_messages.append({'role': 'assistant', 'content': content})
+
         # 4. TOOL MESSAGES (tool outputs)
         elif role == 'tool':
-            # Convert tool result as assistant message
-            prefix = f'EXECUTION RESULT of [{message["name"]}]:\n'
+            # Convert tool result as user message
+            tool_name = message.get('name', 'function')
+            prefix = f'EXECUTION RESULT of [{tool_name}]:\n'
             # and omit "tool_call_id" AND "name"
             if isinstance(content, str):
                 content = prefix + content

+ 11 - 10
openhands/llm/llm.py

@@ -122,6 +122,9 @@ class LLM(RetryMixin, DebugMixin):
             drop_params=self.config.drop_params,
         )
 
+        with warnings.catch_warnings():
+            warnings.simplefilter('ignore')
+            self.init_model_info()
         if self.vision_is_active():
             logger.debug('LLM: model has vision enabled')
         if self.is_caching_prompt_active():
@@ -143,16 +146,6 @@ class LLM(RetryMixin, DebugMixin):
             drop_params=self.config.drop_params,
         )
 
-        with warnings.catch_warnings():
-            warnings.simplefilter('ignore')
-            self.init_model_info()
-        if self.vision_is_active():
-            logger.debug('LLM: model has vision enabled')
-        if self.is_caching_prompt_active():
-            logger.debug('LLM: caching prompt enabled')
-        if self.is_function_calling_active():
-            logger.debug('LLM: model supports function calling')
-
         self._completion_unwrapped = self._completion
 
         @self.retry_decorator(
@@ -342,6 +335,13 @@ class LLM(RetryMixin, DebugMixin):
                 pass
         logger.debug(f'Model info: {self.model_info}')
 
+        if self.config.model.startswith('huggingface'):
+            # HF doesn't support the OpenAI default value for top_p (1)
+            logger.debug(
+                f'Setting top_p to 0.9 for Hugging Face model: {self.config.model}'
+            )
+            self.config.top_p = 0.9 if self.config.top_p == 1 else self.config.top_p
+
         # Set the max tokens in an LM-specific way if not set
         if self.config.max_input_tokens is None:
             if (
@@ -566,6 +566,7 @@ class LLM(RetryMixin, DebugMixin):
         for message in messages:
             message.cache_enabled = self.is_caching_prompt_active()
             message.vision_enabled = self.vision_is_active()
+            message.function_calling_enabled = self.is_function_calling_active()
 
         # let pydantic handle the serialization
         return [message.model_dump() for message in messages]