edit.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350
  1. import copy
  2. import os
  3. import re
  4. import tempfile
  5. from abc import ABC, abstractmethod
  6. from openhands.core.config import AppConfig
  7. from openhands.core.logger import openhands_logger as logger
  8. from openhands.events.action import (
  9. FileEditAction,
  10. FileReadAction,
  11. FileWriteAction,
  12. )
  13. from openhands.events.observation import (
  14. ErrorObservation,
  15. FatalErrorObservation,
  16. FileEditObservation,
  17. FileReadObservation,
  18. FileWriteObservation,
  19. Observation,
  20. )
  21. from openhands.linter import DefaultLinter
  22. from openhands.llm.llm import LLM
  23. from openhands.llm.metrics import Metrics
  24. from openhands.utils.chunk_localizer import Chunk, get_top_k_chunk_matches
  25. from openhands.utils.diff import get_diff
  26. SYS_MSG = """Your job is to produce a new version of the file based on the old version and the
  27. provided draft of the new version. The provided draft may be incomplete (it may skip lines) and/or incorrectly indented. You should try to apply the changes present in the draft to the old version, and output a new version of the file.
  28. NOTE:
  29. - The output file should be COMPLETE and CORRECTLY INDENTED. Do not omit any lines, and do not change any lines that are not part of the changes.
  30. - You should output the new version of the file by wrapping the new version of the file content in a ``` block.
  31. - If there's no explicit comment to remove the existing code, we should keep them and append the new code to the end of the file.
  32. - If there's placeholder comments like `# no changes before` or `# no changes here`, we should replace these comments with the original code near the placeholder comments.
  33. """
  34. USER_MSG = """
  35. HERE IS THE OLD VERSION OF THE FILE:
  36. ```
  37. {old_contents}
  38. ```
  39. HERE IS THE DRAFT OF THE NEW VERSION OF THE FILE:
  40. ```
  41. {draft_changes}
  42. ```
  43. GIVE ME THE NEW VERSION OF THE FILE.
  44. IMPORTANT:
  45. - There should be NO placeholder comments like `# no changes before` or `# no changes here`. They should be replaced with the original code near the placeholder comments.
  46. - The output file should be COMPLETE and CORRECTLY INDENTED. Do not omit any lines, and do not change any lines that are not part of the changes.
  47. """.strip()
  48. def _extract_code(string):
  49. pattern = r'```(?:\w*\n)?(.*?)```'
  50. matches = re.findall(pattern, string, re.DOTALL)
  51. if not matches:
  52. return None
  53. return matches[0]
  54. def get_new_file_contents(
  55. llm: LLM, old_contents: str, draft_changes: str, num_retries: int = 3
  56. ) -> str | None:
  57. while num_retries > 0:
  58. messages = [
  59. {'role': 'system', 'content': SYS_MSG},
  60. {
  61. 'role': 'user',
  62. 'content': USER_MSG.format(
  63. old_contents=old_contents, draft_changes=draft_changes
  64. ),
  65. },
  66. ]
  67. resp = llm.completion(messages=messages)
  68. new_contents = _extract_code(resp['choices'][0]['message']['content'])
  69. if new_contents is not None:
  70. return new_contents
  71. num_retries -= 1
  72. return None
  73. class FileEditRuntimeInterface(ABC):
  74. config: AppConfig
  75. @abstractmethod
  76. def read(self, action: FileReadAction) -> Observation:
  77. pass
  78. @abstractmethod
  79. def write(self, action: FileWriteAction) -> Observation:
  80. pass
  81. class FileEditRuntimeMixin(FileEditRuntimeInterface):
  82. # Most LLMs have output token limit of 4k tokens.
  83. # This restricts the number of lines we can edit to avoid exceeding the token limit.
  84. MAX_LINES_TO_EDIT = 300
  85. def __init__(self, *args, **kwargs):
  86. super().__init__(*args, **kwargs)
  87. llm_config = self.config.get_llm_config()
  88. if llm_config.draft_editor is None:
  89. llm_config.draft_editor = copy.deepcopy(llm_config)
  90. # manually set the model name for the draft editor LLM to distinguish token costs
  91. llm_metrics = Metrics(
  92. model_name='draft_editor:' + llm_config.draft_editor.model
  93. )
  94. if llm_config.draft_editor.caching_prompt:
  95. logger.info(
  96. 'It is not recommended to cache draft editor LLM prompts as it may incur high costs for the same prompt. '
  97. 'Automatically setting caching_prompt=false.'
  98. )
  99. llm_config.draft_editor.caching_prompt = False
  100. self.draft_editor_llm = LLM(llm_config.draft_editor, metrics=llm_metrics)
  101. logger.info(
  102. f'[Draft edit functionality] enabled with LLM: {self.draft_editor_llm}'
  103. )
  104. def _validate_range(
  105. self, start: int, end: int, total_lines: int
  106. ) -> Observation | None:
  107. # start and end are 1-indexed and inclusive
  108. if (
  109. (start < 1 and start != -1)
  110. or start > total_lines
  111. or (start > end and end != -1 and start != -1)
  112. ):
  113. return ErrorObservation(
  114. f'Invalid range for editing: start={start}, end={end}, total lines={total_lines}. start must be >= 1 and <={total_lines} (total lines of the edited file), start <= end, or start == -1 (append to the end of the file).'
  115. )
  116. if (
  117. (end < 1 and end != -1)
  118. or end > total_lines
  119. or (end < start and start != -1 and end != -1)
  120. ):
  121. return ErrorObservation(
  122. f'Invalid range for editing: start={start}, end={end}, total lines={total_lines}. end must be >= 1 and <= {total_lines} (total lines of the edited file), end >= start, or end == -1 (to edit till the end of the file).'
  123. )
  124. return None
  125. def _get_lint_error(
  126. self,
  127. suffix: str,
  128. old_content: str,
  129. new_content: str,
  130. filepath: str,
  131. diff: str,
  132. ) -> ErrorObservation | None:
  133. linter = DefaultLinter()
  134. # Copy the original file to a temporary file (with the same ext) and lint it
  135. with tempfile.NamedTemporaryFile(
  136. suffix=suffix, mode='w+', encoding='utf-8'
  137. ) as original_file_copy, tempfile.NamedTemporaryFile(
  138. suffix=suffix, mode='w+', encoding='utf-8'
  139. ) as updated_file_copy:
  140. # Lint the original file
  141. original_file_copy.write(old_content)
  142. original_file_copy.flush()
  143. # Lint the updated file
  144. updated_file_copy.write(new_content)
  145. updated_file_copy.flush()
  146. updated_lint_error = linter.lint_file_diff(
  147. original_file_copy.name, updated_file_copy.name
  148. )
  149. if len(updated_lint_error) > 0:
  150. _obs = FileEditObservation(
  151. content=diff,
  152. path=filepath,
  153. prev_exist=True,
  154. old_content=old_content,
  155. new_content=new_content,
  156. )
  157. error_message = (
  158. (
  159. f'\n[Linting failed for edited file {filepath}. {len(updated_lint_error)} lint errors found.]\n'
  160. '[begin attempted changes]\n'
  161. f'{_obs.visualize_diff(change_applied=False)}\n'
  162. '[end attempted changes]\n'
  163. )
  164. + '-' * 40
  165. + '\n'
  166. )
  167. error_message += '-' * 20 + 'First 5 lint errors' + '-' * 20 + '\n'
  168. for i, lint_error in enumerate(updated_lint_error[:5]):
  169. error_message += f'[begin lint error {i}]\n'
  170. error_message += lint_error.visualize().strip() + '\n'
  171. error_message += f'[end lint error {i}]\n'
  172. error_message += '-' * 40 + '\n'
  173. return ErrorObservation(error_message)
  174. return None
  175. def edit(self, action: FileEditAction) -> Observation:
  176. obs = self.read(FileReadAction(path=action.path))
  177. if (
  178. isinstance(obs, ErrorObservation)
  179. and 'File not found'.lower() in obs.content.lower()
  180. ):
  181. logger.debug(
  182. f'Agent attempted to edit a file that does not exist. Creating the file. Error msg: {obs.content}'
  183. )
  184. # directly write the new content
  185. obs = self.write(
  186. FileWriteAction(path=action.path, content=action.content.strip())
  187. )
  188. if isinstance(obs, ErrorObservation):
  189. return obs
  190. if not isinstance(obs, FileWriteObservation):
  191. return FatalErrorObservation(
  192. f'Fatal Runtime in editing: Expected FileWriteObservation, got {type(obs)}: {str(obs)}'
  193. )
  194. return FileEditObservation(
  195. content=get_diff('', action.content, action.path),
  196. path=action.path,
  197. prev_exist=False,
  198. old_content='',
  199. new_content=action.content,
  200. )
  201. if not isinstance(obs, FileReadObservation):
  202. return FatalErrorObservation(
  203. f'Fatal Runtime in editing: Expected FileReadObservation, got {type(obs)}: {str(obs)}'
  204. )
  205. original_file_content = obs.content
  206. old_file_lines = original_file_content.split('\n')
  207. # NOTE: start and end are 1-indexed
  208. start = action.start
  209. end = action.end
  210. # validate the range
  211. error = self._validate_range(start, end, len(old_file_lines))
  212. if error is not None:
  213. return error
  214. # append to the end of the file
  215. if start == -1:
  216. updated_content = '\n'.join(old_file_lines + action.content.split('\n'))
  217. diff = get_diff(original_file_content, updated_content, action.path)
  218. # Lint the updated content
  219. if self.config.sandbox.enable_auto_lint:
  220. suffix = os.path.splitext(action.path)[1]
  221. error_obs = self._get_lint_error(
  222. suffix,
  223. original_file_content,
  224. updated_content,
  225. action.path,
  226. diff,
  227. )
  228. if error_obs is not None:
  229. return error_obs
  230. obs = self.write(FileWriteAction(path=action.path, content=updated_content))
  231. return FileEditObservation(
  232. content=diff,
  233. path=action.path,
  234. prev_exist=True,
  235. old_content=original_file_content,
  236. new_content=updated_content,
  237. )
  238. # Get the 0-indexed start and end
  239. start_idx = start - 1
  240. if end != -1:
  241. # remove 1 to make it 0-indexed
  242. # then add 1 since the `end` is inclusive
  243. end_idx = end - 1 + 1
  244. else:
  245. # end == -1 means the user wants to edit till the end of the file
  246. end_idx = len(old_file_lines)
  247. # Get the range of lines to edit - reject if too long
  248. length_of_range = end_idx - start_idx
  249. if length_of_range > self.MAX_LINES_TO_EDIT + 1:
  250. error_msg = (
  251. f'[Edit error: The range of lines to edit is too long.]\n'
  252. f'[The maximum number of lines allowed to edit at once is {self.MAX_LINES_TO_EDIT}. '
  253. f'Got (L{start_idx + 1}-L{end_idx}) {length_of_range} lines.]\n' # [start_idx, end_idx), so no need to + 1
  254. )
  255. # search for relevant ranges to hint the agent
  256. topk_chunks: list[Chunk] = get_top_k_chunk_matches(
  257. text=original_file_content,
  258. query=action.content, # edit draft as query
  259. k=3,
  260. max_chunk_size=20, # lines
  261. )
  262. error_msg += (
  263. 'Here are some snippets that maybe relevant to the provided edit.\n'
  264. )
  265. for i, chunk in enumerate(topk_chunks):
  266. error_msg += f'[begin relevant snippet {i+1}. Line range: L{chunk.line_range[0]}-L{chunk.line_range[1]}. Similarity: {chunk.normalized_lcs}]\n'
  267. error_msg += f'[Browse around it via `open_file("{action.path}", {(chunk.line_range[0] + chunk.line_range[1]) // 2})`]\n'
  268. error_msg += chunk.visualize() + '\n'
  269. error_msg += f'[end relevant snippet {i+1}]\n'
  270. error_msg += '-' * 40 + '\n'
  271. error_msg += 'Consider using `open_file` to explore around the relevant snippets if needed.\n'
  272. error_msg += f'**IMPORTANT**: Please REDUCE the range of edits to less than {self.MAX_LINES_TO_EDIT} lines by setting `start` and `end` in the edit action (e.g. `<file_edit path="{action.path}" start=[PUT LINE NUMBER HERE] end=[PUT LINE NUMBER HERE] />`). '
  273. return ErrorObservation(error_msg)
  274. content_to_edit = '\n'.join(old_file_lines[start_idx:end_idx])
  275. self.draft_editor_llm.reset()
  276. _edited_content = get_new_file_contents(
  277. self.draft_editor_llm, content_to_edit, action.content
  278. )
  279. if _edited_content is None:
  280. ret_err = ErrorObservation(
  281. 'Failed to get new file contents. '
  282. 'Please try to reduce the number of edits and try again.'
  283. )
  284. ret_err.llm_metrics = self.draft_editor_llm.metrics
  285. return ret_err
  286. # piece the updated content with the unchanged content
  287. updated_lines = (
  288. old_file_lines[:start_idx]
  289. + _edited_content.split('\n')
  290. + old_file_lines[end_idx:]
  291. )
  292. updated_content = '\n'.join(updated_lines)
  293. diff = get_diff(original_file_content, updated_content, action.path)
  294. # Lint the updated content
  295. if self.config.sandbox.enable_auto_lint:
  296. suffix = os.path.splitext(action.path)[1]
  297. error_obs = self._get_lint_error(
  298. suffix, original_file_content, updated_content, action.path, diff
  299. )
  300. if error_obs is not None:
  301. error_obs.llm_metrics = self.draft_editor_llm.metrics
  302. return error_obs
  303. obs = self.write(FileWriteAction(path=action.path, content=updated_content))
  304. ret_obs = FileEditObservation(
  305. content=diff,
  306. path=action.path,
  307. prev_exist=True,
  308. old_content=original_file_content,
  309. new_content=updated_content,
  310. )
  311. ret_obs.llm_metrics = self.draft_editor_llm.metrics
  312. return ret_obs