From 279e1d7abc9e27cf80de86ef866cf9cc70dae0d6 Mon Sep 17 00:00:00 2001 From: Engel Nyst Date: Sun, 8 Dec 2024 18:34:01 +0100 Subject: [PATCH] Resolver minor tweaks (#5461) --- openhands/resolver/issue_definitions.py | 41 ++++++++++- openhands/resolver/send_pull_request.py | 69 ++++++++++++++++--- .../resolver/test_pr_handler_guess_success.py | 18 ++--- tests/unit/resolver/test_pr_title_escaping.py | 2 - tests/unit/resolver/test_send_pull_request.py | 46 +++++++------ 5 files changed, 131 insertions(+), 45 deletions(-) diff --git a/openhands/resolver/issue_definitions.py b/openhands/resolver/issue_definitions.py index 52c552b10227..f36bb7a61527 100644 --- a/openhands/resolver/issue_definitions.py +++ b/openhands/resolver/issue_definitions.py @@ -62,19 +62,23 @@ def _download_issues_from_github(self) -> list[Any]: params: dict[str, int | str] = {'state': 'open', 'per_page': 100, 'page': 1} all_issues = [] + # Get issues, page by page while True: response = requests.get(url, headers=headers, params=params) response.raise_for_status() issues = response.json() + # No more issues, break the loop if not issues: break + # Sanity check - the response is a list of dictionaries if not isinstance(issues, list) or any( [not isinstance(issue, dict) for issue in issues] ): raise ValueError('Expected list of dictionaries from Github API.') + # Add the issues to the final list all_issues.extend(issues) assert isinstance(params['page'], int) params['page'] += 1 @@ -107,7 +111,12 @@ def _extract_issue_references(self, body: str) -> list[int]: def _get_issue_comments( self, issue_number: int, comment_id: int | None = None ) -> list[str] | None: - """Download comments for a specific issue from Github.""" + """Retrieve comments for a specific issue from Github. + + Args: + issue_number: The ID of the issue to get comments for + comment_id: The ID of a single comment, if provided, otherwise all comments + """ url = f'https://api.github.com/repos/{self.owner}/{self.repo}/issues/{issue_number}/comments' headers = { 'Authorization': f'token {self.token}', @@ -116,6 +125,7 @@ def _get_issue_comments( params = {'per_page': 100, 'page': 1} all_comments = [] + # Get comments, page by page while True: response = requests.get(url, headers=headers, params=params) response.raise_for_status() @@ -124,6 +134,7 @@ def _get_issue_comments( if not comments: break + # If a single comment ID is provided, return only that comment if comment_id: matching_comment = next( ( @@ -136,6 +147,7 @@ def _get_issue_comments( if matching_comment: return [matching_comment] else: + # Otherwise, return all comments all_comments.extend([comment['body'] for comment in comments]) params['page'] += 1 @@ -147,6 +159,10 @@ def get_converted_issues( ) -> list[GithubIssue]: """Download issues from Github. + Args: + issue_numbers: The numbers of the issues to download + comment_id: The ID of a single comment, if provided, otherwise all comments + Returns: List of Github issues. """ @@ -203,7 +219,14 @@ def get_instruction( prompt_template: str, repo_instruction: str | None = None, ) -> tuple[str, list[str]]: - """Generate instruction for the agent.""" + """Generate instruction for the agent. + + Args: + issue: The issue to generate instruction for + prompt_template: The prompt template to use + repo_instruction: The repository instruction if it exists + """ + # Format thread comments if they exist thread_context = '' if issue.thread_comments: @@ -211,6 +234,7 @@ def get_instruction( issue.thread_comments ) + # Extract image URLs from the issue body and thread comments images = [] images.extend(self._extract_image_urls(issue.body)) images.extend(self._extract_image_urls(thread_context)) @@ -227,8 +251,14 @@ def get_instruction( def guess_success( self, issue: GithubIssue, history: list[Event] ) -> tuple[bool, None | list[bool], str]: - """Guess if the issue is fixed based on the history and the issue description.""" + """Guess if the issue is fixed based on the history and the issue description. + + Args: + issue: The issue to check + history: The agent's history + """ last_message = history[-1].message + # Include thread comments in the prompt if they exist issue_context = issue.body if issue.thread_comments: @@ -236,6 +266,7 @@ def guess_success( issue.thread_comments ) + # Prepare the prompt with open( os.path.join( os.path.dirname(__file__), @@ -246,6 +277,7 @@ def guess_success( template = jinja2.Template(f.read()) prompt = template.render(issue_context=issue_context, last_message=last_message) + # Get the LLM response and check for 'success' and 'explanation' in the answer response = self.llm.completion(messages=[{'role': 'user', 'content': prompt}]) answer = response.choices[0].message.content.strip() @@ -328,6 +360,7 @@ def __download_pr_metadata( variables = {'owner': self.owner, 'repo': self.repo, 'pr': pull_number} + # Run the query url = 'https://api.github.com/graphql' headers = { 'Authorization': f'Bearer {self.token}', @@ -394,10 +427,12 @@ def __download_pr_metadata( review_thread['body'] + '\n' ) # Add each thread in a new line + # Source files on which the comments were made file = review_thread.get('path') if file and file not in files: files.append(file) + # If the comment ID is not provided or the thread contains the comment ID, add the thread to the list if comment_id is None or thread_contains_comment_id: unresolved_thread = ReviewThread(comment=message, files=files) review_threads.append(unresolved_thread) diff --git a/openhands/resolver/send_pull_request.py b/openhands/resolver/send_pull_request.py index 29f720160ba7..9cfe0ce8d32c 100644 --- a/openhands/resolver/send_pull_request.py +++ b/openhands/resolver/send_pull_request.py @@ -5,11 +5,11 @@ import subprocess import jinja2 -import litellm import requests from openhands.core.config import LLMConfig from openhands.core.logger import openhands_logger as logger +from openhands.llm.llm import LLM from openhands.resolver.github_issue import GithubIssue from openhands.resolver.io_utils import ( load_all_resolver_outputs, @@ -20,6 +20,12 @@ def apply_patch(repo_dir: str, patch: str) -> None: + """Apply a patch to a repository. + + Args: + repo_dir: The directory containing the repository + patch: The patch to apply + """ diffs = parse_patch(patch) for diff in diffs: if not diff.header.new_path: @@ -112,6 +118,14 @@ def apply_patch(repo_dir: str, patch: str) -> None: def initialize_repo( output_dir: str, issue_number: int, issue_type: str, base_commit: str | None = None ) -> str: + """Initialize the repository. + + Args: + output_dir: The output directory to write the repository to + issue_number: The issue number to fix + issue_type: The type of the issue + base_commit: The base commit to checkout (if issue_type is pr) + """ src_dir = os.path.join(output_dir, 'repo') dest_dir = os.path.join(output_dir, 'patches', f'{issue_type}_{issue_number}') @@ -124,6 +138,7 @@ def initialize_repo( shutil.copytree(src_dir, dest_dir) print(f'Copied repository to {dest_dir}') + # Checkout the base commit if provided if base_commit: result = subprocess.run( f'git -C {dest_dir} checkout {base_commit}', @@ -139,6 +154,13 @@ def initialize_repo( def make_commit(repo_dir: str, issue: GithubIssue, issue_type: str) -> None: + """Make a commit with the changes to the repository. + + Args: + repo_dir: The directory containing the repository + issue: The issue to fix + issue_type: The type of the issue + """ # Check if git username is set result = subprocess.run( f'git -C {repo_dir} config user.name', @@ -158,6 +180,7 @@ def make_commit(repo_dir: str, issue: GithubIssue, issue_type: str) -> None: ) print('Git user configured as openhands') + # Add all changes to the git index result = subprocess.run( f'git -C {repo_dir} add .', shell=True, capture_output=True, text=True ) @@ -165,6 +188,7 @@ def make_commit(repo_dir: str, issue: GithubIssue, issue_type: str) -> None: print(f'Error adding files: {result.stderr}') raise RuntimeError('Failed to add files to git') + # Check the status of the git index status_result = subprocess.run( f'git -C {repo_dir} status --porcelain', shell=True, @@ -172,11 +196,15 @@ def make_commit(repo_dir: str, issue: GithubIssue, issue_type: str) -> None: text=True, ) + # If there are no changes, raise an error if not status_result.stdout.strip(): print(f'No changes to commit for issue #{issue.number}. Skipping commit.') raise RuntimeError('ERROR: Openhands failed to make code changes.') + # Prepare the commit message commit_message = f'Fix {issue_type} #{issue.number}: {issue.title}' + + # Commit the changes result = subprocess.run( ['git', '-C', repo_dir, 'commit', '-m', commit_message], capture_output=True, @@ -206,12 +234,23 @@ def send_pull_request( github_token: str, github_username: str | None, patch_dir: str, - llm_config: LLMConfig, pr_type: str, fork_owner: str | None = None, additional_message: str | None = None, target_branch: str | None = None, ) -> str: + """Send a pull request to a GitHub repository. + + Args: + github_issue: The issue to send the pull request for + github_token: The GitHub token to use for authentication + github_username: The GitHub username, if provided + patch_dir: The directory containing the patches to apply + pr_type: The type: branch (no PR created), draft or ready (regular PR created) + fork_owner: The owner of the fork to push changes to (if different from the original repo owner) + additional_message: The additional messages to post as a comment on the PR in json list format + target_branch: The target branch to create the pull request against (defaults to repository default branch) + """ if pr_type not in ['branch', 'draft', 'ready']: raise ValueError(f'Invalid pr_type: {pr_type}') @@ -227,6 +266,7 @@ def send_pull_request( branch_name = base_branch_name attempt = 1 + # Find a unique branch name print('Checking if branch exists...') while branch_exists(base_url, branch_name, headers): attempt += 1 @@ -279,6 +319,7 @@ def send_pull_request( print(f'Error pushing changes: {result.stderr}') raise RuntimeError('Failed to push changes to the remote repository') + # Prepare the PR data: title and body pr_title = f'Fix issue #{github_issue.number}: {github_issue.title}' pr_body = f'This pull request fixes #{github_issue.number}.' if additional_message: @@ -290,6 +331,7 @@ def send_pull_request( if pr_type == 'branch': url = f'https://github.com/{push_owner}/{github_issue.repo}/compare/{branch_name}?expand=1' else: + # Prepare the PR for the GitHub API data = { 'title': pr_title, # No need to escape title for GitHub API 'body': pr_body, @@ -297,6 +339,8 @@ def send_pull_request( 'base': base_branch, 'draft': pr_type == 'draft', } + + # Send the PR and get its URL to tell the user response = requests.post(f'{base_url}/pulls', headers=headers, json=data) if response.status_code == 403: raise RuntimeError( @@ -314,6 +358,13 @@ def send_pull_request( def reply_to_comment(github_token: str, comment_id: str, reply: str): + """Reply to a comment on a GitHub issue or pull request. + + Args: + github_token: The GitHub token to use for authentication + comment_id: The ID of the comment to reply to + reply: The reply message to post + """ # Opting for graphql as REST API doesn't allow reply to replies in comment threads query = """ mutation($body: String!, $pullRequestReviewThreadId: ID!) { @@ -327,6 +378,7 @@ def reply_to_comment(github_token: str, comment_id: str, reply: str): } """ + # Prepare the reply to the comment comment_reply = f'Openhands fix success summary\n\n\n{reply}' variables = {'body': comment_reply, 'pullRequestReviewThreadId': comment_id} url = 'https://api.github.com/graphql' @@ -335,6 +387,7 @@ def reply_to_comment(github_token: str, comment_id: str, reply: str): 'Content-Type': 'application/json', } + # Send the reply to the comment response = requests.post( url, json={'query': query, 'variables': variables}, headers=headers ) @@ -392,13 +445,14 @@ def update_existing_pull_request( base_url = f'https://api.github.com/repos/{github_issue.owner}/{github_issue.repo}' branch_name = github_issue.head_branch - # Push the changes to the existing branch + # Prepare the push command push_command = ( f'git -C {patch_dir} push ' f'https://{github_username}:{github_token}@github.com/' f'{github_issue.owner}/{github_issue.repo}.git {branch_name}' ) + # Push the changes to the existing branch result = subprocess.run(push_command, shell=True, capture_output=True, text=True) if result.returncode != 0: print(f'Error pushing changes: {result.stderr}') @@ -420,6 +474,7 @@ def update_existing_pull_request( # Summarize with LLM if provided if llm_config is not None: + llm = LLM(llm_config) with open( os.path.join( os.path.dirname(__file__), @@ -429,16 +484,13 @@ def update_existing_pull_request( ) as f: template = jinja2.Template(f.read()) prompt = template.render(comment_message=comment_message) - response = litellm.completion( - model=llm_config.model, + response = llm.completion( messages=[{'role': 'user', 'content': prompt}], - api_key=llm_config.api_key, - base_url=llm_config.base_url, ) comment_message = response.choices[0].message.content.strip() except (json.JSONDecodeError, TypeError): - comment_message = 'New OpenHands update' + comment_message = f'A new OpenHands update is available, but failed to parse or summarize the changes:\n{additional_message}' # Post a comment on the PR if comment_message: @@ -514,7 +566,6 @@ def process_single_issue( github_username=github_username, patch_dir=patched_repo_dir, pr_type=pr_type, - llm_config=llm_config, fork_owner=fork_owner, additional_message=resolver_output.success_explanation, target_branch=target_branch, diff --git a/tests/unit/resolver/test_pr_handler_guess_success.py b/tests/unit/resolver/test_pr_handler_guess_success.py index 58cbd89c8180..af7f22e3ab9c 100644 --- a/tests/unit/resolver/test_pr_handler_guess_success.py +++ b/tests/unit/resolver/test_pr_handler_guess_success.py @@ -16,7 +16,7 @@ def mock_llm_response(content): def test_guess_success_review_threads_litellm_call(): - """Test that the litellm.completion() call for review threads contains the expected content.""" + """Test that the completion() call for review threads contains the expected content.""" # Create a PR handler instance llm_config = LLMConfig(model='test', api_key='test') handler = PRHandler('test-owner', 'test-repo', 'test-token', llm_config) @@ -77,7 +77,7 @@ def test_guess_success_review_threads_litellm_call(): mock_completion.return_value = mock_response success, success_list, explanation = handler.guess_success(issue, history) - # Verify the litellm.completion() calls + # Verify the completion() calls assert mock_completion.call_count == 2 # One call per review thread # Check first call @@ -121,7 +121,7 @@ def test_guess_success_review_threads_litellm_call(): def test_guess_success_thread_comments_litellm_call(): - """Test that the litellm.completion() call for thread comments contains the expected content.""" + """Test that the completion() call for thread comments contains the expected content.""" # Create a PR handler instance llm_config = LLMConfig(model='test', api_key='test') handler = PRHandler('test-owner', 'test-repo', 'test-token', llm_config) @@ -176,7 +176,7 @@ def test_guess_success_thread_comments_litellm_call(): mock_completion.return_value = mock_response success, success_list, explanation = handler.guess_success(issue, history) - # Verify the litellm.completion() call + # Verify the completion() call mock_completion.assert_called_once() call_args = mock_completion.call_args prompt = call_args[1]['messages'][0]['content'] @@ -270,7 +270,7 @@ def test_check_review_thread(): review_thread, issues_context, last_message ) - # Verify the litellm.completion() call + # Verify the completion() call mock_completion.assert_called_once() call_args = mock_completion.call_args prompt = call_args[1]['messages'][0]['content'] @@ -326,7 +326,7 @@ def test_check_thread_comments(): thread_comments, issues_context, last_message ) - # Verify the litellm.completion() call + # Verify the completion() call mock_completion.assert_called_once() call_args = mock_completion.call_args prompt = call_args[1]['messages'][0]['content'] @@ -379,7 +379,7 @@ def test_check_review_comments(): review_comments, issues_context, last_message ) - # Verify the litellm.completion() call + # Verify the completion() call mock_completion.assert_called_once() call_args = mock_completion.call_args prompt = call_args[1]['messages'][0]['content'] @@ -395,7 +395,7 @@ def test_check_review_comments(): def test_guess_success_review_comments_litellm_call(): - """Test that the litellm.completion() call for review comments contains the expected content.""" + """Test that the completion() call for review comments contains the expected content.""" # Create a PR handler instance llm_config = LLMConfig(model='test', api_key='test') handler = PRHandler('test-owner', 'test-repo', 'test-token', llm_config) @@ -447,7 +447,7 @@ def test_guess_success_review_comments_litellm_call(): mock_completion.return_value = mock_response success, success_list, explanation = handler.guess_success(issue, history) - # Verify the litellm.completion() call + # Verify the completion() call mock_completion.assert_called_once() call_args = mock_completion.call_args prompt = call_args[1]['messages'][0]['content'] diff --git a/tests/unit/resolver/test_pr_title_escaping.py b/tests/unit/resolver/test_pr_title_escaping.py index 45dd523b036a..9cc5d90bc4b0 100644 --- a/tests/unit/resolver/test_pr_title_escaping.py +++ b/tests/unit/resolver/test_pr_title_escaping.py @@ -153,7 +153,6 @@ def mock_run(*args, **kwargs): # Try to send a PR - this will fail if the title is incorrectly escaped print('Sending PR...') - from openhands.core.config import LLMConfig from openhands.resolver.send_pull_request import send_pull_request send_pull_request( @@ -161,6 +160,5 @@ def mock_run(*args, **kwargs): github_token='dummy-token', github_username='test-user', patch_dir=temp_dir, - llm_config=LLMConfig(model='test-model', api_key='test-key'), pr_type='ready', ) diff --git a/tests/unit/resolver/test_send_pull_request.py b/tests/unit/resolver/test_send_pull_request.py index f83e2e97ec2f..c31d88cbae85 100644 --- a/tests/unit/resolver/test_send_pull_request.py +++ b/tests/unit/resolver/test_send_pull_request.py @@ -244,8 +244,12 @@ def test_initialize_repo(mock_output_dir): @patch('openhands.resolver.send_pull_request.reply_to_comment') @patch('requests.post') @patch('subprocess.run') +@patch('openhands.resolver.send_pull_request.LLM') def test_update_existing_pull_request( - mock_subprocess_run, mock_requests_post, mock_reply_to_comment + mock_llm_class, + mock_subprocess_run, + mock_requests_post, + mock_reply_to_comment, ): # Arrange: Set up test data github_issue = GithubIssue( @@ -267,23 +271,28 @@ def test_update_existing_pull_request( # Mock the requests.post call for adding a PR comment mock_requests_post.return_value.status_code = 201 + + # Mock LLM instance and completion call + mock_llm_instance = MagicMock() mock_completion_response = MagicMock() mock_completion_response.choices = [ MagicMock(message=MagicMock(content='This is an issue resolution.')) ] + mock_llm_instance.completion.return_value = mock_completion_response + mock_llm_class.return_value = mock_llm_instance + llm_config = LLMConfig() # Act: Call the function without comment_message to test auto-generation - with patch('litellm.completion', MagicMock(return_value=mock_completion_response)): - result = update_existing_pull_request( - github_issue, - github_token, - github_username, - patch_dir, - llm_config, - comment_message=None, - additional_message=additional_message, - ) + result = update_existing_pull_request( + github_issue, + github_token, + github_username, + patch_dir, + llm_config, + comment_message=None, + additional_message=additional_message, + ) # Assert: Check if the git push command was executed push_command = ( @@ -342,7 +351,6 @@ def test_send_pull_request( mock_run, mock_github_issue, mock_output_dir, - mock_llm_config, pr_type, target_branch, ): @@ -377,7 +385,6 @@ def test_send_pull_request( github_username='test-user', patch_dir=repo_path, pr_type=pr_type, - llm_config=mock_llm_config, target_branch=target_branch, ) @@ -427,7 +434,7 @@ def test_send_pull_request( @patch('requests.get') def test_send_pull_request_invalid_target_branch( - mock_get, mock_github_issue, mock_output_dir, mock_llm_config + mock_get, mock_github_issue, mock_output_dir ): """Test that an error is raised when specifying a non-existent target branch""" repo_path = os.path.join(mock_output_dir, 'repo') @@ -448,7 +455,6 @@ def test_send_pull_request_invalid_target_branch( github_username='test-user', patch_dir=repo_path, pr_type='ready', - llm_config=mock_llm_config, target_branch='nonexistent-branch', ) @@ -460,7 +466,7 @@ def test_send_pull_request_invalid_target_branch( @patch('requests.post') @patch('requests.get') def test_send_pull_request_git_push_failure( - mock_get, mock_post, mock_run, mock_github_issue, mock_output_dir, mock_llm_config + mock_get, mock_post, mock_run, mock_github_issue, mock_output_dir ): repo_path = os.path.join(mock_output_dir, 'repo') @@ -483,7 +489,6 @@ def test_send_pull_request_git_push_failure( github_username='test-user', patch_dir=repo_path, pr_type='ready', - llm_config=mock_llm_config, ) # Assert that subprocess.run was called twice @@ -519,7 +524,7 @@ def test_send_pull_request_git_push_failure( @patch('requests.post') @patch('requests.get') def test_send_pull_request_permission_error( - mock_get, mock_post, mock_run, mock_github_issue, mock_output_dir, mock_llm_config + mock_get, mock_post, mock_run, mock_github_issue, mock_output_dir ): repo_path = os.path.join(mock_output_dir, 'repo') @@ -543,7 +548,6 @@ def test_send_pull_request_permission_error( github_username='test-user', patch_dir=repo_path, pr_type='ready', - llm_config=mock_llm_config, ) # Assert that the branch was created and pushed @@ -757,7 +761,6 @@ def test_process_single_issue( github_username=github_username, patch_dir=f'{mock_output_dir}/patches/issue_1', pr_type=pr_type, - llm_config=mock_llm_config, fork_owner=None, additional_message=resolver_output.success_explanation, target_branch=None, @@ -940,7 +943,7 @@ def test_process_all_successful_issues( @patch('requests.get') @patch('subprocess.run') def test_send_pull_request_branch_naming( - mock_run, mock_get, mock_github_issue, mock_output_dir, mock_llm_config + mock_run, mock_get, mock_github_issue, mock_output_dir ): repo_path = os.path.join(mock_output_dir, 'repo') @@ -965,7 +968,6 @@ def test_send_pull_request_branch_naming( github_username='test-user', patch_dir=repo_path, pr_type='branch', - llm_config=mock_llm_config, ) # Assert API calls