Просмотр исходного кода

Transitioned to use LiteLLM Router to support retries and backoffs (#501)

Patrick Nercessian 1 год назад
Родитель
Сommit
64281c4cc4
1 измененных файлов с 33 добавлено и 5 удалено
  1. 33 5
      opendevin/llm/llm.py

+ 33 - 5
opendevin/llm/llm.py

@@ -1,26 +1,54 @@
 import os
 import uuid
 
-from litellm import completion as litellm_completion
+from litellm.router import Router
 from functools import partial
 
 from opendevin import config
 
-DEFAULT_MODEL = config.get_or_default("LLM_MODEL", "gpt-4-0125-preview")
+DEFAULT_MODEL_NAME = config.get_or_default("LLM_MODEL", "gpt-4-0125-preview")
 DEFAULT_API_KEY = config.get_or_none("LLM_API_KEY")
 DEFAULT_BASE_URL = config.get_or_none("LLM_BASE_URL")
+DEFAULT_LLM_NUM_RETRIES = config.get_or_default("LLM_NUM_RETRIES", 6)
+DEFAULT_LLM_COOLDOWN_TIME = config.get_or_default("LLM_COOLDOWN_TIME", 1)
 PROMPT_DEBUG_DIR = config.get_or_default("PROMPT_DEBUG_DIR", "")
 
 class LLM:
-    def __init__(self, model=DEFAULT_MODEL, api_key=DEFAULT_API_KEY, base_url=DEFAULT_BASE_URL, debug_dir=PROMPT_DEBUG_DIR):
-        self.model = model if model else DEFAULT_MODEL
+    def __init__(self,
+            model=DEFAULT_MODEL_NAME,
+            api_key=DEFAULT_API_KEY,
+            base_url=DEFAULT_BASE_URL,
+            num_retries=DEFAULT_LLM_NUM_RETRIES,
+            cooldown_time=DEFAULT_LLM_COOLDOWN_TIME,
+            debug_dir=PROMPT_DEBUG_DIR
+    ):
+        self.model_name = model if model else DEFAULT_MODEL_NAME
         self.api_key = api_key if api_key else DEFAULT_API_KEY
         self.base_url = base_url if base_url else DEFAULT_BASE_URL
+        self.num_retries = num_retries if num_retries else DEFAULT_LLM_NUM_RETRIES
+        self.cooldown_time = cooldown_time if cooldown_time else DEFAULT_LLM_COOLDOWN_TIME
         self._debug_dir = debug_dir if debug_dir else PROMPT_DEBUG_DIR
         self._debug_idx = 0
         self._debug_id = uuid.uuid4().hex
 
-        self._completion = partial(litellm_completion, model=self.model, api_key=self.api_key, base_url=self.base_url)
+        # We use litellm's Router in order to support retries (especially rate limit backoff retries). 
+        # Typically you would use a whole model list, but it's unnecessary with our implementation's structure
+        self._router = Router(
+            model_list=[{
+                "model_name": self.model_name,
+                "litellm_params": {
+                    "model": self.model_name,
+                    "api_key": self.api_key,
+                    "api_base": self.base_url
+                }
+            }],
+            num_retries=self.num_retries,
+            allowed_fails=self.num_retries, # We allow all retries to fail, so they can retry instead of going into "cooldown"
+            cooldown_time=self.cooldown_time,
+            # set_verbose=True,
+            # debug_level="DEBUG"
+        )
+        self._completion = partial(self._router.completion, model=self.model_name)
 
         if self._debug_dir:
             print(f"Logging prompts to {self._debug_dir}/{self._debug_id}")