Ver código fonte

Resolver minor tweaks (#5461)

Engel Nyst 1 ano atrás
pai
commit
279e1d7abc

+ 38 - 3
openhands/resolver/issue_definitions.py

@@ -62,19 +62,23 @@ class IssueHandler(IssueHandlerInterface):
         params: dict[str, int | str] = {'state': 'open', 'per_page': 100, 'page': 1}
         all_issues = []
 
+        # Get issues, page by page
         while True:
             response = requests.get(url, headers=headers, params=params)
             response.raise_for_status()
             issues = response.json()
 
+            # No more issues, break the loop
             if not issues:
                 break
 
+            # Sanity check - the response is a list of dictionaries
             if not isinstance(issues, list) or any(
                 [not isinstance(issue, dict) for issue in issues]
             ):
                 raise ValueError('Expected list of dictionaries from Github API.')
 
+            # Add the issues to the final list
             all_issues.extend(issues)
             assert isinstance(params['page'], int)
             params['page'] += 1
@@ -107,7 +111,12 @@ class IssueHandler(IssueHandlerInterface):
     def _get_issue_comments(
         self, issue_number: int, comment_id: int | None = None
     ) -> list[str] | None:
-        """Download comments for a specific issue from Github."""
+        """Retrieve comments for a specific issue from Github.
+
+        Args:
+            issue_number: The ID of the issue to get comments for
+            comment_id: The ID of a single comment, if provided, otherwise all comments
+        """
         url = f'https://api.github.com/repos/{self.owner}/{self.repo}/issues/{issue_number}/comments'
         headers = {
             'Authorization': f'token {self.token}',
@@ -116,6 +125,7 @@ class IssueHandler(IssueHandlerInterface):
         params = {'per_page': 100, 'page': 1}
         all_comments = []
 
+        # Get comments, page by page
         while True:
             response = requests.get(url, headers=headers, params=params)
             response.raise_for_status()
@@ -124,6 +134,7 @@ class IssueHandler(IssueHandlerInterface):
             if not comments:
                 break
 
+            # If a single comment ID is provided, return only that comment
             if comment_id:
                 matching_comment = next(
                     (
@@ -136,6 +147,7 @@ class IssueHandler(IssueHandlerInterface):
                 if matching_comment:
                     return [matching_comment]
             else:
+                # Otherwise, return all comments
                 all_comments.extend([comment['body'] for comment in comments])
 
             params['page'] += 1
@@ -147,6 +159,10 @@ class IssueHandler(IssueHandlerInterface):
     ) -> list[GithubIssue]:
         """Download issues from Github.
 
+        Args:
+            issue_numbers: The numbers of the issues to download
+            comment_id: The ID of a single comment, if provided, otherwise all comments
+
         Returns:
             List of Github issues.
         """
@@ -203,7 +219,14 @@ class IssueHandler(IssueHandlerInterface):
         prompt_template: str,
         repo_instruction: str | None = None,
     ) -> tuple[str, list[str]]:
-        """Generate instruction for the agent."""
+        """Generate instruction for the agent.
+
+        Args:
+            issue: The issue to generate instruction for
+            prompt_template: The prompt template to use
+            repo_instruction: The repository instruction if it exists
+        """
+
         # Format thread comments if they exist
         thread_context = ''
         if issue.thread_comments:
@@ -211,6 +234,7 @@ class IssueHandler(IssueHandlerInterface):
                 issue.thread_comments
             )
 
+        # Extract image URLs from the issue body and thread comments
         images = []
         images.extend(self._extract_image_urls(issue.body))
         images.extend(self._extract_image_urls(thread_context))
@@ -227,8 +251,14 @@ class IssueHandler(IssueHandlerInterface):
     def guess_success(
         self, issue: GithubIssue, history: list[Event]
     ) -> tuple[bool, None | list[bool], str]:
-        """Guess if the issue is fixed based on the history and the issue description."""
+        """Guess if the issue is fixed based on the history and the issue description.
+
+        Args:
+            issue: The issue to check
+            history: The agent's history
+        """
         last_message = history[-1].message
+
         # Include thread comments in the prompt if they exist
         issue_context = issue.body
         if issue.thread_comments:
@@ -236,6 +266,7 @@ class IssueHandler(IssueHandlerInterface):
                 issue.thread_comments
             )
 
