files.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133
  1. from dataclasses import dataclass
  2. from difflib import SequenceMatcher
  3. from openhands.core.schema import ObservationType
  4. from openhands.events.observation.observation import Observation
  5. @dataclass
  6. class FileReadObservation(Observation):
  7. """This data class represents the content of a file."""
  8. path: str
  9. observation: str = ObservationType.READ
  10. @property
  11. def message(self) -> str:
  12. return f'I read the file {self.path}.'
  13. @dataclass
  14. class FileWriteObservation(Observation):
  15. """This data class represents a file write operation"""
  16. path: str
  17. observation: str = ObservationType.WRITE
  18. @property
  19. def message(self) -> str:
  20. return f'I wrote to the file {self.path}.'
  21. @dataclass
  22. class FileEditObservation(Observation):
  23. """This data class represents a file edit operation"""
  24. # content: str will be a unified diff patch string include NO context lines
  25. path: str
  26. prev_exist: bool
  27. old_content: str
  28. new_content: str
  29. observation: str = ObservationType.EDIT
  30. @property
  31. def message(self) -> str:
  32. return f'I edited the file {self.path}.'
  33. def get_edit_groups(self, n_context_lines: int = 2) -> list[dict[str, list[str]]]:
  34. """Get the edit groups of the file edit."""
  35. old_lines = self.old_content.split('\n')
  36. new_lines = self.new_content.split('\n')
  37. # Borrowed from difflib.unified_diff to directly parse into structured format.
  38. edit_groups: list[dict] = []
  39. for group in SequenceMatcher(None, old_lines, new_lines).get_grouped_opcodes(
  40. n_context_lines
  41. ):
  42. # take the max line number in the group
  43. _indent_pad_size = len(str(group[-1][3])) + 1 # +1 for the "*" prefix
  44. cur_group: dict[str, list[str]] = {
  45. 'before_edits': [],
  46. 'after_edits': [],
  47. }
  48. for tag, i1, i2, j1, j2 in group:
  49. if tag == 'equal':
  50. for idx, line in enumerate(old_lines[i1:i2]):
  51. cur_group['before_edits'].append(
  52. f'{i1+idx+1:>{_indent_pad_size}}|{line}'
  53. )
  54. for idx, line in enumerate(new_lines[j1:j2]):
  55. cur_group['after_edits'].append(
  56. f'{j1+idx+1:>{_indent_pad_size}}|{line}'
  57. )
  58. continue
  59. if tag in {'replace', 'delete'}:
  60. for idx, line in enumerate(old_lines[i1:i2]):
  61. cur_group['before_edits'].append(
  62. f'-{i1+idx+1:>{_indent_pad_size-1}}|{line}'
  63. )
  64. if tag in {'replace', 'insert'}:
  65. for idx, line in enumerate(new_lines[j1:j2]):
  66. cur_group['after_edits'].append(
  67. f'+{j1+idx+1:>{_indent_pad_size-1}}|{line}'
  68. )
  69. edit_groups.append(cur_group)
  70. return edit_groups
  71. def visualize_diff(
  72. self,
  73. n_context_lines: int = 2,
  74. change_applied: bool = True,
  75. ) -> str:
  76. """Visualize the diff of the file edit.
  77. Instead of showing the diff line by line, this function
  78. shows each hunk of changes as a separate entity.
  79. Args:
  80. n_context_lines: The number of lines of context to show before and after the changes.
  81. change_applied: Whether the changes are applied to the file. If true, the file have been modified. If not, the file is not modified (due to linting errors).
  82. """
  83. if change_applied and self.content.strip() == '':
  84. # diff patch is empty
  85. return '(no changes detected. Please make sure your edits changes the content of the existing file.)\n'
  86. edit_groups = self.get_edit_groups(n_context_lines=n_context_lines)
  87. result = [
  88. f'[Existing file {self.path} is edited with {len(edit_groups)} changes.]'
  89. if change_applied
  90. else f"[Changes are NOT applied to {self.path} - Here's how the file looks like if changes are applied.]"
  91. ]
  92. op_type = 'edit' if change_applied else 'ATTEMPTED edit'
  93. for i, cur_edit_group in enumerate(edit_groups):
  94. if i != 0:
  95. result.append('-------------------------')
  96. result.append(f'[begin of {op_type} {i+1} / {len(edit_groups)}]')
  97. result.append(f'(content before {op_type})')
  98. result.extend(cur_edit_group['before_edits'])
  99. result.append(f'(content after {op_type})')
  100. result.extend(cur_edit_group['after_edits'])
  101. result.append(f'[end of {op_type} {i+1} / {len(edit_groups)}]')
  102. return '\n'.join(result)
  103. def __str__(self) -> str:
  104. ret = ''
  105. if not self.prev_exist:
  106. assert (
  107. self.old_content == ''
  108. ), 'old_content should be empty if the file is new (prev_exist=False).'
  109. ret += f'[New file {self.path} is created with the provided content.]\n'
  110. return ret.rstrip() + '\n'
  111. ret += self.visualize_diff()
  112. return ret.rstrip() + '\n'