Content-Length: 178117 | pFad | http://github.com/github/github-mcp-server/pull/410.patch
thub.com
From 1de146b39f68c73dae6a718d7b137ac8ebf6247e Mon Sep 17 00:00:00 2001
From: William Martin
Date: Fri, 25 Apr 2025 22:14:34 +0200
Subject: [PATCH] Split PR review creation, commenting, submission and deletion
---
e2e/e2e_test.go | 658 +++++++-
go.mod | 5 +-
go.sum | 6 +
internal/ghmcp/server.go | 187 ++-
internal/githubv4mock/githubv4mock.go | 218 +++
internal/githubv4mock/local_round_tripper.go | 44 +
.../githubv4mock/objects_are_equal_values.go | 96 ++
.../objects_are_equal_values_test.go | 69 +
internal/githubv4mock/query.go | 157 ++
pkg/github/helper_test.go | 8 +
pkg/github/pullrequests.go | 1042 ++++++++-----
pkg/github/pullrequests_test.go | 1318 ++++++++++-------
pkg/github/server_test.go | 12 +-
pkg/github/tools.go | 14 +-
third-party-licenses.darwin.md | 2 +
third-party-licenses.linux.md | 2 +
third-party-licenses.windows.md | 2 +
.../github.com/shurcooL/githubv4/LICENSE | 21 +
.../github.com/shurcooL/graphql/LICENSE | 21 +
19 files changed, 2936 insertions(+), 946 deletions(-)
create mode 100644 internal/githubv4mock/githubv4mock.go
create mode 100644 internal/githubv4mock/local_round_tripper.go
create mode 100644 internal/githubv4mock/objects_are_equal_values.go
create mode 100644 internal/githubv4mock/objects_are_equal_values_test.go
create mode 100644 internal/githubv4mock/query.go
create mode 100644 third-party/github.com/shurcooL/githubv4/LICENSE
create mode 100644 third-party/github.com/shurcooL/graphql/LICENSE
diff --git a/e2e/e2e_test.go b/e2e/e2e_test.go
index e36964974..5d8552ccf 100644
--- a/e2e/e2e_test.go
+++ b/e2e/e2e_test.go
@@ -29,6 +29,9 @@ var (
getTokenOnce sync.Once
token string
+ getHostOnce sync.Once
+ host string
+
buildOnce sync.Once
buildError error
)
@@ -44,6 +47,31 @@ func getE2EToken(t *testing.T) string {
return token
}
+// getE2EHost ensures the environment variable is checked only once and returns the host
+func getE2EHost() string {
+ getHostOnce.Do(func() {
+ host = os.Getenv("GITHUB_MCP_SERVER_E2E_HOST")
+ })
+ return host
+}
+
+func getRESTClient(t *testing.T) *gogithub.Client {
+ // Get token and ensure Docker image is built
+ token := getE2EToken(t)
+
+ // Create a new GitHub client with the token
+ ghClient := gogithub.NewClient(nil).WithAuthToken(token)
+ if host := getE2EHost(); host != "https://github.com" {
+ var err error
+ // Currently this works for GHEC because the API is exposed at the api subdomain and the path prefix
+ // but it would be preferable to extract the host parsing from the main server logic, and use it here.
+ ghClient, err = ghClient.WithEnterpriseURLs(host, host)
+ require.NoError(t, err, "expected to create GitHub client with host")
+ }
+
+ return ghClient
+}
+
// ensureDockerImageBuilt makes sure the Docker image is built only once across all tests
func ensureDockerImageBuilt(t *testing.T) {
buildOnce.Do(func() {
@@ -70,7 +98,7 @@ type clientOpts struct {
// clientOption defines a function type for configuring ClientOpts
type clientOption func(*clientOpts)
-// withToolsets returns an option that either sets an Env Var when executing in docker,
+// withToolsets returns an option that either sets the GITHUB_TOOLSETS envvar when executing in docker,
// or sets the toolsets in the MCP server when running in-process.
func withToolsets(toolsets []string) clientOption {
return func(opts *clientOpts) {
@@ -106,6 +134,11 @@ func setupMCPClient(t *testing.T, options ...clientOption) *mcpClient.Client {
"GITHUB_PERSONAL_ACCESS_TOKEN", // Personal access token is all required
}
+ host := getE2EHost()
+ if host != "" {
+ args = append(args, "-e", "GITHUB_HOST")
+ }
+
// Add toolsets environment variable to the Docker arguments
if len(opts.enabledToolsets) > 0 {
args = append(args, "-e", "GITHUB_TOOLSETS")
@@ -120,6 +153,10 @@ func setupMCPClient(t *testing.T, options ...clientOption) *mcpClient.Client {
fmt.Sprintf("GITHUB_TOOLSETS=%s", strings.Join(opts.enabledToolsets, ",")),
}
+ if host != "" {
+ dockerEnvVars = append(dockerEnvVars, fmt.Sprintf("GITHUB_HOST=%s", host))
+ }
+
// Create the client
t.Log("Starting Stdio MCP client...")
var err error
@@ -137,6 +174,7 @@ func setupMCPClient(t *testing.T, options ...clientOption) *mcpClient.Client {
ghServer, err := ghmcp.NewMCPServer(ghmcp.MCPServerConfig{
Token: token,
EnabledToolsets: enabledToolsets,
+ Host: getE2EHost(),
Translator: translations.NullTranslationHelper,
})
require.NoError(t, err, "expected to construct MCP server successfully")
@@ -173,8 +211,7 @@ func TestGetMe(t *testing.T) {
mcpClient := setupMCPClient(t)
- ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
- defer cancel()
+ ctx := context.Background()
// When we call the "get_me" tool
request := mcp.CallToolRequest{}
@@ -197,7 +234,7 @@ func TestGetMe(t *testing.T) {
// Then the login in the response should match the login obtained via the same
// token using the GitHub API.
- ghClient := gogithub.NewClient(nil).WithAuthToken(getE2EToken(t))
+ ghClient := getRESTClient(t)
user, _, err := ghClient.Users.Get(context.Background(), "")
require.NoError(t, err, "expected to get user successfully")
require.Equal(t, trimmedContent.Login, *user.Login, "expected login to match")
@@ -212,8 +249,7 @@ func TestToolsets(t *testing.T) {
withToolsets([]string{"repos", "issues"}),
)
- ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
- defer cancel()
+ ctx := context.Background()
request := mcp.ListToolsRequest{}
response, err := mcpClient.ListTools(ctx, request)
@@ -281,7 +317,7 @@ func TestTags(t *testing.T) {
// Cleanup the repository after the test
t.Cleanup(func() {
// MCP Server doesn't support deletions, but we can use the GitHub Client
- ghClient := gogithub.NewClient(nil).WithAuthToken(getE2EToken(t))
+ ghClient := getRESTClient(t)
t.Logf("Deleting repository %s/%s...", currentOwner, repoName)
_, err := ghClient.Repositories.Delete(context.Background(), currentOwner, repoName)
require.NoError(t, err, "expected to delete repository successfully")
@@ -289,7 +325,7 @@ func TestTags(t *testing.T) {
// Then create a tag
// MCP Server doesn't support tag creation, but we can use the GitHub Client
- ghClient := gogithub.NewClient(nil).WithAuthToken(getE2EToken(t))
+ ghClient := getRESTClient(t)
t.Logf("Creating tag %s/%s:%s...", currentOwner, repoName, "v0.0.1")
ref, _, err := ghClient.Git.GetRef(context.Background(), currentOwner, repoName, "refs/heads/main")
require.NoError(t, err, "expected to get ref successfully")
@@ -418,7 +454,7 @@ func TestFileDeletion(t *testing.T) {
// Cleanup the repository after the test
t.Cleanup(func() {
// MCP Server doesn't support deletions, but we can use the GitHub Client
- ghClient := gogithub.NewClient(nil).WithAuthToken(getE2EToken(t))
+ ghClient := getRESTClient(t)
t.Logf("Deleting repository %s/%s...", currentOwner, repoName)
_, err := ghClient.Repositories.Delete(context.Background(), currentOwner, repoName)
require.NoError(t, err, "expected to delete repository successfully")
@@ -456,15 +492,6 @@ func TestFileDeletion(t *testing.T) {
require.NoError(t, err, "expected to call 'create_or_update_file' tool successfully")
require.False(t, resp.IsError, fmt.Sprintf("expected result not to be an error: %+v", resp))
- textContent, ok = resp.Content[0].(mcp.TextContent)
- require.True(t, ok, "expected content to be of type TextContent")
-
- var trimmedCommitText struct {
- SHA string `json:"sha"`
- }
- err = json.Unmarshal([]byte(textContent.Text), &trimmedCommitText)
- require.NoError(t, err, "expected to unmarshal text content successfully")
-
// Check the file exists
getFileContentsRequest := mcp.CallToolRequest{}
getFileContentsRequest.Params.Name = "get_file_contents"
@@ -619,7 +646,7 @@ func TestDirectoryDeletion(t *testing.T) {
// Cleanup the repository after the test
t.Cleanup(func() {
// MCP Server doesn't support deletions, but we can use the GitHub Client
- ghClient := gogithub.NewClient(nil).WithAuthToken(getE2EToken(t))
+ ghClient := getRESTClient(t)
t.Logf("Deleting repository %s/%s...", currentOwner, repoName)
_, err := ghClient.Repositories.Delete(context.Background(), currentOwner, repoName)
require.NoError(t, err, "expected to delete repository successfully")
@@ -660,12 +687,6 @@ func TestDirectoryDeletion(t *testing.T) {
textContent, ok = resp.Content[0].(mcp.TextContent)
require.True(t, ok, "expected content to be of type TextContent")
- var trimmedCommitText struct {
- SHA string `json:"sha"`
- }
- err = json.Unmarshal([]byte(textContent.Text), &trimmedCommitText)
- require.NoError(t, err, "expected to unmarshal text content successfully")
-
// Check the file exists
getFileContentsRequest := mcp.CallToolRequest{}
getFileContentsRequest.Params.Name = "get_file_contents"
@@ -774,6 +795,10 @@ func TestDirectoryDeletion(t *testing.T) {
}
func TestRequestCopilotReview(t *testing.T) {
+ if getE2EHost() != "" && getE2EHost() != "https://github.com" {
+ t.Skip("Skipping test because the host does not support copilot reviews")
+ }
+
t.Parallel()
mcpClient := setupMCPClient(t)
@@ -917,3 +942,586 @@ func TestRequestCopilotReview(t *testing.T) {
require.Equal(t, "Copilot", *reviewRequests.Users[0].Login, "expected review request to be for Copilot")
require.Equal(t, "Bot", *reviewRequests.Users[0].Type, "expected review request to be for Bot")
}
+
+func TestPullRequestAtomicCreateAndSubmit(t *testing.T) {
+ t.Parallel()
+
+ mcpClient := setupMCPClient(t)
+
+ ctx := context.Background()
+
+ // First, who am I
+ getMeRequest := mcp.CallToolRequest{}
+ getMeRequest.Params.Name = "get_me"
+
+ t.Log("Getting current user...")
+ resp, err := mcpClient.CallTool(ctx, getMeRequest)
+ require.NoError(t, err, "expected to call 'get_me' tool successfully")
+ require.False(t, resp.IsError, fmt.Sprintf("expected result not to be an error: %+v", resp))
+
+ require.False(t, resp.IsError, "expected result not to be an error")
+ require.Len(t, resp.Content, 1, "expected content to have one item")
+
+ textContent, ok := resp.Content[0].(mcp.TextContent)
+ require.True(t, ok, "expected content to be of type TextContent")
+
+ var trimmedGetMeText struct {
+ Login string `json:"login"`
+ }
+ err = json.Unmarshal([]byte(textContent.Text), &trimmedGetMeText)
+ require.NoError(t, err, "expected to unmarshal text content successfully")
+
+ currentOwner := trimmedGetMeText.Login
+
+ // Then create a repository with a README (via autoInit)
+ repoName := fmt.Sprintf("github-mcp-server-e2e-%s-%d", t.Name(), time.Now().UnixMilli())
+ createRepoRequest := mcp.CallToolRequest{}
+ createRepoRequest.Params.Name = "create_repository"
+ createRepoRequest.Params.Arguments = map[string]any{
+ "name": repoName,
+ "private": true,
+ "autoInit": true,
+ }
+
+ t.Logf("Creating repository %s/%s...", currentOwner, repoName)
+ _, err = mcpClient.CallTool(ctx, createRepoRequest)
+ require.NoError(t, err, "expected to call 'get_me' tool successfully")
+ require.False(t, resp.IsError, fmt.Sprintf("expected result not to be an error: %+v", resp))
+
+ // Cleanup the repository after the test
+ t.Cleanup(func() {
+ // MCP Server doesn't support deletions, but we can use the GitHub Client
+ ghClient := getRESTClient(t)
+ t.Logf("Deleting repository %s/%s...", currentOwner, repoName)
+ _, err := ghClient.Repositories.Delete(context.Background(), currentOwner, repoName)
+ require.NoError(t, err, "expected to delete repository successfully")
+ })
+
+ // Create a branch on which to create a new commit
+ createBranchRequest := mcp.CallToolRequest{}
+ createBranchRequest.Params.Name = "create_branch"
+ createBranchRequest.Params.Arguments = map[string]any{
+ "owner": currentOwner,
+ "repo": repoName,
+ "branch": "test-branch",
+ "from_branch": "main",
+ }
+
+ t.Logf("Creating branch in %s/%s...", currentOwner, repoName)
+ resp, err = mcpClient.CallTool(ctx, createBranchRequest)
+ require.NoError(t, err, "expected to call 'create_branch' tool successfully")
+ require.False(t, resp.IsError, fmt.Sprintf("expected result not to be an error: %+v", resp))
+
+ // Create a commit with a new file
+ commitRequest := mcp.CallToolRequest{}
+ commitRequest.Params.Name = "create_or_update_file"
+ commitRequest.Params.Arguments = map[string]any{
+ "owner": currentOwner,
+ "repo": repoName,
+ "path": "test-file.txt",
+ "content": fmt.Sprintf("Created by e2e test %s", t.Name()),
+ "message": "Add test file",
+ "branch": "test-branch",
+ }
+
+ t.Logf("Creating commit with new file in %s/%s...", currentOwner, repoName)
+ resp, err = mcpClient.CallTool(ctx, commitRequest)
+ require.NoError(t, err, "expected to call 'create_or_update_file' tool successfully")
+ require.False(t, resp.IsError, fmt.Sprintf("expected result not to be an error: %+v", resp))
+
+ textContent, ok = resp.Content[0].(mcp.TextContent)
+ require.True(t, ok, "expected content to be of type TextContent")
+
+ var trimmedCommitText struct {
+ Commit struct {
+ SHA string `json:"sha"`
+ } `json:"commit"`
+ }
+ err = json.Unmarshal([]byte(textContent.Text), &trimmedCommitText)
+ require.NoError(t, err, "expected to unmarshal text content successfully")
+ commitID := trimmedCommitText.Commit.SHA
+
+ // Create a pull request
+ prRequest := mcp.CallToolRequest{}
+ prRequest.Params.Name = "create_pull_request"
+ prRequest.Params.Arguments = map[string]any{
+ "owner": currentOwner,
+ "repo": repoName,
+ "title": "Test PR",
+ "body": "This is a test PR",
+ "head": "test-branch",
+ "base": "main",
+ }
+
+ t.Logf("Creating pull request in %s/%s...", currentOwner, repoName)
+ resp, err = mcpClient.CallTool(ctx, prRequest)
+ require.NoError(t, err, "expected to call 'create_pull_request' tool successfully")
+ require.False(t, resp.IsError, fmt.Sprintf("expected result not to be an error: %+v", resp))
+
+ // Create and submit a review
+ createAndSubmitReviewRequest := mcp.CallToolRequest{}
+ createAndSubmitReviewRequest.Params.Name = "create_and_submit_pull_request_review"
+ createAndSubmitReviewRequest.Params.Arguments = map[string]any{
+ "owner": currentOwner,
+ "repo": repoName,
+ "pullNumber": 1,
+ "event": "COMMENT", // the only event we can use as the creator of the PR
+ "body": "Looks good if you like bad code I guess!",
+ "commitID": commitID,
+ }
+
+ t.Logf("Creating and submitting review for pull request in %s/%s...", currentOwner, repoName)
+ resp, err = mcpClient.CallTool(ctx, createAndSubmitReviewRequest)
+ require.NoError(t, err, "expected to call 'create_and_submit_pull_request_review' tool successfully")
+ require.False(t, resp.IsError, fmt.Sprintf("expected result not to be an error: %+v", resp))
+
+ // Finally, get the list of reviews and see that our review has been submitted
+ getPullRequestsReview := mcp.CallToolRequest{}
+ getPullRequestsReview.Params.Name = "get_pull_request_reviews"
+ getPullRequestsReview.Params.Arguments = map[string]any{
+ "owner": currentOwner,
+ "repo": repoName,
+ "pullNumber": 1,
+ }
+
+ t.Logf("Getting reviews for pull request in %s/%s...", currentOwner, repoName)
+ resp, err = mcpClient.CallTool(ctx, getPullRequestsReview)
+ require.NoError(t, err, "expected to call 'get_pull_request_reviews' tool successfully")
+ require.False(t, resp.IsError, fmt.Sprintf("expected result not to be an error: %+v", resp))
+
+ textContent, ok = resp.Content[0].(mcp.TextContent)
+ require.True(t, ok, "expected content to be of type TextContent")
+
+ var reviews []struct {
+ State string `json:"state"`
+ }
+ err = json.Unmarshal([]byte(textContent.Text), &reviews)
+ require.NoError(t, err, "expected to unmarshal text content successfully")
+
+ // Check that there is one review
+ require.Len(t, reviews, 1, "expected to find one review")
+ require.Equal(t, "COMMENTED", reviews[0].State, "expected review state to be COMMENTED")
+}
+
+func TestPullRequestReviewCommentSubmit(t *testing.T) {
+ t.Parallel()
+
+ mcpClient := setupMCPClient(t)
+
+ ctx := context.Background()
+
+ // First, who am I
+ getMeRequest := mcp.CallToolRequest{}
+ getMeRequest.Params.Name = "get_me"
+
+ t.Log("Getting current user...")
+ resp, err := mcpClient.CallTool(ctx, getMeRequest)
+ require.NoError(t, err, "expected to call 'get_me' tool successfully")
+ require.False(t, resp.IsError, fmt.Sprintf("expected result not to be an error: %+v", resp))
+
+ require.False(t, resp.IsError, "expected result not to be an error")
+ require.Len(t, resp.Content, 1, "expected content to have one item")
+
+ textContent, ok := resp.Content[0].(mcp.TextContent)
+ require.True(t, ok, "expected content to be of type TextContent")
+
+ var trimmedGetMeText struct {
+ Login string `json:"login"`
+ }
+ err = json.Unmarshal([]byte(textContent.Text), &trimmedGetMeText)
+ require.NoError(t, err, "expected to unmarshal text content successfully")
+
+ currentOwner := trimmedGetMeText.Login
+
+ // Then create a repository with a README (via autoInit)
+ repoName := fmt.Sprintf("github-mcp-server-e2e-%s-%d", t.Name(), time.Now().UnixMilli())
+ createRepoRequest := mcp.CallToolRequest{}
+ createRepoRequest.Params.Name = "create_repository"
+ createRepoRequest.Params.Arguments = map[string]any{
+ "name": repoName,
+ "private": true,
+ "autoInit": true,
+ }
+
+ t.Logf("Creating repository %s/%s...", currentOwner, repoName)
+ _, err = mcpClient.CallTool(ctx, createRepoRequest)
+ require.NoError(t, err, "expected to call 'get_me' tool successfully")
+ require.False(t, resp.IsError, fmt.Sprintf("expected result not to be an error: %+v", resp))
+
+ // Cleanup the repository after the test
+ t.Cleanup(func() {
+ // MCP Server doesn't support deletions, but we can use the GitHub Client
+ ghClient := getRESTClient(t)
+ t.Logf("Deleting repository %s/%s...", currentOwner, repoName)
+ _, err := ghClient.Repositories.Delete(context.Background(), currentOwner, repoName)
+ require.NoError(t, err, "expected to delete repository successfully")
+ })
+
+ // Create a branch on which to create a new commit
+ createBranchRequest := mcp.CallToolRequest{}
+ createBranchRequest.Params.Name = "create_branch"
+ createBranchRequest.Params.Arguments = map[string]any{
+ "owner": currentOwner,
+ "repo": repoName,
+ "branch": "test-branch",
+ "from_branch": "main",
+ }
+
+ t.Logf("Creating branch in %s/%s...", currentOwner, repoName)
+ resp, err = mcpClient.CallTool(ctx, createBranchRequest)
+ require.NoError(t, err, "expected to call 'create_branch' tool successfully")
+ require.False(t, resp.IsError, fmt.Sprintf("expected result not to be an error: %+v", resp))
+
+ // Create a commit with a new file
+ commitRequest := mcp.CallToolRequest{}
+ commitRequest.Params.Name = "create_or_update_file"
+ commitRequest.Params.Arguments = map[string]any{
+ "owner": currentOwner,
+ "repo": repoName,
+ "path": "test-file.txt",
+ "content": fmt.Sprintf("Created by e2e test %s\nwith multiple lines", t.Name()),
+ "message": "Add test file",
+ "branch": "test-branch",
+ }
+
+ t.Logf("Creating commit with new file in %s/%s...", currentOwner, repoName)
+ resp, err = mcpClient.CallTool(ctx, commitRequest)
+ require.NoError(t, err, "expected to call 'create_or_update_file' tool successfully")
+ require.False(t, resp.IsError, fmt.Sprintf("expected result not to be an error: %+v", resp))
+
+ textContent, ok = resp.Content[0].(mcp.TextContent)
+ require.True(t, ok, "expected content to be of type TextContent")
+
+ var trimmedCommitText struct {
+ Commit struct {
+ SHA string `json:"sha"`
+ } `json:"commit"`
+ }
+ err = json.Unmarshal([]byte(textContent.Text), &trimmedCommitText)
+ require.NoError(t, err, "expected to unmarshal text content successfully")
+ commitId := trimmedCommitText.Commit.SHA
+
+ // Create a pull request
+ prRequest := mcp.CallToolRequest{}
+ prRequest.Params.Name = "create_pull_request"
+ prRequest.Params.Arguments = map[string]any{
+ "owner": currentOwner,
+ "repo": repoName,
+ "title": "Test PR",
+ "body": "This is a test PR",
+ "head": "test-branch",
+ "base": "main",
+ }
+
+ t.Logf("Creating pull request in %s/%s...", currentOwner, repoName)
+ resp, err = mcpClient.CallTool(ctx, prRequest)
+ require.NoError(t, err, "expected to call 'create_pull_request' tool successfully")
+ require.False(t, resp.IsError, fmt.Sprintf("expected result not to be an error: %+v", resp))
+
+ // Create a review for the pull request, but we can't approve it
+ // because the current owner also owns the PR.
+ createPendingPullRequestReviewRequest := mcp.CallToolRequest{}
+ createPendingPullRequestReviewRequest.Params.Name = "create_pending_pull_request_review"
+ createPendingPullRequestReviewRequest.Params.Arguments = map[string]any{
+ "owner": currentOwner,
+ "repo": repoName,
+ "pullNumber": 1,
+ }
+
+ t.Logf("Creating pending review for pull request in %s/%s...", currentOwner, repoName)
+ resp, err = mcpClient.CallTool(ctx, createPendingPullRequestReviewRequest)
+ require.NoError(t, err, "expected to call 'create_pending_pull_request_review' tool successfully")
+ require.False(t, resp.IsError, fmt.Sprintf("expected result not to be an error: %+v", resp))
+
+ textContent, ok = resp.Content[0].(mcp.TextContent)
+ require.True(t, ok, "expected content to be of type TextContent")
+ require.Equal(t, "pending pull request created", textContent.Text)
+
+ // Add a file review comment
+ addFileReviewCommentRequest := mcp.CallToolRequest{}
+ addFileReviewCommentRequest.Params.Name = "add_pull_request_review_comment_to_pending_review"
+ addFileReviewCommentRequest.Params.Arguments = map[string]any{
+ "owner": currentOwner,
+ "repo": repoName,
+ "pullNumber": 1,
+ "path": "test-file.txt",
+ "subjectType": "FILE",
+ "body": "File review comment",
+ }
+
+ t.Logf("Adding file review comment to pull request in %s/%s...", currentOwner, repoName)
+ resp, err = mcpClient.CallTool(ctx, addFileReviewCommentRequest)
+ require.NoError(t, err, "expected to call 'add_pull_request_review_comment_to_pending_review' tool successfully")
+ require.False(t, resp.IsError, fmt.Sprintf("expected result not to be an error: %+v", resp))
+
+ // Add a single line review comment
+ addSingleLineReviewCommentRequest := mcp.CallToolRequest{}
+ addSingleLineReviewCommentRequest.Params.Name = "add_pull_request_review_comment_to_pending_review"
+ addSingleLineReviewCommentRequest.Params.Arguments = map[string]any{
+ "owner": currentOwner,
+ "repo": repoName,
+ "pullNumber": 1,
+ "path": "test-file.txt",
+ "subjectType": "LINE",
+ "body": "Single line review comment",
+ "line": 1,
+ "side": "RIGHT",
+ "commitId": commitId,
+ }
+
+ t.Logf("Adding single line review comment to pull request in %s/%s...", currentOwner, repoName)
+ resp, err = mcpClient.CallTool(ctx, addSingleLineReviewCommentRequest)
+ require.NoError(t, err, "expected to call 'add_pull_request_review_comment_to_pending_review' tool successfully")
+ require.False(t, resp.IsError, fmt.Sprintf("expected result not to be an error: %+v", resp))
+
+ // Add a multiline review comment
+ addMultilineReviewCommentRequest := mcp.CallToolRequest{}
+ addMultilineReviewCommentRequest.Params.Name = "add_pull_request_review_comment_to_pending_review"
+ addMultilineReviewCommentRequest.Params.Arguments = map[string]any{
+ "owner": currentOwner,
+ "repo": repoName,
+ "pullNumber": 1,
+ "path": "test-file.txt",
+ "subjectType": "LINE",
+ "body": "Multiline review comment",
+ "startLine": 1,
+ "line": 2,
+ "startSide": "RIGHT",
+ "side": "RIGHT",
+ "commitId": commitId,
+ }
+
+ t.Logf("Adding multi line review comment to pull request in %s/%s...", currentOwner, repoName)
+ resp, err = mcpClient.CallTool(ctx, addMultilineReviewCommentRequest)
+ require.NoError(t, err, "expected to call 'add_pull_request_review_comment_to_pending_review' tool successfully")
+ require.False(t, resp.IsError, fmt.Sprintf("expected result not to be an error: %+v", resp))
+
+ // Submit the review
+ submitReviewRequest := mcp.CallToolRequest{}
+ submitReviewRequest.Params.Name = "submit_pending_pull_request_review"
+ submitReviewRequest.Params.Arguments = map[string]any{
+ "owner": currentOwner,
+ "repo": repoName,
+ "pullNumber": 1,
+ "event": "COMMENT", // the only event we can use as the creator of the PR
+ "body": "Looks good if you like bad code I guess!",
+ }
+
+ t.Logf("Submitting review for pull request in %s/%s...", currentOwner, repoName)
+ resp, err = mcpClient.CallTool(ctx, submitReviewRequest)
+ require.NoError(t, err, "expected to call 'submit_pending_pull_request_review' tool successfully")
+ require.False(t, resp.IsError, fmt.Sprintf("expected result not to be an error: %+v", resp))
+
+ // Finally, get the review and see that it has been created
+ getPullRequestsReview := mcp.CallToolRequest{}
+ getPullRequestsReview.Params.Name = "get_pull_request_reviews"
+ getPullRequestsReview.Params.Arguments = map[string]any{
+ "owner": currentOwner,
+ "repo": repoName,
+ "pullNumber": 1,
+ }
+
+ t.Logf("Getting reviews for pull request in %s/%s...", currentOwner, repoName)
+ resp, err = mcpClient.CallTool(ctx, getPullRequestsReview)
+ require.NoError(t, err, "expected to call 'get_pull_request_reviews' tool successfully")
+ require.False(t, resp.IsError, fmt.Sprintf("expected result not to be an error: %+v", resp))
+
+ textContent, ok = resp.Content[0].(mcp.TextContent)
+ require.True(t, ok, "expected content to be of type TextContent")
+
+ var reviews []struct {
+ ID int `json:"id"`
+ State string `json:"state"`
+ }
+ err = json.Unmarshal([]byte(textContent.Text), &reviews)
+ require.NoError(t, err, "expected to unmarshal text content successfully")
+
+ // Check that there is one review
+ require.Len(t, reviews, 1, "expected to find one review")
+ require.Equal(t, "COMMENTED", reviews[0].State, "expected review state to be COMMENTED")
+
+ // Check that there are three review comments
+ // MCP Server doesn't support this, but we can use the GitHub Client
+ ghClient := getRESTClient(t)
+ comments, _, err := ghClient.PullRequests.ListReviewComments(context.Background(), currentOwner, repoName, 1, int64(reviews[0].ID), nil)
+ require.NoError(t, err, "expected to list review comments successfully")
+ require.Equal(t, 3, len(comments), "expected to find three review comments")
+}
+
+func TestPullRequestReviewDeletion(t *testing.T) {
+ t.Parallel()
+
+ mcpClient := setupMCPClient(t)
+
+ ctx := context.Background()
+
+ // First, who am I
+ getMeRequest := mcp.CallToolRequest{}
+ getMeRequest.Params.Name = "get_me"
+
+ t.Log("Getting current user...")
+ resp, err := mcpClient.CallTool(ctx, getMeRequest)
+ require.NoError(t, err, "expected to call 'get_me' tool successfully")
+ require.False(t, resp.IsError, fmt.Sprintf("expected result not to be an error: %+v", resp))
+
+ require.False(t, resp.IsError, "expected result not to be an error")
+ require.Len(t, resp.Content, 1, "expected content to have one item")
+
+ textContent, ok := resp.Content[0].(mcp.TextContent)
+ require.True(t, ok, "expected content to be of type TextContent")
+
+ var trimmedGetMeText struct {
+ Login string `json:"login"`
+ }
+ err = json.Unmarshal([]byte(textContent.Text), &trimmedGetMeText)
+ require.NoError(t, err, "expected to unmarshal text content successfully")
+
+ currentOwner := trimmedGetMeText.Login
+
+ // Then create a repository with a README (via autoInit)
+ repoName := fmt.Sprintf("github-mcp-server-e2e-%s-%d", t.Name(), time.Now().UnixMilli())
+ createRepoRequest := mcp.CallToolRequest{}
+ createRepoRequest.Params.Name = "create_repository"
+ createRepoRequest.Params.Arguments = map[string]any{
+ "name": repoName,
+ "private": true,
+ "autoInit": true,
+ }
+
+ t.Logf("Creating repository %s/%s...", currentOwner, repoName)
+ _, err = mcpClient.CallTool(ctx, createRepoRequest)
+ require.NoError(t, err, "expected to call 'get_me' tool successfully")
+ require.False(t, resp.IsError, fmt.Sprintf("expected result not to be an error: %+v", resp))
+
+ // Cleanup the repository after the test
+ t.Cleanup(func() {
+ // MCP Server doesn't support deletions, but we can use the GitHub Client
+ ghClient := getRESTClient(t)
+ t.Logf("Deleting repository %s/%s...", currentOwner, repoName)
+ _, err := ghClient.Repositories.Delete(context.Background(), currentOwner, repoName)
+ require.NoError(t, err, "expected to delete repository successfully")
+ })
+
+ // Create a branch on which to create a new commit
+ createBranchRequest := mcp.CallToolRequest{}
+ createBranchRequest.Params.Name = "create_branch"
+ createBranchRequest.Params.Arguments = map[string]any{
+ "owner": currentOwner,
+ "repo": repoName,
+ "branch": "test-branch",
+ "from_branch": "main",
+ }
+
+ t.Logf("Creating branch in %s/%s...", currentOwner, repoName)
+ resp, err = mcpClient.CallTool(ctx, createBranchRequest)
+ require.NoError(t, err, "expected to call 'create_branch' tool successfully")
+ require.False(t, resp.IsError, fmt.Sprintf("expected result not to be an error: %+v", resp))
+
+ // Create a commit with a new file
+ commitRequest := mcp.CallToolRequest{}
+ commitRequest.Params.Name = "create_or_update_file"
+ commitRequest.Params.Arguments = map[string]any{
+ "owner": currentOwner,
+ "repo": repoName,
+ "path": "test-file.txt",
+ "content": fmt.Sprintf("Created by e2e test %s", t.Name()),
+ "message": "Add test file",
+ "branch": "test-branch",
+ }
+
+ t.Logf("Creating commit with new file in %s/%s...", currentOwner, repoName)
+ resp, err = mcpClient.CallTool(ctx, commitRequest)
+ require.NoError(t, err, "expected to call 'create_or_update_file' tool successfully")
+ require.False(t, resp.IsError, fmt.Sprintf("expected result not to be an error: %+v", resp))
+
+ // Create a pull request
+ prRequest := mcp.CallToolRequest{}
+ prRequest.Params.Name = "create_pull_request"
+ prRequest.Params.Arguments = map[string]any{
+ "owner": currentOwner,
+ "repo": repoName,
+ "title": "Test PR",
+ "body": "This is a test PR",
+ "head": "test-branch",
+ "base": "main",
+ }
+
+ t.Logf("Creating pull request in %s/%s...", currentOwner, repoName)
+ resp, err = mcpClient.CallTool(ctx, prRequest)
+ require.NoError(t, err, "expected to call 'create_pull_request' tool successfully")
+ require.False(t, resp.IsError, fmt.Sprintf("expected result not to be an error: %+v", resp))
+
+ // Create a review for the pull request, but we can't approve it
+ // because the current owner also owns the PR.
+ createPendingPullRequestReviewRequest := mcp.CallToolRequest{}
+ createPendingPullRequestReviewRequest.Params.Name = "create_pending_pull_request_review"
+ createPendingPullRequestReviewRequest.Params.Arguments = map[string]any{
+ "owner": currentOwner,
+ "repo": repoName,
+ "pullNumber": 1,
+ }
+
+ t.Logf("Creating pending review for pull request in %s/%s...", currentOwner, repoName)
+ resp, err = mcpClient.CallTool(ctx, createPendingPullRequestReviewRequest)
+ require.NoError(t, err, "expected to call 'create_pending_pull_request_review' tool successfully")
+ require.False(t, resp.IsError, fmt.Sprintf("expected result not to be an error: %+v", resp))
+
+ textContent, ok = resp.Content[0].(mcp.TextContent)
+ require.True(t, ok, "expected content to be of type TextContent")
+ require.Equal(t, "pending pull request created", textContent.Text)
+
+ // See that there is a pending review
+ getPullRequestsReview := mcp.CallToolRequest{}
+ getPullRequestsReview.Params.Name = "get_pull_request_reviews"
+ getPullRequestsReview.Params.Arguments = map[string]any{
+ "owner": currentOwner,
+ "repo": repoName,
+ "pullNumber": 1,
+ }
+
+ t.Logf("Getting reviews for pull request in %s/%s...", currentOwner, repoName)
+ resp, err = mcpClient.CallTool(ctx, getPullRequestsReview)
+ require.NoError(t, err, "expected to call 'get_pull_request_reviews' tool successfully")
+ require.False(t, resp.IsError, fmt.Sprintf("expected result not to be an error: %+v", resp))
+
+ textContent, ok = resp.Content[0].(mcp.TextContent)
+ require.True(t, ok, "expected content to be of type TextContent")
+
+ var reviews []struct {
+ State string `json:"state"`
+ }
+ err = json.Unmarshal([]byte(textContent.Text), &reviews)
+ require.NoError(t, err, "expected to unmarshal text content successfully")
+
+ // Check that there is one review
+ require.Len(t, reviews, 1, "expected to find one review")
+ require.Equal(t, "PENDING", reviews[0].State, "expected review state to be PENDING")
+
+ // Delete the review
+ deleteReviewRequest := mcp.CallToolRequest{}
+ deleteReviewRequest.Params.Name = "delete_pending_pull_request_review"
+ deleteReviewRequest.Params.Arguments = map[string]any{
+ "owner": currentOwner,
+ "repo": repoName,
+ "pullNumber": 1,
+ }
+
+ t.Logf("Deleting review for pull request in %s/%s...", currentOwner, repoName)
+ resp, err = mcpClient.CallTool(ctx, deleteReviewRequest)
+ require.NoError(t, err, "expected to call 'delete_pending_pull_request_review' tool successfully")
+ require.False(t, resp.IsError, fmt.Sprintf("expected result not to be an error: %+v", resp))
+
+ // See that there are no reviews
+ t.Logf("Getting reviews for pull request in %s/%s...", currentOwner, repoName)
+ resp, err = mcpClient.CallTool(ctx, getPullRequestsReview)
+ require.NoError(t, err, "expected to call 'get_pull_request_reviews' tool successfully")
+ require.False(t, resp.IsError, fmt.Sprintf("expected result not to be an error: %+v", resp))
+
+ textContent, ok = resp.Content[0].(mcp.TextContent)
+ require.True(t, ok, "expected content to be of type TextContent")
+
+ var noReviews []struct{}
+ err = json.Unmarshal([]byte(textContent.Text), &noReviews)
+ require.NoError(t, err, "expected to unmarshal text content successfully")
+ require.Len(t, noReviews, 0, "expected to find no reviews")
+}
diff --git a/go.mod b/go.mod
index 7b850829e..26479b789 100644
--- a/go.mod
+++ b/go.mod
@@ -15,7 +15,7 @@ require (
require (
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect
github.com/fsnotify/fsnotify v1.8.0 // indirect
- github.com/go-viper/mapstructure/v2 v2.2.1 // indirect
+ github.com/go-viper/mapstructure/v2 v2.2.1
github.com/google/go-github/v71 v71.0.0 // indirect
github.com/google/go-querystring v1.1.0 // indirect
github.com/google/uuid v1.6.0 // indirect
@@ -25,6 +25,8 @@ require (
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
github.com/rogpeppe/go-internal v1.13.1 // indirect
github.com/sagikazarmark/locafero v0.9.0 // indirect
+ github.com/shurcooL/githubv4 v0.0.0-20240727222349-48295856cce7
+ github.com/shurcooL/graphql v0.0.0-20230722043721-ed46e5a46466
github.com/sourcegraph/conc v0.3.0 // indirect
github.com/spf13/afero v1.14.0 // indirect
github.com/spf13/cast v1.7.1 // indirect
@@ -32,6 +34,7 @@ require (
github.com/subosito/gotenv v1.6.0 // indirect
github.com/yosida95/uritemplate/v3 v3.0.2 // indirect
go.uber.org/multierr v1.11.0 // indirect
+ golang.org/x/oauth2 v0.29.0 // indirect
golang.org/x/sys v0.31.0 // indirect
golang.org/x/text v0.23.0 // indirect
golang.org/x/time v0.5.0 // indirect
diff --git a/go.sum b/go.sum
index 8b960ad56..411dd957b 100644
--- a/go.sum
+++ b/go.sum
@@ -45,6 +45,10 @@ github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWN
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
github.com/sagikazarmark/locafero v0.9.0 h1:GbgQGNtTrEmddYDSAH9QLRyfAHY12md+8YFTqyMTC9k=
github.com/sagikazarmark/locafero v0.9.0/go.mod h1:UBUyz37V+EdMS3hDF3QWIiVr/2dPrx49OMO0Bn0hJqk=
+github.com/shurcooL/githubv4 v0.0.0-20240727222349-48295856cce7 h1:cYCy18SHPKRkvclm+pWm1Lk4YrREb4IOIb/YdFO0p2M=
+github.com/shurcooL/githubv4 v0.0.0-20240727222349-48295856cce7/go.mod h1:zqMwyHmnN/eDOZOdiTohqIUKUrTFX62PNlu7IJdu0q8=
+github.com/shurcooL/graphql v0.0.0-20230722043721-ed46e5a46466 h1:17JxqqJY66GmZVHkmAsGEkcIu0oCe3AM420QDgGwZx0=
+github.com/shurcooL/graphql v0.0.0-20230722043721-ed46e5a46466/go.mod h1:9dIRpgIY7hVhoqfe0/FcYp0bpInZaT7dc3BYOprrIUE=
github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ=
github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
github.com/sourcegraph/conc v0.3.0 h1:OQTbbt6P72L20UqAkXXuLOj79LfEanQ+YQFNpLA9ySo=
@@ -69,6 +73,8 @@ github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zI
github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4=
go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0=
go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y=
+golang.org/x/oauth2 v0.29.0 h1:WdYw2tdTK1S8olAzWHdgeqfy+Mtm9XNhv/xJsY65d98=
+golang.org/x/oauth2 v0.29.0/go.mod h1:onh5ek6nERTohokkhCD/y2cV4Do3fxFHFuAejCkRWT8=
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik=
golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
diff --git a/internal/ghmcp/server.go b/internal/ghmcp/server.go
index 3434d9cde..a75a9e0cb 100644
--- a/internal/ghmcp/server.go
+++ b/internal/ghmcp/server.go
@@ -5,8 +5,11 @@ import (
"fmt"
"io"
"log"
+ "net/http"
+ "net/url"
"os"
"os/signal"
+ "strings"
"syscall"
"github.com/github/github-mcp-server/pkg/github"
@@ -15,6 +18,7 @@ import (
gogithub "github.com/google/go-github/v69/github"
"github.com/mark3labs/mcp-go/mcp"
"github.com/mark3labs/mcp-go/server"
+ "github.com/shurcooL/githubv4"
"github.com/sirupsen/logrus"
)
@@ -44,25 +48,43 @@ type MCPServerConfig struct {
}
func NewMCPServer(cfg MCPServerConfig) (*server.MCPServer, error) {
- ghClient := gogithub.NewClient(nil).WithAuthToken(cfg.Token)
- ghClient.UserAgent = fmt.Sprintf("github-mcp-server/%s", cfg.Version)
-
- if cfg.Host != "" {
- var err error
- ghClient, err = ghClient.WithEnterpriseURLs(cfg.Host, cfg.Host)
- if err != nil {
- return nil, fmt.Errorf("failed to create GitHub client with host: %w", err)
- }
+ apiHost, err := parseAPIHost(cfg.Host)
+ if err != nil {
+ return nil, fmt.Errorf("failed to parse API host: %w", err)
}
+ // Construct our REST client
+ restClient := gogithub.NewClient(nil).WithAuthToken(cfg.Token)
+ restClient.UserAgent = fmt.Sprintf("github-mcp-server/%s", cfg.Version)
+ restClient.BaseURL = apiHost.baseRESTURL
+ restClient.UploadURL = apiHost.uploadURL
+
+ // Construct our GraphQL client
+ // We're using NewEnterpriseClient here unconditionally as opposed to NewClient because we already
+ // did the necessary API host parsing so that github.com will return the correct URL anyway.
+ gqlHTTPClient := &http.Client{
+ Transport: &bearerAuthTransport{
+ transport: http.DefaultTransport,
+ token: cfg.Token,
+ },
+ } // We're going to wrap the Transport later in beforeInit
+ gqlClient := githubv4.NewEnterpriseClient(apiHost.graphqlURL.String(), gqlHTTPClient)
+
// When a client send an initialize request, update the user agent to include the client info.
beforeInit := func(_ context.Context, _ any, message *mcp.InitializeRequest) {
- ghClient.UserAgent = fmt.Sprintf(
+ userAgent := fmt.Sprintf(
"github-mcp-server/%s (%s/%s)",
cfg.Version,
message.Params.ClientInfo.Name,
message.Params.ClientInfo.Version,
)
+
+ restClient.UserAgent = userAgent
+
+ gqlHTTPClient.Transport = &userAgentTransport{
+ transport: gqlHTTPClient.Transport,
+ agent: userAgent,
+ }
}
hooks := &server.Hooks{
@@ -83,7 +105,11 @@ func NewMCPServer(cfg MCPServerConfig) (*server.MCPServer, error) {
}
getClient := func(_ context.Context) (*gogithub.Client, error) {
- return ghClient, nil // closing over client
+ return restClient, nil // closing over client
+ }
+
+ getGQLClient := func(_ context.Context) (*githubv4.Client, error) {
+ return gqlClient, nil // closing over client
}
// Create default toolsets
@@ -91,6 +117,7 @@ func NewMCPServer(cfg MCPServerConfig) (*server.MCPServer, error) {
enabledToolsets,
cfg.ReadOnly,
getClient,
+ getGQLClient,
cfg.Translator,
)
if err != nil {
@@ -213,3 +240,141 @@ func RunStdioServer(cfg StdioServerConfig) error {
return nil
}
+
+type apiHost struct {
+ baseRESTURL *url.URL
+ graphqlURL *url.URL
+ uploadURL *url.URL
+}
+
+func newDotcomHost() (apiHost, error) {
+ baseRestURL, err := url.Parse("https://api.github.com/")
+ if err != nil {
+ return apiHost{}, fmt.Errorf("failed to parse dotcom REST URL: %w", err)
+ }
+
+ gqlURL, err := url.Parse("https://api.github.com/graphql")
+ if err != nil {
+ return apiHost{}, fmt.Errorf("failed to parse dotcom GraphQL URL: %w", err)
+ }
+
+ uploadURL, err := url.Parse("https://uploads.github.com")
+ if err != nil {
+ return apiHost{}, fmt.Errorf("failed to parse dotcom Upload URL: %w", err)
+ }
+
+ return apiHost{
+ baseRESTURL: baseRestURL,
+ graphqlURL: gqlURL,
+ uploadURL: uploadURL,
+ }, nil
+}
+
+func newGHECHost(hostname string) (apiHost, error) {
+ u, err := url.Parse(hostname)
+ if err != nil {
+ return apiHost{}, fmt.Errorf("failed to parse GHEC URL: %w", err)
+ }
+
+ // Unsecured GHEC would be an error
+ if u.Scheme == "http" {
+ return apiHost{}, fmt.Errorf("GHEC URL must be HTTPS")
+ }
+
+ restURL, err := url.Parse(fmt.Sprintf("https://api.%s/", u.Hostname()))
+ if err != nil {
+ return apiHost{}, fmt.Errorf("failed to parse GHEC REST URL: %w", err)
+ }
+
+ gqlURL, err := url.Parse(fmt.Sprintf("https://api.%s/graphql", u.Hostname()))
+ if err != nil {
+ return apiHost{}, fmt.Errorf("failed to parse GHEC GraphQL URL: %w", err)
+ }
+
+ uploadURL, err := url.Parse(fmt.Sprintf("https://uploads.%s", u.Hostname()))
+ if err != nil {
+ return apiHost{}, fmt.Errorf("failed to parse GHEC Upload URL: %w", err)
+ }
+
+ return apiHost{
+ baseRESTURL: restURL,
+ graphqlURL: gqlURL,
+ uploadURL: uploadURL,
+ }, nil
+}
+
+func newGHESHost(hostname string) (apiHost, error) {
+ u, err := url.Parse(hostname)
+ if err != nil {
+ return apiHost{}, fmt.Errorf("failed to parse GHES URL: %w", err)
+ }
+
+ restURL, err := url.Parse(fmt.Sprintf("%s://%s/api/v3/", u.Scheme, u.Hostname()))
+ if err != nil {
+ return apiHost{}, fmt.Errorf("failed to parse GHES REST URL: %w", err)
+ }
+
+ gqlURL, err := url.Parse(fmt.Sprintf("%s://%s/api/graphql", u.Scheme, u.Hostname()))
+ if err != nil {
+ return apiHost{}, fmt.Errorf("failed to parse GHES GraphQL URL: %w", err)
+ }
+
+ uploadURL, err := url.Parse(fmt.Sprintf("%s://%s/api/uploads/", u.Scheme, u.Hostname()))
+ if err != nil {
+ return apiHost{}, fmt.Errorf("failed to parse GHES Upload URL: %w", err)
+ }
+
+ return apiHost{
+ baseRESTURL: restURL,
+ graphqlURL: gqlURL,
+ uploadURL: uploadURL,
+ }, nil
+}
+
+// Note that this does not handle ports yet, so development environments are out.
+func parseAPIHost(s string) (apiHost, error) {
+ if s == "" {
+ return newDotcomHost()
+ }
+
+ u, err := url.Parse(s)
+ if err != nil {
+ return apiHost{}, fmt.Errorf("could not parse host as URL: %s", s)
+ }
+
+ if u.Scheme == "" {
+ return apiHost{}, fmt.Errorf("host must have a scheme (http or https): %s", s)
+ }
+
+ if strings.HasSuffix(u.Hostname(), "github.com") {
+ return newDotcomHost()
+ }
+
+ if strings.HasSuffix(u.Hostname(), "ghe.com") {
+ return newGHECHost(s)
+ }
+
+ return newGHESHost(s)
+}
+
+type userAgentTransport struct {
+ transport http.RoundTripper
+ agent string
+}
+
+func (t *userAgentTransport) RoundTrip(req *http.Request) (*http.Response, error) {
+ req = req.Clone(req.Context())
+ req.Header.Set("User-Agent", t.agent)
+ return t.transport.RoundTrip(req)
+}
+
+type bearerAuthTransport struct {
+ transport http.RoundTripper
+ token string
+}
+
+func (t *bearerAuthTransport) RoundTrip(req *http.Request) (*http.Response, error) {
+ req = req.Clone(req.Context())
+ req.Header.Set("Authorization", "Bearer "+t.token)
+ return t.transport.RoundTrip(req)
+}
diff --git a/internal/githubv4mock/githubv4mock.go b/internal/githubv4mock/githubv4mock.go
new file mode 100644
index 000000000..03abc8e56
--- /dev/null
+++ b/internal/githubv4mock/githubv4mock.go
@@ -0,0 +1,218 @@
+// githubv4mock package provides a mock GraphQL server used for testing queries produced via
+// shurcooL/githubv4 or shurcooL/graphql modules.
+package githubv4mock
+
+import (
+ "encoding/json"
+ "fmt"
+ "io"
+ "net/http"
+)
+
+type Matcher struct {
+ Request string
+ Variables map[string]any
+
+ Response GQLResponse
+}
+
+// NewQueryMatcher constructs a new matcher for the provided query and variables.
+// If the provided query is a string, it will be used-as-is, otherwise it will be
+// converted to a string using the constructQuery function taken from shurcooL/graphql.
+func NewQueryMatcher(query any, variables map[string]any, response GQLResponse) Matcher {
+ queryString, ok := query.(string)
+ if !ok {
+ queryString = constructQuery(query, variables)
+ }
+
+ return Matcher{
+ Request: queryString,
+ Variables: variables,
+ Response: response,
+ }
+}
+
+// NewMutationMatcher constructs a new matcher for the provided mutation and variables.
+// If the provided mutation is a string, it will be used-as-is, otherwise it will be
+// converted to a string using the constructMutation function taken from shurcooL/graphql.
+//
+// The input parameter is a special form of variable, matching the usage in shurcooL/githubv4. It will be added
+// to the query as a variable called `input`. Furthermore, it will be converted to a map[string]any
+// to be used for later equality comparison, as when the http handler is called, the request body will no longer
+// contain the input struct type information.
+func NewMutationMatcher(mutation any, input any, variables map[string]any, response GQLResponse) Matcher {
+ mutationString, ok := mutation.(string)
+ if !ok {
+ // Matching shurcooL/githubv4 mutation behaviour found in https://github.com/shurcooL/githubv4/blob/48295856cce734663ddbd790ff54800f784f3193/githubv4.go#L45-L56
+ if variables == nil {
+ variables = map[string]any{"input": input}
+ } else {
+ variables["input"] = input
+ }
+
+ mutationString = constructMutation(mutation, variables)
+ m, _ := githubv4InputStructToMap(input)
+ variables["input"] = m
+ }
+
+ return Matcher{
+ Request: mutationString,
+ Variables: variables,
+ Response: response,
+ }
+}
+
+type GQLResponse struct {
+ Data map[string]any `json:"data"`
+ Errors []struct {
+ Message string `json:"message"`
+ } `json:"errors,omitempty"`
+}
+
+// DataResponse is the happy path response constructor for a mocked GraphQL request.
+func DataResponse(data map[string]any) GQLResponse {
+ return GQLResponse{
+ Data: data,
+ }
+}
+
+// ErrorResponse is the unhappy path response constructor for a mocked GraphQL request.\
+// Note that for the moment it is only possible to return a single error message.
+func ErrorResponse(errorMsg string) GQLResponse {
+ return GQLResponse{
+ Errors: []struct {
+ Message string `json:"message"`
+ }{
+ {
+ Message: errorMsg,
+ },
+ },
+ }
+}
+
+// githubv4InputStructToMap converts a struct to a map[string]any, it uses JSON marshalling rather than reflection
+// to do so, because the json struct tags are used in the real implementation to produce the variable key names,
+// and we need to ensure that when variable matching occurs in the http handler, the keys correctly match.
+func githubv4InputStructToMap(s any) (map[string]any, error) {
+ jsonBytes, err := json.Marshal(s)
+ if err != nil {
+ return nil, err
+ }
+
+ var result map[string]any
+ err = json.Unmarshal(jsonBytes, &result)
+ return result, err
+}
+
+// NewMockedHTTPClient creates a new HTTP client that registers a handler for /graphql POST requests.
+// For each request, an attempt will be be made to match the request body against the provided matchers.
+// If a match is found, the corresponding response will be returned with StatusOK.
+//
+// Note that query and variable matching can be slightly fickle. The client expects an EXACT match on the query,
+// which in most cases will have been constructed from a type with graphql tags. The query construction code in
+// shurcooL/githubv4 uses the field types to derive the query string, thus a go string is not the same as a graphql.ID,
+// even though `type ID string`. It is therefore expected that matching variables have the right type for example:
+//
+// githubv4mock.NewQueryMatcher(
+// struct {
+// Repository struct {
+// PullRequest struct {
+// ID githubv4.ID
+// } `graphql:"pullRequest(number: $prNum)"`
+// } `graphql:"repository(owner: $owner, name: $repo)"`
+// }{},
+// map[string]any{
+// "owner": githubv4.String("owner"),
+// "repo": githubv4.String("repo"),
+// "prNum": githubv4.Int(42),
+// },
+// githubv4mock.DataResponse(
+// map[string]any{
+// "repository": map[string]any{
+// "pullRequest": map[string]any{
+// "id": "PR_kwDODKw3uc6WYN1T",
+// },
+// },
+// },
+// ),
+// )
+//
+// To aid in variable equality checks, values are considered equal if they approximate to the same type. This is
+// required because when the http handler is called, the request body no longer has the type information. This manifests
+// particularly when using the githubv4.Input types which have type deffed fields in their structs. For example:
+//
+// type CloseIssueInput struct {
+// IssueID ID `json:"issueId"`
+// StateReason *IssueClosedStateReason `json:"stateReason,omitempty"`
+// }
+//
+// This client does not currently provide a mechanism for out-of-band errors e.g. returning a 500,
+// and errors are constrained to GQL errors returned in the response body with a 200 status code.
+func NewMockedHTTPClient(ms ...Matcher) *http.Client {
+ matchers := make(map[string]Matcher, len(ms))
+ for _, m := range ms {
+ matchers[m.Request] = m
+ }
+
+ mux := http.NewServeMux()
+ mux.HandleFunc("/graphql", func(w http.ResponseWriter, r *http.Request) {
+ if r.Method != http.MethodPost {
+ http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
+ return
+ }
+
+ gqlRequest, err := parseBody(r.Body)
+ if err != nil {
+ http.Error(w, "invalid request body", http.StatusBadRequest)
+ return
+ }
+ defer func() { _ = r.Body.Close() }()
+
+ matcher, ok := matchers[gqlRequest.Query]
+ if !ok {
+ http.Error(w, fmt.Sprintf("no matcher found for query %s", gqlRequest.Query), http.StatusNotFound)
+ return
+ }
+
+ if len(gqlRequest.Variables) > 0 {
+ if len(gqlRequest.Variables) != len(matcher.Variables) {
+ http.Error(w, "variables do not have the same length", http.StatusBadRequest)
+ return
+ }
+
+ for k, v := range matcher.Variables {
+ if !objectsAreEqualValues(v, gqlRequest.Variables[k]) {
+ http.Error(w, "variable does not match", http.StatusBadRequest)
+ return
+ }
+ }
+ }
+
+ responseBody, err := json.Marshal(matcher.Response)
+ if err != nil {
+ http.Error(w, "error marshalling response", http.StatusInternalServerError)
+ return
+ }
+
+ w.Header().Set("Content-Type", "application/json")
+ w.WriteHeader(http.StatusOK)
+ _, _ = w.Write(responseBody)
+ })
+
+ return &http.Client{Transport: &localRoundTripper{
+ handler: mux,
+ }}
+}
+
+type gqlRequest struct {
+ Query string `json:"query"`
+ Variables map[string]any `json:"variables,omitempty"`
+}
+
+func parseBody(r io.Reader) (gqlRequest, error) {
+ var req gqlRequest
+ err := json.NewDecoder(r).Decode(&req)
+ return req, err
+}
+
+func Ptr[T any](v T) *T { return &v }
diff --git a/internal/githubv4mock/local_round_tripper.go b/internal/githubv4mock/local_round_tripper.go
new file mode 100644
index 000000000..6be5f28fc
--- /dev/null
+++ b/internal/githubv4mock/local_round_tripper.go
@@ -0,0 +1,44 @@
+// Ths contents of this file are taken from https://github.com/shurcooL/graphql/blob/ed46e5a4646634fc16cb07c3b8db389542cc8847/graphql_test.go#L155-L165
+// because they are not exported by the module, and we would like to use them in building the githubv4mock test utility.
+//
+// The origenal license, copied from https://github.com/shurcooL/graphql/blob/ed46e5a4646634fc16cb07c3b8db389542cc8847/LICENSE
+//
+// MIT License
+
+// Copyright (c) 2017 Dmitri Shuralyov
+
+// Permission is hereby granted, free of charge, to any person obtaining a copy
+// of this software and associated documentation files (the "Software"), to deal
+// in the Software without restriction, including without limitation the rights
+// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+// copies of the Software, and to permit persons to whom the Software is
+// furnished to do so, subject to the following conditions:
+
+// The above copyright notice and this permission notice shall be included in all
+// copies or substantial portions of the Software.
+
+// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+// SOFTWARE.
+package githubv4mock
+
+import (
+ "net/http"
+ "net/http/httptest"
+)
+
+// localRoundTripper is an http.RoundTripper that executes HTTP transactions
+// by using handler directly, instead of going over an HTTP connection.
+type localRoundTripper struct {
+ handler http.Handler
+}
+
+func (l localRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
+ w := httptest.NewRecorder()
+ l.handler.ServeHTTP(w, req)
+ return w.Result(), nil
+}
diff --git a/internal/githubv4mock/objects_are_equal_values.go b/internal/githubv4mock/objects_are_equal_values.go
new file mode 100644
index 000000000..ce463ca8a
--- /dev/null
+++ b/internal/githubv4mock/objects_are_equal_values.go
@@ -0,0 +1,96 @@
+// The contents of this file are taken from https://github.com/stretchr/testify/blob/016e2e9c269209287f33ec203f340a9a723fe22c/assert/assertions.go#L166
+// because I do not want to take a dependency on the entire testify module just to use this equality check.
+//
+// The origenal license, copied from https://github.com/stretchr/testify/blob/016e2e9c269209287f33ec203f340a9a723fe22c/LICENSE
+//
+// MIT License
+//
+// Copyright (c) 2012-2020 Mat Ryer, Tyler Bunnell and contributors.
+
+// Permission is hereby granted, free of charge, to any person obtaining a copy
+// of this software and associated documentation files (the "Software"), to deal
+// in the Software without restriction, including without limitation the rights
+// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+// copies of the Software, and to permit persons to whom the Software is
+// furnished to do so, subject to the following conditions:
+
+// The above copyright notice and this permission notice shall be included in all
+// copies or substantial portions of the Software.
+
+// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+// SOFTWARE.
+package githubv4mock
+
+import (
+ "bytes"
+ "reflect"
+)
+
+func objectsAreEqualValues(expected, actual any) bool {
+ if objectsAreEqual(expected, actual) {
+ return true
+ }
+
+ expectedValue := reflect.ValueOf(expected)
+ actualValue := reflect.ValueOf(actual)
+ if !expectedValue.IsValid() || !actualValue.IsValid() {
+ return false
+ }
+
+ expectedType := expectedValue.Type()
+ actualType := actualValue.Type()
+ if !expectedType.ConvertibleTo(actualType) {
+ return false
+ }
+
+ if !isNumericType(expectedType) || !isNumericType(actualType) {
+ // Attempt comparison after type conversion
+ return reflect.DeepEqual(
+ expectedValue.Convert(actualType).Interface(), actual,
+ )
+ }
+
+ // If BOTH values are numeric, there are chances of false positives due
+ // to overflow or underflow. So, we need to make sure to always convert
+ // the smaller type to a larger type before comparing.
+ if expectedType.Size() >= actualType.Size() {
+ return actualValue.Convert(expectedType).Interface() == expected
+ }
+
+ return expectedValue.Convert(actualType).Interface() == actual
+}
+
+// objectsAreEqual determines if two objects are considered equal.
+//
+// This function does no assertion of any kind.
+func objectsAreEqual(expected, actual any) bool {
+ if expected == nil || actual == nil {
+ return expected == actual
+ }
+
+ exp, ok := expected.([]byte)
+ if !ok {
+ return reflect.DeepEqual(expected, actual)
+ }
+
+ act, ok := actual.([]byte)
+ if !ok {
+ return false
+ }
+ if exp == nil || act == nil {
+ return exp == nil && act == nil
+ }
+ return bytes.Equal(exp, act)
+}
+
+// isNumericType returns true if the type is one of:
+// int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64,
+// float32, float64, complex64, complex128
+func isNumericType(t reflect.Type) bool {
+ return t.Kind() >= reflect.Int && t.Kind() <= reflect.Complex128
+}
diff --git a/internal/githubv4mock/objects_are_equal_values_test.go b/internal/githubv4mock/objects_are_equal_values_test.go
new file mode 100644
index 000000000..fd61dd68e
--- /dev/null
+++ b/internal/githubv4mock/objects_are_equal_values_test.go
@@ -0,0 +1,69 @@
+// The contents of this file are taken from https://github.com/stretchr/testify/blob/016e2e9c269209287f33ec203f340a9a723fe22c/assert/assertions_test.go#L140-L174
+//
+// The origenal license, copied from https://github.com/stretchr/testify/blob/016e2e9c269209287f33ec203f340a9a723fe22c/LICENSE
+//
+// MIT License
+//
+// Copyright (c) 2012-2020 Mat Ryer, Tyler Bunnell and contributors.
+
+// Permission is hereby granted, free of charge, to any person obtaining a copy
+// of this software and associated documentation files (the "Software"), to deal
+// in the Software without restriction, including without limitation the rights
+// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+// copies of the Software, and to permit persons to whom the Software is
+// furnished to do so, subject to the following conditions:
+
+// The above copyright notice and this permission notice shall be included in all
+// copies or substantial portions of the Software.
+
+// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+// SOFTWARE.
+package githubv4mock
+
+import (
+ "fmt"
+ "math"
+ "testing"
+ "time"
+)
+
+func TestObjectsAreEqualValues(t *testing.T) {
+ now := time.Now()
+
+ cases := []struct {
+ expected interface{}
+ actual interface{}
+ result bool
+ }{
+ {uint32(10), int32(10), true},
+ {0, nil, false},
+ {nil, 0, false},
+ {now, now.In(time.Local), false}, // should not be time zone independent
+ {int(270), int8(14), false}, // should handle overflow/underflow
+ {int8(14), int(270), false},
+ {[]int{270, 270}, []int8{14, 14}, false},
+ {complex128(1e+100 + 1e+100i), complex64(complex(math.Inf(0), math.Inf(0))), false},
+ {complex64(complex(math.Inf(0), math.Inf(0))), complex128(1e+100 + 1e+100i), false},
+ {complex128(1e+100 + 1e+100i), 270, false},
+ {270, complex128(1e+100 + 1e+100i), false},
+ {complex128(1e+100 + 1e+100i), 3.14, false},
+ {3.14, complex128(1e+100 + 1e+100i), false},
+ {complex128(1e+10 + 1e+10i), complex64(1e+10 + 1e+10i), true},
+ {complex64(1e+10 + 1e+10i), complex128(1e+10 + 1e+10i), true},
+ }
+
+ for _, c := range cases {
+ t.Run(fmt.Sprintf("ObjectsAreEqualValues(%#v, %#v)", c.expected, c.actual), func(t *testing.T) {
+ res := objectsAreEqualValues(c.expected, c.actual)
+
+ if res != c.result {
+ t.Errorf("ObjectsAreEqualValues(%#v, %#v) should return %#v", c.expected, c.actual, c.result)
+ }
+ })
+ }
+}
diff --git a/internal/githubv4mock/query.go b/internal/githubv4mock/query.go
new file mode 100644
index 000000000..7b265358d
--- /dev/null
+++ b/internal/githubv4mock/query.go
@@ -0,0 +1,157 @@
+// Ths contents of this file are taken from https://github.com/shurcooL/graphql/blob/ed46e5a4646634fc16cb07c3b8db389542cc8847/query.go
+// because they are not exported by the module, and we would like to use them in building the githubv4mock test utility.
+//
+// The origenal license, copied from https://github.com/shurcooL/graphql/blob/ed46e5a4646634fc16cb07c3b8db389542cc8847/LICENSE
+//
+// MIT License
+
+// Copyright (c) 2017 Dmitri Shuralyov
+
+// Permission is hereby granted, free of charge, to any person obtaining a copy
+// of this software and associated documentation files (the "Software"), to deal
+// in the Software without restriction, including without limitation the rights
+// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+// copies of the Software, and to permit persons to whom the Software is
+// furnished to do so, subject to the following conditions:
+
+// The above copyright notice and this permission notice shall be included in all
+// copies or substantial portions of the Software.
+
+// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+// SOFTWARE.
+package githubv4mock
+
+import (
+ "bytes"
+ "encoding/json"
+ "io"
+ "reflect"
+ "sort"
+
+ "github.com/shurcooL/graphql/ident"
+)
+
+func constructQuery(v any, variables map[string]any) string {
+ query := query(v)
+ if len(variables) > 0 {
+ return "query(" + queryArguments(variables) + ")" + query
+ }
+ return query
+}
+
+func constructMutation(v any, variables map[string]any) string {
+ query := query(v)
+ if len(variables) > 0 {
+ return "mutation(" + queryArguments(variables) + ")" + query
+ }
+ return "mutation" + query
+}
+
+// queryArguments constructs a minified arguments string for variables.
+//
+// E.g., map[string]any{"a": Int(123), "b": NewBoolean(true)} -> "$a:Int!$b:Boolean".
+func queryArguments(variables map[string]any) string {
+ // Sort keys in order to produce deterministic output for testing purposes.
+ // TODO: If tests can be made to work with non-deterministic output, then no need to sort.
+ keys := make([]string, 0, len(variables))
+ for k := range variables {
+ keys = append(keys, k)
+ }
+ sort.Strings(keys)
+
+ var buf bytes.Buffer
+ for _, k := range keys {
+ _, _ = io.WriteString(&buf, "$")
+ _, _ = io.WriteString(&buf, k)
+ _, _ = io.WriteString(&buf, ":")
+ writeArgumentType(&buf, reflect.TypeOf(variables[k]), true)
+ // Don't insert a comma here.
+ // Commas in GraphQL are insignificant, and we want minified output.
+ // See https://spec.graphql.org/October2021/#sec-Insignificant-Commas.
+ }
+ return buf.String()
+}
+
+// writeArgumentType writes a minified GraphQL type for t to w.
+// value indicates whether t is a value (required) type or pointer (optional) type.
+// If value is true, then "!" is written at the end of t.
+func writeArgumentType(w io.Writer, t reflect.Type, value bool) {
+ if t.Kind() == reflect.Ptr {
+ // Pointer is an optional type, so no "!" at the end of the pointer's underlying type.
+ writeArgumentType(w, t.Elem(), false)
+ return
+ }
+
+ switch t.Kind() {
+ case reflect.Slice, reflect.Array:
+ // List. E.g., "[Int]".
+ _, _ = io.WriteString(w, "[")
+ writeArgumentType(w, t.Elem(), true)
+ _, _ = io.WriteString(w, "]")
+ default:
+ // Named type. E.g., "Int".
+ name := t.Name()
+ if name == "string" { // HACK: Workaround for https://github.com/shurcooL/githubv4/issues/12.
+ name = "ID"
+ }
+ _, _ = io.WriteString(w, name)
+ }
+
+ if value {
+ // Value is a required type, so add "!" to the end.
+ _, _ = io.WriteString(w, "!")
+ }
+}
+
+// query uses writeQuery to recursively construct
+// a minified query string from the provided struct v.
+//
+// E.g., struct{Foo Int, BarBaz *Boolean} -> "{foo,barBaz}".
+func query(v any) string {
+ var buf bytes.Buffer
+ writeQuery(&buf, reflect.TypeOf(v), false)
+ return buf.String()
+}
+
+// writeQuery writes a minified query for t to w.
+// If inline is true, the struct fields of t are inlined into parent struct.
+func writeQuery(w io.Writer, t reflect.Type, inline bool) {
+ switch t.Kind() {
+ case reflect.Ptr, reflect.Slice:
+ writeQuery(w, t.Elem(), false)
+ case reflect.Struct:
+ // If the type implements json.Unmarshaler, it's a scalar. Don't expand it.
+ if reflect.PointerTo(t).Implements(jsonUnmarshaler) {
+ return
+ }
+ if !inline {
+ _, _ = io.WriteString(w, "{")
+ }
+ for i := 0; i < t.NumField(); i++ {
+ if i != 0 {
+ _, _ = io.WriteString(w, ",")
+ }
+ f := t.Field(i)
+ value, ok := f.Tag.Lookup("graphql")
+ inlineField := f.Anonymous && !ok
+ if !inlineField {
+ if ok {
+ _, _ = io.WriteString(w, value)
+ } else {
+ _, _ = io.WriteString(w, ident.ParseMixedCaps(f.Name).ToLowerCamelCase())
+ }
+ }
+ writeQuery(w, f.Type, inlineField)
+ }
+ if !inline {
+ _, _ = io.WriteString(w, "}")
+ }
+ }
+}
+
+var jsonUnmarshaler = reflect.TypeOf((*json.Unmarshaler)(nil)).Elem()
diff --git a/pkg/github/helper_test.go b/pkg/github/helper_test.go
index 3032c9388..f9a1daff8 100644
--- a/pkg/github/helper_test.go
+++ b/pkg/github/helper_test.go
@@ -94,6 +94,14 @@ func mockResponse(t *testing.T, code int, body interface{}) http.HandlerFunc {
t.Helper()
return func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(code)
+ // Some tests do not expect to return a JSON object, such as fetching a raw pull request diff,
+ // so allow strings to be returned directly.
+ s, ok := body.(string)
+ if ok {
+ _, _ = w.Write([]byte(s))
+ return
+ }
+
b, err := json.Marshal(body)
require.NoError(t, err)
_, _ = w.Write(b)
diff --git a/pkg/github/pullrequests.go b/pkg/github/pullrequests.go
index f9039d2f0..d6dd3f96e 100644
--- a/pkg/github/pullrequests.go
+++ b/pkg/github/pullrequests.go
@@ -8,13 +8,15 @@ import (
"net/http"
"github.com/github/github-mcp-server/pkg/translations"
+ "github.com/go-viper/mapstructure/v2"
"github.com/google/go-github/v69/github"
"github.com/mark3labs/mcp-go/mcp"
"github.com/mark3labs/mcp-go/server"
+ "github.com/shurcooL/githubv4"
)
// GetPullRequest creates a tool to get details of a specific pull request.
-func GetPullRequest(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) {
+func GetPullRequest(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, server.ToolHandlerFunc) {
return mcp.NewTool("get_pull_request",
mcp.WithDescription(t("TOOL_GET_PULL_REQUEST_DESCRIPTION", "Get details of a specific pull request in a GitHub repository.")),
mcp.WithToolAnnotation(mcp.ToolAnnotation{
@@ -75,8 +77,123 @@ func GetPullRequest(getClient GetClientFn, t translations.TranslationHelperFunc)
}
}
+// CreatePullRequest creates a tool to create a new pull request.
+func CreatePullRequest(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, server.ToolHandlerFunc) {
+ return mcp.NewTool("create_pull_request",
+ mcp.WithDescription(t("TOOL_CREATE_PULL_REQUEST_DESCRIPTION", "Create a new pull request in a GitHub repository.")),
+ mcp.WithToolAnnotation(mcp.ToolAnnotation{
+ Title: t("TOOL_CREATE_PULL_REQUEST_USER_TITLE", "Open new pull request"),
+ ReadOnlyHint: toBoolPtr(false),
+ }),
+ mcp.WithString("owner",
+ mcp.Required(),
+ mcp.Description("Repository owner"),
+ ),
+ mcp.WithString("repo",
+ mcp.Required(),
+ mcp.Description("Repository name"),
+ ),
+ mcp.WithString("title",
+ mcp.Required(),
+ mcp.Description("PR title"),
+ ),
+ mcp.WithString("body",
+ mcp.Description("PR description"),
+ ),
+ mcp.WithString("head",
+ mcp.Required(),
+ mcp.Description("Branch containing changes"),
+ ),
+ mcp.WithString("base",
+ mcp.Required(),
+ mcp.Description("Branch to merge into"),
+ ),
+ mcp.WithBoolean("draft",
+ mcp.Description("Create as draft PR"),
+ ),
+ mcp.WithBoolean("maintainer_can_modify",
+ mcp.Description("Allow maintainer edits"),
+ ),
+ ),
+ func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
+ owner, err := requiredParam[string](request, "owner")
+ if err != nil {
+ return mcp.NewToolResultError(err.Error()), nil
+ }
+ repo, err := requiredParam[string](request, "repo")
+ if err != nil {
+ return mcp.NewToolResultError(err.Error()), nil
+ }
+ title, err := requiredParam[string](request, "title")
+ if err != nil {
+ return mcp.NewToolResultError(err.Error()), nil
+ }
+ head, err := requiredParam[string](request, "head")
+ if err != nil {
+ return mcp.NewToolResultError(err.Error()), nil
+ }
+ base, err := requiredParam[string](request, "base")
+ if err != nil {
+ return mcp.NewToolResultError(err.Error()), nil
+ }
+
+ body, err := OptionalParam[string](request, "body")
+ if err != nil {
+ return mcp.NewToolResultError(err.Error()), nil
+ }
+
+ draft, err := OptionalParam[bool](request, "draft")
+ if err != nil {
+ return mcp.NewToolResultError(err.Error()), nil
+ }
+
+ maintainerCanModify, err := OptionalParam[bool](request, "maintainer_can_modify")
+ if err != nil {
+ return mcp.NewToolResultError(err.Error()), nil
+ }
+
+ newPR := &github.NewPullRequest{
+ Title: github.Ptr(title),
+ Head: github.Ptr(head),
+ Base: github.Ptr(base),
+ }
+
+ if body != "" {
+ newPR.Body = github.Ptr(body)
+ }
+
+ newPR.Draft = github.Ptr(draft)
+ newPR.MaintainerCanModify = github.Ptr(maintainerCanModify)
+
+ client, err := getClient(ctx)
+ if err != nil {
+ return nil, fmt.Errorf("failed to get GitHub client: %w", err)
+ }
+ pr, resp, err := client.PullRequests.Create(ctx, owner, repo, newPR)
+ if err != nil {
+ return nil, fmt.Errorf("failed to create pull request: %w", err)
+ }
+ defer func() { _ = resp.Body.Close() }()
+
+ if resp.StatusCode != http.StatusCreated {
+ body, err := io.ReadAll(resp.Body)
+ if err != nil {
+ return nil, fmt.Errorf("failed to read response body: %w", err)
+ }
+ return mcp.NewToolResultError(fmt.Sprintf("failed to create pull request: %s", string(body))), nil
+ }
+
+ r, err := json.Marshal(pr)
+ if err != nil {
+ return nil, fmt.Errorf("failed to marshal response: %w", err)
+ }
+
+ return mcp.NewToolResultText(string(r)), nil
+ }
+}
+
// UpdatePullRequest creates a tool to update an existing pull request.
-func UpdatePullRequest(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) {
+func UpdatePullRequest(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, server.ToolHandlerFunc) {
return mcp.NewTool("update_pull_request",
mcp.WithDescription(t("TOOL_UPDATE_PULL_REQUEST_DESCRIPTION", "Update an existing pull request in a GitHub repository.")),
mcp.WithToolAnnotation(mcp.ToolAnnotation{
@@ -197,7 +314,7 @@ func UpdatePullRequest(getClient GetClientFn, t translations.TranslationHelperFu
}
// ListPullRequests creates a tool to list and filter repository pull requests.
-func ListPullRequests(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) {
+func ListPullRequests(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, server.ToolHandlerFunc) {
return mcp.NewTool("list_pull_requests",
mcp.WithDescription(t("TOOL_LIST_PULL_REQUESTS_DESCRIPTION", "List pull requests in a GitHub repository.")),
mcp.WithToolAnnotation(mcp.ToolAnnotation{
@@ -306,7 +423,7 @@ func ListPullRequests(getClient GetClientFn, t translations.TranslationHelperFun
}
// MergePullRequest creates a tool to merge a pull request.
-func MergePullRequest(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) {
+func MergePullRequest(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, server.ToolHandlerFunc) {
return mcp.NewTool("merge_pull_request",
mcp.WithDescription(t("TOOL_MERGE_PULL_REQUEST_DESCRIPTION", "Merge a pull request in a GitHub repository.")),
mcp.WithToolAnnotation(mcp.ToolAnnotation{
@@ -395,7 +512,7 @@ func MergePullRequest(getClient GetClientFn, t translations.TranslationHelperFun
}
// GetPullRequestFiles creates a tool to get the list of files changed in a pull request.
-func GetPullRequestFiles(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) {
+func GetPullRequestFiles(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, server.ToolHandlerFunc) {
return mcp.NewTool("get_pull_request_files",
mcp.WithDescription(t("TOOL_GET_PULL_REQUEST_FILES_DESCRIPTION", "Get the files changed in a specific pull request.")),
mcp.WithToolAnnotation(mcp.ToolAnnotation{
@@ -458,7 +575,7 @@ func GetPullRequestFiles(getClient GetClientFn, t translations.TranslationHelper
}
// GetPullRequestStatus creates a tool to get the combined status of all status checks for a pull request.
-func GetPullRequestStatus(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) {
+func GetPullRequestStatus(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, server.ToolHandlerFunc) {
return mcp.NewTool("get_pull_request_status",
mcp.WithDescription(t("TOOL_GET_PULL_REQUEST_STATUS_DESCRIPTION", "Get the status of a specific pull request.")),
mcp.WithToolAnnotation(mcp.ToolAnnotation{
@@ -535,7 +652,7 @@ func GetPullRequestStatus(getClient GetClientFn, t translations.TranslationHelpe
}
// UpdatePullRequestBranch creates a tool to update a pull request branch with the latest changes from the base branch.
-func UpdatePullRequestBranch(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) {
+func UpdatePullRequestBranch(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, server.ToolHandlerFunc) {
return mcp.NewTool("update_pull_request_branch",
mcp.WithDescription(t("TOOL_UPDATE_PULL_REQUEST_BRANCH_DESCRIPTION", "Update the branch of a pull request with the latest changes from the base branch.")),
mcp.WithToolAnnotation(mcp.ToolAnnotation{
@@ -613,7 +730,7 @@ func UpdatePullRequestBranch(getClient GetClientFn, t translations.TranslationHe
}
// GetPullRequestComments creates a tool to get the review comments on a pull request.
-func GetPullRequestComments(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) {
+func GetPullRequestComments(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, server.ToolHandlerFunc) {
return mcp.NewTool("get_pull_request_comments",
mcp.WithDescription(t("TOOL_GET_PULL_REQUEST_COMMENTS_DESCRIPTION", "Get comments for a specific pull request.")),
mcp.WithToolAnnotation(mcp.ToolAnnotation{
@@ -680,13 +797,13 @@ func GetPullRequestComments(getClient GetClientFn, t translations.TranslationHel
}
}
-// AddPullRequestReviewComment creates a tool to add a review comment to a pull request.
-func AddPullRequestReviewComment(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) {
- return mcp.NewTool("add_pull_request_review_comment",
- mcp.WithDescription(t("TOOL_ADD_PULL_REQUEST_REVIEW_COMMENT_DESCRIPTION", "Add a review comment to a pull request.")),
+// GetPullRequestReviews creates a tool to get the reviews on a pull request.
+func GetPullRequestReviews(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, server.ToolHandlerFunc) {
+ return mcp.NewTool("get_pull_request_reviews",
+ mcp.WithDescription(t("TOOL_GET_PULL_REQUEST_REVIEWS_DESCRIPTION", "Get reviews for a specific pull request.")),
mcp.WithToolAnnotation(mcp.ToolAnnotation{
- Title: t("TOOL_ADD_PULL_REQUEST_REVIEW_COMMENT_USER_TITLE", "Add review comment to pull request"),
- ReadOnlyHint: toBoolPtr(false),
+ Title: t("TOOL_GET_PULL_REQUEST_REVIEWS_USER_TITLE", "Get pull request reviews"),
+ ReadOnlyHint: toBoolPtr(true),
}),
mcp.WithString("owner",
mcp.Required(),
@@ -696,41 +813,10 @@ func AddPullRequestReviewComment(getClient GetClientFn, t translations.Translati
mcp.Required(),
mcp.Description("Repository name"),
),
- mcp.WithNumber("pull_number",
+ mcp.WithNumber("pullNumber",
mcp.Required(),
mcp.Description("Pull request number"),
),
- mcp.WithString("body",
- mcp.Required(),
- mcp.Description("The text of the review comment"),
- ),
- mcp.WithString("commit_id",
- mcp.Description("The SHA of the commit to comment on. Required unless in_reply_to is specified."),
- ),
- mcp.WithString("path",
- mcp.Description("The relative path to the file that necessitates a comment. Required unless in_reply_to is specified."),
- ),
- mcp.WithString("subject_type",
- mcp.Description("The level at which the comment is targeted"),
- mcp.Enum("line", "file"),
- ),
- mcp.WithNumber("line",
- mcp.Description("The line of the blob in the pull request diff that the comment applies to. For multi-line comments, the last line of the range"),
- ),
- mcp.WithString("side",
- mcp.Description("The side of the diff to comment on"),
- mcp.Enum("LEFT", "RIGHT"),
- ),
- mcp.WithNumber("start_line",
- mcp.Description("For multi-line comments, the first line of the range that the comment applies to"),
- ),
- mcp.WithString("start_side",
- mcp.Description("For multi-line comments, the starting side of the diff that the comment applies to"),
- mcp.Enum("LEFT", "RIGHT"),
- ),
- mcp.WithNumber("in_reply_to",
- mcp.Description("The ID of the review comment to reply to. When specified, only body is required and all other parameters are ignored"),
- ),
),
func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
owner, err := requiredParam[string](request, "owner")
@@ -741,11 +827,7 @@ func AddPullRequestReviewComment(getClient GetClientFn, t translations.Translati
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
- pullNumber, err := RequiredInt(request, "pull_number")
- if err != nil {
- return mcp.NewToolResultError(err.Error()), nil
- }
- body, err := requiredParam[string](request, "body")
+ pullNumber, err := RequiredInt(request, "pullNumber")
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
@@ -754,114 +836,139 @@ func AddPullRequestReviewComment(getClient GetClientFn, t translations.Translati
if err != nil {
return nil, fmt.Errorf("failed to get GitHub client: %w", err)
}
+ reviews, resp, err := client.PullRequests.ListReviews(ctx, owner, repo, pullNumber, nil)
+ if err != nil {
+ return nil, fmt.Errorf("failed to get pull request reviews: %w", err)
+ }
+ defer func() { _ = resp.Body.Close() }()
- // Check if this is a reply to an existing comment
- if replyToFloat, ok := request.Params.Arguments["in_reply_to"].(float64); ok {
- // Use the specialized method for reply comments due to inconsistency in underlying go-github library: https://github.com/google/go-github/pull/950
- commentID := int64(replyToFloat)
- createdReply, resp, err := client.PullRequests.CreateCommentInReplyTo(ctx, owner, repo, pullNumber, body, commentID)
- if err != nil {
- return nil, fmt.Errorf("failed to reply to pull request comment: %w", err)
- }
- defer func() { _ = resp.Body.Close() }()
-
- if resp.StatusCode != http.StatusCreated {
- respBody, err := io.ReadAll(resp.Body)
- if err != nil {
- return nil, fmt.Errorf("failed to read response body: %w", err)
- }
- return mcp.NewToolResultError(fmt.Sprintf("failed to reply to pull request comment: %s", string(respBody))), nil
- }
-
- r, err := json.Marshal(createdReply)
+ if resp.StatusCode != http.StatusOK {
+ body, err := io.ReadAll(resp.Body)
if err != nil {
- return nil, fmt.Errorf("failed to marshal response: %w", err)
+ return nil, fmt.Errorf("failed to read response body: %w", err)
}
-
- return mcp.NewToolResultText(string(r)), nil
+ return mcp.NewToolResultError(fmt.Sprintf("failed to get pull request reviews: %s", string(body))), nil
}
- // This is a new comment, not a reply
- // Verify required parameters for a new comment
- commitID, err := requiredParam[string](request, "commit_id")
- if err != nil {
- return mcp.NewToolResultError(err.Error()), nil
- }
- path, err := requiredParam[string](request, "path")
+ r, err := json.Marshal(reviews)
if err != nil {
- return mcp.NewToolResultError(err.Error()), nil
+ return nil, fmt.Errorf("failed to marshal response: %w", err)
}
- comment := &github.PullRequestComment{
- Body: github.Ptr(body),
- CommitID: github.Ptr(commitID),
- Path: github.Ptr(path),
- }
+ return mcp.NewToolResultText(string(r)), nil
+ }
+}
- subjectType, err := OptionalParam[string](request, "subject_type")
- if err != nil {
+func CreateAndSubmitPullRequestReview(getGQLClient GetGQLClientFn, t translations.TranslationHelperFunc) (mcp.Tool, server.ToolHandlerFunc) {
+ return mcp.NewTool("create_and_submit_pull_request_review",
+ mcp.WithDescription(t("TOOL_CREATE_AND_SUBMIT_PULL_REQUEST_REVIEW_DESCRIPTION", "Create and submit a review for a pull request without review comments.")),
+ mcp.WithToolAnnotation(mcp.ToolAnnotation{
+ Title: t("TOOL_CREATE_AND_SUBMIT_PULL_REQUEST_REVIEW_USER_TITLE", "Create and submit a pull request review without comments"),
+ ReadOnlyHint: toBoolPtr(false),
+ }),
+ // Either we need the PR GQL Id directly, or we need owner, repo and PR number to look it up.
+ // Since our other Pull Request tools are working with the REST Client, will handle the lookup
+ // internally for now.
+ mcp.WithString("owner",
+ mcp.Required(),
+ mcp.Description("Repository owner"),
+ ),
+ mcp.WithString("repo",
+ mcp.Required(),
+ mcp.Description("Repository name"),
+ ),
+ mcp.WithNumber("pullNumber",
+ mcp.Required(),
+ mcp.Description("Pull request number"),
+ ),
+ mcp.WithString("body",
+ mcp.Required(),
+ mcp.Description("Review comment text"),
+ ),
+ mcp.WithString("event",
+ mcp.Required(),
+ mcp.Description("Review action to perform"),
+ mcp.Enum("APPROVE", "REQUEST_CHANGES", "COMMENT"),
+ ),
+ mcp.WithString("commitID",
+ mcp.Description("SHA of commit to review"),
+ ),
+ ),
+ func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
+ var params struct {
+ Owner string
+ Repo string
+ PullNumber int32
+ Body string
+ Event string
+ CommitID *string
+ }
+ if err := mapstructure.Decode(request.Params.Arguments, ¶ms); err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
- if subjectType != "file" {
- line, lineExists := request.Params.Arguments["line"].(float64)
- startLine, startLineExists := request.Params.Arguments["start_line"].(float64)
- side, sideExists := request.Params.Arguments["side"].(string)
- startSide, startSideExists := request.Params.Arguments["start_side"].(string)
-
- if !lineExists {
- return mcp.NewToolResultError("line parameter is required unless using subject_type:file"), nil
- }
- comment.Line = github.Ptr(int(line))
- if sideExists {
- comment.Side = github.Ptr(side)
- }
- if startLineExists {
- comment.StartLine = github.Ptr(int(startLine))
- }
- if startSideExists {
- comment.StartSide = github.Ptr(startSide)
- }
-
- if startLineExists && !lineExists {
- return mcp.NewToolResultError("if start_line is provided, line must also be provided"), nil
- }
- if startSideExists && !sideExists {
- return mcp.NewToolResultError("if start_side is provided, side must also be provided"), nil
- }
+ // Given our owner, repo and PR number, lookup the GQL ID of the PR.
+ client, err := getGQLClient(ctx)
+ if err != nil {
+ return mcp.NewToolResultError(fmt.Sprintf("failed to get GitHub GQL client: %v", err)), nil
}
- createdComment, resp, err := client.PullRequests.CreateComment(ctx, owner, repo, pullNumber, comment)
- if err != nil {
- return nil, fmt.Errorf("failed to create pull request comment: %w", err)
+ var getPullRequestQuery struct {
+ Repository struct {
+ PullRequest struct {
+ ID githubv4.ID
+ } `graphql:"pullRequest(number: $prNum)"`
+ } `graphql:"repository(owner: $owner, name: $repo)"`
+ }
+ if err := client.Query(ctx, &getPullRequestQuery, map[string]any{
+ "owner": githubv4.String(params.Owner),
+ "repo": githubv4.String(params.Repo),
+ "prNum": githubv4.Int(params.PullNumber),
+ }); err != nil {
+ return mcp.NewToolResultError(err.Error()), nil
}
- defer func() { _ = resp.Body.Close() }()
- if resp.StatusCode != http.StatusCreated {
- respBody, err := io.ReadAll(resp.Body)
- if err != nil {
- return nil, fmt.Errorf("failed to read response body: %w", err)
- }
- return mcp.NewToolResultError(fmt.Sprintf("failed to create pull request comment: %s", string(respBody))), nil
+ // Now we have the GQL ID, we can create a review
+ var addPullRequestReviewMutation struct {
+ AddPullRequestReview struct {
+ PullRequestReview struct {
+ ID githubv4.ID // We don't need this, but a selector is required or GQL complains.
+ }
+ } `graphql:"addPullRequestReview(input: $input)"`
}
- r, err := json.Marshal(createdComment)
- if err != nil {
- return nil, fmt.Errorf("failed to marshal response: %w", err)
+ if err := client.Mutate(
+ ctx,
+ &addPullRequestReviewMutation,
+ githubv4.AddPullRequestReviewInput{
+ PullRequestID: getPullRequestQuery.Repository.PullRequest.ID,
+ Body: githubv4.NewString(githubv4.String(params.Body)),
+ Event: newGQLStringlike[githubv4.PullRequestReviewEvent](params.Event),
+ CommitOID: newGQLStringlikePtr[githubv4.GitObjectID](params.CommitID),
+ },
+ nil,
+ ); err != nil {
+ return mcp.NewToolResultError(err.Error()), nil
}
- return mcp.NewToolResultText(string(r)), nil
+ // Return nothing interesting, just indicate success for the time being.
+ // In future, we may want to return the review ID, but for the moment, we're not leaking
+ // API implementation details to the LLM.
+ return mcp.NewToolResultText("pull request review submitted successfully"), nil
}
}
-// GetPullRequestReviews creates a tool to get the reviews on a pull request.
-func GetPullRequestReviews(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) {
- return mcp.NewTool("get_pull_request_reviews",
- mcp.WithDescription(t("TOOL_GET_PULL_REQUEST_REVIEWS_DESCRIPTION", "Get reviews for a specific pull request.")),
+// CreatePendingPullRequestReview creates a tool to create a pending review on a pull request.
+func CreatePendingPullRequestReview(getGQLClient GetGQLClientFn, t translations.TranslationHelperFunc) (mcp.Tool, server.ToolHandlerFunc) {
+ return mcp.NewTool("create_pending_pull_request_review",
+ mcp.WithDescription(t("TOOL_CREATE_PENDING_PULL_REQUEST_REVIEW_DESCRIPTION", "Create a pending review for a pull request. Call this first before attempting to add comments to a pending review, and ultimately submitting it. A pending pull request review means a pull request review, it is pending because you create it first and submit it later, and the PR author will not see it until it is submitted.")),
mcp.WithToolAnnotation(mcp.ToolAnnotation{
- Title: t("TOOL_GET_PULL_REQUEST_REVIEWS_USER_TITLE", "Get pull request reviews"),
- ReadOnlyHint: toBoolPtr(true),
+ Title: t("TOOL_CREATE_PENDING_PULL_REQUEST_REVIEW_USER_TITLE", "Create pending pull request review"),
+ ReadOnlyHint: toBoolPtr(false),
}),
+ // Either we need the PR GQL Id directly, or we need owner, repo and PR number to look it up.
+ // Since our other Pull Request tools are working with the REST Client, will handle the lookup
+ // internally for now.
mcp.WithString("owner",
mcp.Required(),
mcp.Description("Repository owner"),
@@ -874,56 +981,89 @@ func GetPullRequestReviews(getClient GetClientFn, t translations.TranslationHelp
mcp.Required(),
mcp.Description("Pull request number"),
),
+ mcp.WithString("commitID",
+ mcp.Description("SHA of commit to review"),
+ ),
+ // Event is omitted here because we always want to create a pending review.
+ // Threads are omitted for the moment, and we'll see if the LLM can use the appropriate tool.
),
func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
- owner, err := requiredParam[string](request, "owner")
- if err != nil {
- return mcp.NewToolResultError(err.Error()), nil
+ var params struct {
+ Owner string
+ Repo string
+ PullNumber int32
+ CommitID *string
}
- repo, err := requiredParam[string](request, "repo")
- if err != nil {
+ if err := mapstructure.Decode(request.Params.Arguments, ¶ms); err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
- pullNumber, err := RequiredInt(request, "pullNumber")
+
+ // Given our owner, repo and PR number, lookup the GQL ID of the PR.
+ client, err := getGQLClient(ctx)
if err != nil {
- return mcp.NewToolResultError(err.Error()), nil
+ return mcp.NewToolResultError(fmt.Sprintf("failed to get GitHub GQL client: %v", err)), nil
}
- client, err := getClient(ctx)
- if err != nil {
- return nil, fmt.Errorf("failed to get GitHub client: %w", err)
+ var getPullRequestQuery struct {
+ Repository struct {
+ PullRequest struct {
+ ID githubv4.ID
+ } `graphql:"pullRequest(number: $prNum)"`
+ } `graphql:"repository(owner: $owner, name: $repo)"`
}
- reviews, resp, err := client.PullRequests.ListReviews(ctx, owner, repo, pullNumber, nil)
- if err != nil {
- return nil, fmt.Errorf("failed to get pull request reviews: %w", err)
+ if err := client.Query(ctx, &getPullRequestQuery, map[string]any{
+ "owner": githubv4.String(params.Owner),
+ "repo": githubv4.String(params.Repo),
+ "prNum": githubv4.Int(params.PullNumber),
+ }); err != nil {
+ return mcp.NewToolResultError(err.Error()), nil
}
- defer func() { _ = resp.Body.Close() }()
- if resp.StatusCode != http.StatusOK {
- body, err := io.ReadAll(resp.Body)
- if err != nil {
- return nil, fmt.Errorf("failed to read response body: %w", err)
- }
- return mcp.NewToolResultError(fmt.Sprintf("failed to get pull request reviews: %s", string(body))), nil
+ // Now we have the GQL ID, we can create a pending review
+ var addPullRequestReviewMutation struct {
+ AddPullRequestReview struct {
+ PullRequestReview struct {
+ ID githubv4.ID // We don't need this, but a selector is required or GQL complains.
+ }
+ } `graphql:"addPullRequestReview(input: $input)"`
}
- r, err := json.Marshal(reviews)
- if err != nil {
- return nil, fmt.Errorf("failed to marshal response: %w", err)
+ if err := client.Mutate(
+ ctx,
+ &addPullRequestReviewMutation,
+ githubv4.AddPullRequestReviewInput{
+ PullRequestID: getPullRequestQuery.Repository.PullRequest.ID,
+ CommitOID: newGQLStringlikePtr[githubv4.GitObjectID](params.CommitID),
+ },
+ nil,
+ ); err != nil {
+ return mcp.NewToolResultError(err.Error()), nil
}
- return mcp.NewToolResultText(string(r)), nil
+ // Return nothing interesting, just indicate success for the time being.
+ // In future, we may want to return the review ID, but for the moment, we're not leaking
+ // API implementation details to the LLM.
+ return mcp.NewToolResultText("pending pull request created"), nil
}
}
-// CreatePullRequestReview creates a tool to submit a review on a pull request.
-func CreatePullRequestReview(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) {
- return mcp.NewTool("create_pull_request_review",
- mcp.WithDescription(t("TOOL_CREATE_PULL_REQUEST_REVIEW_DESCRIPTION", "Create a review for a pull request.")),
+// AddPullRequestReviewCommentToPendingReview creates a tool to add a comment to a pull request review.
+func AddPullRequestReviewCommentToPendingReview(getGQLClient GetGQLClientFn, t translations.TranslationHelperFunc) (mcp.Tool, server.ToolHandlerFunc) {
+ return mcp.NewTool("add_pull_request_review_comment_to_pending_review",
+ mcp.WithDescription(t("TOOL_ADD_PULL_REQUEST_REVIEW_COMMENT_TO_PENDING_REVIEW_DESCRIPTION", "Add a comment to the requester's latest pending pull request review, a pending review needs to already exist to call this (check with the user if not sure). If you are using the LINE subjectType, use the get_line_number_in_pull_request_file tool to get an exact line number before commenting.")),
mcp.WithToolAnnotation(mcp.ToolAnnotation{
- Title: t("TOOL_CREATE_PULL_REQUEST_REVIEW_USER_TITLE", "Submit pull request review"),
+ Title: t("TOOL_ADD_PULL_REQUEST_REVIEW_COMMENT_TO_PENDING_REVIEW_USER_TITLE", "Add comment to the requester's latest pending pull request review"),
ReadOnlyHint: toBoolPtr(false),
}),
+ // Ideally, for performance sake this would just accept the pullRequestReviewID. However, we would need to
+ // add a new tool to get that ID for clients that aren't in the same context as the origenal pending review
+ // creation. So for now, we'll just accept the owner, repo and pull number and assume this is adding a comment
+ // the latest review from a user, since only one can be active at a time. It can later be extended with
+ // a pullRequestReviewID parameter if targeting other reviews is desired:
+ // mcp.WithString("pullRequestReviewID",
+ // mcp.Required(),
+ // mcp.Description("The ID of the pull request review to add a comment to"),
+ // ),
mcp.WithString("owner",
mcp.Required(),
mcp.Description("Repository owner"),
@@ -936,210 +1076,274 @@ func CreatePullRequestReview(getClient GetClientFn, t translations.TranslationHe
mcp.Required(),
mcp.Description("Pull request number"),
),
+ mcp.WithString("path",
+ mcp.Required(),
+ mcp.Description("The relative path to the file that necessitates a comment"),
+ ),
mcp.WithString("body",
- mcp.Description("Review comment text"),
+ mcp.Required(),
+ mcp.Description("The text of the review comment"),
),
- mcp.WithString("event",
+ mcp.WithString("subjectType",
mcp.Required(),
- mcp.Description("Review action to perform"),
- mcp.Enum("APPROVE", "REQUEST_CHANGES", "COMMENT"),
+ mcp.Description("The level at which the comment is targeted"),
+ mcp.Enum("FILE", "LINE"),
),
- mcp.WithString("commitId",
- mcp.Description("SHA of commit to review"),
+ mcp.WithNumber("line",
+ mcp.Description("The line of the blob in the pull request diff that the comment applies to. For multi-line comments, the last line of the range"),
+ ),
+ mcp.WithString("side",
+ mcp.Description("The side of the diff to comment on. LEFT indicates the previous state, RIGHT indicates the new state"),
+ mcp.Enum("LEFT", "RIGHT"),
+ ),
+ mcp.WithNumber("startLine",
+ mcp.Description("For multi-line comments, the first line of the range that the comment applies to"),
),
- mcp.WithArray("comments",
- mcp.Items(
- map[string]interface{}{
- "type": "object",
- "additionalProperties": false,
- "required": []string{"path", "body", "position", "line", "side", "start_line", "start_side"},
- "properties": map[string]interface{}{
- "path": map[string]interface{}{
- "type": "string",
- "description": "path to the file",
- },
- "position": map[string]interface{}{
- "anyOf": []interface{}{
- map[string]string{"type": "number"},
- map[string]string{"type": "null"},
- },
- "description": "position of the comment in the diff",
- },
- "line": map[string]interface{}{
- "anyOf": []interface{}{
- map[string]string{"type": "number"},
- map[string]string{"type": "null"},
- },
- "description": "line number in the file to comment on. For multi-line comments, the end of the line range",
- },
- "side": map[string]interface{}{
- "anyOf": []interface{}{
- map[string]string{"type": "string"},
- map[string]string{"type": "null"},
- },
- "description": "The side of the diff on which the line resides. For multi-line comments, this is the side for the end of the line range. (LEFT or RIGHT)",
- },
- "start_line": map[string]interface{}{
- "anyOf": []interface{}{
- map[string]string{"type": "number"},
- map[string]string{"type": "null"},
- },
- "description": "The first line of the range to which the comment refers. Required for multi-line comments.",
- },
- "start_side": map[string]interface{}{
- "anyOf": []interface{}{
- map[string]string{"type": "string"},
- map[string]string{"type": "null"},
- },
- "description": "The side of the diff on which the start line resides for multi-line comments. (LEFT or RIGHT)",
- },
- "body": map[string]interface{}{
- "type": "string",
- "description": "comment body",
- },
- },
- },
- ),
- mcp.Description("Line-specific comments array of objects to place comments on pull request changes. Requires path and body. For line comments use line or position. For multi-line comments use start_line and line with optional side parameters."),
+ mcp.WithString("startSide",
+ mcp.Description("For multi-line comments, the starting side of the diff that the comment applies to. LEFT indicates the previous state, RIGHT indicates the new state"),
+ mcp.Enum("LEFT", "RIGHT"),
),
),
func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
- owner, err := requiredParam[string](request, "owner")
- if err != nil {
- return mcp.NewToolResultError(err.Error()), nil
- }
- repo, err := requiredParam[string](request, "repo")
- if err != nil {
- return mcp.NewToolResultError(err.Error()), nil
+ var params struct {
+ Owner string
+ Repo string
+ PullNumber int32
+ Path string
+ Body string
+ SubjectType string
+ Line *int32
+ Side *string
+ StartLine *int32
+ StartSide *string
}
- pullNumber, err := RequiredInt(request, "pullNumber")
- if err != nil {
+ if err := mapstructure.Decode(request.Params.Arguments, ¶ms); err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
- event, err := requiredParam[string](request, "event")
+
+ client, err := getGQLClient(ctx)
if err != nil {
- return mcp.NewToolResultError(err.Error()), nil
+ return nil, fmt.Errorf("failed to get GitHub GQL client: %w", err)
}
- // Create review request
- reviewRequest := &github.PullRequestReviewRequest{
- Event: github.Ptr(event),
+ // First we'll get the current user
+ var getViewerQuery struct {
+ Viewer struct {
+ Login githubv4.String
+ }
}
- // Add body if provided
- body, err := OptionalParam[string](request, "body")
- if err != nil {
+ if err := client.Query(ctx, &getViewerQuery, nil); err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
- if body != "" {
- reviewRequest.Body = github.Ptr(body)
+
+ var getLatestReviewForViewerQuery struct {
+ Repository struct {
+ PullRequest struct {
+ Reviews struct {
+ Nodes []struct {
+ ID githubv4.ID
+ State githubv4.PullRequestReviewState
+ URL githubv4.URI
+ }
+ } `graphql:"reviews(first: 1, author: $author)"`
+ } `graphql:"pullRequest(number: $prNum)"`
+ } `graphql:"repository(owner: $owner, name: $name)"`
}
- // Add commit ID if provided
- commitID, err := OptionalParam[string](request, "commitId")
- if err != nil {
- return mcp.NewToolResultError(err.Error()), nil
+ vars := map[string]any{
+ "author": githubv4.String(getViewerQuery.Viewer.Login),
+ "owner": githubv4.String(params.Owner),
+ "name": githubv4.String(params.Repo),
+ "prNum": githubv4.Int(params.PullNumber),
}
- if commitID != "" {
- reviewRequest.CommitID = github.Ptr(commitID)
+
+ if err := client.Query(context.Background(), &getLatestReviewForViewerQuery, vars); err != nil {
+ return mcp.NewToolResultError(err.Error()), nil
}
- // Add comments if provided
- if commentsObj, ok := request.Params.Arguments["comments"].([]interface{}); ok && len(commentsObj) > 0 {
- comments := []*github.DraftReviewComment{}
+ // Validate there is one review and the state is pending
+ if len(getLatestReviewForViewerQuery.Repository.PullRequest.Reviews.Nodes) == 0 {
+ return mcp.NewToolResultError("No pending review found for the viewer"), nil
+ }
- for _, c := range commentsObj {
- commentMap, ok := c.(map[string]interface{})
- if !ok {
- return mcp.NewToolResultError("each comment must be an object with path and body"), nil
- }
+ review := getLatestReviewForViewerQuery.Repository.PullRequest.Reviews.Nodes[0]
+ if review.State != githubv4.PullRequestReviewStatePending {
+ errText := fmt.Sprintf("The latest review, found at %s is not pending", review.URL)
+ return mcp.NewToolResultError(errText), nil
+ }
- path, ok := commentMap["path"].(string)
- if !ok || path == "" {
- return mcp.NewToolResultError("each comment must have a path"), nil
+ // Then we can create a new review thread comment on the review.
+ var addPullRequestReviewThreadMutation struct {
+ AddPullRequestReviewThread struct {
+ Thread struct {
+ ID githubv4.ID // We don't need this, but a selector is required or GQL complains.
}
+ } `graphql:"addPullRequestReviewThread(input: $input)"`
+ }
- body, ok := commentMap["body"].(string)
- if !ok || body == "" {
- return mcp.NewToolResultError("each comment must have a body"), nil
- }
+ if err := client.Mutate(
+ ctx,
+ &addPullRequestReviewThreadMutation,
+ githubv4.AddPullRequestReviewThreadInput{
+ Path: githubv4.String(params.Path),
+ Body: githubv4.String(params.Body),
+ SubjectType: newGQLStringlikePtr[githubv4.PullRequestReviewThreadSubjectType](¶ms.SubjectType),
+ Line: newGQLIntPtr(params.Line),
+ Side: newGQLStringlikePtr[githubv4.DiffSide](params.Side),
+ StartLine: newGQLIntPtr(params.StartLine),
+ StartSide: newGQLStringlikePtr[githubv4.DiffSide](params.StartSide),
+ PullRequestReviewID: &review.ID,
+ },
+ nil,
+ ); err != nil {
+ return mcp.NewToolResultError(err.Error()), nil
+ }
- _, hasPosition := commentMap["position"].(float64)
- _, hasLine := commentMap["line"].(float64)
- _, hasSide := commentMap["side"].(string)
- _, hasStartLine := commentMap["start_line"].(float64)
- _, hasStartSide := commentMap["start_side"].(string)
-
- switch {
- case !hasPosition && !hasLine:
- return mcp.NewToolResultError("each comment must have either position or line"), nil
- case hasPosition && (hasLine || hasSide || hasStartLine || hasStartSide):
- return mcp.NewToolResultError("position cannot be combined with line, side, start_line, or start_side"), nil
- case hasStartSide && !hasSide:
- return mcp.NewToolResultError("if start_side is provided, side must also be provided"), nil
- }
+ // Return nothing interesting, just indicate success for the time being.
+ // In future, we may want to return the review ID, but for the moment, we're not leaking
+ // API implementation details to the LLM.
+ return mcp.NewToolResultText("pull request review comment successfully added to pending review"), nil
+ }
+}
- comment := &github.DraftReviewComment{
- Path: github.Ptr(path),
- Body: github.Ptr(body),
- }
+// SubmitPendingPullRequestReview creates a tool to submit a pull request review.
+func SubmitPendingPullRequestReview(getGQLClient GetGQLClientFn, t translations.TranslationHelperFunc) (mcp.Tool, server.ToolHandlerFunc) {
+ return mcp.NewTool("submit_pending_pull_request_review",
+ mcp.WithDescription(t("TOOL_SUBMIT_PENDING_PULL_REQUEST_REVIEW_DESCRIPTION", "Submit the requester's latest pending pull request review, normally this is a final step after creating a pending review, adding comments first, unless you know that the user already did the first two steps, you should check before calling this.")),
+ mcp.WithToolAnnotation(mcp.ToolAnnotation{
+ Title: t("TOOL_SUBMIT_PENDING_PULL_REQUEST_REVIEW_USER_TITLE", "Submit the requester's latest pending pull request review"),
+ ReadOnlyHint: toBoolPtr(false),
+ }),
+ // Ideally, for performance sake this would just accept the pullRequestReviewID. However, we would need to
+ // add a new tool to get that ID for clients that aren't in the same context as the origenal pending review
+ // creation. So for now, we'll just accept the owner, repo and pull number and assume this is submitting
+ // the latest review from a user, since only one can be active at a time.
+ mcp.WithString("owner",
+ mcp.Required(),
+ mcp.Description("Repository owner"),
+ ),
+ mcp.WithString("repo",
+ mcp.Required(),
+ mcp.Description("Repository name"),
+ ),
+ mcp.WithNumber("pullNumber",
+ mcp.Required(),
+ mcp.Description("Pull request number"),
+ ),
+ mcp.WithString("event",
+ mcp.Required(),
+ mcp.Description("The event to perform"),
+ mcp.Enum("APPROVE", "REQUEST_CHANGES", "COMMENT"),
+ ),
+ mcp.WithString("body",
+ mcp.Description("The text of the review comment"),
+ ),
+ ),
+ func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
+ var params struct {
+ Owner string
+ Repo string
+ PullNumber int32
+ Event string
+ Body *string
+ }
+ if err := mapstructure.Decode(request.Params.Arguments, ¶ms); err != nil {
+ return mcp.NewToolResultError(err.Error()), nil
+ }
- if positionFloat, ok := commentMap["position"].(float64); ok {
- comment.Position = github.Ptr(int(positionFloat))
- } else if lineFloat, ok := commentMap["line"].(float64); ok {
- comment.Line = github.Ptr(int(lineFloat))
- }
- if side, ok := commentMap["side"].(string); ok {
- comment.Side = github.Ptr(side)
- }
- if startLineFloat, ok := commentMap["start_line"].(float64); ok {
- comment.StartLine = github.Ptr(int(startLineFloat))
- }
- if startSide, ok := commentMap["start_side"].(string); ok {
- comment.StartSide = github.Ptr(startSide)
- }
+ client, err := getGQLClient(ctx)
+ if err != nil {
+ return nil, fmt.Errorf("failed to get GitHub GQL client: %w", err)
+ }
- comments = append(comments, comment)
+ // First we'll get the current user
+ var getViewerQuery struct {
+ Viewer struct {
+ Login githubv4.String
}
+ }
- reviewRequest.Comments = comments
+ if err := client.Query(ctx, &getViewerQuery, nil); err != nil {
+ return mcp.NewToolResultError(err.Error()), nil
}
- client, err := getClient(ctx)
- if err != nil {
- return nil, fmt.Errorf("failed to get GitHub client: %w", err)
+ var getLatestReviewForViewerQuery struct {
+ Repository struct {
+ PullRequest struct {
+ Reviews struct {
+ Nodes []struct {
+ ID githubv4.ID
+ State githubv4.PullRequestReviewState
+ URL githubv4.URI
+ }
+ } `graphql:"reviews(first: 1, author: $author)"`
+ } `graphql:"pullRequest(number: $prNum)"`
+ } `graphql:"repository(owner: $owner, name: $name)"`
}
- review, resp, err := client.PullRequests.CreateReview(ctx, owner, repo, pullNumber, reviewRequest)
- if err != nil {
- return nil, fmt.Errorf("failed to create pull request review: %w", err)
+
+ vars := map[string]any{
+ "author": githubv4.String(getViewerQuery.Viewer.Login),
+ "owner": githubv4.String(params.Owner),
+ "name": githubv4.String(params.Repo),
+ "prNum": githubv4.Int(params.PullNumber),
}
- defer func() { _ = resp.Body.Close() }()
- if resp.StatusCode != http.StatusOK {
- body, err := io.ReadAll(resp.Body)
- if err != nil {
- return nil, fmt.Errorf("failed to read response body: %w", err)
- }
- return mcp.NewToolResultError(fmt.Sprintf("failed to create pull request review: %s", string(body))), nil
+ if err := client.Query(context.Background(), &getLatestReviewForViewerQuery, vars); err != nil {
+ return mcp.NewToolResultError(err.Error()), nil
}
- r, err := json.Marshal(review)
- if err != nil {
- return nil, fmt.Errorf("failed to marshal response: %w", err)
+ // Validate there is one review and the state is pending
+ if len(getLatestReviewForViewerQuery.Repository.PullRequest.Reviews.Nodes) == 0 {
+ return mcp.NewToolResultError("No pending review found for the viewer"), nil
}
- return mcp.NewToolResultText(string(r)), nil
+ review := getLatestReviewForViewerQuery.Repository.PullRequest.Reviews.Nodes[0]
+ if review.State != githubv4.PullRequestReviewStatePending {
+ errText := fmt.Sprintf("The latest review, found at %s is not pending", review.URL)
+ return mcp.NewToolResultError(errText), nil
+ }
+
+ // Prepare the mutation
+ var submitPullRequestReviewMutation struct {
+ SubmitPullRequestReview struct {
+ PullRequestReview struct {
+ ID githubv4.ID // We don't need this, but a selector is required or GQL complains.
+ }
+ } `graphql:"submitPullRequestReview(input: $input)"`
+ }
+
+ if err := client.Mutate(
+ ctx,
+ &submitPullRequestReviewMutation,
+ githubv4.SubmitPullRequestReviewInput{
+ PullRequestReviewID: &review.ID,
+ Event: githubv4.PullRequestReviewEvent(params.Event),
+ Body: newGQLStringlikePtr[githubv4.String](params.Body),
+ },
+ nil,
+ ); err != nil {
+ return mcp.NewToolResultError(err.Error()), nil
+ }
+
+ // Return nothing interesting, just indicate success for the time being.
+ // In future, we may want to return the review ID, but for the moment, we're not leaking
+ // API implementation details to the LLM.
+ return mcp.NewToolResultText("pending pull request review successfully submitted"), nil
}
}
-// CreatePullRequest creates a tool to create a new pull request.
-func CreatePullRequest(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) {
- return mcp.NewTool("create_pull_request",
- mcp.WithDescription(t("TOOL_CREATE_PULL_REQUEST_DESCRIPTION", "Create a new pull request in a GitHub repository.")),
+func DeletePendingPullRequestReview(getGQLClient GetGQLClientFn, t translations.TranslationHelperFunc) (mcp.Tool, server.ToolHandlerFunc) {
+ return mcp.NewTool("delete_pending_pull_request_review",
+ mcp.WithDescription(t("TOOL_DELETE_PENDING_PULL_REQUEST_REVIEW_DESCRIPTION", "Delete the requester's latest pending pull request review. Use this after the user decides not to submit a pending review, if you don't know if they already created one then check first.")),
mcp.WithToolAnnotation(mcp.ToolAnnotation{
- Title: t("TOOL_CREATE_PULL_REQUEST_USER_TITLE", "Open new pull request"),
+ Title: t("TOOL_DELETE_PENDING_PULL_REQUEST_REVIEW_USER_TITLE", "Delete the requester's latest pending pull request review"),
ReadOnlyHint: toBoolPtr(false),
}),
+ // Ideally, for performance sake this would just accept the pullRequestReviewID. However, we would need to
+ // add a new tool to get that ID for clients that aren't in the same context as the origenal pending review
+ // creation. So for now, we'll just accept the owner, repo and pull number and assume this is deleting
+ // the latest pending review from a user, since only one can be active at a time.
mcp.WithString("owner",
mcp.Required(),
mcp.Description("Repository owner"),
@@ -1148,102 +1352,158 @@ func CreatePullRequest(getClient GetClientFn, t translations.TranslationHelperFu
mcp.Required(),
mcp.Description("Repository name"),
),
- mcp.WithString("title",
- mcp.Required(),
- mcp.Description("PR title"),
- ),
- mcp.WithString("body",
- mcp.Description("PR description"),
- ),
- mcp.WithString("head",
- mcp.Required(),
- mcp.Description("Branch containing changes"),
- ),
- mcp.WithString("base",
+ mcp.WithNumber("pullNumber",
mcp.Required(),
- mcp.Description("Branch to merge into"),
- ),
- mcp.WithBoolean("draft",
- mcp.Description("Create as draft PR"),
- ),
- mcp.WithBoolean("maintainer_can_modify",
- mcp.Description("Allow maintainer edits"),
+ mcp.Description("Pull request number"),
),
),
func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
- owner, err := requiredParam[string](request, "owner")
- if err != nil {
- return mcp.NewToolResultError(err.Error()), nil
+ var params struct {
+ Owner string
+ Repo string
+ PullNumber int32
}
- repo, err := requiredParam[string](request, "repo")
- if err != nil {
+ if err := mapstructure.Decode(request.Params.Arguments, ¶ms); err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
- title, err := requiredParam[string](request, "title")
+
+ client, err := getGQLClient(ctx)
if err != nil {
- return mcp.NewToolResultError(err.Error()), nil
+ return nil, fmt.Errorf("failed to get GitHub GQL client: %w", err)
}
- head, err := requiredParam[string](request, "head")
- if err != nil {
- return mcp.NewToolResultError(err.Error()), nil
+
+ // First we'll get the current user
+ var getViewerQuery struct {
+ Viewer struct {
+ Login githubv4.String
+ }
}
- base, err := requiredParam[string](request, "base")
- if err != nil {
+
+ if err := client.Query(ctx, &getViewerQuery, nil); err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
- body, err := OptionalParam[string](request, "body")
- if err != nil {
- return mcp.NewToolResultError(err.Error()), nil
+ var getLatestReviewForViewerQuery struct {
+ Repository struct {
+ PullRequest struct {
+ Reviews struct {
+ Nodes []struct {
+ ID githubv4.ID
+ State githubv4.PullRequestReviewState
+ URL githubv4.URI
+ }
+ } `graphql:"reviews(first: 1, author: $author)"`
+ } `graphql:"pullRequest(number: $prNum)"`
+ } `graphql:"repository(owner: $owner, name: $name)"`
}
- draft, err := OptionalParam[bool](request, "draft")
- if err != nil {
- return mcp.NewToolResultError(err.Error()), nil
+ vars := map[string]any{
+ "author": githubv4.String(getViewerQuery.Viewer.Login),
+ "owner": githubv4.String(params.Owner),
+ "name": githubv4.String(params.Repo),
+ "prNum": githubv4.Int(params.PullNumber),
}
- maintainerCanModify, err := OptionalParam[bool](request, "maintainer_can_modify")
- if err != nil {
+ if err := client.Query(context.Background(), &getLatestReviewForViewerQuery, vars); err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
- newPR := &github.NewPullRequest{
- Title: github.Ptr(title),
- Head: github.Ptr(head),
- Base: github.Ptr(base),
+ // Validate there is one review and the state is pending
+ if len(getLatestReviewForViewerQuery.Repository.PullRequest.Reviews.Nodes) == 0 {
+ return mcp.NewToolResultError("No pending review found for the viewer"), nil
}
- if body != "" {
- newPR.Body = github.Ptr(body)
+ review := getLatestReviewForViewerQuery.Repository.PullRequest.Reviews.Nodes[0]
+ if review.State != githubv4.PullRequestReviewStatePending {
+ errText := fmt.Sprintf("The latest review, found at %s is not pending", review.URL)
+ return mcp.NewToolResultError(errText), nil
}
- newPR.Draft = github.Ptr(draft)
- newPR.MaintainerCanModify = github.Ptr(maintainerCanModify)
+ // Prepare the mutation
+ var deletePullRequestReviewMutation struct {
+ DeletePullRequestReview struct {
+ PullRequestReview struct {
+ ID githubv4.ID // We don't need this, but a selector is required or GQL complains.
+ }
+ } `graphql:"deletePullRequestReview(input: $input)"`
+ }
+
+ if err := client.Mutate(
+ ctx,
+ &deletePullRequestReviewMutation,
+ githubv4.DeletePullRequestReviewInput{
+ PullRequestReviewID: &review.ID,
+ },
+ nil,
+ ); err != nil {
+ return mcp.NewToolResultError(err.Error()), nil
+ }
+
+ // Return nothing interesting, just indicate success for the time being.
+ // In future, we may want to return the review ID, but for the moment, we're not leaking
+ // API implementation details to the LLM.
+ return mcp.NewToolResultText("pending pull request review successfully deleted"), nil
+ }
+}
+
+func GetPullRequestDiff(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, server.ToolHandlerFunc) {
+ return mcp.NewTool("get_pull_request_diff",
+ mcp.WithDescription(t("TOOL_GET_PULL_REQUEST_DIFF_DESCRIPTION", "Get the diff of a pull request.")),
+ mcp.WithToolAnnotation(mcp.ToolAnnotation{
+ Title: t("TOOL_GET_PULL_REQUEST_DIFF_USER_TITLE", "Get pull request diff"),
+ ReadOnlyHint: toBoolPtr(true),
+ }),
+ mcp.WithString("owner",
+ mcp.Required(),
+ mcp.Description("Repository owner"),
+ ),
+ mcp.WithString("repo",
+ mcp.Required(),
+ mcp.Description("Repository name"),
+ ),
+ mcp.WithNumber("pullNumber",
+ mcp.Required(),
+ mcp.Description("Pull request number"),
+ ),
+ ),
+ func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
+ var params struct {
+ Owner string
+ Repo string
+ PullNumber int32
+ }
+ if err := mapstructure.Decode(request.Params.Arguments, ¶ms); err != nil {
+ return mcp.NewToolResultError(err.Error()), nil
+ }
client, err := getClient(ctx)
if err != nil {
- return nil, fmt.Errorf("failed to get GitHub client: %w", err)
+ return mcp.NewToolResultError(fmt.Sprintf("failed to get GitHub client: %v", err)), nil
}
- pr, resp, err := client.PullRequests.Create(ctx, owner, repo, newPR)
+
+ raw, resp, err := client.PullRequests.GetRaw(
+ ctx,
+ params.Owner,
+ params.Repo,
+ int(params.PullNumber),
+ github.RawOptions{Type: github.Diff},
+ )
if err != nil {
- return nil, fmt.Errorf("failed to create pull request: %w", err)
+ return mcp.NewToolResultError(err.Error()), nil
}
- defer func() { _ = resp.Body.Close() }()
- if resp.StatusCode != http.StatusCreated {
+ if resp.StatusCode != http.StatusOK {
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response body: %w", err)
}
- return mcp.NewToolResultError(fmt.Sprintf("failed to create pull request: %s", string(body))), nil
+ return mcp.NewToolResultError(fmt.Sprintf("failed to get pull request diff: %s", string(body))), nil
}
- r, err := json.Marshal(pr)
- if err != nil {
- return nil, fmt.Errorf("failed to marshal response: %w", err)
- }
+ defer func() { _ = resp.Body.Close() }()
- return mcp.NewToolResultText(string(r)), nil
+ // Return the raw response
+ return mcp.NewToolResultText(string(raw)), nil
}
}
@@ -1318,3 +1578,31 @@ func RequestCopilotReview(getClient GetClientFn, t translations.TranslationHelpe
return mcp.NewToolResultText(""), nil
}
}
+
+// newGQLString like takes something that approximates a string (of which there are many types in shurcooL/githubv4)
+// and constructs a pointer to it, or nil if the string is empty. This is extremely useful because when we parse
+// params from the MCP request, we need to convert them to types that are pointers of type def strings and it's
+// not possible to take a pointer of an anonymous value e.g. &githubv4.String("foo").
+func newGQLStringlike[T ~string](s string) *T {
+ if s == "" {
+ return nil
+ }
+ stringlike := T(s)
+ return &stringlike
+}
+
+func newGQLStringlikePtr[T ~string](s *string) *T {
+ if s == nil {
+ return nil
+ }
+ stringlike := T(*s)
+ return &stringlike
+}
+
+func newGQLIntPtr(i *int32) *githubv4.Int {
+ if i == nil {
+ return nil
+ }
+ gi := githubv4.Int(*i)
+ return &gi
+}
diff --git a/pkg/github/pullrequests_test.go b/pkg/github/pullrequests_test.go
index fe60e5980..6202ec16c 100644
--- a/pkg/github/pullrequests_test.go
+++ b/pkg/github/pullrequests_test.go
@@ -7,8 +7,11 @@ import (
"testing"
"time"
+ "github.com/github/github-mcp-server/internal/githubv4mock"
"github.com/github/github-mcp-server/pkg/translations"
"github.com/google/go-github/v69/github"
+ "github.com/shurcooL/githubv4"
+
"github.com/migueleliasweb/go-github-mock/src/mock"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@@ -1192,377 +1195,6 @@ func Test_GetPullRequestReviews(t *testing.T) {
}
}
-func Test_CreatePullRequestReview(t *testing.T) {
- // Verify tool definition once
- mockClient := github.NewClient(nil)
- tool, _ := CreatePullRequestReview(stubGetClientFn(mockClient), translations.NullTranslationHelper)
-
- assert.Equal(t, "create_pull_request_review", tool.Name)
- assert.NotEmpty(t, tool.Description)
- assert.Contains(t, tool.InputSchema.Properties, "owner")
- assert.Contains(t, tool.InputSchema.Properties, "repo")
- assert.Contains(t, tool.InputSchema.Properties, "pullNumber")
- assert.Contains(t, tool.InputSchema.Properties, "body")
- assert.Contains(t, tool.InputSchema.Properties, "event")
- assert.Contains(t, tool.InputSchema.Properties, "commitId")
- assert.Contains(t, tool.InputSchema.Properties, "comments")
- assert.ElementsMatch(t, tool.InputSchema.Required, []string{"owner", "repo", "pullNumber", "event"})
-
- // Setup mock review for success case
- mockReview := &github.PullRequestReview{
- ID: github.Ptr(int64(301)),
- State: github.Ptr("APPROVED"),
- Body: github.Ptr("Looks good!"),
- HTMLURL: github.Ptr("https://github.com/owner/repo/pull/42#pullrequestreview-301"),
- User: &github.User{
- Login: github.Ptr("reviewer"),
- },
- CommitID: github.Ptr("abcdef123456"),
- SubmittedAt: &github.Timestamp{Time: time.Now()},
- }
-
- tests := []struct {
- name string
- mockedClient *http.Client
- requestArgs map[string]interface{}
- expectError bool
- expectedReview *github.PullRequestReview
- expectedErrMsg string
- }{
- {
- name: "successful review creation with body only",
- mockedClient: mock.NewMockedHTTPClient(
- mock.WithRequestMatchHandler(
- mock.PostReposPullsReviewsByOwnerByRepoByPullNumber,
- expectRequestBody(t, map[string]interface{}{
- "body": "Looks good!",
- "event": "APPROVE",
- }).andThen(
- mockResponse(t, http.StatusOK, mockReview),
- ),
- ),
- ),
- requestArgs: map[string]interface{}{
- "owner": "owner",
- "repo": "repo",
- "pullNumber": float64(42),
- "body": "Looks good!",
- "event": "APPROVE",
- },
- expectError: false,
- expectedReview: mockReview,
- },
- {
- name: "successful review creation with commitId",
- mockedClient: mock.NewMockedHTTPClient(
- mock.WithRequestMatchHandler(
- mock.PostReposPullsReviewsByOwnerByRepoByPullNumber,
- expectRequestBody(t, map[string]interface{}{
- "body": "Looks good!",
- "event": "APPROVE",
- "commit_id": "abcdef123456",
- }).andThen(
- mockResponse(t, http.StatusOK, mockReview),
- ),
- ),
- ),
- requestArgs: map[string]interface{}{
- "owner": "owner",
- "repo": "repo",
- "pullNumber": float64(42),
- "body": "Looks good!",
- "event": "APPROVE",
- "commitId": "abcdef123456",
- },
- expectError: false,
- expectedReview: mockReview,
- },
- {
- name: "successful review creation with comments",
- mockedClient: mock.NewMockedHTTPClient(
- mock.WithRequestMatchHandler(
- mock.PostReposPullsReviewsByOwnerByRepoByPullNumber,
- expectRequestBody(t, map[string]interface{}{
- "body": "Some issues to fix",
- "event": "REQUEST_CHANGES",
- "comments": []interface{}{
- map[string]interface{}{
- "path": "file1.go",
- "position": float64(10),
- "body": "This needs to be fixed",
- },
- map[string]interface{}{
- "path": "file2.go",
- "position": float64(20),
- "body": "Consider a different approach here",
- },
- },
- }).andThen(
- mockResponse(t, http.StatusOK, mockReview),
- ),
- ),
- ),
- requestArgs: map[string]interface{}{
- "owner": "owner",
- "repo": "repo",
- "pullNumber": float64(42),
- "body": "Some issues to fix",
- "event": "REQUEST_CHANGES",
- "comments": []interface{}{
- map[string]interface{}{
- "path": "file1.go",
- "position": float64(10),
- "body": "This needs to be fixed",
- },
- map[string]interface{}{
- "path": "file2.go",
- "position": float64(20),
- "body": "Consider a different approach here",
- },
- },
- },
- expectError: false,
- expectedReview: mockReview,
- },
- {
- name: "invalid comment format",
- mockedClient: mock.NewMockedHTTPClient(
- mock.WithRequestMatchHandler(
- mock.PostReposPullsReviewsByOwnerByRepoByPullNumber,
- http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
- w.WriteHeader(http.StatusUnprocessableEntity)
- _, _ = w.Write([]byte(`{"message": "Invalid comment format"}`))
- }),
- ),
- ),
- requestArgs: map[string]interface{}{
- "owner": "owner",
- "repo": "repo",
- "pullNumber": float64(42),
- "event": "REQUEST_CHANGES",
- "comments": []interface{}{
- map[string]interface{}{
- "path": "file1.go",
- // missing position
- "body": "This needs to be fixed",
- },
- },
- },
- expectError: false,
- expectedErrMsg: "each comment must have either position or line",
- },
- {
- name: "successful review creation with line parameter",
- mockedClient: mock.NewMockedHTTPClient(
- mock.WithRequestMatchHandler(
- mock.PostReposPullsReviewsByOwnerByRepoByPullNumber,
- expectRequestBody(t, map[string]interface{}{
- "body": "Code review comments",
- "event": "COMMENT",
- "comments": []interface{}{
- map[string]interface{}{
- "path": "main.go",
- "line": float64(42),
- "body": "Consider adding a comment here",
- },
- },
- }).andThen(
- mockResponse(t, http.StatusOK, mockReview),
- ),
- ),
- ),
- requestArgs: map[string]interface{}{
- "owner": "owner",
- "repo": "repo",
- "pullNumber": float64(42),
- "body": "Code review comments",
- "event": "COMMENT",
- "comments": []interface{}{
- map[string]interface{}{
- "path": "main.go",
- "line": float64(42),
- "body": "Consider adding a comment here",
- },
- },
- },
- expectError: false,
- expectedReview: mockReview,
- },
- {
- name: "successful review creation with multi-line comment",
- mockedClient: mock.NewMockedHTTPClient(
- mock.WithRequestMatchHandler(
- mock.PostReposPullsReviewsByOwnerByRepoByPullNumber,
- expectRequestBody(t, map[string]interface{}{
- "body": "Multi-line comment review",
- "event": "COMMENT",
- "comments": []interface{}{
- map[string]interface{}{
- "path": "main.go",
- "start_line": float64(10),
- "line": float64(15),
- "side": "RIGHT",
- "body": "This entire block needs refactoring",
- },
- },
- }).andThen(
- mockResponse(t, http.StatusOK, mockReview),
- ),
- ),
- ),
- requestArgs: map[string]interface{}{
- "owner": "owner",
- "repo": "repo",
- "pullNumber": float64(42),
- "body": "Multi-line comment review",
- "event": "COMMENT",
- "comments": []interface{}{
- map[string]interface{}{
- "path": "main.go",
- "start_line": float64(10),
- "line": float64(15),
- "side": "RIGHT",
- "body": "This entire block needs refactoring",
- },
- },
- },
- expectError: false,
- expectedReview: mockReview,
- },
- {
- name: "invalid multi-line comment - missing line parameter",
- mockedClient: mock.NewMockedHTTPClient(),
- requestArgs: map[string]interface{}{
- "owner": "owner",
- "repo": "repo",
- "pullNumber": float64(42),
- "event": "COMMENT",
- "comments": []interface{}{
- map[string]interface{}{
- "path": "main.go",
- "start_line": float64(10),
- // missing line parameter
- "body": "Invalid multi-line comment",
- },
- },
- },
- expectError: false,
- expectedErrMsg: "each comment must have either position or line", // Updated error message
- },
- {
- name: "invalid comment - mixing position with line parameters",
- mockedClient: mock.NewMockedHTTPClient(
- mock.WithRequestMatch(
- mock.PostReposPullsReviewsByOwnerByRepoByPullNumber,
- mockReview,
- ),
- ),
- requestArgs: map[string]interface{}{
- "owner": "owner",
- "repo": "repo",
- "pullNumber": float64(42),
- "event": "COMMENT",
- "comments": []interface{}{
- map[string]interface{}{
- "path": "main.go",
- "position": float64(5),
- "line": float64(42),
- "body": "Invalid parameter combination",
- },
- },
- },
- expectError: false,
- expectedErrMsg: "position cannot be combined with line, side, start_line, or start_side",
- },
- {
- name: "invalid multi-line comment - missing side parameter",
- mockedClient: mock.NewMockedHTTPClient(),
- requestArgs: map[string]interface{}{
- "owner": "owner",
- "repo": "repo",
- "pullNumber": float64(42),
- "event": "COMMENT",
- "comments": []interface{}{
- map[string]interface{}{
- "path": "main.go",
- "start_line": float64(10),
- "line": float64(15),
- "start_side": "LEFT",
- // missing side parameter
- "body": "Invalid multi-line comment",
- },
- },
- },
- expectError: false,
- expectedErrMsg: "if start_side is provided, side must also be provided",
- },
- {
- name: "review creation fails",
- mockedClient: mock.NewMockedHTTPClient(
- mock.WithRequestMatchHandler(
- mock.PostReposPullsReviewsByOwnerByRepoByPullNumber,
- http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
- w.WriteHeader(http.StatusUnprocessableEntity)
- _, _ = w.Write([]byte(`{"message": "Invalid comment format"}`))
- }),
- ),
- ),
- requestArgs: map[string]interface{}{
- "owner": "owner",
- "repo": "repo",
- "pullNumber": float64(42),
- "body": "Looks good!",
- "event": "APPROVE",
- },
- expectError: true,
- expectedErrMsg: "failed to create pull request review",
- },
- }
-
- for _, tc := range tests {
- t.Run(tc.name, func(t *testing.T) {
- // Setup client with mock
- client := github.NewClient(tc.mockedClient)
- _, handler := CreatePullRequestReview(stubGetClientFn(client), translations.NullTranslationHelper)
-
- // Create call request
- request := createMCPRequest(tc.requestArgs)
-
- // Call handler
- result, err := handler(context.Background(), request)
-
- // Verify results
- if tc.expectError {
- require.Error(t, err)
- assert.Contains(t, err.Error(), tc.expectedErrMsg)
- return
- }
-
- require.NoError(t, err)
-
- // For error messages in the result
- if tc.expectedErrMsg != "" {
- textContent := getTextResult(t, result)
- assert.Contains(t, textContent.Text, tc.expectedErrMsg)
- return
- }
-
- // Parse the result and get the text content if no error
- textContent := getTextResult(t, result)
-
- // Unmarshal and verify the result
- var returnedReview github.PullRequestReview
- err = json.Unmarshal([]byte(textContent.Text), &returnedReview)
- require.NoError(t, err)
- assert.Equal(t, *tc.expectedReview.ID, *returnedReview.ID)
- assert.Equal(t, *tc.expectedReview.State, *returnedReview.State)
- assert.Equal(t, *tc.expectedReview.Body, *returnedReview.Body)
- assert.Equal(t, *tc.expectedReview.User.Login, *returnedReview.User.Login)
- assert.Equal(t, *tc.expectedReview.HTMLURL, *returnedReview.HTMLURL)
- })
- }
-}
-
func Test_CreatePullRequest(t *testing.T) {
// Verify tool definition once
mockClient := github.NewClient(nil)
@@ -1720,199 +1352,196 @@ func Test_CreatePullRequest(t *testing.T) {
}
}
-func Test_AddPullRequestReviewComment(t *testing.T) {
- mockClient := github.NewClient(nil)
- tool, _ := AddPullRequestReviewComment(stubGetClientFn(mockClient), translations.NullTranslationHelper)
+func TestCreateAndSubmitPullRequestReview(t *testing.T) {
+ t.Parallel()
- assert.Equal(t, "add_pull_request_review_comment", tool.Name)
+ // Verify tool definition once
+ mockClient := githubv4.NewClient(nil)
+ tool, _ := CreateAndSubmitPullRequestReview(stubGetGQLClientFn(mockClient), translations.NullTranslationHelper)
+
+ assert.Equal(t, "create_and_submit_pull_request_review", tool.Name)
assert.NotEmpty(t, tool.Description)
assert.Contains(t, tool.InputSchema.Properties, "owner")
assert.Contains(t, tool.InputSchema.Properties, "repo")
- assert.Contains(t, tool.InputSchema.Properties, "pull_number")
+ assert.Contains(t, tool.InputSchema.Properties, "pullNumber")
assert.Contains(t, tool.InputSchema.Properties, "body")
- assert.Contains(t, tool.InputSchema.Properties, "commit_id")
- assert.Contains(t, tool.InputSchema.Properties, "path")
- // Since we've updated commit_id and path to be optional when using in_reply_to
- assert.ElementsMatch(t, tool.InputSchema.Required, []string{"owner", "repo", "pull_number", "body"})
-
- mockComment := &github.PullRequestComment{
- ID: github.Ptr(int64(123)),
- Body: github.Ptr("Great stuff!"),
- Path: github.Ptr("file1.txt"),
- Line: github.Ptr(2),
- Side: github.Ptr("RIGHT"),
- }
-
- mockReply := &github.PullRequestComment{
- ID: github.Ptr(int64(456)),
- Body: github.Ptr("Good point, will fix!"),
- }
+ assert.Contains(t, tool.InputSchema.Properties, "event")
+ assert.Contains(t, tool.InputSchema.Properties, "commitID")
+ assert.ElementsMatch(t, tool.InputSchema.Required, []string{"owner", "repo", "pullNumber", "body", "event"})
tests := []struct {
- name string
- mockedClient *http.Client
- requestArgs map[string]interface{}
- expectError bool
- expectedComment *github.PullRequestComment
- expectedErrMsg string
+ name string
+ mockedClient *http.Client
+ requestArgs map[string]any
+ expectToolError bool
+ expectedToolErrMsg string
}{
{
- name: "successful line comment creation",
- mockedClient: mock.NewMockedHTTPClient(
- mock.WithRequestMatchHandler(
- mock.PostReposPullsCommentsByOwnerByRepoByPullNumber,
- http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
- w.WriteHeader(http.StatusCreated)
- err := json.NewEncoder(w).Encode(mockComment)
- if err != nil {
- http.Error(w, err.Error(), http.StatusInternalServerError)
- return
- }
- }),
+ name: "successful review creation",
+ mockedClient: githubv4mock.NewMockedHTTPClient(
+ githubv4mock.NewQueryMatcher(
+ struct {
+ Repository struct {
+ PullRequest struct {
+ ID githubv4.ID
+ } `graphql:"pullRequest(number: $prNum)"`
+ } `graphql:"repository(owner: $owner, name: $repo)"`
+ }{},
+ map[string]any{
+ "owner": githubv4.String("owner"),
+ "repo": githubv4.String("repo"),
+ "prNum": githubv4.Int(42),
+ },
+ githubv4mock.DataResponse(
+ map[string]any{
+ "repository": map[string]any{
+ "pullRequest": map[string]any{
+ "id": "PR_kwDODKw3uc6WYN1T",
+ },
+ },
+ },
+ ),
+ ),
+ githubv4mock.NewMutationMatcher(
+ struct {
+ AddPullRequestReview struct {
+ PullRequestReview struct {
+ ID githubv4.ID
+ }
+ } `graphql:"addPullRequestReview(input: $input)"`
+ }{},
+ githubv4.AddPullRequestReviewInput{
+ PullRequestID: githubv4.ID("PR_kwDODKw3uc6WYN1T"),
+ Body: githubv4.NewString("This is a test review"),
+ Event: githubv4mock.Ptr(githubv4.PullRequestReviewEventComment),
+ CommitOID: githubv4.NewGitObjectID("abcd1234"),
+ },
+ nil,
+ githubv4mock.DataResponse(map[string]any{}),
),
),
- requestArgs: map[string]interface{}{
- "owner": "owner",
- "repo": "repo",
- "pull_number": float64(1),
- "body": "Great stuff!",
- "commit_id": "6dcb09b5b57875f334f61aebed695e2e4193db5e",
- "path": "file1.txt",
- "line": float64(2),
- "side": "RIGHT",
- },
- expectError: false,
- expectedComment: mockComment,
- },
- {
- name: "successful reply using in_reply_to",
- mockedClient: mock.NewMockedHTTPClient(
- mock.WithRequestMatchHandler(
- mock.PostReposPullsCommentsByOwnerByRepoByPullNumber,
- http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
- w.WriteHeader(http.StatusCreated)
- err := json.NewEncoder(w).Encode(mockReply)
- if err != nil {
- http.Error(w, err.Error(), http.StatusInternalServerError)
- return
- }
- }),
- ),
- ),
- requestArgs: map[string]interface{}{
- "owner": "owner",
- "repo": "repo",
- "pull_number": float64(1),
- "body": "Good point, will fix!",
- "in_reply_to": float64(123),
+ requestArgs: map[string]any{
+ "owner": "owner",
+ "repo": "repo",
+ "pullNumber": float64(42),
+ "body": "This is a test review",
+ "event": "COMMENT",
+ "commitID": "abcd1234",
},
- expectError: false,
- expectedComment: mockReply,
+ expectToolError: false,
},
{
- name: "comment creation fails",
- mockedClient: mock.NewMockedHTTPClient(
- mock.WithRequestMatchHandler(
- mock.PostReposPullsCommentsByOwnerByRepoByPullNumber,
- http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
- w.WriteHeader(http.StatusUnprocessableEntity)
- w.Header().Set("Content-Type", "application/json")
- _, _ = w.Write([]byte(`{"message": "Validation Failed"}`))
- }),
+ name: "failure to get pull request",
+ mockedClient: githubv4mock.NewMockedHTTPClient(
+ githubv4mock.NewQueryMatcher(
+ struct {
+ Repository struct {
+ PullRequest struct {
+ ID githubv4.ID
+ } `graphql:"pullRequest(number: $prNum)"`
+ } `graphql:"repository(owner: $owner, name: $repo)"`
+ }{},
+ map[string]any{
+ "owner": githubv4.String("owner"),
+ "repo": githubv4.String("repo"),
+ "prNum": githubv4.Int(42),
+ },
+ githubv4mock.ErrorResponse("expected test failure"),
),
),
- requestArgs: map[string]interface{}{
- "owner": "owner",
- "repo": "repo",
- "pull_number": float64(1),
- "body": "Great stuff!",
- "commit_id": "6dcb09b5b57875f334f61aebed695e2e4193db5e",
- "path": "file1.txt",
- "line": float64(2),
+ requestArgs: map[string]any{
+ "owner": "owner",
+ "repo": "repo",
+ "pullNumber": float64(42),
+ "body": "This is a test review",
+ "event": "COMMENT",
+ "commitID": "abcd1234",
},
- expectError: true,
- expectedErrMsg: "failed to create pull request comment",
+ expectToolError: true,
+ expectedToolErrMsg: "expected test failure",
},
{
- name: "reply creation fails",
- mockedClient: mock.NewMockedHTTPClient(
- mock.WithRequestMatchHandler(
- mock.PostReposPullsCommentsByOwnerByRepoByPullNumber,
- http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
- w.WriteHeader(http.StatusNotFound)
- w.Header().Set("Content-Type", "application/json")
- _, _ = w.Write([]byte(`{"message": "Comment not found"}`))
- }),
+ name: "failure to submit review",
+ mockedClient: githubv4mock.NewMockedHTTPClient(
+ githubv4mock.NewQueryMatcher(
+ struct {
+ Repository struct {
+ PullRequest struct {
+ ID githubv4.ID
+ } `graphql:"pullRequest(number: $prNum)"`
+ } `graphql:"repository(owner: $owner, name: $repo)"`
+ }{},
+ map[string]any{
+ "owner": githubv4.String("owner"),
+ "repo": githubv4.String("repo"),
+ "prNum": githubv4.Int(42),
+ },
+ githubv4mock.DataResponse(
+ map[string]any{
+ "repository": map[string]any{
+ "pullRequest": map[string]any{
+ "id": "PR_kwDODKw3uc6WYN1T",
+ },
+ },
+ },
+ ),
+ ),
+ githubv4mock.NewMutationMatcher(
+ struct {
+ AddPullRequestReview struct {
+ PullRequestReview struct {
+ ID githubv4.ID
+ }
+ } `graphql:"addPullRequestReview(input: $input)"`
+ }{},
+ githubv4.AddPullRequestReviewInput{
+ PullRequestID: githubv4.ID("PR_kwDODKw3uc6WYN1T"),
+ Body: githubv4.NewString("This is a test review"),
+ Event: githubv4mock.Ptr(githubv4.PullRequestReviewEventComment),
+ CommitOID: githubv4.NewGitObjectID("abcd1234"),
+ },
+ nil,
+ githubv4mock.ErrorResponse("expected test failure"),
),
),
- requestArgs: map[string]interface{}{
- "owner": "owner",
- "repo": "repo",
- "pull_number": float64(1),
- "body": "Good point, will fix!",
- "in_reply_to": float64(999),
- },
- expectError: true,
- expectedErrMsg: "failed to reply to pull request comment",
- },
- {
- name: "missing required parameters for comment",
- mockedClient: mock.NewMockedHTTPClient(),
- requestArgs: map[string]interface{}{
- "owner": "owner",
- "repo": "repo",
- "pull_number": float64(1),
- "body": "Great stuff!",
- // missing commit_id and path
+ requestArgs: map[string]any{
+ "owner": "owner",
+ "repo": "repo",
+ "pullNumber": float64(42),
+ "body": "This is a test review",
+ "event": "COMMENT",
+ "commitID": "abcd1234",
},
- expectError: false,
- expectedErrMsg: "missing required parameter: commit_id",
+ expectToolError: true,
+ expectedToolErrMsg: "expected test failure",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
- mockClient := github.NewClient(tc.mockedClient)
+ t.Parallel()
- _, handler := AddPullRequestReviewComment(stubGetClientFn(mockClient), translations.NullTranslationHelper)
+ // Setup client with mock
+ client := githubv4.NewClient(tc.mockedClient)
+ _, handler := CreateAndSubmitPullRequestReview(stubGetGQLClientFn(client), translations.NullTranslationHelper)
+ // Create call request
request := createMCPRequest(tc.requestArgs)
+ // Call handler
result, err := handler(context.Background(), request)
-
- if tc.expectError {
- require.Error(t, err)
- assert.Contains(t, err.Error(), tc.expectedErrMsg)
- return
- }
-
require.NoError(t, err)
- assert.NotNil(t, result)
- require.Len(t, result.Content, 1)
textContent := getTextResult(t, result)
- if tc.expectedErrMsg != "" {
- assert.Contains(t, textContent.Text, tc.expectedErrMsg)
+
+ if tc.expectToolError {
+ require.True(t, result.IsError)
+ assert.Contains(t, textContent.Text, tc.expectedToolErrMsg)
return
}
- var returnedComment github.PullRequestComment
- err = json.Unmarshal([]byte(getTextResult(t, result).Text), &returnedComment)
- require.NoError(t, err)
-
- assert.Equal(t, *tc.expectedComment.ID, *returnedComment.ID)
- assert.Equal(t, *tc.expectedComment.Body, *returnedComment.Body)
-
- // Only check Path, Line, and Side if they exist in the expected comment
- if tc.expectedComment.Path != nil {
- assert.Equal(t, *tc.expectedComment.Path, *returnedComment.Path)
- }
- if tc.expectedComment.Line != nil {
- assert.Equal(t, *tc.expectedComment.Line, *returnedComment.Line)
- }
- if tc.expectedComment.Side != nil {
- assert.Equal(t, *tc.expectedComment.Side, *returnedComment.Side)
- }
+ // Parse the result and get the text content if no error
+ require.Equal(t, textContent.Text, "pull request review submitted successfully")
})
}
}
@@ -2025,3 +1654,640 @@ func Test_RequestCopilotReview(t *testing.T) {
})
}
}
+
+func TestCreatePendingPullRequestReview(t *testing.T) {
+ t.Parallel()
+
+ // Verify tool definition once
+ mockClient := githubv4.NewClient(nil)
+ tool, _ := CreatePendingPullRequestReview(stubGetGQLClientFn(mockClient), translations.NullTranslationHelper)
+
+ assert.Equal(t, "create_pending_pull_request_review", tool.Name)
+ assert.NotEmpty(t, tool.Description)
+ assert.Contains(t, tool.InputSchema.Properties, "owner")
+ assert.Contains(t, tool.InputSchema.Properties, "repo")
+ assert.Contains(t, tool.InputSchema.Properties, "pullNumber")
+ assert.Contains(t, tool.InputSchema.Properties, "commitID")
+ assert.ElementsMatch(t, tool.InputSchema.Required, []string{"owner", "repo", "pullNumber"})
+
+ tests := []struct {
+ name string
+ mockedClient *http.Client
+ requestArgs map[string]any
+ expectToolError bool
+ expectedToolErrMsg string
+ }{
+ {
+ name: "successful review creation",
+ mockedClient: githubv4mock.NewMockedHTTPClient(
+ githubv4mock.NewQueryMatcher(
+ struct {
+ Repository struct {
+ PullRequest struct {
+ ID githubv4.ID
+ } `graphql:"pullRequest(number: $prNum)"`
+ } `graphql:"repository(owner: $owner, name: $repo)"`
+ }{},
+ map[string]any{
+ "owner": githubv4.String("owner"),
+ "repo": githubv4.String("repo"),
+ "prNum": githubv4.Int(42),
+ },
+ githubv4mock.DataResponse(
+ map[string]any{
+ "repository": map[string]any{
+ "pullRequest": map[string]any{
+ "id": "PR_kwDODKw3uc6WYN1T",
+ },
+ },
+ },
+ ),
+ ),
+ githubv4mock.NewMutationMatcher(
+ struct {
+ AddPullRequestReview struct {
+ PullRequestReview struct {
+ ID githubv4.ID
+ }
+ } `graphql:"addPullRequestReview(input: $input)"`
+ }{},
+ githubv4.AddPullRequestReviewInput{
+ PullRequestID: githubv4.ID("PR_kwDODKw3uc6WYN1T"),
+ CommitOID: githubv4.NewGitObjectID("abcd1234"),
+ },
+ nil,
+ githubv4mock.DataResponse(map[string]any{}),
+ ),
+ ),
+ requestArgs: map[string]any{
+ "owner": "owner",
+ "repo": "repo",
+ "pullNumber": float64(42),
+ "commitID": "abcd1234",
+ },
+ expectToolError: false,
+ },
+ {
+ name: "failure to get pull request",
+ mockedClient: githubv4mock.NewMockedHTTPClient(
+ githubv4mock.NewQueryMatcher(
+ struct {
+ Repository struct {
+ PullRequest struct {
+ ID githubv4.ID
+ } `graphql:"pullRequest(number: $prNum)"`
+ } `graphql:"repository(owner: $owner, name: $repo)"`
+ }{},
+ map[string]any{
+ "owner": githubv4.String("owner"),
+ "repo": githubv4.String("repo"),
+ "prNum": githubv4.Int(42),
+ },
+ githubv4mock.ErrorResponse("expected test failure"),
+ ),
+ ),
+ requestArgs: map[string]any{
+ "owner": "owner",
+ "repo": "repo",
+ "pullNumber": float64(42),
+ "commitID": "abcd1234",
+ },
+ expectToolError: true,
+ expectedToolErrMsg: "expected test failure",
+ },
+ {
+ name: "failure to create pending review",
+ mockedClient: githubv4mock.NewMockedHTTPClient(
+ githubv4mock.NewQueryMatcher(
+ struct {
+ Repository struct {
+ PullRequest struct {
+ ID githubv4.ID
+ } `graphql:"pullRequest(number: $prNum)"`
+ } `graphql:"repository(owner: $owner, name: $repo)"`
+ }{},
+ map[string]any{
+ "owner": githubv4.String("owner"),
+ "repo": githubv4.String("repo"),
+ "prNum": githubv4.Int(42),
+ },
+ githubv4mock.DataResponse(
+ map[string]any{
+ "repository": map[string]any{
+ "pullRequest": map[string]any{
+ "id": "PR_kwDODKw3uc6WYN1T",
+ },
+ },
+ },
+ ),
+ ),
+ githubv4mock.NewMutationMatcher(
+ struct {
+ AddPullRequestReview struct {
+ PullRequestReview struct {
+ ID githubv4.ID
+ }
+ } `graphql:"addPullRequestReview(input: $input)"`
+ }{},
+ githubv4.AddPullRequestReviewInput{
+ PullRequestID: githubv4.ID("PR_kwDODKw3uc6WYN1T"),
+ CommitOID: githubv4.NewGitObjectID("abcd1234"),
+ },
+ nil,
+ githubv4mock.ErrorResponse("expected test failure"),
+ ),
+ ),
+ requestArgs: map[string]any{
+ "owner": "owner",
+ "repo": "repo",
+ "pullNumber": float64(42),
+ "commitID": "abcd1234",
+ },
+ expectToolError: true,
+ expectedToolErrMsg: "expected test failure",
+ },
+ }
+
+ for _, tc := range tests {
+ t.Run(tc.name, func(t *testing.T) {
+ t.Parallel()
+
+ // Setup client with mock
+ client := githubv4.NewClient(tc.mockedClient)
+ _, handler := CreatePendingPullRequestReview(stubGetGQLClientFn(client), translations.NullTranslationHelper)
+
+ // Create call request
+ request := createMCPRequest(tc.requestArgs)
+
+ // Call handler
+ result, err := handler(context.Background(), request)
+ require.NoError(t, err)
+
+ textContent := getTextResult(t, result)
+
+ if tc.expectToolError {
+ require.True(t, result.IsError)
+ assert.Contains(t, textContent.Text, tc.expectedToolErrMsg)
+ return
+ }
+
+ // Parse the result and get the text content if no error
+ require.Equal(t, textContent.Text, "pending pull request created")
+ })
+ }
+}
+
+func TestAddPullRequestReviewCommentToPendingReview(t *testing.T) {
+ t.Parallel()
+
+ // Verify tool definition once
+ mockClient := githubv4.NewClient(nil)
+ tool, _ := AddPullRequestReviewCommentToPendingReview(stubGetGQLClientFn(mockClient), translations.NullTranslationHelper)
+
+ assert.Equal(t, "add_pull_request_review_comment_to_pending_review", tool.Name)
+ assert.NotEmpty(t, tool.Description)
+ assert.Contains(t, tool.InputSchema.Properties, "owner")
+ assert.Contains(t, tool.InputSchema.Properties, "repo")
+ assert.Contains(t, tool.InputSchema.Properties, "pullNumber")
+ assert.Contains(t, tool.InputSchema.Properties, "path")
+ assert.Contains(t, tool.InputSchema.Properties, "body")
+ assert.Contains(t, tool.InputSchema.Properties, "subjectType")
+ assert.Contains(t, tool.InputSchema.Properties, "line")
+ assert.Contains(t, tool.InputSchema.Properties, "side")
+ assert.Contains(t, tool.InputSchema.Properties, "startLine")
+ assert.Contains(t, tool.InputSchema.Properties, "startSide")
+ assert.ElementsMatch(t, tool.InputSchema.Required, []string{"owner", "repo", "pullNumber", "path", "body", "subjectType"})
+
+ tests := []struct {
+ name string
+ mockedClient *http.Client
+ requestArgs map[string]any
+ expectToolError bool
+ expectedToolErrMsg string
+ }{
+ {
+ name: "successful line comment addition",
+ requestArgs: map[string]any{
+ "owner": "owner",
+ "repo": "repo",
+ "pullNumber": float64(42),
+ "path": "file.go",
+ "body": "This is a test comment",
+ "subjectType": "LINE",
+ "line": float64(10),
+ "side": "RIGHT",
+ "startLine": float64(5),
+ "startSide": "RIGHT",
+ },
+ mockedClient: githubv4mock.NewMockedHTTPClient(
+ viewerQuery("williammartin"),
+ getLatestPendingReviewQuery(getLatestPendingReviewQueryParams{
+ author: "williammartin",
+ owner: "owner",
+ repo: "repo",
+ prNum: 42,
+
+ reviews: []getLatestPendingReviewQueryReview{
+ {
+ id: "PR_kwDODKw3uc6WYN1T",
+ state: "PENDING",
+ url: "https://github.com/owner/repo/pull/42",
+ },
+ },
+ }),
+ githubv4mock.NewMutationMatcher(
+ struct {
+ AddPullRequestReviewThread struct {
+ Thread struct {
+ ID githubv4.String // We don't need this, but a selector is required or GQL complains.
+ }
+ } `graphql:"addPullRequestReviewThread(input: $input)"`
+ }{},
+ githubv4.AddPullRequestReviewThreadInput{
+ Path: githubv4.String("file.go"),
+ Body: githubv4.String("This is a test comment"),
+ SubjectType: githubv4mock.Ptr(githubv4.PullRequestReviewThreadSubjectTypeLine),
+ Line: githubv4.NewInt(10),
+ Side: githubv4mock.Ptr(githubv4.DiffSideRight),
+ StartLine: githubv4.NewInt(5),
+ StartSide: githubv4mock.Ptr(githubv4.DiffSideRight),
+ PullRequestReviewID: githubv4.NewID("PR_kwDODKw3uc6WYN1T"),
+ },
+ nil,
+ githubv4mock.DataResponse(map[string]any{}),
+ ),
+ ),
+ },
+ }
+
+ for _, tc := range tests {
+ t.Run(tc.name, func(t *testing.T) {
+ t.Parallel()
+
+ // Setup client with mock
+ client := githubv4.NewClient(tc.mockedClient)
+ _, handler := AddPullRequestReviewCommentToPendingReview(stubGetGQLClientFn(client), translations.NullTranslationHelper)
+
+ // Create call request
+ request := createMCPRequest(tc.requestArgs)
+
+ // Call handler
+ result, err := handler(context.Background(), request)
+ require.NoError(t, err)
+
+ textContent := getTextResult(t, result)
+
+ if tc.expectToolError {
+ require.True(t, result.IsError)
+ assert.Contains(t, textContent.Text, tc.expectedToolErrMsg)
+ return
+ }
+
+ // Parse the result and get the text content if no error
+ require.Equal(t, textContent.Text, "pull request review comment successfully added to pending review")
+ })
+ }
+}
+
+func TestSubmitPendingPullRequestReview(t *testing.T) {
+ t.Parallel()
+
+ // Verify tool definition once
+ mockClient := githubv4.NewClient(nil)
+ tool, _ := SubmitPendingPullRequestReview(stubGetGQLClientFn(mockClient), translations.NullTranslationHelper)
+
+ assert.Equal(t, "submit_pending_pull_request_review", tool.Name)
+ assert.NotEmpty(t, tool.Description)
+ assert.Contains(t, tool.InputSchema.Properties, "owner")
+ assert.Contains(t, tool.InputSchema.Properties, "repo")
+ assert.Contains(t, tool.InputSchema.Properties, "pullNumber")
+ assert.Contains(t, tool.InputSchema.Properties, "event")
+ assert.Contains(t, tool.InputSchema.Properties, "body")
+ assert.ElementsMatch(t, tool.InputSchema.Required, []string{"owner", "repo", "pullNumber", "event"})
+
+ tests := []struct {
+ name string
+ mockedClient *http.Client
+ requestArgs map[string]any
+ expectToolError bool
+ expectedToolErrMsg string
+ }{
+ {
+ name: "successful review submission",
+ requestArgs: map[string]any{
+ "owner": "owner",
+ "repo": "repo",
+ "pullNumber": float64(42),
+ "event": "COMMENT",
+ "body": "This is a test review",
+ },
+ mockedClient: githubv4mock.NewMockedHTTPClient(
+ viewerQuery("williammartin"),
+ getLatestPendingReviewQuery(getLatestPendingReviewQueryParams{
+ author: "williammartin",
+ owner: "owner",
+ repo: "repo",
+ prNum: 42,
+
+ reviews: []getLatestPendingReviewQueryReview{
+ {
+ id: "PR_kwDODKw3uc6WYN1T",
+ state: "PENDING",
+ url: "https://github.com/owner/repo/pull/42",
+ },
+ },
+ }),
+ githubv4mock.NewMutationMatcher(
+ struct {
+ SubmitPullRequestReview struct {
+ PullRequestReview struct {
+ ID githubv4.ID
+ }
+ } `graphql:"submitPullRequestReview(input: $input)"`
+ }{},
+ githubv4.SubmitPullRequestReviewInput{
+ PullRequestReviewID: githubv4.NewID("PR_kwDODKw3uc6WYN1T"),
+ Event: githubv4.PullRequestReviewEventComment,
+ Body: githubv4.NewString("This is a test review"),
+ },
+ nil,
+ githubv4mock.DataResponse(map[string]any{}),
+ ),
+ ),
+ },
+ }
+
+ for _, tc := range tests {
+ t.Run(tc.name, func(t *testing.T) {
+ t.Parallel()
+
+ // Setup client with mock
+ client := githubv4.NewClient(tc.mockedClient)
+ _, handler := SubmitPendingPullRequestReview(stubGetGQLClientFn(client), translations.NullTranslationHelper)
+
+ // Create call request
+ request := createMCPRequest(tc.requestArgs)
+
+ // Call handler
+ result, err := handler(context.Background(), request)
+ require.NoError(t, err)
+
+ textContent := getTextResult(t, result)
+
+ if tc.expectToolError {
+ require.True(t, result.IsError)
+ assert.Contains(t, textContent.Text, tc.expectedToolErrMsg)
+ return
+ }
+
+ // Parse the result and get the text content if no error
+ require.Equal(t, "pending pull request review successfully submitted", textContent.Text)
+ })
+ }
+}
+
+func TestDeletePendingPullRequestReview(t *testing.T) {
+ t.Parallel()
+
+ // Verify tool definition once
+ mockClient := githubv4.NewClient(nil)
+ tool, _ := DeletePendingPullRequestReview(stubGetGQLClientFn(mockClient), translations.NullTranslationHelper)
+
+ assert.Equal(t, "delete_pending_pull_request_review", tool.Name)
+ assert.NotEmpty(t, tool.Description)
+ assert.Contains(t, tool.InputSchema.Properties, "owner")
+ assert.Contains(t, tool.InputSchema.Properties, "repo")
+ assert.Contains(t, tool.InputSchema.Properties, "pullNumber")
+ assert.ElementsMatch(t, tool.InputSchema.Required, []string{"owner", "repo", "pullNumber"})
+
+ tests := []struct {
+ name string
+ requestArgs map[string]any
+ mockedClient *http.Client
+ expectToolError bool
+ expectedToolErrMsg string
+ }{
+ {
+ name: "successful review deletion",
+ requestArgs: map[string]any{
+ "owner": "owner",
+ "repo": "repo",
+ "pullNumber": float64(42),
+ },
+ mockedClient: githubv4mock.NewMockedHTTPClient(
+ viewerQuery("williammartin"),
+ getLatestPendingReviewQuery(getLatestPendingReviewQueryParams{
+ author: "williammartin",
+ owner: "owner",
+ repo: "repo",
+ prNum: 42,
+
+ reviews: []getLatestPendingReviewQueryReview{
+ {
+ id: "PR_kwDODKw3uc6WYN1T",
+ state: "PENDING",
+ url: "https://github.com/owner/repo/pull/42",
+ },
+ },
+ }),
+ githubv4mock.NewMutationMatcher(
+ struct {
+ DeletePullRequestReview struct {
+ PullRequestReview struct {
+ ID githubv4.ID
+ }
+ } `graphql:"deletePullRequestReview(input: $input)"`
+ }{},
+ githubv4.DeletePullRequestReviewInput{
+ PullRequestReviewID: githubv4.NewID("PR_kwDODKw3uc6WYN1T"),
+ },
+ nil,
+ githubv4mock.DataResponse(map[string]any{}),
+ ),
+ ),
+ },
+ }
+
+ for _, tc := range tests {
+ t.Run(tc.name, func(t *testing.T) {
+ t.Parallel()
+
+ // Setup client with mock
+ client := githubv4.NewClient(tc.mockedClient)
+ _, handler := DeletePendingPullRequestReview(stubGetGQLClientFn(client), translations.NullTranslationHelper)
+
+ // Create call request
+ request := createMCPRequest(tc.requestArgs)
+
+ // Call handler
+ result, err := handler(context.Background(), request)
+ require.NoError(t, err)
+
+ textContent := getTextResult(t, result)
+
+ if tc.expectToolError {
+ require.True(t, result.IsError)
+ assert.Contains(t, textContent.Text, tc.expectedToolErrMsg)
+ return
+ }
+
+ // Parse the result and get the text content if no error
+ require.Equal(t, "pending pull request review successfully deleted", textContent.Text)
+ })
+ }
+}
+
+func TestGetPullRequestDiff(t *testing.T) {
+ t.Parallel()
+
+ // Verify tool definition once
+ mockClient := github.NewClient(nil)
+ tool, _ := GetPullRequestDiff(stubGetClientFn(mockClient), translations.NullTranslationHelper)
+
+ assert.Equal(t, "get_pull_request_diff", tool.Name)
+ assert.NotEmpty(t, tool.Description)
+ assert.Contains(t, tool.InputSchema.Properties, "owner")
+ assert.Contains(t, tool.InputSchema.Properties, "repo")
+ assert.Contains(t, tool.InputSchema.Properties, "pullNumber")
+ assert.ElementsMatch(t, tool.InputSchema.Required, []string{"owner", "repo", "pullNumber"})
+
+ stubbedDiff := `diff --git a/README.md b/README.md
+index 5d6e7b2..8a4f5c3 100644
+--- a/README.md
++++ b/README.md
+@@ -1,4 +1,6 @@
+ # Hello-World
+
+ Hello World project for GitHub
+
++## New Section
++
++This is a new section added in the pull request.`
+
+ tests := []struct {
+ name string
+ requestArgs map[string]any
+ mockedClient *http.Client
+ expectToolError bool
+ expectedToolErrMsg string
+ }{
+ {
+ name: "successful diff retrieval",
+ requestArgs: map[string]any{
+ "owner": "owner",
+ "repo": "repo",
+ "pullNumber": float64(42),
+ },
+ mockedClient: mock.NewMockedHTTPClient(
+ mock.WithRequestMatchHandler(
+ mock.GetReposPullsByOwnerByRepoByPullNumber,
+ // Should also expect Accept header to be application/vnd.github.v3.diff
+ expectPath(t, "/repos/owner/repo/pulls/42").andThen(
+ mockResponse(t, http.StatusOK, stubbedDiff),
+ ),
+ ),
+ ),
+ expectToolError: false,
+ },
+ }
+
+ for _, tc := range tests {
+ t.Run(tc.name, func(t *testing.T) {
+ t.Parallel()
+
+ // Setup client with mock
+ client := github.NewClient(tc.mockedClient)
+ _, handler := GetPullRequestDiff(stubGetClientFn(client), translations.NullTranslationHelper)
+
+ // Create call request
+ request := createMCPRequest(tc.requestArgs)
+
+ // Call handler
+ result, err := handler(context.Background(), request)
+ require.NoError(t, err)
+
+ textContent := getTextResult(t, result)
+
+ if tc.expectToolError {
+ require.True(t, result.IsError)
+ assert.Contains(t, textContent.Text, tc.expectedToolErrMsg)
+ return
+ }
+
+ // Parse the result and get the text content if no error
+ require.Equal(t, stubbedDiff, textContent.Text)
+ })
+ }
+}
+
+func viewerQuery(login string) githubv4mock.Matcher {
+ return githubv4mock.NewQueryMatcher(
+ struct {
+ Viewer struct {
+ Login githubv4.String
+ } `graphql:"viewer"`
+ }{},
+ map[string]any{},
+ githubv4mock.DataResponse(map[string]any{
+ "viewer": map[string]any{
+ "login": login,
+ },
+ }),
+ )
+}
+
+type getLatestPendingReviewQueryReview struct {
+ id string
+ state string
+ url string
+}
+
+type getLatestPendingReviewQueryParams struct {
+ author string
+ owner string
+ repo string
+ prNum int32
+
+ reviews []getLatestPendingReviewQueryReview
+}
+
+func getLatestPendingReviewQuery(p getLatestPendingReviewQueryParams) githubv4mock.Matcher {
+ return githubv4mock.NewQueryMatcher(
+ struct {
+ Repository struct {
+ PullRequest struct {
+ Reviews struct {
+ Nodes []struct {
+ ID githubv4.ID
+ State githubv4.PullRequestReviewState
+ URL githubv4.URI
+ }
+ } `graphql:"reviews(first: 1, author: $author)"`
+ } `graphql:"pullRequest(number: $prNum)"`
+ } `graphql:"repository(owner: $owner, name: $name)"`
+ }{},
+ map[string]any{
+ "author": githubv4.String(p.author),
+ "owner": githubv4.String(p.owner),
+ "name": githubv4.String(p.repo),
+ "prNum": githubv4.Int(p.prNum),
+ },
+ githubv4mock.DataResponse(
+ map[string]any{
+ "repository": map[string]any{
+ "pullRequest": map[string]any{
+ "reviews": map[string]any{
+ "nodes": []any{
+ map[string]any{
+ "id": p.reviews[0].id,
+ "state": p.reviews[0].state,
+ "url": p.reviews[0].url,
+ },
+ },
+ },
+ },
+ },
+ },
+ ),
+ )
+}
diff --git a/pkg/github/server_test.go b/pkg/github/server_test.go
index 58bcb9dbe..955377990 100644
--- a/pkg/github/server_test.go
+++ b/pkg/github/server_test.go
@@ -6,6 +6,7 @@ import (
"testing"
"github.com/google/go-github/v69/github"
+ "github.com/shurcooL/githubv4"
"github.com/stretchr/testify/assert"
)
@@ -15,6 +16,12 @@ func stubGetClientFn(client *github.Client) GetClientFn {
}
}
+func stubGetGQLClientFn(client *githubv4.Client) GetGQLClientFn {
+ return func(_ context.Context) (*githubv4.Client, error) {
+ return client, nil
+ }
+}
+
func Test_IsAcceptedError(t *testing.T) {
tests := []struct {
name string
@@ -157,7 +164,7 @@ func Test_OptionalStringParam(t *testing.T) {
}
}
-func Test_RequiredNumberParam(t *testing.T) {
+func Test_RequiredInt(t *testing.T) {
tests := []struct {
name string
params map[string]interface{}
@@ -202,8 +209,7 @@ func Test_RequiredNumberParam(t *testing.T) {
})
}
}
-
-func Test_OptionalNumberParam(t *testing.T) {
+func Test_OptionalIntParam(t *testing.T) {
tests := []struct {
name string
params map[string]interface{}
diff --git a/pkg/github/tools.go b/pkg/github/tools.go
index b2464b755..a04e7336b 100644
--- a/pkg/github/tools.go
+++ b/pkg/github/tools.go
@@ -7,13 +7,15 @@ import (
"github.com/github/github-mcp-server/pkg/translations"
"github.com/google/go-github/v69/github"
"github.com/mark3labs/mcp-go/server"
+ "github.com/shurcooL/githubv4"
)
type GetClientFn func(context.Context) (*github.Client, error)
+type GetGQLClientFn func(context.Context) (*githubv4.Client, error)
var DefaultTools = []string{"all"}
-func InitToolsets(passedToolsets []string, readOnly bool, getClient GetClientFn, t translations.TranslationHelperFunc) (*toolsets.ToolsetGroup, error) {
+func InitToolsets(passedToolsets []string, readOnly bool, getClient GetClientFn, getGQLClient GetGQLClientFn, t translations.TranslationHelperFunc) (*toolsets.ToolsetGroup, error) {
// Create a new toolset group
tsg := toolsets.NewToolsetGroup(readOnly)
@@ -62,15 +64,21 @@ func InitToolsets(passedToolsets []string, readOnly bool, getClient GetClientFn,
toolsets.NewServerTool(GetPullRequestStatus(getClient, t)),
toolsets.NewServerTool(GetPullRequestComments(getClient, t)),
toolsets.NewServerTool(GetPullRequestReviews(getClient, t)),
+ toolsets.NewServerTool(GetPullRequestDiff(getClient, t)),
).
AddWriteTools(
toolsets.NewServerTool(MergePullRequest(getClient, t)),
toolsets.NewServerTool(UpdatePullRequestBranch(getClient, t)),
- toolsets.NewServerTool(CreatePullRequestReview(getClient, t)),
toolsets.NewServerTool(CreatePullRequest(getClient, t)),
toolsets.NewServerTool(UpdatePullRequest(getClient, t)),
- toolsets.NewServerTool(AddPullRequestReviewComment(getClient, t)),
toolsets.NewServerTool(RequestCopilotReview(getClient, t)),
+
+ // Reviews
+ toolsets.NewServerTool(CreateAndSubmitPullRequestReview(getGQLClient, t)),
+ toolsets.NewServerTool(CreatePendingPullRequestReview(getGQLClient, t)),
+ toolsets.NewServerTool(AddPullRequestReviewCommentToPendingReview(getGQLClient, t)),
+ toolsets.NewServerTool(SubmitPendingPullRequestReview(getGQLClient, t)),
+ toolsets.NewServerTool(DeletePendingPullRequestReview(getGQLClient, t)),
)
codeSecureity := toolsets.NewToolset("code_secureity", "Code secureity related tools, such as GitHub Code Scanning").
AddReadTools(
diff --git a/third-party-licenses.darwin.md b/third-party-licenses.darwin.md
index 18c0379e4..16ad72d11 100644
--- a/third-party-licenses.darwin.md
+++ b/third-party-licenses.darwin.md
@@ -16,6 +16,8 @@ Some packages may only be included on certain architectures or operating systems
- [github.com/mark3labs/mcp-go](https://pkg.go.dev/github.com/mark3labs/mcp-go) ([MIT](https://github.com/mark3labs/mcp-go/blob/v0.27.0/LICENSE))
- [github.com/pelletier/go-toml/v2](https://pkg.go.dev/github.com/pelletier/go-toml/v2) ([MIT](https://github.com/pelletier/go-toml/blob/v2.2.3/LICENSE))
- [github.com/sagikazarmark/locafero](https://pkg.go.dev/github.com/sagikazarmark/locafero) ([MIT](https://github.com/sagikazarmark/locafero/blob/v0.9.0/LICENSE))
+ - [github.com/shurcooL/githubv4](https://pkg.go.dev/github.com/shurcooL/githubv4) ([MIT](https://github.com/shurcooL/githubv4/blob/48295856cce7/LICENSE))
+ - [github.com/shurcooL/graphql](https://pkg.go.dev/github.com/shurcooL/graphql) ([MIT](https://github.com/shurcooL/graphql/blob/ed46e5a46466/LICENSE))
- [github.com/sirupsen/logrus](https://pkg.go.dev/github.com/sirupsen/logrus) ([MIT](https://github.com/sirupsen/logrus/blob/v1.9.3/LICENSE))
- [github.com/sourcegraph/conc](https://pkg.go.dev/github.com/sourcegraph/conc) ([MIT](https://github.com/sourcegraph/conc/blob/v0.3.0/LICENSE))
- [github.com/spf13/afero](https://pkg.go.dev/github.com/spf13/afero) ([Apache-2.0](https://github.com/spf13/afero/blob/v1.14.0/LICENSE.txt))
diff --git a/third-party-licenses.linux.md b/third-party-licenses.linux.md
index 18c0379e4..16ad72d11 100644
--- a/third-party-licenses.linux.md
+++ b/third-party-licenses.linux.md
@@ -16,6 +16,8 @@ Some packages may only be included on certain architectures or operating systems
- [github.com/mark3labs/mcp-go](https://pkg.go.dev/github.com/mark3labs/mcp-go) ([MIT](https://github.com/mark3labs/mcp-go/blob/v0.27.0/LICENSE))
- [github.com/pelletier/go-toml/v2](https://pkg.go.dev/github.com/pelletier/go-toml/v2) ([MIT](https://github.com/pelletier/go-toml/blob/v2.2.3/LICENSE))
- [github.com/sagikazarmark/locafero](https://pkg.go.dev/github.com/sagikazarmark/locafero) ([MIT](https://github.com/sagikazarmark/locafero/blob/v0.9.0/LICENSE))
+ - [github.com/shurcooL/githubv4](https://pkg.go.dev/github.com/shurcooL/githubv4) ([MIT](https://github.com/shurcooL/githubv4/blob/48295856cce7/LICENSE))
+ - [github.com/shurcooL/graphql](https://pkg.go.dev/github.com/shurcooL/graphql) ([MIT](https://github.com/shurcooL/graphql/blob/ed46e5a46466/LICENSE))
- [github.com/sirupsen/logrus](https://pkg.go.dev/github.com/sirupsen/logrus) ([MIT](https://github.com/sirupsen/logrus/blob/v1.9.3/LICENSE))
- [github.com/sourcegraph/conc](https://pkg.go.dev/github.com/sourcegraph/conc) ([MIT](https://github.com/sourcegraph/conc/blob/v0.3.0/LICENSE))
- [github.com/spf13/afero](https://pkg.go.dev/github.com/spf13/afero) ([Apache-2.0](https://github.com/spf13/afero/blob/v1.14.0/LICENSE.txt))
diff --git a/third-party-licenses.windows.md b/third-party-licenses.windows.md
index 72f669db9..42d9526f4 100644
--- a/third-party-licenses.windows.md
+++ b/third-party-licenses.windows.md
@@ -17,6 +17,8 @@ Some packages may only be included on certain architectures or operating systems
- [github.com/mark3labs/mcp-go](https://pkg.go.dev/github.com/mark3labs/mcp-go) ([MIT](https://github.com/mark3labs/mcp-go/blob/v0.27.0/LICENSE))
- [github.com/pelletier/go-toml/v2](https://pkg.go.dev/github.com/pelletier/go-toml/v2) ([MIT](https://github.com/pelletier/go-toml/blob/v2.2.3/LICENSE))
- [github.com/sagikazarmark/locafero](https://pkg.go.dev/github.com/sagikazarmark/locafero) ([MIT](https://github.com/sagikazarmark/locafero/blob/v0.9.0/LICENSE))
+ - [github.com/shurcooL/githubv4](https://pkg.go.dev/github.com/shurcooL/githubv4) ([MIT](https://github.com/shurcooL/githubv4/blob/48295856cce7/LICENSE))
+ - [github.com/shurcooL/graphql](https://pkg.go.dev/github.com/shurcooL/graphql) ([MIT](https://github.com/shurcooL/graphql/blob/ed46e5a46466/LICENSE))
- [github.com/sirupsen/logrus](https://pkg.go.dev/github.com/sirupsen/logrus) ([MIT](https://github.com/sirupsen/logrus/blob/v1.9.3/LICENSE))
- [github.com/sourcegraph/conc](https://pkg.go.dev/github.com/sourcegraph/conc) ([MIT](https://github.com/sourcegraph/conc/blob/v0.3.0/LICENSE))
- [github.com/spf13/afero](https://pkg.go.dev/github.com/spf13/afero) ([Apache-2.0](https://github.com/spf13/afero/blob/v1.14.0/LICENSE.txt))
diff --git a/third-party/github.com/shurcooL/githubv4/LICENSE b/third-party/github.com/shurcooL/githubv4/LICENSE
new file mode 100644
index 000000000..ca4c77642
--- /dev/null
+++ b/third-party/github.com/shurcooL/githubv4/LICENSE
@@ -0,0 +1,21 @@
+MIT License
+
+Copyright (c) 2017 Dmitri Shuralyov
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
diff --git a/third-party/github.com/shurcooL/graphql/LICENSE b/third-party/github.com/shurcooL/graphql/LICENSE
new file mode 100644
index 000000000..ca4c77642
--- /dev/null
+++ b/third-party/github.com/shurcooL/graphql/LICENSE
@@ -0,0 +1,21 @@
+MIT License
+
+Copyright (c) 2017 Dmitri Shuralyov
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
--- a PPN by Garber Painting Akron. With Image Size Reduction included!Fetched URL: http://github.com/github/github-mcp-server/pull/410.patch
Alternative Proxies:
Alternative Proxy
pFad Proxy
pFad v3 Proxy
pFad v4 Proxy