瀏覽代碼

feat: add support for custom PR titles (#5706)

Co-authored-by: David Walsh <walsha@gmail.com>
d-walsh 1 年之前
父節點
當前提交
5ad361623d
共有 2 個文件被更改,包括 37 次插入11 次删除
  1. 21 3
      openhands/resolver/send_pull_request.py
  2. 16 8
      tests/unit/resolver/test_send_pull_request.py

+ 21 - 3
openhands/resolver/send_pull_request.py

@@ -239,6 +239,7 @@ def send_pull_request(
     additional_message: str | None = None,
     target_branch: str | None = None,
     reviewer: str | None = None,
+    pr_title: str | None = None,
 ) -> str:
     """Send a pull request to a GitHub repository.
 
@@ -251,6 +252,8 @@ def send_pull_request(
         fork_owner: The owner of the fork to push changes to (if different from the original repo owner)
         additional_message: The additional messages to post as a comment on the PR in json list format
         target_branch: The target branch to create the pull request against (defaults to repository default branch)
+        reviewer: The GitHub username of the reviewer to assign
+        pr_title: Custom title for the pull request (optional)
     """
     if pr_type not in ['branch', 'draft', 'ready']:
         raise ValueError(f'Invalid pr_type: {pr_type}')
@@ -321,7 +324,11 @@ def send_pull_request(
         raise RuntimeError('Failed to push changes to the remote repository')
 
     # Prepare the PR data: title and body
-    pr_title = f'Fix issue #{github_issue.number}: {github_issue.title}'
+    final_pr_title = (
+        pr_title
+        if pr_title
+        else f'Fix issue #{github_issue.number}: {github_issue.title}'
+    )
     pr_body = f'This pull request fixes #{github_issue.number}.'
     if additional_message:
         pr_body += f'\n\n{additional_message}'
@@ -334,7 +341,7 @@ def send_pull_request(
     else:
         # Prepare the PR for the GitHub API
         data = {
-            'title': pr_title,  # No need to escape title for GitHub API
+            'title': final_pr_title,  # No need to escape title for GitHub API
             'body': pr_body,
             'head': branch_name,
             'base': base_branch,
@@ -366,7 +373,9 @@ def send_pull_request(
 
         url = pr_data['html_url']
 
-    print(f'{pr_type} created: {url}\n\n--- Title: {pr_title}\n\n--- Body:\n{pr_body}')
+    print(
+        f'{pr_type} created: {url}\n\n--- Title: {final_pr_title}\n\n--- Body:\n{pr_body}'
+    )
 
     return url
 
@@ -535,6 +544,7 @@ def process_single_issue(
     send_on_failure: bool,
     target_branch: str | None = None,
     reviewer: str | None = None,
+    pr_title: str | None = None,
 ) -> None:
     if not resolver_output.success and not send_on_failure:
         print(
@@ -585,6 +595,7 @@ def process_single_issue(
             additional_message=resolver_output.success_explanation,
             target_branch=target_branch,
             reviewer=reviewer,
+            pr_title=pr_title,
         )
 
 
@@ -687,6 +698,12 @@ def main():
         help='GitHub username of the person to request review from',
         default=None,
     )
+    parser.add_argument(
+        '--pr-title',
+        type=str,
+        help='Custom title for the pull request',
+        default=None,
+    )
     my_args = parser.parse_args()
 
     github_token = (
@@ -741,6 +758,7 @@ def main():
             my_args.send_on_failure,
             my_args.target_branch,
             my_args.reviewer,
+            my_args.pr_title,
         )
 
 

+ 16 - 8
tests/unit/resolver/test_send_pull_request.py

@@ -332,14 +332,16 @@ def test_update_existing_pull_request(
 
 
 @pytest.mark.parametrize(
-    'pr_type,target_branch',
+    'pr_type,target_branch,pr_title',
     [
-        ('branch', None),
-        ('draft', None),
-        ('ready', None),
-        ('branch', 'feature'),
-        ('draft', 'develop'),
-        ('ready', 'staging'),
+        ('branch', None, None),
+        ('draft', None, None),
+        ('ready', None, None),
+        ('branch', 'feature', None),
+        ('draft', 'develop', None),
+        ('ready', 'staging', None),
+        ('ready', None, 'Custom PR Title'),
+        ('draft', 'develop', 'Another Custom Title'),
     ],
 )
 @patch('subprocess.run')
@@ -353,6 +355,7 @@ def test_send_pull_request(
     mock_output_dir,
     pr_type,
     target_branch,
+    pr_title,
 ):
     repo_path = os.path.join(mock_output_dir, 'repo')
 
@@ -386,6 +389,7 @@ def test_send_pull_request(
         patch_dir=repo_path,
         pr_type=pr_type,
         target_branch=target_branch,
+        pr_title=pr_title,
     )
 
     # Assert API calls
@@ -425,7 +429,8 @@ def test_send_pull_request(
         assert result == 'https://github.com/test-owner/test-repo/pull/1'
         mock_post.assert_called_once()
         post_data = mock_post.call_args[1]['json']
-        assert post_data['title'] == 'Fix issue #42: Test Issue'
+        expected_title = pr_title if pr_title else 'Fix issue #42: Test Issue'
+        assert post_data['title'] == expected_title
         assert post_data['body'].startswith('This pull request fixes #42.')
         assert post_data['head'] == 'openhands-fix-issue-42'
         assert post_data['base'] == (target_branch if target_branch else 'main')
@@ -828,6 +833,7 @@ def test_process_single_issue(
         additional_message=resolver_output.success_explanation,
         target_branch=None,
         reviewer=None,
+        pr_title=None,
     )
 
 
@@ -1096,6 +1102,7 @@ def test_main(
     mock_args.llm_api_key = 'mock_key'
     mock_args.target_branch = None
     mock_args.reviewer = None
+    mock_args.pr_title = None
     mock_parser.return_value.parse_args.return_value = mock_args
 
     # Setup environment variables
@@ -1131,6 +1138,7 @@ def test_main(
         False,
         mock_args.target_branch,
         mock_args.reviewer,
+        mock_args.pr_title,
     )
 
     # Other assertions