chat-slice.ts 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182
  1. import { createSlice, PayloadAction } from "@reduxjs/toolkit";
  2. import { ActionSecurityRisk } from "#/state/security-analyzer-slice";
  3. import { OpenHandsObservation } from "#/types/core/observations";
  4. import { OpenHandsAction } from "#/types/core/actions";
  5. type SliceState = { messages: Message[] };
  6. const MAX_CONTENT_LENGTH = 1000;
  7. const HANDLED_ACTIONS = ["run", "run_ipython", "write", "read", "browse"];
  8. function getRiskText(risk: ActionSecurityRisk) {
  9. switch (risk) {
  10. case ActionSecurityRisk.LOW:
  11. return "Low Risk";
  12. case ActionSecurityRisk.MEDIUM:
  13. return "Medium Risk";
  14. case ActionSecurityRisk.HIGH:
  15. return "High Risk";
  16. case ActionSecurityRisk.UNKNOWN:
  17. default:
  18. return "Unknown Risk";
  19. }
  20. }
  21. const initialState: SliceState = {
  22. messages: [],
  23. };
  24. export const chatSlice = createSlice({
  25. name: "chat",
  26. initialState,
  27. reducers: {
  28. addUserMessage(
  29. state,
  30. action: PayloadAction<{
  31. content: string;
  32. imageUrls: string[];
  33. timestamp: string;
  34. pending?: boolean;
  35. }>,
  36. ) {
  37. const message: Message = {
  38. type: "thought",
  39. sender: "user",
  40. content: action.payload.content,
  41. imageUrls: action.payload.imageUrls,
  42. timestamp: action.payload.timestamp || new Date().toISOString(),
  43. pending: !!action.payload.pending,
  44. };
  45. // Remove any pending messages
  46. let i = state.messages.length;
  47. while (i) {
  48. i -= 1;
  49. const m = state.messages[i] as Message;
  50. if (m.pending) {
  51. state.messages.splice(i, 1);
  52. }
  53. }
  54. state.messages.push(message);
  55. },
  56. addAssistantMessage(state, action: PayloadAction<string>) {
  57. const message: Message = {
  58. type: "thought",
  59. sender: "assistant",
  60. content: action.payload,
  61. imageUrls: [],
  62. timestamp: new Date().toISOString(),
  63. pending: false,
  64. };
  65. state.messages.push(message);
  66. },
  67. addAssistantAction(state, action: PayloadAction<OpenHandsAction>) {
  68. const actionID = action.payload.action;
  69. if (!HANDLED_ACTIONS.includes(actionID)) {
  70. return;
  71. }
  72. const translationID = `ACTION_MESSAGE$${actionID.toUpperCase()}`;
  73. let text = "";
  74. if (actionID === "run") {
  75. text = `\`${action.payload.args.command}\``;
  76. } else if (actionID === "run_ipython") {
  77. text = `\`\`\`\n${action.payload.args.code}\n\`\`\``;
  78. } else if (actionID === "write") {
  79. let { content } = action.payload.args;
  80. if (content.length > MAX_CONTENT_LENGTH) {
  81. content = `${content.slice(0, MAX_CONTENT_LENGTH)}...`;
  82. }
  83. text = `${action.payload.args.path}\n${content}`;
  84. } else if (actionID === "read") {
  85. text = action.payload.args.path;
  86. } else if (actionID === "browse") {
  87. text = `Browsing ${action.payload.args.url}`;
  88. }
  89. if (actionID === "run" || actionID === "run_ipython") {
  90. if (
  91. action.payload.args.confirmation_state === "awaiting_confirmation"
  92. ) {
  93. text += `\n\n${getRiskText(action.payload.args.security_risk as unknown as ActionSecurityRisk)}`;
  94. }
  95. }
  96. const message: Message = {
  97. type: "action",
  98. sender: "assistant",
  99. translationID,
  100. eventID: action.payload.id,
  101. content: text,
  102. imageUrls: [],
  103. timestamp: new Date().toISOString(),
  104. };
  105. state.messages.push(message);
  106. },
  107. addAssistantObservation(
  108. state,
  109. observation: PayloadAction<OpenHandsObservation>,
  110. ) {
  111. const observationID = observation.payload.observation;
  112. if (!HANDLED_ACTIONS.includes(observationID)) {
  113. return;
  114. }
  115. const translationID = `OBSERVATION_MESSAGE$${observationID.toUpperCase()}`;
  116. const causeID = observation.payload.cause;
  117. const causeMessage = state.messages.find(
  118. (message) => message.eventID === causeID,
  119. );
  120. if (!causeMessage) {
  121. return;
  122. }
  123. causeMessage.translationID = translationID;
  124. if (observationID === "run" || observationID === "run_ipython") {
  125. let { content } = observation.payload;
  126. if (content.length > MAX_CONTENT_LENGTH) {
  127. content = `${content.slice(0, MAX_CONTENT_LENGTH)}...`;
  128. }
  129. content = `\`\`\`\n${content}\n\`\`\``;
  130. causeMessage.content = content; // Observation content includes the action
  131. } else if (observationID === "browse") {
  132. let content = `**URL:** ${observation.payload.extras.url}\n`;
  133. if (observation.payload.extras.error) {
  134. content += `**Error:**\n${observation.payload.extras.error}\n`;
  135. }
  136. content += `**Output:**\n${observation.payload.content}`;
  137. if (content.length > MAX_CONTENT_LENGTH) {
  138. content = `${content.slice(0, MAX_CONTENT_LENGTH)}...`;
  139. }
  140. causeMessage.content = content;
  141. }
  142. },
  143. addErrorMessage(
  144. state,
  145. action: PayloadAction<{ id?: string; message: string }>,
  146. ) {
  147. const { id, message } = action.payload;
  148. state.messages.push({
  149. translationID: id,
  150. content: message,
  151. type: "error",
  152. sender: "assistant",
  153. timestamp: new Date().toISOString(),
  154. });
  155. },
  156. clearMessages(state) {
  157. state.messages = [];
  158. },
  159. },
  160. });
  161. export const {
  162. addUserMessage,
  163. addAssistantMessage,
  164. addAssistantAction,
  165. addAssistantObservation,
  166. addErrorMessage,
  167. clearMessages,
  168. } = chatSlice.actions;
  169. export default chatSlice.reducer;