llm.py 2.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758
  1. import os
  2. import uuid
  3. from litellm import completion as litellm_completion
  4. from functools import partial
  5. DEFAULT_MODEL = os.getenv("LLM_MODEL", "gpt-4-0125-preview")
  6. DEFAULT_API_KEY = os.getenv("LLM_API_KEY")
  7. DEFAULT_BASE_URL = os.getenv("LLM_BASE_URL")
  8. PROMPT_DEBUG_DIR = os.getenv("PROMPT_DEBUG_DIR", "")
  9. class LLM:
  10. def __init__(self, model=DEFAULT_MODEL, api_key=DEFAULT_API_KEY, base_url=DEFAULT_BASE_URL, debug_dir=PROMPT_DEBUG_DIR):
  11. self.model = model if model else DEFAULT_MODEL
  12. self.api_key = api_key if api_key else DEFAULT_API_KEY
  13. self.base_url = base_url if base_url else DEFAULT_BASE_URL
  14. self._debug_dir = debug_dir if debug_dir else PROMPT_DEBUG_DIR
  15. self._debug_idx = 0
  16. self._debug_id = uuid.uuid4().hex
  17. self._completion = partial(litellm_completion, model=self.model, api_key=self.api_key, base_url=self.base_url)
  18. if self._debug_dir:
  19. print(f"Logging prompts to {self._debug_dir}/{self._debug_id}")
  20. completion_unwrapped = self._completion
  21. def wrapper(*args, **kwargs):
  22. if "messages" in kwargs:
  23. messages = kwargs["messages"]
  24. else:
  25. messages = args[1]
  26. resp = completion_unwrapped(*args, **kwargs)
  27. message_back = resp['choices'][0]['message']['content']
  28. self.write_debug(messages, message_back)
  29. return resp
  30. self._completion = wrapper # type: ignore
  31. @property
  32. def completion(self):
  33. """
  34. Decorator for the litellm completion function.
  35. """
  36. return self._completion
  37. def write_debug(self, messages, response):
  38. if not self._debug_dir:
  39. return
  40. dir = self._debug_dir + "/" + self._debug_id + "/" + str(self._debug_idx)
  41. os.makedirs(dir, exist_ok=True)
  42. prompt_out = ""
  43. for message in messages:
  44. prompt_out += "<" + message["role"] + ">\n"
  45. prompt_out += message["content"] + "\n\n"
  46. with open(f"{dir}/prompt.md", "w") as f:
  47. f.write(prompt_out)
  48. with open(f"{dir}/response.md", "w") as f:
  49. f.write(response)
  50. self._debug_idx += 1