Преглед изворни кода

[Feat] Custom MicroAgents. (#4983)

Co-authored-by: diwu-sf <di.wu@shadowfaxdata.com>
Raj Maheshwari пре 1 година
родитељ
комит
2b06e4e5d0

+ 8 - 1
frontend/src/context/ws-client-provider.tsx

@@ -38,6 +38,7 @@ interface WsClientProviderProps {
   enabled: boolean;
   token: string | null;
   ghToken: string | null;
+  selectedRepository: string | null;
   settings: Settings | null;
 }
 
@@ -45,12 +46,14 @@ export function WsClientProvider({
   enabled,
   token,
   ghToken,
+  selectedRepository,
   settings,
   children,
 }: React.PropsWithChildren<WsClientProviderProps>) {
   const sioRef = React.useRef<Socket | null>(null);
   const tokenRef = React.useRef<string | null>(token);
   const ghTokenRef = React.useRef<string | null>(ghToken);
+  const selectedRepositoryRef = React.useRef<string | null>(selectedRepository);
   const disconnectRef = React.useRef<ReturnType<typeof setTimeout> | null>(
     null,
   );
@@ -81,6 +84,9 @@ export function WsClientProvider({
     if (ghToken) {
       initEvent.github_token = ghToken;
     }
+    if (selectedRepository) {
+      initEvent.selected_repository = selectedRepository;
+    }
     const lastEvent = lastEventRef.current;
     if (lastEvent) {
       initEvent.latest_event_id = lastEvent.id;
@@ -158,6 +164,7 @@ export function WsClientProvider({
     sioRef.current = sio;
     tokenRef.current = token;
     ghTokenRef.current = ghToken;
+    selectedRepositoryRef.current = selectedRepository;
 
     return () => {
       sio.off("connect", handleConnect);
@@ -166,7 +173,7 @@ export function WsClientProvider({
       sio.off("connect_failed", handleError);
       sio.off("disconnect", handleDisconnect);
     };
-  }, [enabled, token, ghToken]);
+  }, [enabled, token, ghToken, selectedRepository]);
 
   // Strict mode mounts and unmounts each component twice, so we have to wait in the destructor
   // before actually disconnecting the socket and cancel the operation if the component gets remounted.

+ 1 - 8
frontend/src/routes/_oh.app/hooks/use-ws-status-change.ts

@@ -6,7 +6,6 @@ import {
   WsClientProviderStatus,
 } from "#/context/ws-client-provider";
 import { createChatMessage } from "#/services/chat-service";
-import { getCloneRepoCommand } from "#/services/terminal-service";
 import { setCurrentAgentState } from "#/state/agent-slice";
 import { addUserMessage } from "#/state/chat-slice";
 import {
@@ -37,11 +36,6 @@ export const useWSStatusChange = () => {
     send(createChatMessage(query, base64Files, timestamp));
   };
 
-  const dispatchCloneRepoCommand = (ghToken: string, repository: string) => {
-    send(getCloneRepoCommand(ghToken, repository));
-    dispatch(clearSelectedRepository());
-  };
-
   const dispatchInitialQuery = (query: string, additionalInfo: string) => {
     if (additionalInfo) {
       sendInitialQuery(`${query}\n\n[${additionalInfo}]`, files);
@@ -57,8 +51,7 @@ export const useWSStatusChange = () => {
     let additionalInfo = "";
 
     if (gitHubToken && selectedRepository) {
-      dispatchCloneRepoCommand(gitHubToken, selectedRepository);
-      additionalInfo = `Repository ${selectedRepository} has been cloned to /workspace. Please check the /workspace for files.`;
+      dispatch(clearSelectedRepository());
     } else if (importedProjectZip) {
       // if there's an uploaded project zip, add it to the chat
       additionalInfo =

+ 1 - 0
frontend/src/routes/_oh.app/route.tsx

@@ -64,6 +64,7 @@ function App() {
       enabled
       token={token}
       ghToken={gitHubToken}
+      selectedRepository={selectedRepository}
       settings={settings}
     >
       <EventHandler>

+ 0 - 8
frontend/src/services/terminal-service.ts

@@ -10,11 +10,3 @@ export function getGitHubTokenCommand(gitHubToken: string) {
   const event = getTerminalCommand(command, true);
   return event;
 }
-
-export function getCloneRepoCommand(gitHubToken: string, repository: string) {
-  const url = `https://${gitHubToken}@github.com/${repository}.git`;
-  const dirName = repository.split("/")[1];
-  const command = `git clone ${url} ${dirName} ; cd ${dirName} ; git checkout -b openhands-workspace`;
-  const event = getTerminalCommand(command, true);
-  return event;
-}

+ 3 - 0
openhands/agenthub/codeact_agent/codeact_agent.py

@@ -398,6 +398,9 @@ class CodeActAgent(Agent):
             - Messages from the same role are combined to prevent consecutive same-role messages
             - For Anthropic models, specific messages are cached according to their documentation
         """
+        if not self.prompt_manager:
+            raise Exception('Prompt Manager not instantiated.')
+
         messages: list[Message] = [
             Message(
                 role='system',

+ 2 - 0
openhands/controller/agent.py

@@ -11,6 +11,7 @@ from openhands.core.exceptions import (
 )
 from openhands.llm.llm import LLM
 from openhands.runtime.plugins import PluginRequirement
+from openhands.utils.prompt import PromptManager
 
 
 class Agent(ABC):
@@ -33,6 +34,7 @@ class Agent(ABC):
         self.llm = llm
         self.config = config
         self._complete = False
+        self.prompt_manager: PromptManager | None = None
 
     @property
     def complete(self) -> bool:

+ 41 - 0
openhands/runtime/base.py

@@ -213,6 +213,47 @@ class Runtime(FileEditRuntimeMixin):
             source = event.source if event.source else EventSource.AGENT
             self.event_stream.add_event(observation, source)  # type: ignore[arg-type]
 
+    def clone_repo(self, github_token: str | None, selected_repository: str | None):
+        if not github_token or not selected_repository:
+            return
+        url = f'https://{github_token}@github.com/{selected_repository}.git'
+        dir_name = selected_repository.split('/')[1]
+        action = CmdRunAction(
+            command=f'git clone {url} {dir_name} ; cd {dir_name} ; git checkout -b openhands-workspace'
+        )
+        self.log('info', 'Cloning repo: {selected_repository}')
+        self.run_action(action)
+
+    def get_custom_microagents(self, selected_repository: str | None) -> list[str]:
+        custom_microagents_content = []
+        custom_microagents_dir = Path('.openhands') / 'microagents'
+
+        dir_name = str(custom_microagents_dir)
+        if selected_repository:
+            dir_name = str(
+                Path(selected_repository.split('/')[1]) / custom_microagents_dir
+            )
+        oh_instructions_header = '---\nname: openhands_instructions\nagent: CodeActAgent\ntriggers:\n- ""\n---\n'
+        obs = self.read(FileReadAction(path='.openhands_instructions'))
+        if isinstance(obs, ErrorObservation):
+            self.log('error', 'Failed to read openhands_instructions')
+        else:
+            openhands_instructions = oh_instructions_header + obs.content
+            self.log('info', f'openhands_instructions: {openhands_instructions}')
+            custom_microagents_content.append(openhands_instructions)
+
+        files = self.list_files(dir_name)
+
+        self.log('info', f'Found {len(files)} custom microagents.')
+
+        for fname in files:
+            content = self.read(
+                FileReadAction(path=str(custom_microagents_dir / fname))
+            ).content
+            custom_microagents_content.append(content)
+
+        return custom_microagents_content
+
     def run_action(self, action: Action) -> Observation:
         """Run an action and return the resulting observation.
         If the action is not runnable in any runtime, a NullObservation is returned.

+ 2 - 0
openhands/server/listen_socket.py

@@ -32,6 +32,8 @@ async def oh_action(connection_id: str, data: dict):
         latest_event_id = int(data.pop('latest_event_id', -1))
         kwargs = {k.lower(): v for k, v in (data.get('args') or {}).items()}
         session_init_data = SessionInitData(**kwargs)
+        session_init_data.github_token = github_token
+        session_init_data.selected_repository = data.get('selected_repository', None)
         await init_connection(
             connection_id, token, github_token, session_init_data, latest_event_id
         )

+ 18 - 1
openhands/server/session/agent_session.py

@@ -7,7 +7,7 @@ from openhands.controller.state.state import State
 from openhands.core.config import AgentConfig, AppConfig, LLMConfig
 from openhands.core.logger import openhands_logger as logger
 from openhands.core.schema.agent import AgentState
-from openhands.events.action.agent import ChangeAgentStateAction
+from openhands.events.action import ChangeAgentStateAction
 from openhands.events.event import EventSource
 from openhands.events.stream import EventStream
 from openhands.runtime import get_runtime_cls
@@ -60,6 +60,8 @@ class AgentSession:
         max_budget_per_task: float | None = None,
         agent_to_llm_config: dict[str, LLMConfig] | None = None,
         agent_configs: dict[str, AgentConfig] | None = None,
+        github_token: str | None = None,
+        selected_repository: str | None = None,
     ):
         """Starts the Agent session
         Parameters:
@@ -86,6 +88,8 @@ class AgentSession:
             max_budget_per_task,
             agent_to_llm_config,
             agent_configs,
+            github_token,
+            selected_repository,
         )
 
     def _start_thread(self, *args):
@@ -104,13 +108,18 @@ class AgentSession:
         max_budget_per_task: float | None = None,
         agent_to_llm_config: dict[str, LLMConfig] | None = None,
         agent_configs: dict[str, AgentConfig] | None = None,
+        github_token: str | None = None,
+        selected_repository: str | None = None,
     ):
         self._create_security_analyzer(config.security.security_analyzer)
         await self._create_runtime(
             runtime_name=runtime_name,
             config=config,
             agent=agent,
+            github_token=github_token,
+            selected_repository=selected_repository,
         )
+
         self._create_controller(
             agent,
             config.security.confirmation_mode,
@@ -165,6 +174,8 @@ class AgentSession:
         runtime_name: str,
         config: AppConfig,
         agent: Agent,
+        github_token: str | None = None,
+        selected_repository: str | None = None,
     ):
         """Creates a runtime instance
 
@@ -199,6 +210,12 @@ class AgentSession:
             return
 
         if self.runtime is not None:
+            self.runtime.clone_repo(github_token, selected_repository)
+            if agent.prompt_manager:
+                agent.prompt_manager.load_microagent_files(
+                    self.runtime.get_custom_microagents(selected_repository)
+                )
+
             logger.debug(
                 f'Runtime initialized with plugins: {[plugin.name for plugin in self.runtime.plugins]}'
             )

+ 2 - 1
openhands/server/session/session.py

@@ -72,7 +72,6 @@ class Session:
         self.config.security.security_analyzer = session_init_data.security_analyzer or self.config.security.security_analyzer
         max_iterations = session_init_data.max_iterations or self.config.max_iterations
         # override default LLM config
-        
 
         default_llm_config = self.config.get_llm_config()
         default_llm_config.model = session_init_data.llm_model or default_llm_config.model
@@ -94,6 +93,8 @@ class Session:
                 max_budget_per_task=self.config.max_budget_per_task,
                 agent_to_llm_config=self.config.get_agent_to_llm_config_map(),
                 agent_configs=self.config.get_agent_configs(),
+                github_token=session_init_data.github_token,
+                selected_repository=session_init_data.selected_repository,
             )
         except Exception as e:
             logger.exception(f'Error creating controller: {e}')

+ 2 - 0
openhands/server/session/session_init_data.py

@@ -16,3 +16,5 @@ class SessionInitData:
     llm_model: str | None = None
     llm_api_key: str | None = None
     llm_base_url: str | None = None
+    github_token: str | None = None
+    selected_repository: str | None = None

+ 14 - 8
openhands/utils/microagent.py

@@ -11,14 +11,20 @@ class MicroAgentMetadata(pydantic.BaseModel):
 
 
 class MicroAgent:
-    def __init__(self, path: str):
-        self.path = path
-        if not os.path.exists(path):
-            raise FileNotFoundError(f'Micro agent file {path} is not found')
-        with open(path, 'r') as file:
-            self._loaded = frontmatter.load(file)
-            self._content = self._loaded.content
-            self._metadata = MicroAgentMetadata(**self._loaded.metadata)
+    def __init__(self, path: str | None = None, content: str | None = None):
+        if path and not content:
+            self.path = path
+            if not os.path.exists(path):
+                raise FileNotFoundError(f'Micro agent file {path} is not found')
+            with open(path, 'r') as file:
+                self._loaded = frontmatter.load(file)
+                self._content = self._loaded.content
+                self._metadata = MicroAgentMetadata(**self._loaded.metadata)
+        elif content and not path:
+            self._metadata, self._content = frontmatter.parse(content)
+            self._metadata = MicroAgentMetadata(**self._metadata)
+        else:
+            raise Exception('You must pass either path or file content, but not both.')
 
     def get_trigger(self, message: str) -> str | None:
         message = message.lower()

+ 6 - 1
openhands/utils/prompt.py

@@ -42,13 +42,18 @@ class PromptManager:
                 if f.endswith('.md')
             ]
         for microagent_file in microagent_files:
-            microagent = MicroAgent(microagent_file)
+            microagent = MicroAgent(path=microagent_file)
             if (
                 disabled_microagents is None
                 or microagent.name not in disabled_microagents
             ):
                 self.microagents[microagent.name] = microagent
 
+    def load_microagent_files(self, microagent_files: list[str]):
+        for microagent_file in microagent_files:
+            microagent = MicroAgent(content=microagent_file)
+            self.microagents[microagent.name] = microagent
+
     def _load_template(self, template_name: str) -> Template:
         if self.prompt_dir is None:
             raise ValueError('Prompt directory is not set')