impl.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279
  1. from collections import defaultdict
  2. from pathlib import Path
  3. from typing import Literal, get_args
  4. from .base import CLIResult, ToolError, ToolResult
  5. from .run import maybe_truncate, run
  6. Command = Literal[
  7. 'view',
  8. 'create',
  9. 'str_replace',
  10. 'insert',
  11. 'undo_edit',
  12. ]
  13. SNIPPET_LINES: int = 4
  14. class EditTool:
  15. """
  16. An filesystem editor tool that allows the agent to view, create, and edit files.
  17. The tool parameters are defined by Anthropic and are not editable.
  18. Original implementation: https://github.com/anthropics/anthropic-quickstarts/blob/main/computer-use-demo/computer_use_demo/tools/edit.py
  19. """
  20. _file_history: dict[Path, list[str]]
  21. def __init__(self):
  22. self._file_history = defaultdict(list)
  23. super().__init__()
  24. def __call__(
  25. self,
  26. *,
  27. command: Command,
  28. path: str,
  29. file_text: str | None = None,
  30. view_range: list[int] | None = None,
  31. old_str: str | None = None,
  32. new_str: str | None = None,
  33. insert_line: int | None = None,
  34. **kwargs,
  35. ):
  36. _path = Path(path)
  37. self.validate_path(command, _path)
  38. if command == 'view':
  39. return self.view(_path, view_range)
  40. elif command == 'create':
  41. if file_text is None:
  42. raise ToolError('Parameter `file_text` is required for command: create')
  43. self.write_file(_path, file_text)
  44. self._file_history[_path].append(file_text)
  45. return ToolResult(output=f'File created successfully at: {_path}')
  46. elif command == 'str_replace':
  47. if old_str is None:
  48. raise ToolError(
  49. 'Parameter `old_str` is required for command: str_replace'
  50. )
  51. return self.str_replace(_path, old_str, new_str)
  52. elif command == 'insert':
  53. if insert_line is None:
  54. raise ToolError(
  55. 'Parameter `insert_line` is required for command: insert'
  56. )
  57. if new_str is None:
  58. raise ToolError('Parameter `new_str` is required for command: insert')
  59. return self.insert(_path, insert_line, new_str)
  60. elif command == 'undo_edit':
  61. return self.undo_edit(_path)
  62. raise ToolError(
  63. f'Unrecognized command {command}. The allowed commands for the {self.name} tool are: {", ".join(get_args(Command))}'
  64. )
  65. def validate_path(self, command: str, path: Path):
  66. """
  67. Check that the path/command combination is valid.
  68. """
  69. # Check if its an absolute path
  70. if not path.is_absolute():
  71. suggested_path = Path('') / path
  72. raise ToolError(
  73. f'The path {path} is not an absolute path, it should start with `/`. Maybe you meant {suggested_path}?'
  74. )
  75. # Check if path exists
  76. if not path.exists() and command != 'create':
  77. raise ToolError(
  78. f'The path {path} does not exist. Please provide a valid path.'
  79. )
  80. if path.exists() and command == 'create':
  81. raise ToolError(
  82. f'File already exists at: {path}. Cannot overwrite files using command `create`.'
  83. )
  84. # Check if the path points to a directory
  85. if path.is_dir():
  86. if command != 'view':
  87. raise ToolError(
  88. f'The path {path} is a directory and only the `view` command can be used on directories'
  89. )
  90. def view(self, path: Path, view_range: list[int] | None = None):
  91. """Implement the view command"""
  92. if path.is_dir():
  93. if view_range:
  94. raise ToolError(
  95. 'The `view_range` parameter is not allowed when `path` points to a directory.'
  96. )
  97. _, stdout, stderr = run(rf"find {path} -maxdepth 2 -not -path '*/\.*'")
  98. if not stderr:
  99. stdout = f"Here's the files and directories up to 2 levels deep in {path}, excluding hidden items:\n{stdout}\n"
  100. return CLIResult(output=stdout, error=stderr)
  101. file_content = self.read_file(path)
  102. init_line = 1
  103. if view_range:
  104. if len(view_range) != 2 or not all(isinstance(i, int) for i in view_range):
  105. raise ToolError(
  106. 'Invalid `view_range`. It should be a list of two integers.'
  107. )
  108. file_lines = file_content.split('\n')
  109. n_lines_file = len(file_lines)
  110. init_line, final_line = view_range
  111. if init_line < 1 or init_line > n_lines_file:
  112. raise ToolError(
  113. f"Invalid `view_range`: {view_range}. It's first element `{init_line}` should be within the range of lines of the file: {[1, n_lines_file]}"
  114. )
  115. if final_line > n_lines_file:
  116. raise ToolError(
  117. f"Invalid `view_range`: {view_range}. It's second element `{final_line}` should be smaller than the number of lines in the file: `{n_lines_file}`"
  118. )
  119. if final_line != -1 and final_line < init_line:
  120. raise ToolError(
  121. f"Invalid `view_range`: {view_range}. It's second element `{final_line}` should be larger or equal than its first `{init_line}`"
  122. )
  123. if final_line == -1:
  124. file_content = '\n'.join(file_lines[init_line - 1 :])
  125. else:
  126. file_content = '\n'.join(file_lines[init_line - 1 : final_line])
  127. return CLIResult(
  128. output=self._make_output(file_content, str(path), init_line=init_line)
  129. )
  130. def str_replace(self, path: Path, old_str: str, new_str: str | None):
  131. """Implement the str_replace command, which replaces old_str with new_str in the file content"""
  132. # Read the file content
  133. file_content = self.read_file(path).expandtabs()
  134. old_str = old_str.expandtabs()
  135. new_str = new_str.expandtabs() if new_str is not None else ''
  136. # Check if old_str is unique in the file
  137. occurrences = file_content.count(old_str)
  138. if occurrences == 0:
  139. raise ToolError(
  140. f'No replacement was performed, old_str `{old_str}` did not appear verbatim in {path}.'
  141. )
  142. elif occurrences > 1:
  143. file_content_lines = file_content.split('\n')
  144. lines = [
  145. idx + 1
  146. for idx, line in enumerate(file_content_lines)
  147. if old_str in line
  148. ]
  149. raise ToolError(
  150. f'No replacement was performed. Multiple occurrences of old_str `{old_str}` in lines {lines}. Please ensure it is unique'
  151. )
  152. # Replace old_str with new_str
  153. new_file_content = file_content.replace(old_str, new_str)
  154. # Write the new content to the file
  155. self.write_file(path, new_file_content)
  156. # Save the content to history
  157. self._file_history[path].append(file_content)
  158. # Create a snippet of the edited section
  159. replacement_line = file_content.split(old_str)[0].count('\n')
  160. start_line = max(0, replacement_line - SNIPPET_LINES)
  161. end_line = replacement_line + SNIPPET_LINES + new_str.count('\n')
  162. snippet = '\n'.join(new_file_content.split('\n')[start_line : end_line + 1])
  163. # Prepare the success message
  164. success_msg = f'The file {path} has been edited. '
  165. success_msg += self._make_output(
  166. snippet, f'a snippet of {path}', start_line + 1
  167. )
  168. success_msg += 'Review the changes and make sure they are as expected. Edit the file again if necessary.'
  169. return CLIResult(output=success_msg)
  170. def insert(self, path: Path, insert_line: int, new_str: str):
  171. """Implement the insert command, which inserts new_str at the specified line in the file content."""
  172. file_text = self.read_file(path).expandtabs()
  173. new_str = new_str.expandtabs()
  174. file_text_lines = file_text.split('\n')
  175. n_lines_file = len(file_text_lines)
  176. if insert_line < 0 or insert_line > n_lines_file:
  177. raise ToolError(
  178. f'Invalid `insert_line` parameter: {insert_line}. It should be within the range of lines of the file: {[0, n_lines_file]}'
  179. )
  180. new_str_lines = new_str.split('\n')
  181. new_file_text_lines = (
  182. file_text_lines[:insert_line]
  183. + new_str_lines
  184. + file_text_lines[insert_line:]
  185. )
  186. snippet_lines = (
  187. file_text_lines[max(0, insert_line - SNIPPET_LINES) : insert_line]
  188. + new_str_lines
  189. + file_text_lines[insert_line : insert_line + SNIPPET_LINES]
  190. )
  191. new_file_text = '\n'.join(new_file_text_lines)
  192. snippet = '\n'.join(snippet_lines)
  193. self.write_file(path, new_file_text)
  194. self._file_history[path].append(file_text)
  195. success_msg = f'The file {path} has been edited. '
  196. success_msg += self._make_output(
  197. snippet,
  198. 'a snippet of the edited file',
  199. max(1, insert_line - SNIPPET_LINES + 1),
  200. )
  201. success_msg += 'Review the changes and make sure they are as expected (correct indentation, no duplicate lines, etc). Edit the file again if necessary.'
  202. return CLIResult(output=success_msg)
  203. def undo_edit(self, path: Path):
  204. """Implement the undo_edit command."""
  205. if not self._file_history[path]:
  206. raise ToolError(f'No edit history found for {path}.')
  207. old_text = self._file_history[path].pop()
  208. self.write_file(path, old_text)
  209. return CLIResult(
  210. output=f'Last edit to {path} undone successfully. {self._make_output(old_text, str(path))}'
  211. )
  212. def read_file(self, path: Path):
  213. """Read the content of a file from a given path; raise a ToolError if an error occurs."""
  214. try:
  215. return path.read_text()
  216. except Exception as e:
  217. raise ToolError(f'Ran into {e} while trying to read {path}') from None
  218. def write_file(self, path: Path, file: str):
  219. """Write the content of a file to a given path; raise a ToolError if an error occurs."""
  220. try:
  221. path.write_text(file)
  222. except Exception as e:
  223. raise ToolError(f'Ran into {e} while trying to write to {path}') from None
  224. def _make_output(
  225. self,
  226. file_content: str,
  227. file_descriptor: str,
  228. init_line: int = 1,
  229. expand_tabs: bool = True,
  230. ):
  231. """Generate output for the CLI based on the content of a file."""
  232. file_content = maybe_truncate(file_content)
  233. if expand_tabs:
  234. file_content = file_content.expandtabs()
  235. file_content = '\n'.join(
  236. [
  237. f'{i + init_line:6}\t{line}'
  238. for i, line in enumerate(file_content.split('\n'))
  239. ]
  240. )
  241. return (
  242. f"Here's the result of running `cat -n` on {file_descriptor}:\n"
  243. + file_content
  244. + '\n'
  245. )