Просмотр исходного кода

feat: implement basic planner UI (#1173)

* feat: implement basic planner UI

* update planner UI

* lint

* fix type

* fixes

---------

Co-authored-by: Robert Brennan <accounts@rbren.io>
Alex Bäuerle 1 год назад
Родитель
Сommit
959d91c9d6

+ 76 - 1
frontend/src/components/Planner.tsx

@@ -1,7 +1,82 @@
 import React from "react";
+import {
+  FaCheckCircle,
+  FaQuestionCircle,
+  FaRegCheckCircle,
+  FaRegCircle,
+  FaRegClock,
+  FaRegTimesCircle,
+} from "react-icons/fa";
+import { useSelector } from "react-redux";
+import { Plan, Task, TaskState } from "../services/planService";
+import { RootState } from "../store";
+
+function StatusIcon({ status }: { status: TaskState }): JSX.Element {
+  switch (status) {
+    case TaskState.OPEN_STATE:
+      return <FaRegCircle />;
+    case TaskState.COMPLETED_STATE:
+      return <FaRegCheckCircle className="text-green-200" />;
+    case TaskState.ABANDONED_STATE:
+      return <FaRegTimesCircle className="text-red-200" />;
+    case TaskState.IN_PROGRESS_STATE:
+      return <FaRegClock className="text-yellow-200" />;
+    case TaskState.VERIFIED_STATE:
+      return <FaCheckCircle className="text-green-200" />;
+    default:
+      return <FaQuestionCircle />;
+  }
+}
+
+function TaskCard({ task, level }: { task: Task; level: number }): JSX.Element {
+  return (
+    <div
+      className={`flex flex-col rounded-r bg-neutral-700 p-2 border-neutral-600 ${level < 2 ? "border-l-3" : ""}`}
+    >
+      <div className="flex items-center">
+        <div className="px-2">
+          <StatusIcon status={task.state} />
+        </div>
+        <div>{task.goal}</div>
+      </div>
+      {task.subtasks.length > 0 && (
+        <div className="flex flex-col pt-2 pl-2">
+          {task.subtasks.map((subtask) => (
+            <TaskCard key={subtask.id} task={subtask} level={level + 1} />
+          ))}
+        </div>
+      )}
+    </div>
+  );
+}
+
+interface PlanProps {
+  plan: Plan;
+}
+
+function PlanContainer({ plan }: PlanProps): JSX.Element {
+  if (plan.mainGoal === undefined) {
+    return (
+      <div className="p-2">
+        Nothing is currently planned. Start a task for this to change.
+      </div>
+    );
+  }
+  return (
+    <div className="p-2 overflow-y-auto h-full flex flex-col gap-2">
+      <TaskCard task={plan.task} level={0} />
+    </div>
+  );
+}
 
 function Planner(): JSX.Element {
-  return <div className="h-full w-full bg-neutral-700">Coming soon...</div>;
+  const plan = useSelector((state: RootState) => state.plan.plan);
+
+  return (
+    <div className="h-full w-full bg-neutral-800">
+      <PlanContainer plan={plan} />
+    </div>
+  );
 }
 
 export default Planner;

+ 10 - 2
frontend/src/services/actions.ts

@@ -1,14 +1,16 @@
+import { changeTaskState } from "../state/agentSlice";
 import { setScreenshotSrc, setUrl } from "../state/browserSlice";
 import { appendAssistantMessage } from "../state/chatSlice";
 import { setCode, updatePath } from "../state/codeSlice";
 import { appendInput } from "../state/commandSlice";
+import { setPlan } from "../state/planSlice";
 import { setInitialized } from "../state/taskSlice";
 import store from "../store";
+import ActionType from "../types/ActionType";
 import { ActionMessage } from "../types/Message";
 import { SocketMessage } from "../types/ResponseType";
 import { handleObservationMessage } from "./observations";
-import ActionType from "../types/ActionType";
-import { changeTaskState } from "../state/agentSlice";
+import { getPlan } from "./planService";
 
 const messageActions = {
   [ActionType.INIT]: () => {
@@ -33,6 +35,12 @@ const messageActions = {
   [ActionType.RUN]: (message: ActionMessage) => {
     store.dispatch(appendInput(message.args.command));
   },
+  [ActionType.ADD_TASK]: () => {
+    getPlan().then((fetchedPlan) => store.dispatch(setPlan(fetchedPlan)));
+  },
+  [ActionType.MODIFY_TASK]: () => {
+    getPlan().then((fetchedPlan) => store.dispatch(setPlan(fetchedPlan)));
+  },
   [ActionType.CHANGE_TASK_STATE]: (message: ActionMessage) => {
     store.dispatch(changeTaskState(message.args.task_state));
   },

+ 33 - 0
frontend/src/services/planService.ts

@@ -0,0 +1,33 @@
+export type Plan = {
+  mainGoal: string | undefined;
+  task: Task;
+};
+
+export type Task = {
+  id: string;
+  goal: string;
+  parent: "Task | None";
+  subtasks: Task[];
+  state: TaskState;
+};
+
+export enum TaskState {
+  OPEN_STATE = "open",
+  COMPLETED_STATE = "completed",
+  ABANDONED_STATE = "abandoned",
+  IN_PROGRESS_STATE = "in_progress",
+  VERIFIED_STATE = "verified",
+}
+
+export async function getPlan(): Promise<Plan | undefined> {
+  const headers = new Headers({
+    "Content-Type": "application/json",
+    Authorization: `Bearer ${localStorage.getItem("token")}`,
+  });
+  const res = await fetch("/api/plan", { headers });
+  if (res.status !== 200) {
+    return undefined;
+  }
+  const data = await res.json();
+  return JSON.parse(data) as Plan;
+}

+ 27 - 0
frontend/src/state/planSlice.ts

@@ -0,0 +1,27 @@
+import { createSlice } from "@reduxjs/toolkit";
+import { Plan, TaskState } from "../services/planService";
+
+export const planSlice = createSlice({
+  name: "plan",
+  initialState: {
+    plan: {
+      mainGoal: undefined,
+      task: {
+        id: "",
+        goal: "",
+        parent: "Task | None",
+        subtasks: [],
+        state: TaskState.OPEN_STATE,
+      },
+    } as Plan,
+  },
+  reducers: {
+    setPlan: (state, action) => {
+      state.plan = action.payload as Plan;
+    },
+  },
+});
+
+export const { setPlan } = planSlice.actions;
+
+export default planSlice.reducer;

+ 4 - 2
frontend/src/store.ts

@@ -1,12 +1,13 @@
 import { combineReducers, configureStore } from "@reduxjs/toolkit";
+import agentReducer from "./state/agentSlice";
 import browserReducer from "./state/browserSlice";
 import chatReducer from "./state/chatSlice";
 import codeReducer from "./state/codeSlice";
 import commandReducer from "./state/commandSlice";
-import taskReducer from "./state/taskSlice";
 import errorsReducer from "./state/errorsSlice";
+import planReducer from "./state/planSlice";
 import settingsReducer from "./state/settingsSlice";
-import agentReducer from "./state/agentSlice";
+import taskReducer from "./state/taskSlice";
 
 export const rootReducer = combineReducers({
   browser: browserReducer,
@@ -16,6 +17,7 @@ export const rootReducer = combineReducers({
   task: taskReducer,
   errors: errorsReducer,
   settings: settingsReducer,
+  plan: planReducer,
   agent: agentReducer,
 });
 

+ 6 - 0
frontend/src/types/ActionType.tsx

@@ -30,6 +30,12 @@ enum ActionType {
   // use the finish action to stop working.
   FINISH = "finish",
 
+  // Adds a task to the plan.
+  ADD_TASK = "add_task",
+
+  // Updates a task in the plan.
+  MODIFY_TASK = "modify_task",
+
   CHANGE_TASK_STATE = "change_task_state",
 }
 

+ 1 - 1
frontend/src/types/TabOption.tsx

@@ -6,6 +6,6 @@ enum TabOption {
 
 type TabType = TabOption.PLANNER | TabOption.CODE | TabOption.BROWSER;
 
-const AllTabs = [TabOption.CODE, TabOption.BROWSER];
+const AllTabs = [TabOption.CODE, TabOption.BROWSER, TabOption.PLANNER];
 
 export { AllTabs, TabOption, type TabType };

+ 14 - 12
opendevin/controller/agent_controller.py

@@ -1,27 +1,26 @@
 import asyncio
 import time
-from typing import List, Callable
-from opendevin.plan import Plan
-from opendevin.state import State
-from opendevin.agent import Agent
-from opendevin.observation import Observation, AgentErrorObservation, NullObservation
+from typing import Callable, List
+
 from litellm.exceptions import APIConnectionError
 from openai import AuthenticationError
 
 from opendevin import config
-from opendevin.logger import opendevin_logger as logger
-
-from opendevin.exceptions import MaxCharsExceedError
-from .action_manager import ActionManager
-
 from opendevin.action import (
     Action,
-    NullAction,
     AgentFinishAction,
+    NullAction,
 )
-from opendevin.exceptions import AgentNoActionError
+from opendevin.agent import Agent
+from opendevin.exceptions import AgentNoActionError, MaxCharsExceedError
+from opendevin.logger import opendevin_logger as logger
+from opendevin.observation import AgentErrorObservation, NullObservation, Observation
+from opendevin.plan import Plan
+from opendevin.state import State
+
 from ..action.tasks import TaskStateChangedAction
 from ..schema import TaskState
+from .action_manager import ActionManager
 
 MAX_ITERATIONS = config.get('MAX_ITERATIONS')
 MAX_CHARS = config.get('MAX_CHARS')
@@ -219,3 +218,6 @@ class AgentController:
         await asyncio.sleep(
             0.001
         )  # Give back control for a tick, so we can await in callbacks
+
+    def get_state(self):
+        return self.state

+ 38 - 7
opendevin/server/listen.py

@@ -1,12 +1,13 @@
+import json
 import uuid
 from pathlib import Path
 
 import litellm
-from fastapi import Depends, FastAPI, WebSocket, HTTPException, Query, status
+from fastapi import Depends, FastAPI, HTTPException, Query, Response, WebSocket, status
 from fastapi.middleware.cors import CORSMiddleware
+from fastapi.responses import JSONResponse, RedirectResponse
 from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
 from fastapi.staticfiles import StaticFiles
-from fastapi.responses import RedirectResponse, JSONResponse
 
 import agenthub  # noqa F401 (we import this to get the agents registered)
 from opendevin import config, files
@@ -45,9 +46,9 @@ async def websocket_endpoint(websocket: WebSocket):
 
 @app.get('/api/litellm-models')
 async def get_litellm_models():
-    """
+    '''
     Get all models supported by LiteLLM.
-    """
+    '''
     return list(set(litellm.model_list + list(litellm.model_cost.keys())))
 
 
@@ -72,7 +73,9 @@ async def get_token(
         sid = get_sid_from_token(credentials.credentials)
         if not sid:
             sid = str(uuid.uuid4())
-            logger.info(f'Invalid or missing credentials, generating new session ID: {sid}')
+            logger.info(
+                f'Invalid or missing credentials, generating new session ID: {sid}'
+            )
     else:
         sid = str(uuid.uuid4())
         logger.info(f'No credentials provided, generating new session ID: {sid}')
@@ -117,7 +120,9 @@ def refresh_files():
 
 
 @app.get('/api/list-files')
-def list_files(relpath: str = Query(None, description='Relative path from workspace base')):
+def list_files(
+    relpath: str = Query(None, description='Relative path from workspace base')
+):
     """Refreshes and returns the files and directories from a specified subdirectory or the base directory if no subdirectory is specified, limited to one level deep."""
     base_path = Path(config.get('WORKSPACE_BASE')).resolve()
     full_path = (base_path / relpath).resolve() if relpath is not None else base_path
@@ -127,7 +132,11 @@ def list_files(relpath: str = Query(None, description='Relative path from worksp
     # Ensure path exists, is a directory,
     # And is within the workspace base directory - to prevent directory traversal attacks
     # https://owasp.org/www-community/attacks/Path_Traversal
-    if not full_path.exists() or not full_path.is_dir() or not str(full_path).startswith(str(base_path)):
+    if (
+        not full_path.exists()
+        or not full_path.is_dir()
+        or not str(full_path).startswith(str(base_path))
+    ):
         raise HTTPException(status_code=400, detail='Invalid path provided.')
 
     structure = files.get_single_level_folder_structure(base_path, full_path)
@@ -153,6 +162,28 @@ def select_file(file: str):
     return {'code': content}
 
 
+@app.get('/api/plan')
+def get_plan(
+    credentials: HTTPAuthorizationCredentials = Depends(security_scheme),
+):
+    sid = get_sid_from_token(credentials.credentials)
+    agent = agent_manager.sid_to_agent[sid]
+    controller = agent.controller
+    if controller is not None:
+        state = controller.get_state()
+        if state is not None:
+            return JSONResponse(
+                status_code=status.HTTP_200_OK,
+                content=json.dumps(
+                    {
+                        'mainGoal': state.plan.main_goal,
+                        'task': state.plan.task.to_dict(),
+                    }
+                ),
+            )
+    return Response(status_code=status.HTTP_204_NO_CONTENT)
+
+
 @app.get('/')
 async def docs_redirect():
     response = RedirectResponse(url='/index.html')