From ae89361f275086a16c11fa4f03438e59f85f4327 Mon Sep 17 00:00:00 2001 From: Matt Holloway Date: Mon, 28 Jul 2025 15:41:42 +0100 Subject: [PATCH 1/6] initial impl of pull request draft state update --- .../__toolsnaps__/update_pull_request.snap | 4 + pkg/github/pullrequests.go | 146 ++++++++++++++---- pkg/github/pullrequests_test.go | 13 +- pkg/github/tools.go | 2 +- 4 files changed, 136 insertions(+), 29 deletions(-) diff --git a/pkg/github/__toolsnaps__/update_pull_request.snap b/pkg/github/__toolsnaps__/update_pull_request.snap index 765983afd..c44d8afa0 100644 --- a/pkg/github/__toolsnaps__/update_pull_request.snap +++ b/pkg/github/__toolsnaps__/update_pull_request.snap @@ -14,6 +14,10 @@ "description": "New description", "type": "string" }, + "draft": { + "description": "Mark pull request as draft (true) or ready for review (false)", + "type": "boolean" + }, "maintainer_can_modify": { "description": "Allow maintainer edits", "type": "boolean" diff --git a/pkg/github/pullrequests.go b/pkg/github/pullrequests.go index 47b7c6bd2..aeb0551c9 100644 --- a/pkg/github/pullrequests.go +++ b/pkg/github/pullrequests.go @@ -203,7 +203,7 @@ func CreatePullRequest(getClient GetClientFn, t translations.TranslationHelperFu } // UpdatePullRequest creates a tool to update an existing pull request. -func UpdatePullRequest(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, server.ToolHandlerFunc) { +func UpdatePullRequest(getClient GetClientFn, getGQLClient GetGQLClientFn, t translations.TranslationHelperFunc) (mcp.Tool, server.ToolHandlerFunc) { return mcp.NewTool("update_pull_request", mcp.WithDescription(t("TOOL_UPDATE_PULL_REQUEST_DESCRIPTION", "Update an existing pull request in a GitHub repository.")), mcp.WithToolAnnotation(mcp.ToolAnnotation{ @@ -232,6 +232,9 @@ func UpdatePullRequest(getClient GetClientFn, t translations.TranslationHelperFu mcp.Description("New state"), mcp.Enum("open", "closed"), ), + mcp.WithBoolean("draft", + mcp.Description("Mark pull request as draft (true) or ready for review (false)"), + ), mcp.WithString("base", mcp.Description("New base branch name"), ), @@ -253,74 +256,165 @@ func UpdatePullRequest(getClient GetClientFn, t translations.TranslationHelperFu return mcp.NewToolResultError(err.Error()), nil } - // Build the update struct only with provided fields + draftProvided := request.GetArguments()["draft"] != nil + var draftValue bool + if draftProvided { + draftValue, err = OptionalParam[bool](request, "draft") + if err != nil { + return nil, err + } + } + update := &github.PullRequest{} - updateNeeded := false + restUpdateNeeded := false if title, ok, err := OptionalParamOK[string](request, "title"); err != nil { return mcp.NewToolResultError(err.Error()), nil } else if ok { update.Title = github.Ptr(title) - updateNeeded = true + restUpdateNeeded = true } if body, ok, err := OptionalParamOK[string](request, "body"); err != nil { return mcp.NewToolResultError(err.Error()), nil } else if ok { update.Body = github.Ptr(body) - updateNeeded = true + restUpdateNeeded = true } if state, ok, err := OptionalParamOK[string](request, "state"); err != nil { return mcp.NewToolResultError(err.Error()), nil } else if ok { update.State = github.Ptr(state) - updateNeeded = true + restUpdateNeeded = true } if base, ok, err := OptionalParamOK[string](request, "base"); err != nil { return mcp.NewToolResultError(err.Error()), nil } else if ok { update.Base = &github.PullRequestBranch{Ref: github.Ptr(base)} - updateNeeded = true + restUpdateNeeded = true } if maintainerCanModify, ok, err := OptionalParamOK[bool](request, "maintainer_can_modify"); err != nil { return mcp.NewToolResultError(err.Error()), nil } else if ok { update.MaintainerCanModify = github.Ptr(maintainerCanModify) - updateNeeded = true + restUpdateNeeded = true } - if !updateNeeded { + if !restUpdateNeeded && !draftProvided { return mcp.NewToolResultError("No update parameters provided."), nil } + if restUpdateNeeded { + client, err := getClient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get GitHub client: %w", err) + } + + _, resp, err := client.PullRequests.Edit(ctx, owner, repo, pullNumber, update) + if err != nil { + return ghErrors.NewGitHubAPIErrorResponse(ctx, + "failed to update pull request", + resp, + err, + ), nil + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + return mcp.NewToolResultError(fmt.Sprintf("failed to update pull request: %s", string(body))), nil + } + } + + if draftProvided { + gqlClient, err := getGQLClient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get GitHub GraphQL client: %w", err) + } + + var prQuery struct { + Repository struct { + PullRequest struct { + ID githubv4.ID + IsDraft githubv4.Boolean + } `graphql:"pullRequest(number: $prNum)"` + } `graphql:"repository(owner: $owner, name: $repo)"` + } + + err = gqlClient.Query(ctx, &prQuery, map[string]interface{}{ + "owner": githubv4.String(owner), + "repo": githubv4.String(repo), + "prNum": githubv4.Int(pullNumber), + }) + if err != nil { + return ghErrors.NewGitHubGraphQLErrorResponse(ctx, "Failed to find pull request", err), nil + } + + currentIsDraft := bool(prQuery.Repository.PullRequest.IsDraft) + + if currentIsDraft != draftValue { + if draftValue { + // Convert to draft + var mutation struct { + ConvertPullRequestToDraft struct { + PullRequest struct { + ID githubv4.ID + IsDraft githubv4.Boolean + } + } `graphql:"convertPullRequestToDraft(input: $input)"` + } + + err = gqlClient.Mutate(ctx, &mutation, githubv4.ConvertPullRequestToDraftInput{ + PullRequestID: prQuery.Repository.PullRequest.ID, + }, nil) + if err != nil { + return ghErrors.NewGitHubGraphQLErrorResponse(ctx, "Failed to convert pull request to draft", err), nil + } + } else { + // Mark as ready for review + var mutation struct { + MarkPullRequestReadyForReview struct { + PullRequest struct { + ID githubv4.ID + IsDraft githubv4.Boolean + } + } `graphql:"markPullRequestReadyForReview(input: $input)"` + } + + err = gqlClient.Mutate(ctx, &mutation, githubv4.MarkPullRequestReadyForReviewInput{ + PullRequestID: prQuery.Repository.PullRequest.ID, + }, nil) + if err != nil { + return ghErrors.NewGitHubGraphQLErrorResponse(ctx, "Failed to mark pull request ready for review", err), nil + } + } + } + } + client, err := getClient(ctx) if err != nil { - return nil, fmt.Errorf("failed to get GitHub client: %w", err) + return nil, err } - pr, resp, err := client.PullRequests.Edit(ctx, owner, repo, pullNumber, update) + + finalPR, resp, err := client.PullRequests.Get(ctx, owner, repo, pullNumber) if err != nil { - return ghErrors.NewGitHubAPIErrorResponse(ctx, - "failed to update pull request", - resp, - err, - ), nil + return ghErrors.NewGitHubAPIErrorResponse(ctx, "Failed to get pull request", resp, err), nil } - defer func() { _ = resp.Body.Close() }() - - if resp.StatusCode != http.StatusOK { - body, err := io.ReadAll(resp.Body) - if err != nil { - return nil, fmt.Errorf("failed to read response body: %w", err) + defer func() { + if resp != nil && resp.Body != nil { + _ = resp.Body.Close() } - return mcp.NewToolResultError(fmt.Sprintf("failed to update pull request: %s", string(body))), nil - } + }() - r, err := json.Marshal(pr) + r, err := json.Marshal(finalPR) if err != nil { - return nil, fmt.Errorf("failed to marshal response: %w", err) + return ghErrors.NewGitHubGraphQLErrorResponse(ctx, "Failed to marshal response: %v", err), nil } return mcp.NewToolResultText(string(r)), nil diff --git a/pkg/github/pullrequests_test.go b/pkg/github/pullrequests_test.go index 42fd5bf03..823a87525 100644 --- a/pkg/github/pullrequests_test.go +++ b/pkg/github/pullrequests_test.go @@ -137,7 +137,7 @@ func Test_GetPullRequest(t *testing.T) { func Test_UpdatePullRequest(t *testing.T) { // Verify tool definition once mockClient := github.NewClient(nil) - tool, _ := UpdatePullRequest(stubGetClientFn(mockClient), translations.NullTranslationHelper) + tool, _ := UpdatePullRequest(stubGetClientFn(mockClient), stubGetGQLClientFn(githubv4.NewClient(nil)), translations.NullTranslationHelper) require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "update_pull_request", tool.Name) @@ -145,6 +145,7 @@ func Test_UpdatePullRequest(t *testing.T) { assert.Contains(t, tool.InputSchema.Properties, "owner") assert.Contains(t, tool.InputSchema.Properties, "repo") assert.Contains(t, tool.InputSchema.Properties, "pullNumber") + assert.Contains(t, tool.InputSchema.Properties, "draft") assert.Contains(t, tool.InputSchema.Properties, "title") assert.Contains(t, tool.InputSchema.Properties, "body") assert.Contains(t, tool.InputSchema.Properties, "state") @@ -194,6 +195,10 @@ func Test_UpdatePullRequest(t *testing.T) { mockResponse(t, http.StatusOK, mockUpdatedPR), ), ), + mock.WithRequestMatch( + mock.GetReposPullsByOwnerByRepoByPullNumber, + mockUpdatedPR, + ), ), requestArgs: map[string]interface{}{ "owner": "owner", @@ -218,6 +223,10 @@ func Test_UpdatePullRequest(t *testing.T) { mockResponse(t, http.StatusOK, mockClosedPR), ), ), + mock.WithRequestMatch( + mock.GetReposPullsByOwnerByRepoByPullNumber, + mockClosedPR, + ), ), requestArgs: map[string]interface{}{ "owner": "owner", @@ -266,7 +275,7 @@ func Test_UpdatePullRequest(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := UpdatePullRequest(stubGetClientFn(client), translations.NullTranslationHelper) + _, handler := UpdatePullRequest(stubGetClientFn(client), stubGetGQLClientFn(githubv4.NewClient(nil)), translations.NullTranslationHelper) // Create call request request := createMCPRequest(tc.requestArgs) diff --git a/pkg/github/tools.go b/pkg/github/tools.go index e01b7cc40..caa4f9cfe 100644 --- a/pkg/github/tools.go +++ b/pkg/github/tools.go @@ -87,7 +87,7 @@ func DefaultToolsetGroup(readOnly bool, getClient GetClientFn, getGQLClient GetG toolsets.NewServerTool(MergePullRequest(getClient, t)), toolsets.NewServerTool(UpdatePullRequestBranch(getClient, t)), toolsets.NewServerTool(CreatePullRequest(getClient, t)), - toolsets.NewServerTool(UpdatePullRequest(getClient, t)), + toolsets.NewServerTool(UpdatePullRequest(getClient, getGQLClient, t)), toolsets.NewServerTool(RequestCopilotReview(getClient, t)), // Reviews From ffd5798e7aaa2d3566136b29002ec4a2732deb6c Mon Sep 17 00:00:00 2001 From: Matt Holloway Date: Mon, 28 Jul 2025 16:12:03 +0100 Subject: [PATCH 2/6] appease linter --- pkg/github/pullrequests.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/github/pullrequests.go b/pkg/github/pullrequests.go index aeb0551c9..080887f35 100644 --- a/pkg/github/pullrequests.go +++ b/pkg/github/pullrequests.go @@ -350,7 +350,7 @@ func UpdatePullRequest(getClient GetClientFn, getGQLClient GetGQLClientFn, t tra err = gqlClient.Query(ctx, &prQuery, map[string]interface{}{ "owner": githubv4.String(owner), "repo": githubv4.String(repo), - "prNum": githubv4.Int(pullNumber), + "prNum": githubv4.Int(int32(pullNumber)), }) if err != nil { return ghErrors.NewGitHubGraphQLErrorResponse(ctx, "Failed to find pull request", err), nil From e0d67eb7ac8ea4c23124651a40fa61d56c5bdbe4 Mon Sep 17 00:00:00 2001 From: Matt Holloway Date: Mon, 28 Jul 2025 16:13:15 +0100 Subject: [PATCH 3/6] update README --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index be9288e40..34edc7b33 100644 --- a/README.md +++ b/README.md @@ -736,6 +736,7 @@ The following sets of tools are available (all are on by default): - **update_pull_request** - Edit pull request - `base`: New base branch name (string, optional) - `body`: New description (string, optional) + - `draft`: Mark pull request as draft (true) or ready for review (false) (boolean, optional) - `maintainer_can_modify`: Allow maintainer edits (boolean, optional) - `owner`: Repository owner (string, required) - `pullNumber`: Pull request number to update (number, required) From 93fa25b9039cb7198c2d8baac22f0fa2ce4158d7 Mon Sep 17 00:00:00 2001 From: Matt Holloway Date: Mon, 28 Jul 2025 16:19:44 +0100 Subject: [PATCH 4/6] add nosec --- pkg/github/pullrequests.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/github/pullrequests.go b/pkg/github/pullrequests.go index 080887f35..e30f856d2 100644 --- a/pkg/github/pullrequests.go +++ b/pkg/github/pullrequests.go @@ -350,7 +350,7 @@ func UpdatePullRequest(getClient GetClientFn, getGQLClient GetGQLClientFn, t tra err = gqlClient.Query(ctx, &prQuery, map[string]interface{}{ "owner": githubv4.String(owner), "repo": githubv4.String(repo), - "prNum": githubv4.Int(int32(pullNumber)), + "prNum": githubv4.Int(pullNumber), // #nosec G115 - pull request numbers are always small positive integers }) if err != nil { return ghErrors.NewGitHubGraphQLErrorResponse(ctx, "Failed to find pull request", err), nil From ec521d58096db7620bfae61deef3eac0bda1e8a0 Mon Sep 17 00:00:00 2001 From: Matt Holloway Date: Mon, 28 Jul 2025 16:27:07 +0100 Subject: [PATCH 5/6] fixed err return type for json marshalling --- pkg/github/pullrequests.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/github/pullrequests.go b/pkg/github/pullrequests.go index e30f856d2..f380843b0 100644 --- a/pkg/github/pullrequests.go +++ b/pkg/github/pullrequests.go @@ -414,7 +414,7 @@ func UpdatePullRequest(getClient GetClientFn, getGQLClient GetGQLClientFn, t tra r, err := json.Marshal(finalPR) if err != nil { - return ghErrors.NewGitHubGraphQLErrorResponse(ctx, "Failed to marshal response: %v", err), nil + return mcp.NewToolResultError(fmt.Sprintf("Failed to marshal response: %v", err)), nil } return mcp.NewToolResultText(string(r)), nil From 812048bfe2fafd56adfab13db20e5806d73f6804 Mon Sep 17 00:00:00 2001 From: Matt Holloway Date: Tue, 29 Jul 2025 10:31:22 +0100 Subject: [PATCH 6/6] add gql test --- pkg/github/pullrequests_test.go | 211 ++++++++++++++++++++++++++++++++ 1 file changed, 211 insertions(+) diff --git a/pkg/github/pullrequests_test.go b/pkg/github/pullrequests_test.go index 823a87525..179458115 100644 --- a/pkg/github/pullrequests_test.go +++ b/pkg/github/pullrequests_test.go @@ -161,6 +161,7 @@ func Test_UpdatePullRequest(t *testing.T) { HTMLURL: github.Ptr("https://github.com/owner/repo/pull/42"), Body: github.Ptr("Updated test PR body."), MaintainerCanModify: github.Ptr(false), + Draft: github.Ptr(false), Base: &github.PullRequestBranch{ Ref: github.Ptr("develop"), }, @@ -237,6 +238,31 @@ func Test_UpdatePullRequest(t *testing.T) { expectError: false, expectedPR: mockClosedPR, }, + { + name: "successful PR update (title only)", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.PatchReposPullsByOwnerByRepoByPullNumber, + expectRequestBody(t, map[string]interface{}{ + "title": "Updated Test PR Title", + }).andThen( + mockResponse(t, http.StatusOK, mockUpdatedPR), + ), + ), + mock.WithRequestMatch( + mock.GetReposPullsByOwnerByRepoByPullNumber, + mockUpdatedPR, + ), + ), + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + "pullNumber": float64(42), + "title": "Updated Test PR Title", + }, + expectError: false, + expectedPR: mockUpdatedPR, + }, { name: "no update parameters provided", mockedClient: mock.NewMockedHTTPClient(), // No API call expected @@ -325,6 +351,191 @@ func Test_UpdatePullRequest(t *testing.T) { } } +func Test_UpdatePullRequest_Draft(t *testing.T) { + // Setup mock PR for success case + mockUpdatedPR := &github.PullRequest{ + Number: github.Ptr(42), + Title: github.Ptr("Test PR Title"), + State: github.Ptr("open"), + HTMLURL: github.Ptr("https://github.com/owner/repo/pull/42"), + Body: github.Ptr("Test PR body."), + MaintainerCanModify: github.Ptr(false), + Draft: github.Ptr(false), // Updated to ready for review + Base: &github.PullRequestBranch{ + Ref: github.Ptr("main"), + }, + } + + tests := []struct { + name string + mockedClient *http.Client + requestArgs map[string]interface{} + expectError bool + expectedPR *github.PullRequest + expectedErrMsg string + }{ + { + name: "successful draft update to ready for review", + mockedClient: githubv4mock.NewMockedHTTPClient( + githubv4mock.NewQueryMatcher( + struct { + Repository struct { + PullRequest struct { + ID githubv4.ID + IsDraft githubv4.Boolean + } `graphql:"pullRequest(number: $prNum)"` + } `graphql:"repository(owner: $owner, name: $repo)"` + }{}, + map[string]any{ + "owner": githubv4.String("owner"), + "repo": githubv4.String("repo"), + "prNum": githubv4.Int(42), + }, + githubv4mock.DataResponse(map[string]any{ + "repository": map[string]any{ + "pullRequest": map[string]any{ + "id": "PR_kwDOA0xdyM50BPaO", + "isDraft": true, // Current state is draft + }, + }, + }), + ), + githubv4mock.NewMutationMatcher( + struct { + MarkPullRequestReadyForReview struct { + PullRequest struct { + ID githubv4.ID + IsDraft githubv4.Boolean + } + } `graphql:"markPullRequestReadyForReview(input: $input)"` + }{}, + githubv4.MarkPullRequestReadyForReviewInput{ + PullRequestID: "PR_kwDOA0xdyM50BPaO", + }, + nil, + githubv4mock.DataResponse(map[string]any{ + "markPullRequestReadyForReview": map[string]any{ + "pullRequest": map[string]any{ + "id": "PR_kwDOA0xdyM50BPaO", + "isDraft": false, + }, + }, + }), + ), + ), + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + "pullNumber": float64(42), + "draft": false, + }, + expectError: false, + expectedPR: mockUpdatedPR, + }, + { + name: "successful convert pull request to draft", + mockedClient: githubv4mock.NewMockedHTTPClient( + githubv4mock.NewQueryMatcher( + struct { + Repository struct { + PullRequest struct { + ID githubv4.ID + IsDraft githubv4.Boolean + } `graphql:"pullRequest(number: $prNum)"` + } `graphql:"repository(owner: $owner, name: $repo)"` + }{}, + map[string]any{ + "owner": githubv4.String("owner"), + "repo": githubv4.String("repo"), + "prNum": githubv4.Int(42), + }, + githubv4mock.DataResponse(map[string]any{ + "repository": map[string]any{ + "pullRequest": map[string]any{ + "id": "PR_kwDOA0xdyM50BPaO", + "isDraft": false, // Current state is draft + }, + }, + }), + ), + githubv4mock.NewMutationMatcher( + struct { + ConvertPullRequestToDraft struct { + PullRequest struct { + ID githubv4.ID + IsDraft githubv4.Boolean + } + } `graphql:"convertPullRequestToDraft(input: $input)"` + }{}, + githubv4.ConvertPullRequestToDraftInput{ + PullRequestID: "PR_kwDOA0xdyM50BPaO", + }, + nil, + githubv4mock.DataResponse(map[string]any{ + "convertPullRequestToDraft": map[string]any{ + "pullRequest": map[string]any{ + "id": "PR_kwDOA0xdyM50BPaO", + "isDraft": true, + }, + }, + }), + ), + ), + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + "pullNumber": float64(42), + "draft": true, + }, + expectError: false, + expectedPR: mockUpdatedPR, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // For draft-only tests, we need to mock both GraphQL and the final REST GET call + restClient := github.NewClient(mock.NewMockedHTTPClient( + mock.WithRequestMatch( + mock.GetReposPullsByOwnerByRepoByPullNumber, + mockUpdatedPR, + ), + )) + gqlClient := githubv4.NewClient(tc.mockedClient) + + _, handler := UpdatePullRequest(stubGetClientFn(restClient), stubGetGQLClientFn(gqlClient), translations.NullTranslationHelper) + + request := createMCPRequest(tc.requestArgs) + + result, err := handler(context.Background(), request) + + if tc.expectError || tc.expectedErrMsg != "" { + require.NoError(t, err) + require.True(t, result.IsError) + errorContent := getErrorResult(t, result) + if tc.expectedErrMsg != "" { + assert.Contains(t, errorContent.Text, tc.expectedErrMsg) + } + return + } + + require.NoError(t, err) + require.False(t, result.IsError) + + textContent := getTextResult(t, result) + + // Unmarshal and verify the successful result + var returnedPR github.PullRequest + err = json.Unmarshal([]byte(textContent.Text), &returnedPR) + require.NoError(t, err) + assert.Equal(t, *tc.expectedPR.Number, *returnedPR.Number) + if tc.expectedPR.Draft != nil { + assert.Equal(t, *tc.expectedPR.Draft, *returnedPR.Draft) + } + }) + } +} + func Test_ListPullRequests(t *testing.T) { // Verify tool definition once mockClient := github.NewClient(nil) pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy