浏览代码

Add Handling of Cache Prompt When Formatting Messages (#3773)

* Add Handling of Cache Prompt When Formatting Messages

* Fix Value for Cache Control

* Fix Value for Cache Control

* Update openhands/core/message.py

Co-authored-by: Engel Nyst <enyst@users.noreply.github.com>

* Fix lint error

* Serialize Messages if Propt Caching Is Enabled

* Remove formatting message change

---------

Co-authored-by: Engel Nyst <enyst@users.noreply.github.com>
Co-authored-by: tobitege <10787084+tobitege@users.noreply.github.com>
Cole Murray 1 年之前
父节点
当前提交
97a03faf33
共有 3 个文件被更改,包括 11 次插入4 次删除
  1. 5 2
      openhands/core/message.py
  2. 3 1
      openhands/llm/llm.py
  3. 3 1
      tests/integration/conftest.py

+ 5 - 2
openhands/core/message.py

@@ -72,12 +72,14 @@ class Message(BaseModel):
 
 
 def format_messages(
-    messages: Union[Message, list[Message]], with_images: bool
+    messages: Union[Message, list[Message]],
+    with_images: bool,
+    with_prompt_caching: bool,
 ) -> list[dict]:
     if not isinstance(messages, list):
         messages = [messages]
 
-    if with_images:
+    if with_images or with_prompt_caching:
         return [message.model_dump() for message in messages]
 
     converted_messages = []
@@ -113,4 +115,5 @@ def format_messages(
                     'content': content_str,
                 }
             )
+
     return converted_messages

+ 3 - 1
openhands/llm/llm.py

@@ -597,4 +597,6 @@ class LLM:
     def format_messages_for_llm(
         self, messages: Union[Message, list[Message]]
     ) -> list[dict]:
-        return format_messages(messages, self.vision_is_active())
+        return format_messages(
+            messages, self.vision_is_active(), self.is_caching_prompt_active()
+        )

+ 3 - 1
tests/integration/conftest.py

@@ -185,7 +185,9 @@ def mock_user_response(*args, test_name, **kwargs):
 def mock_completion(*args, test_name, **kwargs):
     global cur_id
     messages = kwargs['messages']
-    plain_messages = format_messages(messages, with_images=False)
+    plain_messages = format_messages(
+        messages, with_images=False, with_prompt_caching=False
+    )
     message_str = message_separator.join(msg['content'] for msg in plain_messages)
 
     # this assumes all response_(*).log filenames are in numerical order, starting from one