edit.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368
  1. import copy
  2. import os
  3. import re
  4. import tempfile
  5. from abc import ABC, abstractmethod
  6. from openhands_aci.utils.diff import get_diff
  7. from openhands.core.config import AppConfig
  8. from openhands.core.logger import openhands_logger as logger
  9. from openhands.events.action import (
  10. FileEditAction,
  11. FileReadAction,
  12. FileWriteAction,
  13. IPythonRunCellAction,
  14. )
  15. from openhands.events.event import FileEditSource
  16. from openhands.events.observation import (
  17. ErrorObservation,
  18. FileEditObservation,
  19. FileReadObservation,
  20. FileWriteObservation,
  21. Observation,
  22. )
  23. from openhands.linter import DefaultLinter
  24. from openhands.llm.llm import LLM
  25. from openhands.llm.metrics import Metrics
  26. from openhands.utils.chunk_localizer import Chunk, get_top_k_chunk_matches
  27. SYS_MSG = """Your job is to produce a new version of the file based on the old version and the
  28. provided draft of the new version. The provided draft may be incomplete (it may skip lines) and/or incorrectly indented. You should try to apply the changes present in the draft to the old version, and output a new version of the file.
  29. NOTE:
  30. - The output file should be COMPLETE and CORRECTLY INDENTED. Do not omit any lines, and do not change any lines that are not part of the changes.
  31. - You should output the new version of the file by wrapping the new version of the file content in a ``` block.
  32. - If there's no explicit comment to remove the existing code, we should keep them and append the new code to the end of the file.
  33. - If there's placeholder comments like `# no changes before` or `# no changes here`, we should replace these comments with the original code near the placeholder comments.
  34. """
  35. USER_MSG = """
  36. HERE IS THE OLD VERSION OF THE FILE:
  37. ```
  38. {old_contents}
  39. ```
  40. HERE IS THE DRAFT OF THE NEW VERSION OF THE FILE:
  41. ```
  42. {draft_changes}
  43. ```
  44. GIVE ME THE NEW VERSION OF THE FILE.
  45. IMPORTANT:
  46. - There should be NO placeholder comments like `# no changes before` or `# no changes here`. They should be replaced with the original code near the placeholder comments.
  47. - The output file should be COMPLETE and CORRECTLY INDENTED. Do not omit any lines, and do not change any lines that are not part of the changes.
  48. """.strip()
  49. def _extract_code(string):
  50. pattern = r'```(?:\w*\n)?(.*?)```'
  51. matches = re.findall(pattern, string, re.DOTALL)
  52. if not matches:
  53. return None
  54. return matches[0]
  55. def get_new_file_contents(
  56. llm: LLM, old_contents: str, draft_changes: str, num_retries: int = 3
  57. ) -> str | None:
  58. while num_retries > 0:
  59. messages = [
  60. {'role': 'system', 'content': SYS_MSG},
  61. {
  62. 'role': 'user',
  63. 'content': USER_MSG.format(
  64. old_contents=old_contents, draft_changes=draft_changes
  65. ),
  66. },
  67. ]
  68. resp = llm.completion(messages=messages)
  69. new_contents = _extract_code(resp['choices'][0]['message']['content'])
  70. if new_contents is not None:
  71. return new_contents
  72. num_retries -= 1
  73. return None
  74. class FileEditRuntimeInterface(ABC):
  75. config: AppConfig
  76. @abstractmethod
  77. def read(self, action: FileReadAction) -> Observation:
  78. pass
  79. @abstractmethod
  80. def write(self, action: FileWriteAction) -> Observation:
  81. pass
  82. @abstractmethod
  83. def run_ipython(self, action: IPythonRunCellAction) -> Observation:
  84. pass
  85. class FileEditRuntimeMixin(FileEditRuntimeInterface):
  86. # Most LLMs have output token limit of 4k tokens.
  87. # This restricts the number of lines we can edit to avoid exceeding the token limit.
  88. MAX_LINES_TO_EDIT = 300
  89. def __init__(self, *args, **kwargs):
  90. super().__init__(*args, **kwargs)
  91. llm_config = self.config.get_llm_config()
  92. if llm_config.draft_editor is None:
  93. llm_config.draft_editor = copy.deepcopy(llm_config)
  94. # manually set the model name for the draft editor LLM to distinguish token costs
  95. llm_metrics = Metrics(
  96. model_name='draft_editor:' + llm_config.draft_editor.model
  97. )
  98. if llm_config.draft_editor.caching_prompt:
  99. logger.debug(
  100. 'It is not recommended to cache draft editor LLM prompts as it may incur high costs for the same prompt. '
  101. 'Automatically setting caching_prompt=false.'
  102. )
  103. llm_config.draft_editor.caching_prompt = False
  104. self.draft_editor_llm = LLM(llm_config.draft_editor, metrics=llm_metrics)
  105. logger.debug(
  106. f'[Draft edit functionality] enabled with LLM: {self.draft_editor_llm}'
  107. )
  108. def _validate_range(
  109. self, start: int, end: int, total_lines: int
  110. ) -> Observation | None:
  111. # start and end are 1-indexed and inclusive
  112. if (
  113. (start < 1 and start != -1)
  114. or start > total_lines
  115. or (start > end and end != -1 and start != -1)
  116. ):
  117. return ErrorObservation(
  118. f'Invalid range for editing: start={start}, end={end}, total lines={total_lines}. start must be >= 1 and <={total_lines} (total lines of the edited file), start <= end, or start == -1 (append to the end of the file).'
  119. )
  120. if (
  121. (end < 1 and end != -1)
  122. or end > total_lines
  123. or (end < start and start != -1 and end != -1)
  124. ):
  125. return ErrorObservation(
  126. f'Invalid range for editing: start={start}, end={end}, total lines={total_lines}. end must be >= 1 and <= {total_lines} (total lines of the edited file), end >= start, or end == -1 (to edit till the end of the file).'
  127. )
  128. return None
  129. def _get_lint_error(
  130. self,
  131. suffix: str,
  132. old_content: str,
  133. new_content: str,
  134. filepath: str,
  135. diff: str,
  136. ) -> ErrorObservation | None:
  137. linter = DefaultLinter()
  138. # Copy the original file to a temporary file (with the same ext) and lint it
  139. with (
  140. tempfile.NamedTemporaryFile(
  141. suffix=suffix, mode='w+', encoding='utf-8'
  142. ) as original_file_copy,
  143. tempfile.NamedTemporaryFile(
  144. suffix=suffix, mode='w+', encoding='utf-8'
  145. ) as updated_file_copy,
  146. ):
  147. # Lint the original file
  148. original_file_copy.write(old_content)
  149. original_file_copy.flush()
  150. # Lint the updated file
  151. updated_file_copy.write(new_content)
  152. updated_file_copy.flush()
  153. updated_lint_error = linter.lint_file_diff(
  154. original_file_copy.name, updated_file_copy.name
  155. )
  156. if len(updated_lint_error) > 0:
  157. _obs = FileEditObservation(
  158. content=diff,
  159. path=filepath,
  160. prev_exist=True,
  161. old_content=old_content,
  162. new_content=new_content,
  163. )
  164. error_message = (
  165. (
  166. f'\n[Linting failed for edited file {filepath}. {len(updated_lint_error)} lint errors found.]\n'
  167. '[begin attempted changes]\n'
  168. f'{_obs.visualize_diff(change_applied=False)}\n'
  169. '[end attempted changes]\n'
  170. )
  171. + '-' * 40
  172. + '\n'
  173. )
  174. error_message += '-' * 20 + 'First 5 lint errors' + '-' * 20 + '\n'
  175. for i, lint_error in enumerate(updated_lint_error[:5]):
  176. error_message += f'[begin lint error {i}]\n'
  177. error_message += lint_error.visualize().strip() + '\n'
  178. error_message += f'[end lint error {i}]\n'
  179. error_message += '-' * 40 + '\n'
  180. return ErrorObservation(error_message)
  181. return None
  182. def edit(self, action: FileEditAction) -> Observation:
  183. if action.impl_source == FileEditSource.OH_ACI:
  184. # Translate to ipython command to file_editor
  185. return self.run_ipython(
  186. IPythonRunCellAction(
  187. code=action.translated_ipython_code,
  188. include_extra=False,
  189. )
  190. )
  191. obs = self.read(FileReadAction(path=action.path))
  192. if (
  193. isinstance(obs, ErrorObservation)
  194. and 'File not found'.lower() in obs.content.lower()
  195. ):
  196. logger.debug(
  197. f'Agent attempted to edit a file that does not exist. Creating the file. Error msg: {obs.content}'
  198. )
  199. # directly write the new content
  200. obs = self.write(
  201. FileWriteAction(path=action.path, content=action.content.strip())
  202. )
  203. if isinstance(obs, ErrorObservation):
  204. return obs
  205. if not isinstance(obs, FileWriteObservation):
  206. raise ValueError(
  207. f'Expected FileWriteObservation, got {type(obs)}: {str(obs)}'
  208. )
  209. return FileEditObservation(
  210. content=get_diff('', action.content, action.path),
  211. path=action.path,
  212. prev_exist=False,
  213. old_content='',
  214. new_content=action.content,
  215. )
  216. if not isinstance(obs, FileReadObservation):
  217. raise ValueError(
  218. f'Expected FileReadObservation, got {type(obs)}: {str(obs)}'
  219. )
  220. original_file_content = obs.content
  221. old_file_lines = original_file_content.split('\n')
  222. # NOTE: start and end are 1-indexed
  223. start = action.start
  224. end = action.end
  225. # validate the range
  226. error = self._validate_range(start, end, len(old_file_lines))
  227. if error is not None:
  228. return error
  229. # append to the end of the file
  230. if start == -1:
  231. updated_content = '\n'.join(old_file_lines + action.content.split('\n'))
  232. diff = get_diff(original_file_content, updated_content, action.path)
  233. # Lint the updated content
  234. if self.config.sandbox.enable_auto_lint:
  235. suffix = os.path.splitext(action.path)[1]
  236. error_obs = self._get_lint_error(
  237. suffix,
  238. original_file_content,
  239. updated_content,
  240. action.path,
  241. diff,
  242. )
  243. if error_obs is not None:
  244. return error_obs
  245. obs = self.write(FileWriteAction(path=action.path, content=updated_content))
  246. return FileEditObservation(
  247. content=diff,
  248. path=action.path,
  249. prev_exist=True,
  250. old_content=original_file_content,
  251. new_content=updated_content,
  252. )
  253. # Get the 0-indexed start and end
  254. start_idx = start - 1
  255. if end != -1:
  256. # remove 1 to make it 0-indexed
  257. # then add 1 since the `end` is inclusive
  258. end_idx = end - 1 + 1
  259. else:
  260. # end == -1 means the user wants to edit till the end of the file
  261. end_idx = len(old_file_lines)
  262. # Get the range of lines to edit - reject if too long
  263. length_of_range = end_idx - start_idx
  264. if length_of_range > self.MAX_LINES_TO_EDIT + 1:
  265. error_msg = (
  266. f'[Edit error: The range of lines to edit is too long.]\n'
  267. f'[The maximum number of lines allowed to edit at once is {self.MAX_LINES_TO_EDIT}. '
  268. f'Got (L{start_idx + 1}-L{end_idx}) {length_of_range} lines.]\n' # [start_idx, end_idx), so no need to + 1
  269. )
  270. # search for relevant ranges to hint the agent
  271. topk_chunks: list[Chunk] = get_top_k_chunk_matches(
  272. text=original_file_content,
  273. query=action.content, # edit draft as query
  274. k=3,
  275. max_chunk_size=20, # lines
  276. )
  277. error_msg += (
  278. 'Here are some snippets that maybe relevant to the provided edit.\n'
  279. )
  280. for i, chunk in enumerate(topk_chunks):
  281. error_msg += f'[begin relevant snippet {i+1}. Line range: L{chunk.line_range[0]}-L{chunk.line_range[1]}. Similarity: {chunk.normalized_lcs}]\n'
  282. error_msg += f'[Browse around it via `open_file("{action.path}", {(chunk.line_range[0] + chunk.line_range[1]) // 2})`]\n'
  283. error_msg += chunk.visualize() + '\n'
  284. error_msg += f'[end relevant snippet {i+1}]\n'
  285. error_msg += '-' * 40 + '\n'
  286. error_msg += 'Consider using `open_file` to explore around the relevant snippets if needed.\n'
  287. error_msg += f'**IMPORTANT**: Please REDUCE the range of edits to less than {self.MAX_LINES_TO_EDIT} lines by setting `start` and `end` in the edit action (e.g. `<file_edit path="{action.path}" start=[PUT LINE NUMBER HERE] end=[PUT LINE NUMBER HERE] />`). '
  288. return ErrorObservation(error_msg)
  289. content_to_edit = '\n'.join(old_file_lines[start_idx:end_idx])
  290. self.draft_editor_llm.reset()
  291. _edited_content = get_new_file_contents(
  292. self.draft_editor_llm, content_to_edit, action.content
  293. )
  294. if _edited_content is None:
  295. ret_err = ErrorObservation(
  296. 'Failed to get new file contents. '
  297. 'Please try to reduce the number of edits and try again.'
  298. )
  299. ret_err.llm_metrics = self.draft_editor_llm.metrics
  300. return ret_err
  301. # piece the updated content with the unchanged content
  302. updated_lines = (
  303. old_file_lines[:start_idx]
  304. + _edited_content.split('\n')
  305. + old_file_lines[end_idx:]
  306. )
  307. updated_content = '\n'.join(updated_lines)
  308. diff = get_diff(original_file_content, updated_content, action.path)
  309. # Lint the updated content
  310. if self.config.sandbox.enable_auto_lint:
  311. suffix = os.path.splitext(action.path)[1]
  312. error_obs = self._get_lint_error(
  313. suffix, original_file_content, updated_content, action.path, diff
  314. )
  315. if error_obs is not None:
  316. error_obs.llm_metrics = self.draft_editor_llm.metrics
  317. return error_obs
  318. obs = self.write(FileWriteAction(path=action.path, content=updated_content))
  319. ret_obs = FileEditObservation(
  320. content=diff,
  321. path=action.path,
  322. prev_exist=True,
  323. old_content=original_file_content,
  324. new_content=updated_content,
  325. )
  326. ret_obs.llm_metrics = self.draft_editor_llm.metrics
  327. return ret_obs