Browse Source

(feat) LLM class: add safety_settings for Gemini; improve max_output_tokens defaulting (#3925)

tobitege 1 year ago
parent
commit
b4408b41c9
2 changed files with 44 additions and 14 deletions
  1. 40 10
      openhands/llm/llm.py
  2. 4 4
      tests/unit/test_llm.py

+ 40 - 10
openhands/llm/llm.py

@@ -101,23 +101,51 @@ class LLM:
             ):
                 self.config.max_input_tokens = self.model_info['max_input_tokens']
             else:
-                # Max input tokens for gpt3.5, so this is a safe fallback for any potentially viable model
+                # Safe fallback for any potentially viable model
                 self.config.max_input_tokens = 4096
 
         if self.config.max_output_tokens is None:
-            if (
-                self.model_info is not None
-                and 'max_output_tokens' in self.model_info
-                and isinstance(self.model_info['max_output_tokens'], int)
-            ):
-                self.config.max_output_tokens = self.model_info['max_output_tokens']
-            else:
-                # Max output tokens for gpt3.5, so this is a safe fallback for any potentially viable model
-                self.config.max_output_tokens = 1024
+            # Safe default for any potentially viable model
+            self.config.max_output_tokens = 4096
+            if self.model_info is not None:
+                # max_output_tokens has precedence over max_tokens, if either exists.
+                # litellm has models with both, one or none of these 2 parameters!
+                if 'max_output_tokens' in self.model_info and isinstance(
+                    self.model_info['max_output_tokens'], int
+                ):
+                    self.config.max_output_tokens = self.model_info['max_output_tokens']
+                elif 'max_tokens' in self.model_info and isinstance(
+                    self.model_info['max_tokens'], int
+                ):
+                    self.config.max_output_tokens = self.model_info['max_tokens']
 
         if self.config.drop_params:
             litellm.drop_params = self.config.drop_params
 
+        # This only seems to work with Google as the provider, not with OpenRouter!
+        gemini_safety_settings = (
+            [
+                {
+                    'category': 'HARM_CATEGORY_HARASSMENT',
+                    'threshold': 'BLOCK_NONE',
+                },
+                {
+                    'category': 'HARM_CATEGORY_HATE_SPEECH',
+                    'threshold': 'BLOCK_NONE',
+                },
+                {
+                    'category': 'HARM_CATEGORY_SEXUALLY_EXPLICIT',
+                    'threshold': 'BLOCK_NONE',
+                },
+                {
+                    'category': 'HARM_CATEGORY_DANGEROUS_CONTENT',
+                    'threshold': 'BLOCK_NONE',
+                },
+            ]
+            if self.config.model.lower().startswith('gemini')
+            else None
+        )
+
         self._completion = partial(
             litellm_completion,
             model=self.config.model,
@@ -129,6 +157,7 @@ class LLM:
             timeout=self.config.timeout,
             temperature=self.config.temperature,
             top_p=self.config.top_p,
+            safety_settings=gemini_safety_settings,
         )
 
         if self.vision_is_active():
@@ -235,6 +264,7 @@ class LLM:
             temperature=self.config.temperature,
             top_p=self.config.top_p,
             drop_params=True,
+            safety_settings=gemini_safety_settings,
         )
 
         async_completion_unwrapped = self._async_completion

+ 4 - 4
tests/unit/test_llm.py

@@ -9,12 +9,12 @@ from openhands.llm.llm import LLM
 
 @pytest.fixture
 def default_config():
-    return LLMConfig(model='gpt-3.5-turbo', api_key='test_key')
+    return LLMConfig(model='gpt-4o', api_key='test_key')
 
 
 def test_llm_init_with_default_config(default_config):
     llm = LLM(default_config)
-    assert llm.config.model == 'gpt-3.5-turbo'
+    assert llm.config.model == 'gpt-4o'
     assert llm.config.api_key == 'test_key'
     assert isinstance(llm.metrics, Metrics)
 
@@ -35,7 +35,7 @@ def test_llm_init_without_model_info(mock_get_model_info, default_config):
     mock_get_model_info.side_effect = Exception('Model info not available')
     llm = LLM(default_config)
     assert llm.config.max_input_tokens == 4096
-    assert llm.config.max_output_tokens == 1024
+    assert llm.config.max_output_tokens == 4096
 
 
 def test_llm_init_with_custom_config():
@@ -57,7 +57,7 @@ def test_llm_init_with_custom_config():
 
 
 def test_llm_init_with_metrics():
-    config = LLMConfig(model='gpt-3.5-turbo', api_key='test_key')
+    config = LLMConfig(model='gpt-4o', api_key='test_key')
     metrics = Metrics()
     llm = LLM(config, metrics=metrics)
     assert llm.metrics is metrics