+        # Prepare the prompt
         with open(
             os.path.join(
                 os.path.dirname(__file__),
@@ -246,6 +277,7 @@ class IssueHandler(IssueHandlerInterface):
             template = jinja2.Template(f.read())
         prompt = template.render(issue_context=issue_context, last_message=last_message)
 
+        # Get the LLM response and check for 'success' and 'explanation' in the answer
         response = self.llm.completion(messages=[{'role': 'user', 'content': prompt}])
 
         answer = response.choices[0].message.content.strip()
@@ -328,6 +360,7 @@ class PRHandler(IssueHandler):
 
         variables = {'owner': self.owner, 'repo': self.repo, 'pr': pull_number}
 
+        # Run the query
         url = 'https://api.github.com/graphql'
         headers = {
             'Authorization': f'Bearer {self.token}',
@@ -394,10 +427,12 @@ class PRHandler(IssueHandler):
                             review_thread['body'] + '\n'
                         )  # Add each thread in a new line
 
+                    # Source files on which the comments were made
                     file = review_thread.get('path')
                     if file and file not in files:
                         files.append(file)
 
+                # If the comment ID is not provided or the thread contains the comment ID, add the thread to the list
                 if comment_id is None or thread_contains_comment_id:
                     unresolved_thread = ReviewThread(comment=message, files=files)
                     review_threads.append(unresolved_thread)

+ 60 - 9
openhands/resolver/send_pull_request.py

@@ -5,11 +5,11 @@ import shutil
 import subprocess
 
 import jinja2
-import litellm
 import requests
 
 from openhands.core.config import LLMConfig
 from openhands.core.logger import openhands_logger as logger
+from openhands.llm.llm import LLM
 from openhands.resolver.github_issue import GithubIssue
 from openhands.resolver.io_utils import (
     load_all_resolver_outputs,
@@ -20,6 +20,12 @@ from openhands.resolver.resolver_output import ResolverOutput
 
 
 def apply_patch(repo_dir: str, patch: str) -> None:
+    """Apply a patch to a repository.
+
+    Args:
+        repo_dir: The directory containing the repository
+        patch: The patch to apply
+    """
     diffs = parse_patch(patch)
     for diff in diffs:
         if not diff.header.new_path:
@@ -112,6 +118,14 @@ def apply_patch(repo_dir: str, patch: str) -> None:
 def initialize_repo(
     output_dir: str, issue_number: int, issue_type: str, base_commit: str | None = None
 ) -> str:
+    """Initialize the repository.
+
+    Args:
+        output_dir: The output directory to write the repository to
+        issue_number: The issue number to fix
+        issue_type: The type of the issue
+        base_commit: The base commit to checkout (if issue_type is pr)
+    """
     src_dir = os.path.join(output_dir, 'repo')
     dest_dir = os.path.join(output_dir, 'patches', f'{issue_type}_{issue_number}')
 
@@ -124,6 +138,7 @@ def initialize_repo(
     shutil.copytree(src_dir, dest_dir)
     print(f'Copied repository to {dest_dir}')
 
+    # Checkout the base commit if provided
     if base_commit:
         result = subprocess.run(
             f'git -C {dest_dir} checkout {base_commit}',
@@ -139,6 +154,13 @@ def initialize_repo(
 
 
 def make_commit(repo_dir: str, issue: GithubIssue, issue_type: str) -> None:
+    """Make a commit with the changes to the repository.
+
+    Args:
+        repo_dir: The directory containing the repository
+        issue: The issue to fix
+        issue_type: The type of the issue
+    """
     # Check if git username is set
     result = subprocess.run(
         f'git -C {repo_dir} config user.name',
@@ -158,6 +180,7 @@ def make_commit(repo_dir: str, issue: GithubIssue, issue_type: str) -> None:
         )
         print('Git user configured as openhands')
 
+    # Add all changes to the git index
     result = subprocess.run(
         f'git -C {repo_dir} add .', shell=True, capture_output=True, text=True
     )
@@ -165,6 +188,7 @@ def make_commit(repo_dir: str, issue: GithubIssue, issue_type: str) -> None:
         print(f'Error adding files: {result.stderr}')
         raise RuntimeError('Failed to add files to git')
 
+    # Check the status of the git index
     status_result = subprocess.run(
         f'git -C {repo_dir} status --porcelain',
         shell=True,
@@ -172,11 +196,15 @@ def make_commit(repo_dir: str, issue: GithubIssue, issue_type: str) -> None:
         text=True,
     )
 
+    # If there are no changes, raise an error
     if not status_result.stdout.strip():
         print(f'No changes to commit for issue #{issue.number}. Skipping commit.')
         raise RuntimeError('ERROR: Openhands failed to make code changes.')
 
+    # Prepare the commit message
     commit_message = f'Fix {issue_type} #{issue.number}: {issue.title}'
+
+    # Commit the changes
     result = subprocess.run(
         ['git', '-C', repo_dir, 'commit', '-m', commit_message],
         capture_output=True,
@@ -206,12 +234,23 @@ def send_pull_request(
     github_token: str,
     github_username: str | None,
     patch_dir: str,
-    llm_config: LLMConfig,
     pr_type: str,
     fork_owner: str | None = None,
     additional_message: str | None = None,
     target_branch: str | None = None,
 ) -> str:
+    """Send a pull request to a GitHub repository.
+
+    Args:
+        github_issue: The issue to send the pull request for
+        github_token: The GitHub token to use for authentication
+        github_username: The GitHub username, if provided
+        patch_dir: The directory containing the patches to apply
+        pr_type: The type: branch (no PR created), draft or ready (regular PR created)
+        fork_owner: The owner of the fork to push changes to (if different from the original repo owner)
+        additional_message: The additional messages to post as a comment on the PR in json list format
+        target_branch: The target branch to create the pull request against (defaults to repository default branch)
+    """
     if pr_type not in ['branch', 'draft', 'ready']:
         raise ValueError(f'Invalid pr_type: {pr_type}')
 
@@ -227,6 +266,7 @@ def send_pull_request(
     branch_name = base_branch_name
     attempt = 1
 
+    # Find a unique branch name
     print('Checking if branch exists...')
     while branch_exists(base_url, branch_name, headers):
         attempt += 1
@@ -279,6 +319,7 @@ def send_pull_request(
         print(f'Error pushing changes: {result.stderr}')
         raise RuntimeError('Failed to push changes to the remote repository')
 
+    # Prepare the PR data: title and body
     pr_title = f'Fix issue #{github_issue.number}: {github_issue.title}'
     pr_body = f'This pull request fixes #{github_issue.number}.'
     if additional_message:
@@ -290,6 +331,7 @@ def send_pull_request(
     if pr_type == 'branch':
         url = f'https://github.com/{push_owner}/{github_issue.repo}/compare/{branch_name}?expand=1'
     else:
+        # Prepare the PR for the GitHub API
         data = {
             'title': pr_title,  # No need to escape title for GitHub API
             'body': pr_body,
@@ -297,6 +339,8 @@ def send_pull_request(
             'base': base_branch,
             'draft': pr_type == 'draft',
         }
+
+        # Send the PR and get its URL to tell the user
         response = requests.post(f'{base_url}/pulls', headers=headers, json=data)
         if response.status_code == 403:
             raise RuntimeError(
@@ -314,6 +358,13 @@ def send_pull_request(
 
 
 def reply_to_comment(github_token: str, comment_id: str, reply: str):
+    """Reply to a comment on a GitHub issue or pull request.
+
+    Args:
+        github_token: The GitHub token to use for authentication
+        comment_id: The ID of the comment to reply to
+        reply: The reply message to post
+    """
     # Opting for graphql as REST API doesn't allow reply to replies in comment threads
     query = """
             mutation($body: String!, $pullRequestReviewThreadId: ID!) {
@@ -327,6 +378,7 @@ def reply_to_comment(github_token: str, comment_id: str, reply: str):
             }
             """
 
+    # Prepare the reply to the comment
     comment_reply = f'Openhands fix success summary\n\n\n{reply}'
     variables = {'body': comment_reply, 'pullRequestReviewThreadId': comment_id}
     url = 'https://api.github.com/graphql'
@@ -335,6 +387,7 @@ def reply_to_comment(github_token: str, comment_id: str, reply: str):
         'Content-Type': 'application/json',
     }
 
+    # Send the reply to the comment
     response = requests.post(
         url, json={'query': query, 'variables': variables}, headers=headers
     )
@@ -392,13 +445,14 @@ def update_existing_pull_request(
     base_url = f'https://api.github.com/repos/{github_issue.owner}/{github_issue.repo}'
     branch_name = github_issue.head_branch
 
-    # Push the changes to the existing branch
+    # Prepare the push command
     push_command = (
         f'git -C {patch_dir} push '
         f'https://{github_username}:{github_token}@github.com/'
         f'{github_issue.owner}/{github_issue.repo}.git {branch_name}'
     )
 
+    # Push the changes to the existing branch
     result = subprocess.run(push_command, shell=True, capture_output=True, text=True)
     if result.returncode != 0:
         print(f'Error pushing changes: {result.stderr}')
@@ -420,6 +474,7 @@ def update_existing_pull_request(
 
                 # Summarize with LLM if provided
                 if llm_config is not None:
+                    llm = LLM(llm_config)
                     with open(
                         os.path.join(
                             os.path.dirname(__file__),
@@ -429,16 +484,13 @@ def update_existing_pull_request(
                     ) as f:
                         template = jinja2.Template(f.read())
                     prompt = template.render(comment_message=comment_message)
-                    response = litellm.completion(
-                        model=llm_config.model,
+                    response = llm.completion(
                         messages=[{'role': 'user', 'content': prompt}],
-                        api_key=llm_config.api_key,
-                        base_url=llm_config.base_url,
                     )
                     comment_message = response.choices[0].message.content.strip()
 
         except (json.JSONDecodeError, TypeError):
-            comment_message = 'New OpenHands update'
+            comment_message = f'A new OpenHands update is available, but failed to parse or summarize the changes:\n{additional_message}'
 
     # Post a comment on the PR
     if comment_message:
@@ -514,7 +566,6 @@ def process_single_issue(
             github_username=github_username,
             patch_dir=patched_repo_dir,
             pr_type=pr_type,
-            llm_config=llm_config,
             fork_owner=fork_owner,
             additional_message=resolver_output.success_explanation,
             target_branch=target_branch,

+ 9 - 9
tests/unit/resolver/test_pr_handler_guess_success.py

@@ -16,7 +16,7 @@ def mock_llm_response(content):
 
 
 def test_guess_success_review_threads_litellm_call():
-    """Test that the litellm.completion() call for review threads contains the expected content."""
+    """Test that the completion() call for review threads contains the expected content."""
     # Create a PR handler instance
     llm_config = LLMConfig(model='test', api_key='test')
     handler = PRHandler('test-owner', 'test-repo', 'test-token', llm_config)
@@ -77,7 +77,7 @@ The changes successfully address the feedback."""
         mock_completion.return_value = mock_response
         success, success_list, explanation = handler.guess_success(issue, history)
 
-        # Verify the litellm.completion() calls
+        # Verify the completion() calls
         assert mock_completion.call_count == 2  # One call per review thread
 
         # Check first call
@@ -121,7 +121,7 @@ The changes successfully address the feedback."""
 
 
 def test_guess_success_thread_comments_litellm_call():
-    """Test that the litellm.completion() call for thread comments contains the expected content."""
+    """Test that the completion() call for thread comments contains the expected content."""
     # Create a PR handler instance
     llm_config = LLMConfig(model='test', api_key='test')
     handler = PRHandler('test-owner', 'test-repo', 'test-token', llm_config)
@@ -176,7 +176,7 @@ The changes successfully address the feedback."""
         mock_completion.return_value = mock_response
         success, success_list, explanation = handler.guess_success(issue, history)
 
-        # Verify the litellm.completion() call
+        # Verify the completion() call
         mock_completion.assert_called_once()
         call_args = mock_completion.call_args
         prompt = call_args[1]['messages'][0]['content']
@@ -270,7 +270,7 @@ Changes look good"""
             review_thread, issues_context, last_message
         )
 
-        # Verify the litellm.completion() call
+        # Verify the completion() call
         mock_completion.assert_called_once()
         call_args = mock_completion.call_args
         prompt = call_args[1]['messages'][0]['content']
@@ -326,7 +326,7 @@ Changes look good"""
             thread_comments, issues_context, last_message
         )
 
-        # Verify the litellm.completion() call
+        # Verify the completion() call
         mock_completion.assert_called_once()
         call_args = mock_completion.call_args
         prompt = call_args[1]['messages'][0]['content']
@@ -379,7 +379,7 @@ Changes look good"""
             review_comments, issues_context, last_message
         )
 
-        # Verify the litellm.completion() call
+        # Verify the completion() call
         mock_completion.assert_called_once()
         call_args = mock_completion.call_args
         prompt = call_args[1]['messages'][0]['content']
@@ -395,7 +395,7 @@ Changes look good"""
 
 
 def test_guess_success_review_comments_litellm_call():
-    """Test that the litellm.completion() call for review comments contains the expected content."""
+    """Test that the completion() call for review comments contains the expected content."""
     # Create a PR handler instance
     llm_config = LLMConfig(model='test', api_key='test')
     handler = PRHandler('test-owner', 'test-repo', 'test-token', llm_config)
@@ -447,7 +447,7 @@ The changes successfully address the feedback."""
         mock_completion.return_value = mock_response
         success, success_list, explanation = handler.guess_success(issue, history)
 
-        # Verify the litellm.completion() call
+        # Verify the completion() call
         mock_completion.assert_called_once()
         call_args = mock_completion.call_args
         prompt = call_args[1]['messages'][0]['content']

+ 0 - 2
tests/unit/resolver/test_pr_title_escaping.py

@@ -153,7 +153,6 @@ def test_pr_title_with_quotes(monkeypatch):
 
         # Try to send a PR - this will fail if the title is incorrectly escaped
         print('Sending PR...')
-        from openhands.core.config import LLMConfig
         from openhands.resolver.send_pull_request import send_pull_request
 
         send_pull_request(
@@ -161,6 +160,5 @@ def test_pr_title_with_quotes(monkeypatch):
             github_token='dummy-token',
             github_username='test-user',
             patch_dir=temp_dir,
-            llm_config=LLMConfig(model='test-model', api_key='test-key'),
             pr_type='ready',
         )

+ 24 - 22
tests/unit/resolver/test_send_pull_request.py

@@ -244,8 +244,12 @@ def test_initialize_repo(mock_output_dir):
 @patch('openhands.resolver.send_pull_request.reply_to_comment')
 @patch('requests.post')
 @patch('subprocess.run')
+@patch('openhands.resolver.send_pull_request.LLM')
 def test_update_existing_pull_request(
-    mock_subprocess_run, mock_requests_post, mock_reply_to_comment
+    mock_llm_class,
+    mock_subprocess_run,
+    mock_requests_post,
+    mock_reply_to_comment,
 ):
     # Arrange: Set up test data
     github_issue = GithubIssue(
@@ -267,23 +271,28 @@ def test_update_existing_pull_request(
 
     # Mock the requests.post call for adding a PR comment
     mock_requests_post.return_value.status_code = 201
+
+    # Mock LLM instance and completion call
+    mock_llm_instance = MagicMock()
     mock_completion_response = MagicMock()
     mock_completion_response.choices = [
         MagicMock(message=MagicMock(content='This is an issue resolution.'))
     ]
+    mock_llm_instance.completion.return_value = mock_completion_response
+    mock_llm_class.return_value = mock_llm_instance
+
     llm_config = LLMConfig()
 
     # Act: Call the function without comment_message to test auto-generation
-    with patch('litellm.completion', MagicMock(return_value=mock_completion_response)):
-        result = update_existing_pull_request(
-            github_issue,
-            github_token,
-            github_username,
-            patch_dir,
-            llm_config,
-            comment_message=None,
-            additional_message=additional_message,
-        )
+    result = update_existing_pull_request(
+        github_issue,
+        github_token,
+        github_username,
+        patch_dir,
+        llm_config,
+        comment_message=None,
+        additional_message=additional_message,
+    )
 
     # Assert: Check if the git push command was executed
     push_command = (
@@ -342,7 +351,6 @@ def test_send_pull_request(
     mock_run,
     mock_github_issue,
     mock_output_dir,
-    mock_llm_config,
     pr_type,
     target_branch,
 ):
@@ -377,7 +385,6 @@ def test_send_pull_request(
         github_username='test-user',
         patch_dir=repo_path,
         pr_type=pr_type,
-        llm_config=mock_llm_config,
         target_branch=target_branch,
     )
 
@@ -427,7 +434,7 @@ def test_send_pull_request(
 
 @patch('requests.get')
 def test_send_pull_request_invalid_target_branch(
-    mock_get, mock_github_issue, mock_output_dir, mock_llm_config
+    mock_get, mock_github_issue, mock_output_dir
 ):
     """Test that an error is raised when specifying a non-existent target branch"""
     repo_path = os.path.join(mock_output_dir, 'repo')
@@ -448,7 +455,6 @@ def test_send_pull_request_invalid_target_branch(
             github_username='test-user',
             patch_dir=repo_path,
             pr_type='ready',
-            llm_config=mock_llm_config,
             target_branch='nonexistent-branch',
         )
 
@@ -460,7 +466,7 @@ def test_send_pull_request_invalid_target_branch(
 @patch('requests.post')
 @patch('requests.get')
 def test_send_pull_request_git_push_failure(
-    mock_get, mock_post, mock_run, mock_github_issue, mock_output_dir, mock_llm_config
+    mock_get, mock_post, mock_run, mock_github_issue, mock_output_dir
 ):
     repo_path = os.path.join(mock_output_dir, 'repo')
 
@@ -483,7 +489,6 @@ def test_send_pull_request_git_push_failure(
             github_username='test-user',
             patch_dir=repo_path,
             pr_type='ready',
-            llm_config=mock_llm_config,
         )
 
     # Assert that subprocess.run was called twice
@@ -519,7 +524,7 @@ def test_send_pull_request_git_push_failure(
 @patch('requests.post')
 @patch('requests.get')
 def test_send_pull_request_permission_error(
-    mock_get, mock_post, mock_run, mock_github_issue, mock_output_dir, mock_llm_config
+    mock_get, mock_post, mock_run, mock_github_issue, mock_output_dir
 ):
     repo_path = os.path.join(mock_output_dir, 'repo')
 
@@ -543,7 +548,6 @@ def test_send_pull_request_permission_error(
             github_username='test-user',
             patch_dir=repo_path,
             pr_type='ready',
-            llm_config=mock_llm_config,
         )
 
     # Assert that the branch was created and pushed
@@ -757,7 +761,6 @@ def test_process_single_issue(
         github_username=github_username,
         patch_dir=f'{mock_output_dir}/patches/issue_1',
         pr_type=pr_type,
-        llm_config=mock_llm_config,
         fork_owner=None,
         additional_message=resolver_output.success_explanation,
         target_branch=None,
@@ -940,7 +943,7 @@ def test_process_all_successful_issues(
 @patch('requests.get')
 @patch('subprocess.run')
 def test_send_pull_request_branch_naming(
-    mock_run, mock_get, mock_github_issue, mock_output_dir, mock_llm_config
+    mock_run, mock_get, mock_github_issue, mock_output_dir
 ):
     repo_path = os.path.join(mock_output_dir, 'repo')
 
@@ -965,7 +968,6 @@ def test_send_pull_request_branch_naming(
         github_username='test-user',
         patch_dir=repo_path,
         pr_type='branch',
-        llm_config=mock_llm_config,
     )
 
     # Assert API calls