edit.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349
  1. import copy
  2. import os
  3. import re
  4. import tempfile
  5. from abc import ABC, abstractmethod
  6. from openhands.core.config import AppConfig
  7. from openhands.core.logger import openhands_logger as logger
  8. from openhands.events.action import (
  9. FileEditAction,
  10. FileReadAction,
  11. FileWriteAction,
  12. )
  13. from openhands.events.observation import (
  14. ErrorObservation,
  15. FileEditObservation,
  16. FileReadObservation,
  17. FileWriteObservation,
  18. Observation,
  19. )
  20. from openhands.linter import DefaultLinter
  21. from openhands.llm.llm import LLM
  22. from openhands.llm.metrics import Metrics
  23. from openhands.utils.chunk_localizer import Chunk, get_top_k_chunk_matches
  24. from openhands.utils.diff import get_diff
  25. SYS_MSG = """Your job is to produce a new version of the file based on the old version and the
  26. provided draft of the new version. The provided draft may be incomplete (it may skip lines) and/or incorrectly indented. You should try to apply the changes present in the draft to the old version, and output a new version of the file.
  27. NOTE:
  28. - The output file should be COMPLETE and CORRECTLY INDENTED. Do not omit any lines, and do not change any lines that are not part of the changes.
  29. - You should output the new version of the file by wrapping the new version of the file content in a ``` block.
  30. - If there's no explicit comment to remove the existing code, we should keep them and append the new code to the end of the file.
  31. - If there's placeholder comments like `# no changes before` or `# no changes here`, we should replace these comments with the original code near the placeholder comments.
  32. """
  33. USER_MSG = """
  34. HERE IS THE OLD VERSION OF THE FILE:
  35. ```
  36. {old_contents}
  37. ```
  38. HERE IS THE DRAFT OF THE NEW VERSION OF THE FILE:
  39. ```
  40. {draft_changes}
  41. ```
  42. GIVE ME THE NEW VERSION OF THE FILE.
  43. IMPORTANT:
  44. - There should be NO placeholder comments like `# no changes before` or `# no changes here`. They should be replaced with the original code near the placeholder comments.
  45. - The output file should be COMPLETE and CORRECTLY INDENTED. Do not omit any lines, and do not change any lines that are not part of the changes.
  46. """.strip()
  47. def _extract_code(string):
  48. pattern = r'```(?:\w*\n)?(.*?)```'
  49. matches = re.findall(pattern, string, re.DOTALL)
  50. if not matches:
  51. return None
  52. return matches[0]
  53. def get_new_file_contents(
  54. llm: LLM, old_contents: str, draft_changes: str, num_retries: int = 3
  55. ) -> str | None:
  56. while num_retries > 0:
  57. messages = [
  58. {'role': 'system', 'content': SYS_MSG},
  59. {
  60. 'role': 'user',
  61. 'content': USER_MSG.format(
  62. old_contents=old_contents, draft_changes=draft_changes
  63. ),
  64. },
  65. ]
  66. resp = llm.completion(messages=messages)
  67. new_contents = _extract_code(resp['choices'][0]['message']['content'])
  68. if new_contents is not None:
  69. return new_contents
  70. num_retries -= 1
  71. return None
  72. class FileEditRuntimeInterface(ABC):
  73. config: AppConfig
  74. @abstractmethod
  75. def read(self, action: FileReadAction) -> Observation:
  76. pass
  77. @abstractmethod
  78. def write(self, action: FileWriteAction) -> Observation:
  79. pass
  80. class FileEditRuntimeMixin(FileEditRuntimeInterface):
  81. # Most LLMs have output token limit of 4k tokens.
  82. # This restricts the number of lines we can edit to avoid exceeding the token limit.
  83. MAX_LINES_TO_EDIT = 300
  84. def __init__(self, *args, **kwargs):
  85. super().__init__(*args, **kwargs)
  86. llm_config = self.config.get_llm_config()
  87. if llm_config.draft_editor is None:
  88. llm_config.draft_editor = copy.deepcopy(llm_config)
  89. # manually set the model name for the draft editor LLM to distinguish token costs
  90. llm_metrics = Metrics(
  91. model_name='draft_editor:' + llm_config.draft_editor.model
  92. )
  93. if llm_config.draft_editor.caching_prompt:
  94. logger.debug(
  95. 'It is not recommended to cache draft editor LLM prompts as it may incur high costs for the same prompt. '
  96. 'Automatically setting caching_prompt=false.'
  97. )
  98. llm_config.draft_editor.caching_prompt = False
  99. self.draft_editor_llm = LLM(llm_config.draft_editor, metrics=llm_metrics)
  100. logger.debug(
  101. f'[Draft edit functionality] enabled with LLM: {self.draft_editor_llm}'
  102. )
  103. def _validate_range(
  104. self, start: int, end: int, total_lines: int
  105. ) -> Observation | None:
  106. # start and end are 1-indexed and inclusive
  107. if (
  108. (start < 1 and start != -1)
  109. or start > total_lines
  110. or (start > end and end != -1 and start != -1)
  111. ):
  112. return ErrorObservation(
  113. f'Invalid range for editing: start={start}, end={end}, total lines={total_lines}. start must be >= 1 and <={total_lines} (total lines of the edited file), start <= end, or start == -1 (append to the end of the file).'
  114. )
  115. if (
  116. (end < 1 and end != -1)
  117. or end > total_lines
  118. or (end < start and start != -1 and end != -1)
  119. ):
  120. return ErrorObservation(
  121. f'Invalid range for editing: start={start}, end={end}, total lines={total_lines}. end must be >= 1 and <= {total_lines} (total lines of the edited file), end >= start, or end == -1 (to edit till the end of the file).'
  122. )
  123. return None
  124. def _get_lint_error(
  125. self,
  126. suffix: str,
  127. old_content: str,
  128. new_content: str,
  129. filepath: str,
  130. diff: str,
  131. ) -> ErrorObservation | None:
  132. linter = DefaultLinter()
  133. # Copy the original file to a temporary file (with the same ext) and lint it
  134. with tempfile.NamedTemporaryFile(
  135. suffix=suffix, mode='w+', encoding='utf-8'
  136. ) as original_file_copy, tempfile.NamedTemporaryFile(
  137. suffix=suffix, mode='w+', encoding='utf-8'
  138. ) as updated_file_copy:
  139. # Lint the original file
  140. original_file_copy.write(old_content)
  141. original_file_copy.flush()
  142. # Lint the updated file
  143. updated_file_copy.write(new_content)
  144. updated_file_copy.flush()
  145. updated_lint_error = linter.lint_file_diff(
  146. original_file_copy.name, updated_file_copy.name
  147. )
  148. if len(updated_lint_error) > 0:
  149. _obs = FileEditObservation(
  150. content=diff,
  151. path=filepath,
  152. prev_exist=True,
  153. old_content=old_content,
  154. new_content=new_content,
  155. )
  156. error_message = (
  157. (
  158. f'\n[Linting failed for edited file {filepath}. {len(updated_lint_error)} lint errors found.]\n'
  159. '[begin attempted changes]\n'
  160. f'{_obs.visualize_diff(change_applied=False)}\n'
  161. '[end attempted changes]\n'
  162. )
  163. + '-' * 40
  164. + '\n'
  165. )
  166. error_message += '-' * 20 + 'First 5 lint errors' + '-' * 20 + '\n'
  167. for i, lint_error in enumerate(updated_lint_error[:5]):
  168. error_message += f'[begin lint error {i}]\n'
  169. error_message += lint_error.visualize().strip() + '\n'
  170. error_message += f'[end lint error {i}]\n'
  171. error_message += '-' * 40 + '\n'
  172. return ErrorObservation(error_message)
  173. return None
  174. def edit(self, action: FileEditAction) -> Observation:
  175. obs = self.read(FileReadAction(path=action.path))
  176. if (
  177. isinstance(obs, ErrorObservation)
  178. and 'File not found'.lower() in obs.content.lower()
  179. ):
  180. logger.debug(
  181. f'Agent attempted to edit a file that does not exist. Creating the file. Error msg: {obs.content}'
  182. )
  183. # directly write the new content
  184. obs = self.write(
  185. FileWriteAction(path=action.path, content=action.content.strip())
  186. )
  187. if isinstance(obs, ErrorObservation):
  188. return obs
  189. if not isinstance(obs, FileWriteObservation):
  190. raise ValueError(
  191. f'Expected FileWriteObservation, got {type(obs)}: {str(obs)}'
  192. )
  193. return FileEditObservation(
  194. content=get_diff('', action.content, action.path),
  195. path=action.path,
  196. prev_exist=False,
  197. old_content='',
  198. new_content=action.content,
  199. )
  200. if not isinstance(obs, FileReadObservation):
  201. raise ValueError(
  202. f'Expected FileReadObservation, got {type(obs)}: {str(obs)}'
  203. )
  204. original_file_content = obs.content
  205. old_file_lines = original_file_content.split('\n')
  206. # NOTE: start and end are 1-indexed
  207. start = action.start
  208. end = action.end
  209. # validate the range
  210. error = self._validate_range(start, end, len(old_file_lines))
  211. if error is not None:
  212. return error
  213. # append to the end of the file
  214. if start == -1:
  215. updated_content = '\n'.join(old_file_lines + action.content.split('\n'))
  216. diff = get_diff(original_file_content, updated_content, action.path)
  217. # Lint the updated content
  218. if self.config.sandbox.enable_auto_lint:
  219. suffix = os.path.splitext(action.path)[1]
  220. error_obs = self._get_lint_error(
  221. suffix,
  222. original_file_content,
  223. updated_content,
  224. action.path,
  225. diff,
  226. )
  227. if error_obs is not None:
  228. return error_obs
  229. obs = self.write(FileWriteAction(path=action.path, content=updated_content))
  230. return FileEditObservation(
  231. content=diff,
  232. path=action.path,
  233. prev_exist=True,
  234. old_content=original_file_content,
  235. new_content=updated_content,
  236. )
  237. # Get the 0-indexed start and end
  238. start_idx = start - 1
  239. if end != -1:
  240. # remove 1 to make it 0-indexed
  241. # then add 1 since the `end` is inclusive
  242. end_idx = end - 1 + 1
  243. else:
  244. # end == -1 means the user wants to edit till the end of the file
  245. end_idx = len(old_file_lines)
  246. # Get the range of lines to edit - reject if too long
  247. length_of_range = end_idx - start_idx
  248. if length_of_range > self.MAX_LINES_TO_EDIT + 1:
  249. error_msg = (
  250. f'[Edit error: The range of lines to edit is too long.]\n'
  251. f'[The maximum number of lines allowed to edit at once is {self.MAX_LINES_TO_EDIT}. '
  252. f'Got (L{start_idx + 1}-L{end_idx}) {length_of_range} lines.]\n' # [start_idx, end_idx), so no need to + 1
  253. )
  254. # search for relevant ranges to hint the agent
  255. topk_chunks: list[Chunk] = get_top_k_chunk_matches(
  256. text=original_file_content,
  257. query=action.content, # edit draft as query
  258. k=3,
  259. max_chunk_size=20, # lines
  260. )
  261. error_msg += (
  262. 'Here are some snippets that maybe relevant to the provided edit.\n'
  263. )
  264. for i, chunk in enumerate(topk_chunks):
  265. error_msg += f'[begin relevant snippet {i+1}. Line range: L{chunk.line_range[0]}-L{chunk.line_range[1]}. Similarity: {chunk.normalized_lcs}]\n'
  266. error_msg += f'[Browse around it via `open_file("{action.path}", {(chunk.line_range[0] + chunk.line_range[1]) // 2})`]\n'
  267. error_msg += chunk.visualize() + '\n'
  268. error_msg += f'[end relevant snippet {i+1}]\n'
  269. error_msg += '-' * 40 + '\n'
  270. error_msg += 'Consider using `open_file` to explore around the relevant snippets if needed.\n'
  271. error_msg += f'**IMPORTANT**: Please REDUCE the range of edits to less than {self.MAX_LINES_TO_EDIT} lines by setting `start` and `end` in the edit action (e.g. `<file_edit path="{action.path}" start=[PUT LINE NUMBER HERE] end=[PUT LINE NUMBER HERE] />`). '
  272. return ErrorObservation(error_msg)
  273. content_to_edit = '\n'.join(old_file_lines[start_idx:end_idx])
  274. self.draft_editor_llm.reset()
  275. _edited_content = get_new_file_contents(
  276. self.draft_editor_llm, content_to_edit, action.content
  277. )
  278. if _edited_content is None:
  279. ret_err = ErrorObservation(
  280. 'Failed to get new file contents. '
  281. 'Please try to reduce the number of edits and try again.'
  282. )
  283. ret_err.llm_metrics = self.draft_editor_llm.metrics
  284. return ret_err
  285. # piece the updated content with the unchanged content
  286. updated_lines = (
  287. old_file_lines[:start_idx]
  288. + _edited_content.split('\n')
  289. + old_file_lines[end_idx:]
  290. )
  291. updated_content = '\n'.join(updated_lines)
  292. diff = get_diff(original_file_content, updated_content, action.path)
  293. # Lint the updated content
  294. if self.config.sandbox.enable_auto_lint:
  295. suffix = os.path.splitext(action.path)[1]
  296. error_obs = self._get_lint_error(
  297. suffix, original_file_content, updated_content, action.path, diff
  298. )
  299. if error_obs is not None:
  300. error_obs.llm_metrics = self.draft_editor_llm.metrics
  301. return error_obs
  302. obs = self.write(FileWriteAction(path=action.path, content=updated_content))
  303. ret_obs = FileEditObservation(
  304. content=diff,
  305. path=action.path,
  306. prev_exist=True,
  307. old_content=original_file_content,
  308. new_content=updated_content,
  309. )
  310. ret_obs.llm_metrics = self.draft_editor_llm.metrics
  311. return ret_obs