Преглед изворни кода

Allow to merge to a specific target branch instead of main (#5109)

Raymond Xu пре 1 година
родитељ
комит
2c580387c5
2 измењених фајлова са 101 додато и 22 уклоњено
  1. 25 7
      openhands/resolver/send_pull_request.py
  2. 76 15
      tests/unit/resolver/test_send_pull_request.py

+ 25 - 7
openhands/resolver/send_pull_request.py

@@ -203,6 +203,7 @@ def send_pull_request(
     pr_type: str,
     fork_owner: str | None = None,
     additional_message: str | None = None,
+    target_branch: str | None = None,
 ) -> str:
     if pr_type not in ['branch', 'draft', 'ready']:
         raise ValueError(f'Invalid pr_type: {pr_type}')
@@ -224,12 +225,19 @@ def send_pull_request(
         attempt += 1
         branch_name = f'{base_branch_name}-try{attempt}'
 
-    # Get the default branch
-    print('Getting default branch...')
-    response = requests.get(f'{base_url}', headers=headers)
-    response.raise_for_status()
-    default_branch = response.json()['default_branch']
-    print(f'Default branch: {default_branch}')
+    # Get the default branch or use specified target branch
+    print('Getting base branch...')
+    if target_branch:
+        base_branch = target_branch
+        # Verify the target branch exists
+        response = requests.get(f'{base_url}/branches/{target_branch}', headers=headers)
+        if response.status_code != 200:
+            raise ValueError(f'Target branch {target_branch} does not exist')
+    else:
+        response = requests.get(f'{base_url}', headers=headers)
+        response.raise_for_status()
+        base_branch = response.json()['default_branch']
+    print(f'Base branch: {base_branch}')
 
     # Create and checkout the new branch
     print('Creating new branch...')
@@ -279,7 +287,7 @@ def send_pull_request(
             'title': pr_title,  # No need to escape title for GitHub API
             'body': pr_body,
             'head': branch_name,
-            'base': default_branch,
+            'base': base_branch,
             'draft': pr_type == 'draft',
         }
         response = requests.post(f'{base_url}/pulls', headers=headers, json=data)
@@ -435,6 +443,7 @@ def process_single_issue(
     llm_config: LLMConfig,
     fork_owner: str | None,
     send_on_failure: bool,
+    target_branch: str | None = None,
 ) -> None:
     if not resolver_output.success and not send_on_failure:
         print(
@@ -484,6 +493,7 @@ def process_single_issue(
             llm_config=llm_config,
             fork_owner=fork_owner,
             additional_message=resolver_output.success_explanation,
+            target_branch=target_branch,
         )
 
 
@@ -508,6 +518,7 @@ def process_all_successful_issues(
                 llm_config,
                 fork_owner,
                 False,
+                None,
             )
 
 
@@ -573,6 +584,12 @@ def main():
         default=None,
         help='Base URL for the LLM model.',
     )
+    parser.add_argument(
+        '--target-branch',
+        type=str,
+        default=None,
+        help='Target branch to create the pull request against (defaults to repository default branch)',
+    )
     my_args = parser.parse_args()
 
     github_token = (
@@ -625,6 +642,7 @@ def main():
             llm_config,
             my_args.fork_owner,
             my_args.send_on_failure,
+            my_args.target_branch,
         )
 
 

+ 76 - 15
tests/unit/resolver/test_send_pull_request.py

@@ -322,7 +322,17 @@ def test_update_existing_pull_request(
     )
 
 
-@pytest.mark.parametrize('pr_type', ['branch', 'draft', 'ready'])
+@pytest.mark.parametrize(
+    'pr_type,target_branch',
+    [
+        ('branch', None),
+        ('draft', None),
+        ('ready', None),
+        ('branch', 'feature'),
+        ('draft', 'develop'),
+        ('ready', 'staging'),
+    ],
+)
 @patch('subprocess.run')
 @patch('requests.post')
 @patch('requests.get')
@@ -334,14 +344,22 @@ def test_send_pull_request(
     mock_output_dir,
     mock_llm_config,
     pr_type,
+    target_branch,
 ):
     repo_path = os.path.join(mock_output_dir, 'repo')
 
-    # Mock API responses
-    mock_get.side_effect = [
-        MagicMock(status_code=404),  # Branch doesn't exist
-        MagicMock(json=lambda: {'default_branch': 'main'}),
-    ]
+    # Mock API responses based on whether target_branch is specified
+    if target_branch:
+        mock_get.side_effect = [
+            MagicMock(status_code=404),  # Branch doesn't exist
+            MagicMock(status_code=200),  # Target branch exists
+        ]
+    else:
+        mock_get.side_effect = [
+            MagicMock(status_code=404),  # Branch doesn't exist
+            MagicMock(json=lambda: {'default_branch': 'main'}),  # Get default branch
+        ]
+
     mock_post.return_value.json.return_value = {
         'html_url': 'https://github.com/test-owner/test-repo/pull/1'
     }
@@ -360,10 +378,12 @@ def test_send_pull_request(
         patch_dir=repo_path,
         pr_type=pr_type,
         llm_config=mock_llm_config,
+        target_branch=target_branch,
     )
 
     # Assert API calls
-    assert mock_get.call_count == 2
+    expected_get_calls = 2
+    assert mock_get.call_count == expected_get_calls
 
     # Check branch creation and push
     assert mock_run.call_count == 2
@@ -401,10 +421,41 @@ def test_send_pull_request(
         assert post_data['title'] == 'Fix issue #42: Test Issue'
         assert post_data['body'].startswith('This pull request fixes #42.')
         assert post_data['head'] == 'openhands-fix-issue-42'
-        assert post_data['base'] == 'main'
+        assert post_data['base'] == (target_branch if target_branch else 'main')
         assert post_data['draft'] == (pr_type == 'draft')
 
 
+@patch('requests.get')
+def test_send_pull_request_invalid_target_branch(
+    mock_get, mock_github_issue, mock_output_dir, mock_llm_config
+):
+    """Test that an error is raised when specifying a non-existent target branch"""
+    repo_path = os.path.join(mock_output_dir, 'repo')
+
+    # Mock API response for non-existent branch
+    mock_get.side_effect = [
+        MagicMock(status_code=404),  # Branch doesn't exist
+        MagicMock(status_code=404),  # Target branch doesn't exist
+    ]
+
+    # Test that ValueError is raised when target branch doesn't exist
+    with pytest.raises(
+        ValueError, match='Target branch nonexistent-branch does not exist'
+    ):
+        send_pull_request(
+            github_issue=mock_github_issue,
+            github_token='test-token',
+            github_username='test-user',
+            patch_dir=repo_path,
+            pr_type='ready',
+            llm_config=mock_llm_config,
+            target_branch='nonexistent-branch',
+        )
+
+    # Verify API calls
+    assert mock_get.call_count == 2
+
+
 @patch('subprocess.run')
 @patch('requests.post')
 @patch('requests.get')
@@ -616,6 +667,7 @@ def test_process_single_pr_update(
         mock_llm_config,
         None,
         False,
+        None,
     )
 
     mock_initialize_repo.assert_called_once_with(mock_output_dir, 1, 'pr', 'branch 1')
@@ -688,6 +740,7 @@ def test_process_single_issue(
         mock_llm_config,
         None,
         False,
+        None,
     )
 
     # Assert that the mocked functions were called with correct arguments
@@ -704,9 +757,10 @@ def test_process_single_issue(
         github_username=github_username,
         patch_dir=f'{mock_output_dir}/patches/issue_1',
         pr_type=pr_type,
+        llm_config=mock_llm_config,
         fork_owner=None,
         additional_message=resolver_output.success_explanation,
-        llm_config=mock_llm_config,
+        target_branch=None,
     )
 
 
@@ -757,6 +811,7 @@ def test_process_single_issue_unsuccessful(
         mock_llm_config,
         None,
         False,
+        None,
     )
 
     # Assert that none of the mocked functions were called
@@ -863,6 +918,7 @@ def test_process_all_successful_issues(
                 mock_llm_config,
                 None,
                 False,
+                None,
             ),
             call(
                 'output_dir',
@@ -873,6 +929,7 @@ def test_process_all_successful_issues(
                 mock_llm_config,
                 None,
                 False,
+                None,
             ),
         ]
     )
@@ -971,6 +1028,7 @@ def test_main(
     mock_args.llm_model = 'mock_model'
     mock_args.llm_base_url = 'mock_url'
     mock_args.llm_api_key = 'mock_key'
+    mock_args.target_branch = None
     mock_parser.return_value.parse_args.return_value = mock_args
 
     # Setup environment variables
@@ -994,12 +1052,8 @@ def test_main(
         api_key=mock_args.llm_api_key,
     )
 
-    # Assert function calls
-    mock_parser.assert_called_once()
-    mock_getenv.assert_any_call('GITHUB_TOKEN')
-    mock_path_exists.assert_called_with('/mock/output')
-    mock_load_single_resolver_output.assert_called_with('/mock/output/output.jsonl', 42)
-    mock_process_single_issue.assert_called_with(
+    # Use any_call instead of assert_called_with for more flexible matching
+    assert mock_process_single_issue.call_args == call(
         '/mock/output',
         mock_resolver_output,
         'mock_token',
@@ -1008,8 +1062,15 @@ def test_main(
         llm_config,
         None,
         False,
+        mock_args.target_branch,
     )
 
+    # Other assertions
+    mock_parser.assert_called_once()
+    mock_getenv.assert_any_call('GITHUB_TOKEN')
+    mock_path_exists.assert_called_with('/mock/output')
+    mock_load_single_resolver_output.assert_called_with('/mock/output/output.jsonl', 42)
+
     # Test for 'all_successful' issue number
     mock_args.issue_number = 'all_successful'
     main()