files.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
  1. from dataclasses import dataclass
  2. from difflib import SequenceMatcher
  3. from openhands.core.schema import ObservationType
  4. from openhands.events.event import FileEditSource, FileReadSource
  5. from openhands.events.observation.observation import Observation
  6. @dataclass
  7. class FileReadObservation(Observation):
  8. """This data class represents the content of a file."""
  9. path: str
  10. observation: str = ObservationType.READ
  11. impl_source: FileReadSource = FileReadSource.DEFAULT
  12. @property
  13. def message(self) -> str:
  14. return f'I read the file {self.path}.'
  15. @dataclass
  16. class FileWriteObservation(Observation):
  17. """This data class represents a file write operation"""
  18. path: str
  19. observation: str = ObservationType.WRITE
  20. @property
  21. def message(self) -> str:
  22. return f'I wrote to the file {self.path}.'
  23. @dataclass
  24. class FileEditObservation(Observation):
  25. """This data class represents a file edit operation"""
  26. # content: str will be a unified diff patch string include NO context lines
  27. path: str
  28. prev_exist: bool
  29. old_content: str
  30. new_content: str
  31. observation: str = ObservationType.EDIT
  32. impl_source: FileEditSource = FileEditSource.LLM_BASED_EDIT
  33. formatted_output_and_error: str = ''
  34. @property
  35. def message(self) -> str:
  36. return f'I edited the file {self.path}.'
  37. def get_edit_groups(self, n_context_lines: int = 2) -> list[dict[str, list[str]]]:
  38. """Get the edit groups of the file edit."""
  39. old_lines = self.old_content.split('\n')
  40. new_lines = self.new_content.split('\n')
  41. # Borrowed from difflib.unified_diff to directly parse into structured format.
  42. edit_groups: list[dict] = []
  43. for group in SequenceMatcher(None, old_lines, new_lines).get_grouped_opcodes(
  44. n_context_lines
  45. ):
  46. # take the max line number in the group
  47. _indent_pad_size = len(str(group[-1][3])) + 1 # +1 for the "*" prefix
  48. cur_group: dict[str, list[str]] = {
  49. 'before_edits': [],
  50. 'after_edits': [],
  51. }
  52. for tag, i1, i2, j1, j2 in group:
  53. if tag == 'equal':
  54. for idx, line in enumerate(old_lines[i1:i2]):
  55. cur_group['before_edits'].append(
  56. f'{i1+idx+1:>{_indent_pad_size}}|{line}'
  57. )
  58. for idx, line in enumerate(new_lines[j1:j2]):
  59. cur_group['after_edits'].append(
  60. f'{j1+idx+1:>{_indent_pad_size}}|{line}'
  61. )
  62. continue
  63. if tag in {'replace', 'delete'}:
  64. for idx, line in enumerate(old_lines[i1:i2]):
  65. cur_group['before_edits'].append(
  66. f'-{i1+idx+1:>{_indent_pad_size-1}}|{line}'
  67. )
  68. if tag in {'replace', 'insert'}:
  69. for idx, line in enumerate(new_lines[j1:j2]):
  70. cur_group['after_edits'].append(
  71. f'+{j1+idx+1:>{_indent_pad_size-1}}|{line}'
  72. )
  73. edit_groups.append(cur_group)
  74. return edit_groups
  75. def visualize_diff(
  76. self,
  77. n_context_lines: int = 2,
  78. change_applied: bool = True,
  79. ) -> str:
  80. """Visualize the diff of the file edit.
  81. Instead of showing the diff line by line, this function
  82. shows each hunk of changes as a separate entity.
  83. Args:
  84. n_context_lines: The number of lines of context to show before and after the changes.
  85. change_applied: Whether the changes are applied to the file. If true, the file have been modified. If not, the file is not modified (due to linting errors).
  86. """
  87. if change_applied and self.content.strip() == '':
  88. # diff patch is empty
  89. return '(no changes detected. Please make sure your edits changes the content of the existing file.)\n'
  90. edit_groups = self.get_edit_groups(n_context_lines=n_context_lines)
  91. result = [
  92. f'[Existing file {self.path} is edited with {len(edit_groups)} changes.]'
  93. if change_applied
  94. else f"[Changes are NOT applied to {self.path} - Here's how the file looks like if changes are applied.]"
  95. ]
  96. op_type = 'edit' if change_applied else 'ATTEMPTED edit'
  97. for i, cur_edit_group in enumerate(edit_groups):
  98. if i != 0:
  99. result.append('-------------------------')
  100. result.append(f'[begin of {op_type} {i+1} / {len(edit_groups)}]')
  101. result.append(f'(content before {op_type})')
  102. result.extend(cur_edit_group['before_edits'])
  103. result.append(f'(content after {op_type})')
  104. result.extend(cur_edit_group['after_edits'])
  105. result.append(f'[end of {op_type} {i+1} / {len(edit_groups)}]')
  106. return '\n'.join(result)
  107. def __str__(self) -> str:
  108. if self.impl_source == FileEditSource.OH_ACI:
  109. return self.formatted_output_and_error
  110. ret = ''
  111. if not self.prev_exist:
  112. assert (
  113. self.old_content == ''
  114. ), 'old_content should be empty if the file is new (prev_exist=False).'
  115. ret += f'[New file {self.path} is created with the provided content.]\n'
  116. return ret.rstrip() + '\n'
  117. ret += self.visualize_diff()
  118. return ret.rstrip() + '\n'