edit.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349
  1. import copy
  2. import os
  3. import re
  4. import tempfile
  5. from abc import ABC, abstractmethod
  6. from openhands_aci.utils.diff import get_diff
  7. from openhands.core.config import AppConfig
  8. from openhands.core.logger import openhands_logger as logger
  9. from openhands.events.action import FileEditAction, FileReadAction, FileWriteAction
  10. from openhands.events.observation import (
  11. ErrorObservation,
  12. FileEditObservation,
  13. FileReadObservation,
  14. FileWriteObservation,
  15. Observation,
  16. )
  17. from openhands.linter import DefaultLinter
  18. from openhands.llm.llm import LLM
  19. from openhands.llm.metrics import Metrics
  20. from openhands.utils.chunk_localizer import Chunk, get_top_k_chunk_matches
  21. SYS_MSG = """Your job is to produce a new version of the file based on the old version and the
  22. provided draft of the new version. The provided draft may be incomplete (it may skip lines) and/or incorrectly indented. You should try to apply the changes present in the draft to the old version, and output a new version of the file.
  23. NOTE:
  24. - The output file should be COMPLETE and CORRECTLY INDENTED. Do not omit any lines, and do not change any lines that are not part of the changes.
  25. - You should output the new version of the file by wrapping the new version of the file content in a ``` block.
  26. - If there's no explicit comment to remove the existing code, we should keep them and append the new code to the end of the file.
  27. - If there's placeholder comments like `# no changes before` or `# no changes here`, we should replace these comments with the original code near the placeholder comments.
  28. """
  29. USER_MSG = """
  30. HERE IS THE OLD VERSION OF THE FILE:
  31. ```
  32. {old_contents}
  33. ```
  34. HERE IS THE DRAFT OF THE NEW VERSION OF THE FILE:
  35. ```
  36. {draft_changes}
  37. ```
  38. GIVE ME THE NEW VERSION OF THE FILE.
  39. IMPORTANT:
  40. - There should be NO placeholder comments like `# no changes before` or `# no changes here`. They should be replaced with the original code near the placeholder comments.
  41. - The output file should be COMPLETE and CORRECTLY INDENTED. Do not omit any lines, and do not change any lines that are not part of the changes.
  42. """.strip()
  43. def _extract_code(string):
  44. pattern = r'```(?:\w*\n)?(.*?)```'
  45. matches = re.findall(pattern, string, re.DOTALL)
  46. if not matches:
  47. return None
  48. return matches[0]
  49. def get_new_file_contents(
  50. llm: LLM, old_contents: str, draft_changes: str, num_retries: int = 3
  51. ) -> str | None:
  52. while num_retries > 0:
  53. messages = [
  54. {'role': 'system', 'content': SYS_MSG},
  55. {
  56. 'role': 'user',
  57. 'content': USER_MSG.format(
  58. old_contents=old_contents, draft_changes=draft_changes
  59. ),
  60. },
  61. ]
  62. resp = llm.completion(messages=messages)
  63. new_contents = _extract_code(resp['choices'][0]['message']['content'])
  64. if new_contents is not None:
  65. return new_contents
  66. num_retries -= 1
  67. return None
  68. class FileEditRuntimeInterface(ABC):
  69. config: AppConfig
  70. @abstractmethod
  71. def read(self, action: FileReadAction) -> Observation:
  72. pass
  73. @abstractmethod
  74. def write(self, action: FileWriteAction) -> Observation:
  75. pass
  76. class FileEditRuntimeMixin(FileEditRuntimeInterface):
  77. # Most LLMs have output token limit of 4k tokens.
  78. # This restricts the number of lines we can edit to avoid exceeding the token limit.
  79. MAX_LINES_TO_EDIT = 300
  80. def __init__(self, *args, **kwargs):
  81. super().__init__(*args, **kwargs)
  82. llm_config = self.config.get_llm_config()
  83. if llm_config.draft_editor is None:
  84. llm_config.draft_editor = copy.deepcopy(llm_config)
  85. # manually set the model name for the draft editor LLM to distinguish token costs
  86. llm_metrics = Metrics(
  87. model_name='draft_editor:' + llm_config.draft_editor.model
  88. )
  89. if llm_config.draft_editor.caching_prompt:
  90. logger.debug(
  91. 'It is not recommended to cache draft editor LLM prompts as it may incur high costs for the same prompt. '
  92. 'Automatically setting caching_prompt=false.'
  93. )
  94. llm_config.draft_editor.caching_prompt = False
  95. self.draft_editor_llm = LLM(llm_config.draft_editor, metrics=llm_metrics)
  96. logger.debug(
  97. f'[Draft edit functionality] enabled with LLM: {self.draft_editor_llm}'
  98. )
  99. def _validate_range(
  100. self, start: int, end: int, total_lines: int
  101. ) -> Observation | None:
  102. # start and end are 1-indexed and inclusive
  103. if (
  104. (start < 1 and start != -1)
  105. or start > total_lines
  106. or (start > end and end != -1 and start != -1)
  107. ):
  108. return ErrorObservation(
  109. f'Invalid range for editing: start={start}, end={end}, total lines={total_lines}. start must be >= 1 and <={total_lines} (total lines of the edited file), start <= end, or start == -1 (append to the end of the file).'
  110. )
  111. if (
  112. (end < 1 and end != -1)
  113. or end > total_lines
  114. or (end < start and start != -1 and end != -1)
  115. ):
  116. return ErrorObservation(
  117. f'Invalid range for editing: start={start}, end={end}, total lines={total_lines}. end must be >= 1 and <= {total_lines} (total lines of the edited file), end >= start, or end == -1 (to edit till the end of the file).'
  118. )
  119. return None
  120. def _get_lint_error(
  121. self,
  122. suffix: str,
  123. old_content: str,
  124. new_content: str,
  125. filepath: str,
  126. diff: str,
  127. ) -> ErrorObservation | None:
  128. linter = DefaultLinter()
  129. # Copy the original file to a temporary file (with the same ext) and lint it
  130. with (
  131. tempfile.NamedTemporaryFile(
  132. suffix=suffix, mode='w+', encoding='utf-8'
  133. ) as original_file_copy,
  134. tempfile.NamedTemporaryFile(
  135. suffix=suffix, mode='w+', encoding='utf-8'
  136. ) as updated_file_copy,
  137. ):
  138. # Lint the original file
  139. original_file_copy.write(old_content)
  140. original_file_copy.flush()
  141. # Lint the updated file
  142. updated_file_copy.write(new_content)
  143. updated_file_copy.flush()
  144. updated_lint_error = linter.lint_file_diff(
  145. original_file_copy.name, updated_file_copy.name
  146. )
  147. if len(updated_lint_error) > 0:
  148. _obs = FileEditObservation(
  149. content=diff,
  150. path=filepath,
  151. prev_exist=True,
  152. old_content=old_content,
  153. new_content=new_content,
  154. )
  155. error_message = (
  156. (
  157. f'\n[Linting failed for edited file {filepath}. {len(updated_lint_error)} lint errors found.]\n'
  158. '[begin attempted changes]\n'
  159. f'{_obs.visualize_diff(change_applied=False)}\n'
  160. '[end attempted changes]\n'
  161. )
  162. + '-' * 40
  163. + '\n'
  164. )
  165. error_message += '-' * 20 + 'First 5 lint errors' + '-' * 20 + '\n'
  166. for i, lint_error in enumerate(updated_lint_error[:5]):
  167. error_message += f'[begin lint error {i}]\n'
  168. error_message += lint_error.visualize().strip() + '\n'
  169. error_message += f'[end lint error {i}]\n'
  170. error_message += '-' * 40 + '\n'
  171. return ErrorObservation(error_message)
  172. return None
  173. def edit(self, action: FileEditAction) -> Observation:
  174. obs = self.read(FileReadAction(path=action.path))
  175. if (
  176. isinstance(obs, ErrorObservation)
  177. and 'File not found'.lower() in obs.content.lower()
  178. ):
  179. logger.debug(
  180. f'Agent attempted to edit a file that does not exist. Creating the file. Error msg: {obs.content}'
  181. )
  182. # directly write the new content
  183. obs = self.write(
  184. FileWriteAction(path=action.path, content=action.content.strip())
  185. )
  186. if isinstance(obs, ErrorObservation):
  187. return obs
  188. if not isinstance(obs, FileWriteObservation):
  189. raise ValueError(
  190. f'Expected FileWriteObservation, got {type(obs)}: {str(obs)}'
  191. )
  192. return FileEditObservation(
  193. content=get_diff('', action.content, action.path),
  194. path=action.path,
  195. prev_exist=False,
  196. old_content='',
  197. new_content=action.content,
  198. )
  199. if not isinstance(obs, FileReadObservation):
  200. raise ValueError(
  201. f'Expected FileReadObservation, got {type(obs)}: {str(obs)}'
  202. )
  203. original_file_content = obs.content
  204. old_file_lines = original_file_content.split('\n')
  205. # NOTE: start and end are 1-indexed
  206. start = action.start
  207. end = action.end
  208. # validate the range
  209. error = self._validate_range(start, end, len(old_file_lines))
  210. if error is not None:
  211. return error
  212. # append to the end of the file
  213. if start == -1:
  214. updated_content = '\n'.join(old_file_lines + action.content.split('\n'))
  215. diff = get_diff(original_file_content, updated_content, action.path)
  216. # Lint the updated content
  217. if self.config.sandbox.enable_auto_lint:
  218. suffix = os.path.splitext(action.path)[1]
  219. error_obs = self._get_lint_error(
  220. suffix,
  221. original_file_content,
  222. updated_content,
  223. action.path,
  224. diff,
  225. )
  226. if error_obs is not None:
  227. return error_obs
  228. obs = self.write(FileWriteAction(path=action.path, content=updated_content))
  229. return FileEditObservation(
  230. content=diff,
  231. path=action.path,
  232. prev_exist=True,
  233. old_content=original_file_content,
  234. new_content=updated_content,
  235. )
  236. # Get the 0-indexed start and end
  237. start_idx = start - 1
  238. if end != -1:
  239. # remove 1 to make it 0-indexed
  240. # then add 1 since the `end` is inclusive
  241. end_idx = end - 1 + 1
  242. else:
  243. # end == -1 means the user wants to edit till the end of the file
  244. end_idx = len(old_file_lines)
  245. # Get the range of lines to edit - reject if too long
  246. length_of_range = end_idx - start_idx
  247. if length_of_range > self.MAX_LINES_TO_EDIT + 1:
  248. error_msg = (
  249. f'[Edit error: The range of lines to edit is too long.]\n'
  250. f'[The maximum number of lines allowed to edit at once is {self.MAX_LINES_TO_EDIT}. '
  251. f'Got (L{start_idx + 1}-L{end_idx}) {length_of_range} lines.]\n' # [start_idx, end_idx), so no need to + 1
  252. )
  253. # search for relevant ranges to hint the agent
  254. topk_chunks: list[Chunk] = get_top_k_chunk_matches(
  255. text=original_file_content,
  256. query=action.content, # edit draft as query
  257. k=3,
  258. max_chunk_size=20, # lines
  259. )
  260. error_msg += (
  261. 'Here are some snippets that maybe relevant to the provided edit.\n'
  262. )
  263. for i, chunk in enumerate(topk_chunks):
  264. error_msg += f'[begin relevant snippet {i+1}. Line range: L{chunk.line_range[0]}-L{chunk.line_range[1]}. Similarity: {chunk.normalized_lcs}]\n'
  265. error_msg += f'[Browse around it via `open_file("{action.path}", {(chunk.line_range[0] + chunk.line_range[1]) // 2})`]\n'
  266. error_msg += chunk.visualize() + '\n'
  267. error_msg += f'[end relevant snippet {i+1}]\n'
  268. error_msg += '-' * 40 + '\n'
  269. error_msg += 'Consider using `open_file` to explore around the relevant snippets if needed.\n'
  270. error_msg += f'**IMPORTANT**: Please REDUCE the range of edits to less than {self.MAX_LINES_TO_EDIT} lines by setting `start` and `end` in the edit action (e.g. `<file_edit path="{action.path}" start=[PUT LINE NUMBER HERE] end=[PUT LINE NUMBER HERE] />`). '
  271. return ErrorObservation(error_msg)
  272. content_to_edit = '\n'.join(old_file_lines[start_idx:end_idx])
  273. self.draft_editor_llm.reset()
  274. _edited_content = get_new_file_contents(
  275. self.draft_editor_llm, content_to_edit, action.content
  276. )
  277. if _edited_content is None:
  278. ret_err = ErrorObservation(
  279. 'Failed to get new file contents. '
  280. 'Please try to reduce the number of edits and try again.'
  281. )
  282. ret_err.llm_metrics = self.draft_editor_llm.metrics
  283. return ret_err
  284. # piece the updated content with the unchanged content
  285. updated_lines = (
  286. old_file_lines[:start_idx]
  287. + _edited_content.split('\n')
  288. + old_file_lines[end_idx:]
  289. )
  290. updated_content = '\n'.join(updated_lines)
  291. diff = get_diff(original_file_content, updated_content, action.path)
  292. # Lint the updated content
  293. if self.config.sandbox.enable_auto_lint:
  294. suffix = os.path.splitext(action.path)[1]
  295. error_obs = self._get_lint_error(
  296. suffix, original_file_content, updated_content, action.path, diff
  297. )
  298. if error_obs is not None:
  299. error_obs.llm_metrics = self.draft_editor_llm.metrics
  300. return error_obs
  301. obs = self.write(FileWriteAction(path=action.path, content=updated_content))
  302. ret_obs = FileEditObservation(
  303. content=diff,
  304. path=action.path,
  305. prev_exist=True,
  306. old_content=original_file_content,
  307. new_content=updated_content,
  308. )
  309. ret_obs.llm_metrics = self.draft_editor_llm.metrics
  310. return ret_obs