From 3dddc2abcd67a95916edbcab572768ff1079afbb Mon Sep 17 00:00:00 2001 From: Javier Uruen Val Date: Sun, 16 Mar 2025 20:18:47 +0100 Subject: [PATCH] add initial tests --- go.mod | 9 +- go.sum | 12 +- pkg/github/code_scanning.go | 5 +- pkg/github/code_scanning_test.go | 230 +++++++ pkg/github/helper_test.go | 49 ++ pkg/github/issues.go | 7 +- pkg/github/issues_test.go | 371 ++++++++++++ pkg/github/pullrequests.go | 24 +- pkg/github/pullrequests_test.go | 990 +++++++++++++++++++++++++++++++ pkg/github/repositories.go | 10 +- pkg/github/repositories_test.go | 909 ++++++++++++++++++++++++++++ pkg/github/search_test.go | 429 ++++++++++++++ pkg/github/server.go | 10 +- pkg/github/server_test.go | 168 ++++++ 14 files changed, 3203 insertions(+), 20 deletions(-) create mode 100644 pkg/github/code_scanning_test.go create mode 100644 pkg/github/helper_test.go create mode 100644 pkg/github/issues_test.go create mode 100644 pkg/github/pullrequests_test.go create mode 100644 pkg/github/repositories_test.go create mode 100644 pkg/github/search_test.go create mode 100644 pkg/github/server_test.go diff --git a/go.mod b/go.mod index e53b8b6b1..4338a69d3 100644 --- a/go.mod +++ b/go.mod @@ -6,21 +6,27 @@ require ( github.com/aws/smithy-go v1.22.3 github.com/google/go-github/v69 v69.2.0 github.com/mark3labs/mcp-go v0.11.2 + github.com/migueleliasweb/go-github-mock v1.1.0 github.com/sirupsen/logrus v1.9.3 github.com/spf13/cobra v1.9.1 github.com/spf13/viper v1.19.0 + github.com/stretchr/testify v1.9.0 golang.org/x/exp v0.0.0-20230905200255-921286631fa9 ) require ( + github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect github.com/fsnotify/fsnotify v1.7.0 // indirect + github.com/google/go-github/v64 v64.0.0 // indirect github.com/google/go-querystring v1.1.0 // indirect github.com/google/uuid v1.6.0 // indirect + github.com/gorilla/mux v1.8.0 // indirect github.com/hashicorp/hcl v1.0.0 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/magiconair/properties v1.8.7 // indirect github.com/mitchellh/mapstructure v1.5.0 // indirect github.com/pelletier/go-toml/v2 v2.2.2 // indirect + github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect github.com/sagikazarmark/locafero v0.4.0 // indirect github.com/sagikazarmark/slog-shim v0.1.0 // indirect github.com/sourcegraph/conc v0.3.0 // indirect @@ -31,7 +37,8 @@ require ( go.uber.org/atomic v1.9.0 // indirect go.uber.org/multierr v1.9.0 // indirect golang.org/x/sys v0.18.0 // indirect - golang.org/x/text v0.14.0 // indirect + golang.org/x/text v0.19.0 // indirect + golang.org/x/time v0.5.0 // indirect gopkg.in/ini.v1 v1.67.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 42c7171f6..5aa0482da 100644 --- a/go.sum +++ b/go.sum @@ -12,12 +12,16 @@ github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyT github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/google/go-github/v64 v64.0.0 h1:4G61sozmY3eiPAjjoOHponXDBONm+utovTKbyUb2Qdg= +github.com/google/go-github/v64 v64.0.0/go.mod h1:xB3vqMQNdHzilXBiO2I+M7iEFtHf+DP/omBOv6tQzVo= github.com/google/go-github/v69 v69.2.0 h1:wR+Wi/fN2zdUx9YxSmYE0ktiX9IAR/BeePzeaUUbEHE= github.com/google/go-github/v69 v69.2.0/go.mod h1:xne4jymxLR6Uj9b7J7PyTpkMYstEMMwGZa0Aehh1azM= github.com/google/go-querystring v1.1.0 h1:AnCroh3fv4ZBgVIf1Iwtovgjaw/GiKJo8M8yD/fhyJ8= github.com/google/go-querystring v1.1.0/go.mod h1:Kcdr2DB4koayq7X8pmAG4sNG59So17icRSOU623lUBU= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/gorilla/mux v1.8.0 h1:i40aqfkR1h2SlN9hojwV5ZA91wcXFOvkdNIeFDP5koI= +github.com/gorilla/mux v1.8.0/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB71So= github.com/hashicorp/hcl v1.0.0 h1:0Anlzjpi4vEasTeNFn2mLJgTSwt0+6sfsiTG8qcWGx4= github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= @@ -30,6 +34,8 @@ github.com/magiconair/properties v1.8.7 h1:IeQXZAiQcpL9mgcAe1Nu6cX9LLw6ExEHKjN0V github.com/magiconair/properties v1.8.7/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0= github.com/mark3labs/mcp-go v0.11.2 h1:mCxWFUTrcXOtJIn9t7F8bxAL8rpE/ZZTTnx3PU/VNdA= github.com/mark3labs/mcp-go v0.11.2/go.mod h1:cjMlBU0cv/cj9kjlgmRhoJ5JREdS7YX83xeIG9Ko/jE= +github.com/migueleliasweb/go-github-mock v1.1.0 h1:GKaOBPsrPGkAKgtfuWY8MclS1xR6MInkx1SexJucMwE= +github.com/migueleliasweb/go-github-mock v1.1.0/go.mod h1:pYe/XlGs4BGMfRY4vmeixVsODHnVDDhJ9zoi0qzSMHc= github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY= github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM= @@ -80,8 +86,10 @@ golang.org/x/exp v0.0.0-20230905200255-921286631fa9/go.mod h1:S2oDrQGGwySpoQPVqR golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.18.0 h1:DBdB3niSjOA/O0blCZBqDefyWNYveAYMNF1Wum0DYQ4= golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= -golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= +golang.org/x/text v0.19.0 h1:kTxAhCbGbxhK0IwgSKiMO5awPoDQ0RpfiVYBfK860YM= +golang.org/x/text v0.19.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY= +golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk= +golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo= diff --git a/pkg/github/code_scanning.go b/pkg/github/code_scanning.go index da7147443..0d9547ebc 100644 --- a/pkg/github/code_scanning.go +++ b/pkg/github/code_scanning.go @@ -5,6 +5,7 @@ import ( "encoding/json" "fmt" "io" + "net/http" "github.com/google/go-github/v69/github" "github.com/mark3labs/mcp-go/mcp" @@ -38,7 +39,7 @@ func getCodeScanningAlert(client *github.Client) (tool mcp.Tool, handler server. } defer func() { _ = resp.Body.Close() }() - if resp.StatusCode != 200 { + if resp.StatusCode != http.StatusOK { body, err := io.ReadAll(resp.Body) if err != nil { return nil, fmt.Errorf("failed to read response body: %w", err) @@ -90,7 +91,7 @@ func listCodeScanningAlerts(client *github.Client) (tool mcp.Tool, handler serve } defer func() { _ = resp.Body.Close() }() - if resp.StatusCode != 200 { + if resp.StatusCode != http.StatusOK { body, err := io.ReadAll(resp.Body) if err != nil { return nil, fmt.Errorf("failed to read response body: %w", err) diff --git a/pkg/github/code_scanning_test.go b/pkg/github/code_scanning_test.go new file mode 100644 index 000000000..149c8b039 --- /dev/null +++ b/pkg/github/code_scanning_test.go @@ -0,0 +1,230 @@ +package github + +import ( + "context" + "encoding/json" + "net/http" + "testing" + + "github.com/google/go-github/v69/github" + "github.com/migueleliasweb/go-github-mock/src/mock" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_GetCodeScanningAlert(t *testing.T) { + // Verify tool definition once + mockClient := github.NewClient(nil) + tool, _ := getCodeScanningAlert(mockClient) + + assert.Equal(t, "get_code_scanning_alert", tool.Name) + assert.NotEmpty(t, tool.Description) + assert.Contains(t, tool.InputSchema.Properties, "owner") + assert.Contains(t, tool.InputSchema.Properties, "repo") + assert.Contains(t, tool.InputSchema.Properties, "alert_number") + assert.ElementsMatch(t, tool.InputSchema.Required, []string{"owner", "repo", "alert_number"}) + + // Setup mock alert for success case + mockAlert := &github.Alert{ + Number: github.Ptr(42), + State: github.Ptr("open"), + Rule: &github.Rule{ID: github.Ptr("test-rule"), Description: github.Ptr("Test Rule Description")}, + HTMLURL: github.Ptr("https://github.com/owner/repo/security/code-scanning/42"), + } + + tests := []struct { + name string + mockedClient *http.Client + requestArgs map[string]interface{} + expectError bool + expectedAlert *github.Alert + expectedErrMsg string + }{ + { + name: "successful alert fetch", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatch( + mock.GetReposCodeScanningAlertsByOwnerByRepoByAlertNumber, + mockAlert, + ), + ), + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + "alert_number": float64(42), + }, + expectError: false, + expectedAlert: mockAlert, + }, + { + name: "alert fetch fails", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.GetReposCodeScanningAlertsByOwnerByRepoByAlertNumber, + http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNotFound) + _, _ = w.Write([]byte(`{"message": "Not Found"}`)) + }), + ), + ), + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + "alert_number": float64(9999), + }, + expectError: true, + expectedErrMsg: "failed to get alert", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Setup client with mock + client := github.NewClient(tc.mockedClient) + _, handler := getCodeScanningAlert(client) + + // Create call request + request := createMCPRequest(tc.requestArgs) + + // Call handler + result, err := handler(context.Background(), request) + + // Verify results + if tc.expectError { + require.Error(t, err) + assert.Contains(t, err.Error(), tc.expectedErrMsg) + return + } + + require.NoError(t, err) + + // Parse the result and get the text content if no error + textContent := getTextResult(t, result) + + // Unmarshal and verify the result + var returnedAlert github.Alert + err = json.Unmarshal([]byte(textContent.Text), &returnedAlert) + assert.NoError(t, err) + assert.Equal(t, *tc.expectedAlert.Number, *returnedAlert.Number) + assert.Equal(t, *tc.expectedAlert.State, *returnedAlert.State) + assert.Equal(t, *tc.expectedAlert.Rule.ID, *returnedAlert.Rule.ID) + assert.Equal(t, *tc.expectedAlert.HTMLURL, *returnedAlert.HTMLURL) + + }) + } +} + +func Test_ListCodeScanningAlerts(t *testing.T) { + // Verify tool definition once + mockClient := github.NewClient(nil) + tool, _ := listCodeScanningAlerts(mockClient) + + assert.Equal(t, "list_code_scanning_alerts", tool.Name) + assert.NotEmpty(t, tool.Description) + assert.Contains(t, tool.InputSchema.Properties, "owner") + assert.Contains(t, tool.InputSchema.Properties, "repo") + assert.Contains(t, tool.InputSchema.Properties, "ref") + assert.Contains(t, tool.InputSchema.Properties, "state") + assert.Contains(t, tool.InputSchema.Properties, "severity") + assert.ElementsMatch(t, tool.InputSchema.Required, []string{"owner", "repo"}) + + // Setup mock alerts for success case + mockAlerts := []*github.Alert{ + { + Number: github.Ptr(42), + State: github.Ptr("open"), + Rule: &github.Rule{ID: github.Ptr("test-rule-1"), Description: github.Ptr("Test Rule 1")}, + HTMLURL: github.Ptr("https://github.com/owner/repo/security/code-scanning/42"), + }, + { + Number: github.Ptr(43), + State: github.Ptr("fixed"), + Rule: &github.Rule{ID: github.Ptr("test-rule-2"), Description: github.Ptr("Test Rule 2")}, + HTMLURL: github.Ptr("https://github.com/owner/repo/security/code-scanning/43"), + }, + } + + tests := []struct { + name string + mockedClient *http.Client + requestArgs map[string]interface{} + expectError bool + expectedAlerts []*github.Alert + expectedErrMsg string + }{ + { + name: "successful alerts listing", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatch( + mock.GetReposCodeScanningAlertsByOwnerByRepo, + mockAlerts, + ), + ), + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + "ref": "main", + "state": "open", + "severity": "high", + }, + expectError: false, + expectedAlerts: mockAlerts, + }, + { + name: "alerts listing fails", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.GetReposCodeScanningAlertsByOwnerByRepo, + http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusUnauthorized) + _, _ = w.Write([]byte(`{"message": "Unauthorized access"}`)) + }), + ), + ), + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + }, + expectError: true, + expectedErrMsg: "failed to list alerts", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Setup client with mock + client := github.NewClient(tc.mockedClient) + _, handler := listCodeScanningAlerts(client) + + // Create call request + request := createMCPRequest(tc.requestArgs) + + // Call handler + result, err := handler(context.Background(), request) + + // Verify results + if tc.expectError { + require.Error(t, err) + assert.Contains(t, err.Error(), tc.expectedErrMsg) + return + } + + require.NoError(t, err) + + // Parse the result and get the text content if no error + textContent := getTextResult(t, result) + + // Unmarshal and verify the result + var returnedAlerts []*github.Alert + err = json.Unmarshal([]byte(textContent.Text), &returnedAlerts) + assert.NoError(t, err) + assert.Len(t, returnedAlerts, len(tc.expectedAlerts)) + for i, alert := range returnedAlerts { + assert.Equal(t, *tc.expectedAlerts[i].Number, *alert.Number) + assert.Equal(t, *tc.expectedAlerts[i].State, *alert.State) + assert.Equal(t, *tc.expectedAlerts[i].Rule.ID, *alert.Rule.ID) + assert.Equal(t, *tc.expectedAlerts[i].HTMLURL, *alert.HTMLURL) + } + }) + } +} diff --git a/pkg/github/helper_test.go b/pkg/github/helper_test.go new file mode 100644 index 000000000..5e71f4187 --- /dev/null +++ b/pkg/github/helper_test.go @@ -0,0 +1,49 @@ +package github + +import ( + "encoding/json" + "github.com/stretchr/testify/assert" + "net/http" + "testing" + + "github.com/mark3labs/mcp-go/mcp" + "github.com/stretchr/testify/require" +) + +// mockResponse is a helper function to create a mock HTTP response handler +// that returns a specified status code and marshalled body. +func mockResponse(t *testing.T, code int, body interface{}) http.HandlerFunc { + t.Helper() + return func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(code) + b, err := json.Marshal(body) + require.NoError(t, err) + _, _ = w.Write(b) + } +} + +// createMCPRequest is a helper function to create a MCP request with the given arguments. +func createMCPRequest(args map[string]interface{}) mcp.CallToolRequest { + return mcp.CallToolRequest{ + Params: struct { + Name string `json:"name"` + Arguments map[string]interface{} `json:"arguments,omitempty"` + Meta *struct { + ProgressToken mcp.ProgressToken `json:"progressToken,omitempty"` + } `json:"_meta,omitempty"` + }{ + Arguments: args, + }, + } +} + +// getTextResult is a helper function that returns a text result from a tool call. +func getTextResult(t *testing.T, result *mcp.CallToolResult) mcp.TextContent { + t.Helper() + assert.NotNil(t, result) + require.Len(t, result.Content, 1) + require.IsType(t, mcp.TextContent{}, result.Content[0]) + textContent := result.Content[0].(mcp.TextContent) + assert.Equal(t, "text", textContent.Type) + return textContent +} diff --git a/pkg/github/issues.go b/pkg/github/issues.go index c7c172892..6a43e59d5 100644 --- a/pkg/github/issues.go +++ b/pkg/github/issues.go @@ -5,6 +5,7 @@ import ( "encoding/json" "fmt" "io" + "net/http" "github.com/google/go-github/v69/github" "github.com/mark3labs/mcp-go/mcp" @@ -39,7 +40,7 @@ func getIssue(client *github.Client) (tool mcp.Tool, handler server.ToolHandlerF } defer func() { _ = resp.Body.Close() }() - if resp.StatusCode != 200 { + if resp.StatusCode != http.StatusOK { body, err := io.ReadAll(resp.Body) if err != nil { return nil, fmt.Errorf("failed to read response body: %w", err) @@ -93,7 +94,7 @@ func addIssueComment(client *github.Client) (tool mcp.Tool, handler server.ToolH } defer func() { _ = resp.Body.Close() }() - if resp.StatusCode != 201 { + if resp.StatusCode != http.StatusCreated { body, err := io.ReadAll(resp.Body) if err != nil { return nil, fmt.Errorf("failed to read response body: %w", err) @@ -165,7 +166,7 @@ func searchIssues(client *github.Client) (tool mcp.Tool, handler server.ToolHand } defer func() { _ = resp.Body.Close() }() - if resp.StatusCode != 200 { + if resp.StatusCode != http.StatusOK { body, err := io.ReadAll(resp.Body) if err != nil { return nil, fmt.Errorf("failed to read response body: %w", err) diff --git a/pkg/github/issues_test.go b/pkg/github/issues_test.go new file mode 100644 index 000000000..7e9944b35 --- /dev/null +++ b/pkg/github/issues_test.go @@ -0,0 +1,371 @@ +package github + +import ( + "context" + "encoding/json" + "net/http" + "testing" + + "github.com/google/go-github/v69/github" + "github.com/mark3labs/mcp-go/mcp" + "github.com/migueleliasweb/go-github-mock/src/mock" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_GetIssue(t *testing.T) { + // Verify tool definition once + mockClient := github.NewClient(nil) + tool, _ := getIssue(mockClient) + + assert.Equal(t, "get_issue", tool.Name) + assert.NotEmpty(t, tool.Description) + assert.Contains(t, tool.InputSchema.Properties, "owner") + assert.Contains(t, tool.InputSchema.Properties, "repo") + assert.Contains(t, tool.InputSchema.Properties, "issue_number") + assert.ElementsMatch(t, tool.InputSchema.Required, []string{"owner", "repo", "issue_number"}) + + // Setup mock issue for success case + mockIssue := &github.Issue{ + Number: github.Ptr(42), + Title: github.Ptr("Test Issue"), + Body: github.Ptr("This is a test issue"), + State: github.Ptr("open"), + HTMLURL: github.Ptr("https://github.com/owner/repo/issues/42"), + } + + tests := []struct { + name string + mockedClient *http.Client + requestArgs map[string]interface{} + expectError bool + expectedIssue *github.Issue + expectedErrMsg string + }{ + { + name: "successful issue retrieval", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatch( + mock.GetReposIssuesByOwnerByRepoByIssueNumber, + mockIssue, + ), + ), + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + "issue_number": float64(42), + }, + expectError: false, + expectedIssue: mockIssue, + }, + { + name: "issue not found", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.GetReposIssuesByOwnerByRepoByIssueNumber, + mockResponse(t, http.StatusNotFound, `{"message": "Issue not found"}`), + ), + ), + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + "issue_number": float64(999), + }, + expectError: true, + expectedErrMsg: "failed to get issue", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Setup client with mock + client := github.NewClient(tc.mockedClient) + _, handler := getIssue(client) + + // Create call request + request := createMCPRequest(tc.requestArgs) + + // Call handler + result, err := handler(context.Background(), request) + + // Verify results + if tc.expectError { + require.Error(t, err) + assert.Contains(t, err.Error(), tc.expectedErrMsg) + return + } + + require.NoError(t, err) + textContent := getTextResult(t, result) + + // Unmarshal and verify the result + var returnedIssue github.Issue + err = json.Unmarshal([]byte(textContent.Text), &returnedIssue) + require.NoError(t, err) + assert.Equal(t, *tc.expectedIssue.Number, *returnedIssue.Number) + assert.Equal(t, *tc.expectedIssue.Title, *returnedIssue.Title) + assert.Equal(t, *tc.expectedIssue.Body, *returnedIssue.Body) + }) + } +} + +func Test_AddIssueComment(t *testing.T) { + // Verify tool definition once + mockClient := github.NewClient(nil) + tool, _ := addIssueComment(mockClient) + + assert.Equal(t, "add_issue_comment", tool.Name) + assert.NotEmpty(t, tool.Description) + assert.Contains(t, tool.InputSchema.Properties, "owner") + assert.Contains(t, tool.InputSchema.Properties, "repo") + assert.Contains(t, tool.InputSchema.Properties, "issue_number") + assert.Contains(t, tool.InputSchema.Properties, "body") + assert.ElementsMatch(t, tool.InputSchema.Required, []string{"owner", "repo", "issue_number", "body"}) + + // Setup mock comment for success case + mockComment := &github.IssueComment{ + ID: github.Ptr(int64(123)), + Body: github.Ptr("This is a test comment"), + User: &github.User{ + Login: github.Ptr("testuser"), + }, + HTMLURL: github.Ptr("https://github.com/owner/repo/issues/42#issuecomment-123"), + } + + tests := []struct { + name string + mockedClient *http.Client + requestArgs map[string]interface{} + expectError bool + expectedComment *github.IssueComment + expectedErrMsg string + }{ + { + name: "successful comment creation", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.PostReposIssuesCommentsByOwnerByRepoByIssueNumber, + mockResponse(t, http.StatusCreated, mockComment), + ), + ), + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + "issue_number": float64(42), + "body": "This is a test comment", + }, + expectError: false, + expectedComment: mockComment, + }, + { + name: "comment creation fails", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.PostReposIssuesCommentsByOwnerByRepoByIssueNumber, + http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusUnprocessableEntity) + _, _ = w.Write([]byte(`{"message": "Invalid request"}`)) + }), + ), + ), + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + "issue_number": float64(42), + "body": "", + }, + expectError: true, + expectedErrMsg: "failed to create comment", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Setup client with mock + client := github.NewClient(tc.mockedClient) + _, handler := addIssueComment(client) + + // Create call request + request := mcp.CallToolRequest{ + Params: struct { + Name string `json:"name"` + Arguments map[string]interface{} `json:"arguments,omitempty"` + Meta *struct { + ProgressToken mcp.ProgressToken `json:"progressToken,omitempty"` + } `json:"_meta,omitempty"` + }{ + Arguments: tc.requestArgs, + }, + } + + // Call handler + result, err := handler(context.Background(), request) + + // Verify results + if tc.expectError { + require.Error(t, err) + assert.Contains(t, err.Error(), tc.expectedErrMsg) + return + } + + require.NoError(t, err) + + // Parse the result and get the text content if no error + textContent := getTextResult(t, result) + + // Unmarshal and verify the result + var returnedComment github.IssueComment + err = json.Unmarshal([]byte(textContent.Text), &returnedComment) + require.NoError(t, err) + assert.Equal(t, *tc.expectedComment.ID, *returnedComment.ID) + assert.Equal(t, *tc.expectedComment.Body, *returnedComment.Body) + assert.Equal(t, *tc.expectedComment.User.Login, *returnedComment.User.Login) + + }) + } +} + +func Test_SearchIssues(t *testing.T) { + // Verify tool definition once + mockClient := github.NewClient(nil) + tool, _ := searchIssues(mockClient) + + assert.Equal(t, "search_issues", tool.Name) + assert.NotEmpty(t, tool.Description) + assert.Contains(t, tool.InputSchema.Properties, "q") + assert.Contains(t, tool.InputSchema.Properties, "sort") + assert.Contains(t, tool.InputSchema.Properties, "order") + assert.Contains(t, tool.InputSchema.Properties, "per_page") + assert.Contains(t, tool.InputSchema.Properties, "page") + assert.ElementsMatch(t, tool.InputSchema.Required, []string{"q"}) + + // Setup mock search results + mockSearchResult := &github.IssuesSearchResult{ + Total: github.Ptr(2), + IncompleteResults: github.Ptr(false), + Issues: []*github.Issue{ + { + Number: github.Ptr(42), + Title: github.Ptr("Bug: Something is broken"), + Body: github.Ptr("This is a bug report"), + State: github.Ptr("open"), + HTMLURL: github.Ptr("https://github.com/owner/repo/issues/42"), + Comments: github.Ptr(5), + User: &github.User{ + Login: github.Ptr("user1"), + }, + }, + { + Number: github.Ptr(43), + Title: github.Ptr("Feature: Add new functionality"), + Body: github.Ptr("This is a feature request"), + State: github.Ptr("open"), + HTMLURL: github.Ptr("https://github.com/owner/repo/issues/43"), + Comments: github.Ptr(3), + User: &github.User{ + Login: github.Ptr("user2"), + }, + }, + }, + } + + tests := []struct { + name string + mockedClient *http.Client + requestArgs map[string]interface{} + expectError bool + expectedResult *github.IssuesSearchResult + expectedErrMsg string + }{ + { + name: "successful issues search with all parameters", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatch( + mock.GetSearchIssues, + mockSearchResult, + ), + ), + requestArgs: map[string]interface{}{ + "q": "repo:owner/repo is:issue is:open", + "sort": "created", + "order": "desc", + "page": float64(1), + "per_page": float64(30), + }, + expectError: false, + expectedResult: mockSearchResult, + }, + { + name: "issues search with minimal parameters", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatch( + mock.GetSearchIssues, + mockSearchResult, + ), + ), + requestArgs: map[string]interface{}{ + "q": "repo:owner/repo is:issue is:open", + }, + expectError: false, + expectedResult: mockSearchResult, + }, + { + name: "search issues fails", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.GetSearchIssues, + http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusBadRequest) + _, _ = w.Write([]byte(`{"message": "Validation Failed"}`)) + }), + ), + ), + requestArgs: map[string]interface{}{ + "q": "invalid:query", + }, + expectError: true, + expectedErrMsg: "failed to search issues", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Setup client with mock + client := github.NewClient(tc.mockedClient) + _, handler := searchIssues(client) + + // Create call request + request := createMCPRequest(tc.requestArgs) + + // Call handler + result, err := handler(context.Background(), request) + + // Verify results + if tc.expectError { + require.Error(t, err) + assert.Contains(t, err.Error(), tc.expectedErrMsg) + return + } + + require.NoError(t, err) + + // Parse the result and get the text content if no error + textContent := getTextResult(t, result) + + // Unmarshal and verify the result + var returnedResult github.IssuesSearchResult + err = json.Unmarshal([]byte(textContent.Text), &returnedResult) + require.NoError(t, err) + assert.Equal(t, *tc.expectedResult.Total, *returnedResult.Total) + assert.Equal(t, *tc.expectedResult.IncompleteResults, *returnedResult.IncompleteResults) + assert.Len(t, returnedResult.Issues, len(tc.expectedResult.Issues)) + for i, issue := range returnedResult.Issues { + assert.Equal(t, *tc.expectedResult.Issues[i].Number, *issue.Number) + assert.Equal(t, *tc.expectedResult.Issues[i].Title, *issue.Title) + assert.Equal(t, *tc.expectedResult.Issues[i].State, *issue.State) + assert.Equal(t, *tc.expectedResult.Issues[i].HTMLURL, *issue.HTMLURL) + assert.Equal(t, *tc.expectedResult.Issues[i].User.Login, *issue.User.Login) + } + }) + } +} diff --git a/pkg/github/pullrequests.go b/pkg/github/pullrequests.go index 1cf5f7240..b2f191b49 100644 --- a/pkg/github/pullrequests.go +++ b/pkg/github/pullrequests.go @@ -5,6 +5,7 @@ import ( "encoding/json" "fmt" "io" + "net/http" "github.com/google/go-github/v69/github" "github.com/mark3labs/mcp-go/mcp" @@ -39,7 +40,7 @@ func getPullRequest(client *github.Client) (tool mcp.Tool, handler server.ToolHa } defer func() { _ = resp.Body.Close() }() - if resp.StatusCode != 200 { + if resp.StatusCode != http.StatusOK { body, err := io.ReadAll(resp.Body) if err != nil { return nil, fmt.Errorf("failed to read response body: %w", err) @@ -140,7 +141,7 @@ func listPullRequests(client *github.Client) (tool mcp.Tool, handler server.Tool } defer func() { _ = resp.Body.Close() }() - if resp.StatusCode != 200 { + if resp.StatusCode != http.StatusOK { body, err := io.ReadAll(resp.Body) if err != nil { return nil, fmt.Errorf("failed to read response body: %w", err) @@ -211,7 +212,7 @@ func mergePullRequest(client *github.Client) (tool mcp.Tool, handler server.Tool } defer func() { _ = resp.Body.Close() }() - if resp.StatusCode != 200 { + if resp.StatusCode != http.StatusOK { body, err := io.ReadAll(resp.Body) if err != nil { return nil, fmt.Errorf("failed to read response body: %w", err) @@ -257,7 +258,7 @@ func getPullRequestFiles(client *github.Client) (tool mcp.Tool, handler server.T } defer func() { _ = resp.Body.Close() }() - if resp.StatusCode != 200 { + if resp.StatusCode != http.StatusOK { body, err := io.ReadAll(resp.Body) if err != nil { return nil, fmt.Errorf("failed to read response body: %w", err) @@ -303,7 +304,7 @@ func getPullRequestStatus(client *github.Client) (tool mcp.Tool, handler server. } defer func() { _ = resp.Body.Close() }() - if resp.StatusCode != 200 { + if resp.StatusCode != http.StatusOK { body, err := io.ReadAll(resp.Body) if err != nil { return nil, fmt.Errorf("failed to read response body: %w", err) @@ -318,7 +319,7 @@ func getPullRequestStatus(client *github.Client) (tool mcp.Tool, handler server. } defer func() { _ = resp.Body.Close() }() - if resp.StatusCode != 200 { + if resp.StatusCode != http.StatusOK { body, err := io.ReadAll(resp.Body) if err != nil { return nil, fmt.Errorf("failed to read response body: %w", err) @@ -371,11 +372,16 @@ func updatePullRequestBranch(client *github.Client) (tool mcp.Tool, handler serv result, resp, err := client.PullRequests.UpdateBranch(ctx, owner, repo, pullNumber, opts) if err != nil { + // Check if it's an acceptedError. An acceptedError indicates that the update is in progress, + // and it's not a real error. + if resp != nil && resp.StatusCode == http.StatusAccepted && isAcceptedError(err) { + return mcp.NewToolResultText("Pull request branch update is in progress"), nil + } return nil, fmt.Errorf("failed to update pull request branch: %w", err) } defer func() { _ = resp.Body.Close() }() - if resp.StatusCode != 202 { + if resp.StatusCode != http.StatusAccepted { body, err := io.ReadAll(resp.Body) if err != nil { return nil, fmt.Errorf("failed to read response body: %w", err) @@ -426,7 +432,7 @@ func getPullRequestComments(client *github.Client) (tool mcp.Tool, handler serve } defer func() { _ = resp.Body.Close() }() - if resp.StatusCode != 200 { + if resp.StatusCode != http.StatusOK { body, err := io.ReadAll(resp.Body) if err != nil { return nil, fmt.Errorf("failed to read response body: %w", err) @@ -471,7 +477,7 @@ func getPullRequestReviews(client *github.Client) (tool mcp.Tool, handler server } defer func() { _ = resp.Body.Close() }() - if resp.StatusCode != 200 { + if resp.StatusCode != http.StatusOK { body, err := io.ReadAll(resp.Body) if err != nil { return nil, fmt.Errorf("failed to read response body: %w", err) diff --git a/pkg/github/pullrequests_test.go b/pkg/github/pullrequests_test.go new file mode 100644 index 000000000..bbafc9211 --- /dev/null +++ b/pkg/github/pullrequests_test.go @@ -0,0 +1,990 @@ +package github + +import ( + "context" + "encoding/json" + "net/http" + "testing" + "time" + + "github.com/google/go-github/v69/github" + "github.com/migueleliasweb/go-github-mock/src/mock" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_GetPullRequest(t *testing.T) { + // Verify tool definition once + mockClient := github.NewClient(nil) + tool, _ := getPullRequest(mockClient) + + assert.Equal(t, "get_pull_request", tool.Name) + assert.NotEmpty(t, tool.Description) + assert.Contains(t, tool.InputSchema.Properties, "owner") + assert.Contains(t, tool.InputSchema.Properties, "repo") + assert.Contains(t, tool.InputSchema.Properties, "pull_number") + assert.ElementsMatch(t, tool.InputSchema.Required, []string{"owner", "repo", "pull_number"}) + + // Setup mock PR for success case + mockPR := &github.PullRequest{ + Number: github.Ptr(42), + Title: github.Ptr("Test PR"), + State: github.Ptr("open"), + HTMLURL: github.Ptr("https://github.com/owner/repo/pull/42"), + Head: &github.PullRequestBranch{ + SHA: github.Ptr("abcd1234"), + Ref: github.Ptr("feature-branch"), + }, + Base: &github.PullRequestBranch{ + Ref: github.Ptr("main"), + }, + Body: github.Ptr("This is a test PR"), + User: &github.User{ + Login: github.Ptr("testuser"), + }, + } + + tests := []struct { + name string + mockedClient *http.Client + requestArgs map[string]interface{} + expectError bool + expectedPR *github.PullRequest + expectedErrMsg string + }{ + { + name: "successful PR fetch", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatch( + mock.GetReposPullsByOwnerByRepoByPullNumber, + mockPR, + ), + ), + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + "pull_number": float64(42), + }, + expectError: false, + expectedPR: mockPR, + }, + { + name: "PR fetch fails", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.GetReposPullsByOwnerByRepoByPullNumber, + http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNotFound) + _, _ = w.Write([]byte(`{"message": "Not Found"}`)) + }), + ), + ), + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + "pull_number": float64(999), + }, + expectError: true, + expectedErrMsg: "failed to get pull request", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Setup client with mock + client := github.NewClient(tc.mockedClient) + _, handler := getPullRequest(client) + + // Create call request + request := createMCPRequest(tc.requestArgs) + + // Call handler + result, err := handler(context.Background(), request) + + // Verify results + if tc.expectError { + require.Error(t, err) + assert.Contains(t, err.Error(), tc.expectedErrMsg) + return + } + + require.NoError(t, err) + + // Parse the result and get the text content if no error + textContent := getTextResult(t, result) + + // Unmarshal and verify the result + var returnedPR github.PullRequest + err = json.Unmarshal([]byte(textContent.Text), &returnedPR) + require.NoError(t, err) + assert.Equal(t, *tc.expectedPR.Number, *returnedPR.Number) + assert.Equal(t, *tc.expectedPR.Title, *returnedPR.Title) + assert.Equal(t, *tc.expectedPR.State, *returnedPR.State) + assert.Equal(t, *tc.expectedPR.HTMLURL, *returnedPR.HTMLURL) + }) + } +} + +func Test_ListPullRequests(t *testing.T) { + // Verify tool definition once + mockClient := github.NewClient(nil) + tool, _ := listPullRequests(mockClient) + + assert.Equal(t, "list_pull_requests", tool.Name) + assert.NotEmpty(t, tool.Description) + assert.Contains(t, tool.InputSchema.Properties, "owner") + assert.Contains(t, tool.InputSchema.Properties, "repo") + assert.Contains(t, tool.InputSchema.Properties, "state") + assert.Contains(t, tool.InputSchema.Properties, "head") + assert.Contains(t, tool.InputSchema.Properties, "base") + assert.Contains(t, tool.InputSchema.Properties, "sort") + assert.Contains(t, tool.InputSchema.Properties, "direction") + assert.Contains(t, tool.InputSchema.Properties, "per_page") + assert.Contains(t, tool.InputSchema.Properties, "page") + assert.ElementsMatch(t, tool.InputSchema.Required, []string{"owner", "repo"}) + + // Setup mock PRs for success case + mockPRs := []*github.PullRequest{ + { + Number: github.Ptr(42), + Title: github.Ptr("First PR"), + State: github.Ptr("open"), + HTMLURL: github.Ptr("https://github.com/owner/repo/pull/42"), + }, + { + Number: github.Ptr(43), + Title: github.Ptr("Second PR"), + State: github.Ptr("closed"), + HTMLURL: github.Ptr("https://github.com/owner/repo/pull/43"), + }, + } + + tests := []struct { + name string + mockedClient *http.Client + requestArgs map[string]interface{} + expectError bool + expectedPRs []*github.PullRequest + expectedErrMsg string + }{ + { + name: "successful PRs listing", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatch( + mock.GetReposPullsByOwnerByRepo, + mockPRs, + ), + ), + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + "state": "all", + "sort": "created", + "direction": "desc", + "per_page": float64(30), + "page": float64(1), + }, + expectError: false, + expectedPRs: mockPRs, + }, + { + name: "PRs listing fails", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.GetReposPullsByOwnerByRepo, + http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusBadRequest) + _, _ = w.Write([]byte(`{"message": "Invalid request"}`)) + }), + ), + ), + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + "state": "invalid", + }, + expectError: true, + expectedErrMsg: "failed to list pull requests", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Setup client with mock + client := github.NewClient(tc.mockedClient) + _, handler := listPullRequests(client) + + // Create call request + request := createMCPRequest(tc.requestArgs) + + // Call handler + result, err := handler(context.Background(), request) + + // Verify results + if tc.expectError { + require.Error(t, err) + assert.Contains(t, err.Error(), tc.expectedErrMsg) + return + } + + require.NoError(t, err) + + // Parse the result and get the text content if no error + textContent := getTextResult(t, result) + + // Unmarshal and verify the result + var returnedPRs []*github.PullRequest + err = json.Unmarshal([]byte(textContent.Text), &returnedPRs) + require.NoError(t, err) + assert.Len(t, returnedPRs, 2) + assert.Equal(t, *tc.expectedPRs[0].Number, *returnedPRs[0].Number) + assert.Equal(t, *tc.expectedPRs[0].Title, *returnedPRs[0].Title) + assert.Equal(t, *tc.expectedPRs[0].State, *returnedPRs[0].State) + assert.Equal(t, *tc.expectedPRs[1].Number, *returnedPRs[1].Number) + assert.Equal(t, *tc.expectedPRs[1].Title, *returnedPRs[1].Title) + assert.Equal(t, *tc.expectedPRs[1].State, *returnedPRs[1].State) + }) + } +} + +func Test_MergePullRequest(t *testing.T) { + // Verify tool definition once + mockClient := github.NewClient(nil) + tool, _ := mergePullRequest(mockClient) + + assert.Equal(t, "merge_pull_request", tool.Name) + assert.NotEmpty(t, tool.Description) + assert.Contains(t, tool.InputSchema.Properties, "owner") + assert.Contains(t, tool.InputSchema.Properties, "repo") + assert.Contains(t, tool.InputSchema.Properties, "pull_number") + assert.Contains(t, tool.InputSchema.Properties, "commit_title") + assert.Contains(t, tool.InputSchema.Properties, "commit_message") + assert.Contains(t, tool.InputSchema.Properties, "merge_method") + assert.ElementsMatch(t, tool.InputSchema.Required, []string{"owner", "repo", "pull_number"}) + + // Setup mock merge result for success case + mockMergeResult := &github.PullRequestMergeResult{ + Merged: github.Ptr(true), + Message: github.Ptr("Pull Request successfully merged"), + SHA: github.Ptr("abcd1234efgh5678"), + } + + tests := []struct { + name string + mockedClient *http.Client + requestArgs map[string]interface{} + expectError bool + expectedMergeResult *github.PullRequestMergeResult + expectedErrMsg string + }{ + { + name: "successful merge", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatch( + mock.PutReposPullsMergeByOwnerByRepoByPullNumber, + mockMergeResult, + ), + ), + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + "pull_number": float64(42), + "commit_title": "Merge PR #42", + "commit_message": "Merging awesome feature", + "merge_method": "squash", + }, + expectError: false, + expectedMergeResult: mockMergeResult, + }, + { + name: "merge fails", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.PutReposPullsMergeByOwnerByRepoByPullNumber, + http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusMethodNotAllowed) + _, _ = w.Write([]byte(`{"message": "Pull request cannot be merged"}`)) + }), + ), + ), + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + "pull_number": float64(42), + }, + expectError: true, + expectedErrMsg: "failed to merge pull request", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Setup client with mock + client := github.NewClient(tc.mockedClient) + _, handler := mergePullRequest(client) + + // Create call request + request := createMCPRequest(tc.requestArgs) + + // Call handler + result, err := handler(context.Background(), request) + + // Verify results + if tc.expectError { + require.Error(t, err) + assert.Contains(t, err.Error(), tc.expectedErrMsg) + return + } + + require.NoError(t, err) + + // Parse the result and get the text content if no error + textContent := getTextResult(t, result) + + // Unmarshal and verify the result + var returnedResult github.PullRequestMergeResult + err = json.Unmarshal([]byte(textContent.Text), &returnedResult) + require.NoError(t, err) + assert.Equal(t, *tc.expectedMergeResult.Merged, *returnedResult.Merged) + assert.Equal(t, *tc.expectedMergeResult.Message, *returnedResult.Message) + assert.Equal(t, *tc.expectedMergeResult.SHA, *returnedResult.SHA) + }) + } +} + +func Test_GetPullRequestFiles(t *testing.T) { + // Verify tool definition once + mockClient := github.NewClient(nil) + tool, _ := getPullRequestFiles(mockClient) + + assert.Equal(t, "get_pull_request_files", tool.Name) + assert.NotEmpty(t, tool.Description) + assert.Contains(t, tool.InputSchema.Properties, "owner") + assert.Contains(t, tool.InputSchema.Properties, "repo") + assert.Contains(t, tool.InputSchema.Properties, "pull_number") + assert.ElementsMatch(t, tool.InputSchema.Required, []string{"owner", "repo", "pull_number"}) + + // Setup mock PR files for success case + mockFiles := []*github.CommitFile{ + { + Filename: github.Ptr("file1.go"), + Status: github.Ptr("modified"), + Additions: github.Ptr(10), + Deletions: github.Ptr(5), + Changes: github.Ptr(15), + Patch: github.Ptr("@@ -1,5 +1,10 @@"), + }, + { + Filename: github.Ptr("file2.go"), + Status: github.Ptr("added"), + Additions: github.Ptr(20), + Deletions: github.Ptr(0), + Changes: github.Ptr(20), + Patch: github.Ptr("@@ -0,0 +1,20 @@"), + }, + } + + tests := []struct { + name string + mockedClient *http.Client + requestArgs map[string]interface{} + expectError bool + expectedFiles []*github.CommitFile + expectedErrMsg string + }{ + { + name: "successful files fetch", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatch( + mock.GetReposPullsFilesByOwnerByRepoByPullNumber, + mockFiles, + ), + ), + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + "pull_number": float64(42), + }, + expectError: false, + expectedFiles: mockFiles, + }, + { + name: "files fetch fails", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.GetReposPullsFilesByOwnerByRepoByPullNumber, + http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNotFound) + _, _ = w.Write([]byte(`{"message": "Not Found"}`)) + }), + ), + ), + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + "pull_number": float64(999), + }, + expectError: true, + expectedErrMsg: "failed to get pull request files", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Setup client with mock + client := github.NewClient(tc.mockedClient) + _, handler := getPullRequestFiles(client) + + // Create call request + request := createMCPRequest(tc.requestArgs) + + // Call handler + result, err := handler(context.Background(), request) + + // Verify results + if tc.expectError { + require.Error(t, err) + assert.Contains(t, err.Error(), tc.expectedErrMsg) + return + } + + require.NoError(t, err) + + // Parse the result and get the text content if no error + textContent := getTextResult(t, result) + + // Unmarshal and verify the result + var returnedFiles []*github.CommitFile + err = json.Unmarshal([]byte(textContent.Text), &returnedFiles) + require.NoError(t, err) + assert.Len(t, returnedFiles, len(tc.expectedFiles)) + for i, file := range returnedFiles { + assert.Equal(t, *tc.expectedFiles[i].Filename, *file.Filename) + assert.Equal(t, *tc.expectedFiles[i].Status, *file.Status) + assert.Equal(t, *tc.expectedFiles[i].Additions, *file.Additions) + assert.Equal(t, *tc.expectedFiles[i].Deletions, *file.Deletions) + } + }) + } +} + +func Test_GetPullRequestStatus(t *testing.T) { + // Verify tool definition once + mockClient := github.NewClient(nil) + tool, _ := getPullRequestStatus(mockClient) + + assert.Equal(t, "get_pull_request_status", tool.Name) + assert.NotEmpty(t, tool.Description) + assert.Contains(t, tool.InputSchema.Properties, "owner") + assert.Contains(t, tool.InputSchema.Properties, "repo") + assert.Contains(t, tool.InputSchema.Properties, "pull_number") + assert.ElementsMatch(t, tool.InputSchema.Required, []string{"owner", "repo", "pull_number"}) + + // Setup mock PR for successful PR fetch + mockPR := &github.PullRequest{ + Number: github.Ptr(42), + Title: github.Ptr("Test PR"), + HTMLURL: github.Ptr("https://github.com/owner/repo/pull/42"), + Head: &github.PullRequestBranch{ + SHA: github.Ptr("abcd1234"), + Ref: github.Ptr("feature-branch"), + }, + } + + // Setup mock status for success case + mockStatus := &github.CombinedStatus{ + State: github.Ptr("success"), + TotalCount: github.Ptr(3), + Statuses: []*github.RepoStatus{ + { + State: github.Ptr("success"), + Context: github.Ptr("continuous-integration/travis-ci"), + Description: github.Ptr("Build succeeded"), + TargetURL: github.Ptr("https://travis-ci.org/owner/repo/builds/123"), + }, + { + State: github.Ptr("success"), + Context: github.Ptr("codecov/patch"), + Description: github.Ptr("Coverage increased"), + TargetURL: github.Ptr("https://codecov.io/gh/owner/repo/pull/42"), + }, + { + State: github.Ptr("success"), + Context: github.Ptr("lint/golangci-lint"), + Description: github.Ptr("No issues found"), + TargetURL: github.Ptr("https://golangci.com/r/owner/repo/pull/42"), + }, + }, + } + + tests := []struct { + name string + mockedClient *http.Client + requestArgs map[string]interface{} + expectError bool + expectedStatus *github.CombinedStatus + expectedErrMsg string + }{ + { + name: "successful status fetch", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatch( + mock.GetReposPullsByOwnerByRepoByPullNumber, + mockPR, + ), + mock.WithRequestMatch( + mock.GetReposCommitsStatusByOwnerByRepoByRef, + mockStatus, + ), + ), + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + "pull_number": float64(42), + }, + expectError: false, + expectedStatus: mockStatus, + }, + { + name: "PR fetch fails", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.GetReposPullsByOwnerByRepoByPullNumber, + http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNotFound) + _, _ = w.Write([]byte(`{"message": "Not Found"}`)) + }), + ), + ), + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + "pull_number": float64(999), + }, + expectError: true, + expectedErrMsg: "failed to get pull request", + }, + { + name: "status fetch fails", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatch( + mock.GetReposPullsByOwnerByRepoByPullNumber, + mockPR, + ), + mock.WithRequestMatchHandler( + mock.GetReposCommitsStatusesByOwnerByRepoByRef, + http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNotFound) + _, _ = w.Write([]byte(`{"message": "Not Found"}`)) + }), + ), + ), + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + "pull_number": float64(42), + }, + expectError: true, + expectedErrMsg: "failed to get combined status", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Setup client with mock + client := github.NewClient(tc.mockedClient) + _, handler := getPullRequestStatus(client) + + // Create call request + request := createMCPRequest(tc.requestArgs) + + // Call handler + result, err := handler(context.Background(), request) + + // Verify results + if tc.expectError { + require.Error(t, err) + assert.Contains(t, err.Error(), tc.expectedErrMsg) + return + } + + require.NoError(t, err) + + // Parse the result and get the text content if no error + textContent := getTextResult(t, result) + + // Unmarshal and verify the result + var returnedStatus github.CombinedStatus + err = json.Unmarshal([]byte(textContent.Text), &returnedStatus) + require.NoError(t, err) + assert.Equal(t, *tc.expectedStatus.State, *returnedStatus.State) + assert.Equal(t, *tc.expectedStatus.TotalCount, *returnedStatus.TotalCount) + assert.Len(t, returnedStatus.Statuses, len(tc.expectedStatus.Statuses)) + for i, status := range returnedStatus.Statuses { + assert.Equal(t, *tc.expectedStatus.Statuses[i].State, *status.State) + assert.Equal(t, *tc.expectedStatus.Statuses[i].Context, *status.Context) + assert.Equal(t, *tc.expectedStatus.Statuses[i].Description, *status.Description) + } + }) + } +} + +func Test_UpdatePullRequestBranch(t *testing.T) { + // Verify tool definition once + mockClient := github.NewClient(nil) + tool, _ := updatePullRequestBranch(mockClient) + + assert.Equal(t, "update_pull_request_branch", tool.Name) + assert.NotEmpty(t, tool.Description) + assert.Contains(t, tool.InputSchema.Properties, "owner") + assert.Contains(t, tool.InputSchema.Properties, "repo") + assert.Contains(t, tool.InputSchema.Properties, "pull_number") + assert.Contains(t, tool.InputSchema.Properties, "expected_head_sha") + assert.ElementsMatch(t, tool.InputSchema.Required, []string{"owner", "repo", "pull_number"}) + + // Setup mock update result for success case + mockUpdateResult := &github.PullRequestBranchUpdateResponse{ + Message: github.Ptr("Branch was updated successfully"), + URL: github.Ptr("https://api.github.com/repos/owner/repo/pulls/42"), + } + + tests := []struct { + name string + mockedClient *http.Client + requestArgs map[string]interface{} + expectError bool + expectedUpdateResult *github.PullRequestBranchUpdateResponse + expectedErrMsg string + }{ + { + name: "successful branch update", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.PutReposPullsUpdateBranchByOwnerByRepoByPullNumber, + mockResponse(t, http.StatusAccepted, mockUpdateResult), + ), + ), + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + "pull_number": float64(42), + "expected_head_sha": "abcd1234", + }, + expectError: false, + expectedUpdateResult: mockUpdateResult, + }, + { + name: "branch update without expected SHA", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.PutReposPullsUpdateBranchByOwnerByRepoByPullNumber, + mockResponse(t, http.StatusAccepted, mockUpdateResult), + ), + ), + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + "pull_number": float64(42), + }, + expectError: false, + expectedUpdateResult: mockUpdateResult, + }, + { + name: "branch update fails", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.PutReposPullsUpdateBranchByOwnerByRepoByPullNumber, + http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusConflict) + _, _ = w.Write([]byte(`{"message": "Merge conflict"}`)) + }), + ), + ), + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + "pull_number": float64(42), + }, + expectError: true, + expectedErrMsg: "failed to update pull request branch", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Setup client with mock + client := github.NewClient(tc.mockedClient) + _, handler := updatePullRequestBranch(client) + + // Create call request + request := createMCPRequest(tc.requestArgs) + + // Call handler + result, err := handler(context.Background(), request) + + // Verify results + if tc.expectError { + require.Error(t, err) + assert.Contains(t, err.Error(), tc.expectedErrMsg) + return + } + + require.NoError(t, err) + + // Parse the result and get the text content if no error + textContent := getTextResult(t, result) + + assert.Contains(t, textContent.Text, "is in progress") + }) + } +} + +func Test_GetPullRequestComments(t *testing.T) { + // Verify tool definition once + mockClient := github.NewClient(nil) + tool, _ := getPullRequestComments(mockClient) + + assert.Equal(t, "get_pull_request_comments", tool.Name) + assert.NotEmpty(t, tool.Description) + assert.Contains(t, tool.InputSchema.Properties, "owner") + assert.Contains(t, tool.InputSchema.Properties, "repo") + assert.Contains(t, tool.InputSchema.Properties, "pull_number") + assert.ElementsMatch(t, tool.InputSchema.Required, []string{"owner", "repo", "pull_number"}) + + // Setup mock PR comments for success case + mockComments := []*github.PullRequestComment{ + { + ID: github.Ptr(int64(101)), + Body: github.Ptr("This looks good"), + HTMLURL: github.Ptr("https://github.com/owner/repo/pull/42#discussion_r101"), + User: &github.User{ + Login: github.Ptr("reviewer1"), + }, + Path: github.Ptr("file1.go"), + Position: github.Ptr(5), + CommitID: github.Ptr("abcdef123456"), + CreatedAt: &github.Timestamp{Time: time.Now().Add(-24 * time.Hour)}, + UpdatedAt: &github.Timestamp{Time: time.Now().Add(-24 * time.Hour)}, + }, + { + ID: github.Ptr(int64(102)), + Body: github.Ptr("Please fix this"), + HTMLURL: github.Ptr("https://github.com/owner/repo/pull/42#discussion_r102"), + User: &github.User{ + Login: github.Ptr("reviewer2"), + }, + Path: github.Ptr("file2.go"), + Position: github.Ptr(10), + CommitID: github.Ptr("abcdef123456"), + CreatedAt: &github.Timestamp{Time: time.Now().Add(-12 * time.Hour)}, + UpdatedAt: &github.Timestamp{Time: time.Now().Add(-12 * time.Hour)}, + }, + } + + tests := []struct { + name string + mockedClient *http.Client + requestArgs map[string]interface{} + expectError bool + expectedComments []*github.PullRequestComment + expectedErrMsg string + }{ + { + name: "successful comments fetch", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatch( + mock.GetReposPullsCommentsByOwnerByRepoByPullNumber, + mockComments, + ), + ), + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + "pull_number": float64(42), + }, + expectError: false, + expectedComments: mockComments, + }, + { + name: "comments fetch fails", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.GetReposPullsCommentsByOwnerByRepoByPullNumber, + http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNotFound) + _, _ = w.Write([]byte(`{"message": "Not Found"}`)) + }), + ), + ), + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + "pull_number": float64(999), + }, + expectError: true, + expectedErrMsg: "failed to get pull request comments", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Setup client with mock + client := github.NewClient(tc.mockedClient) + _, handler := getPullRequestComments(client) + + // Create call request + request := createMCPRequest(tc.requestArgs) + + // Call handler + result, err := handler(context.Background(), request) + + // Verify results + if tc.expectError { + require.Error(t, err) + assert.Contains(t, err.Error(), tc.expectedErrMsg) + return + } + + require.NoError(t, err) + + // Parse the result and get the text content if no error + textContent := getTextResult(t, result) + + // Unmarshal and verify the result + var returnedComments []*github.PullRequestComment + err = json.Unmarshal([]byte(textContent.Text), &returnedComments) + require.NoError(t, err) + assert.Len(t, returnedComments, len(tc.expectedComments)) + for i, comment := range returnedComments { + assert.Equal(t, *tc.expectedComments[i].ID, *comment.ID) + assert.Equal(t, *tc.expectedComments[i].Body, *comment.Body) + assert.Equal(t, *tc.expectedComments[i].User.Login, *comment.User.Login) + assert.Equal(t, *tc.expectedComments[i].Path, *comment.Path) + assert.Equal(t, *tc.expectedComments[i].HTMLURL, *comment.HTMLURL) + } + }) + } +} + +func Test_GetPullRequestReviews(t *testing.T) { + // Verify tool definition once + mockClient := github.NewClient(nil) + tool, _ := getPullRequestReviews(mockClient) + + assert.Equal(t, "get_pull_request_reviews", tool.Name) + assert.NotEmpty(t, tool.Description) + assert.Contains(t, tool.InputSchema.Properties, "owner") + assert.Contains(t, tool.InputSchema.Properties, "repo") + assert.Contains(t, tool.InputSchema.Properties, "pull_number") + assert.ElementsMatch(t, tool.InputSchema.Required, []string{"owner", "repo", "pull_number"}) + + // Setup mock PR reviews for success case + mockReviews := []*github.PullRequestReview{ + { + ID: github.Ptr(int64(201)), + State: github.Ptr("APPROVED"), + Body: github.Ptr("LGTM"), + HTMLURL: github.Ptr("https://github.com/owner/repo/pull/42#pullrequestreview-201"), + User: &github.User{ + Login: github.Ptr("approver"), + }, + CommitID: github.Ptr("abcdef123456"), + SubmittedAt: &github.Timestamp{Time: time.Now().Add(-24 * time.Hour)}, + }, + { + ID: github.Ptr(int64(202)), + State: github.Ptr("CHANGES_REQUESTED"), + Body: github.Ptr("Please address the following issues"), + HTMLURL: github.Ptr("https://github.com/owner/repo/pull/42#pullrequestreview-202"), + User: &github.User{ + Login: github.Ptr("reviewer"), + }, + CommitID: github.Ptr("abcdef123456"), + SubmittedAt: &github.Timestamp{Time: time.Now().Add(-12 * time.Hour)}, + }, + } + + tests := []struct { + name string + mockedClient *http.Client + requestArgs map[string]interface{} + expectError bool + expectedReviews []*github.PullRequestReview + expectedErrMsg string + }{ + { + name: "successful reviews fetch", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatch( + mock.GetReposPullsReviewsByOwnerByRepoByPullNumber, + mockReviews, + ), + ), + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + "pull_number": float64(42), + }, + expectError: false, + expectedReviews: mockReviews, + }, + { + name: "reviews fetch fails", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.GetReposPullsReviewsByOwnerByRepoByPullNumber, + http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNotFound) + _, _ = w.Write([]byte(`{"message": "Not Found"}`)) + }), + ), + ), + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + "pull_number": float64(999), + }, + expectError: true, + expectedErrMsg: "failed to get pull request reviews", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Setup client with mock + client := github.NewClient(tc.mockedClient) + _, handler := getPullRequestReviews(client) + + // Create call request + request := createMCPRequest(tc.requestArgs) + + // Call handler + result, err := handler(context.Background(), request) + + // Verify results + if tc.expectError { + require.Error(t, err) + assert.Contains(t, err.Error(), tc.expectedErrMsg) + return + } + + require.NoError(t, err) + + // Parse the result and get the text content if no error + textContent := getTextResult(t, result) + + // Unmarshal and verify the result + var returnedReviews []*github.PullRequestReview + err = json.Unmarshal([]byte(textContent.Text), &returnedReviews) + require.NoError(t, err) + assert.Len(t, returnedReviews, len(tc.expectedReviews)) + for i, review := range returnedReviews { + assert.Equal(t, *tc.expectedReviews[i].ID, *review.ID) + assert.Equal(t, *tc.expectedReviews[i].State, *review.State) + assert.Equal(t, *tc.expectedReviews[i].Body, *review.Body) + assert.Equal(t, *tc.expectedReviews[i].User.Login, *review.User.Login) + assert.Equal(t, *tc.expectedReviews[i].HTMLURL, *review.HTMLURL) + } + }) + } +} diff --git a/pkg/github/repositories.go b/pkg/github/repositories.go index 37e07597e..607f9d926 100644 --- a/pkg/github/repositories.go +++ b/pkg/github/repositories.go @@ -5,6 +5,7 @@ import ( "encoding/json" "fmt" "io" + "net/http" "github.com/aws/smithy-go/ptr" "github.com/google/go-github/v69/github" @@ -206,7 +207,7 @@ func createRepository(client *github.Client) (tool mcp.Tool, handler server.Tool } defer func() { _ = resp.Body.Close() }() - if resp.StatusCode != 201 { + if resp.StatusCode != http.StatusCreated { body, err := io.ReadAll(resp.Body) if err != nil { return nil, fmt.Errorf("failed to read response body: %w", err) @@ -314,11 +315,16 @@ func forkRepository(client *github.Client) (tool mcp.Tool, handler server.ToolHa forkedRepo, resp, err := client.Repositories.CreateFork(ctx, owner, repo, opts) if err != nil { + // Check if it's an acceptedError. An acceptedError indicates that the update is in progress, + // and it's not a real error. + if resp != nil && resp.StatusCode == http.StatusAccepted && isAcceptedError(err) { + return mcp.NewToolResultText("Fork is in progress"), nil + } return nil, fmt.Errorf("failed to fork repository: %w", err) } defer func() { _ = resp.Body.Close() }() - if resp.StatusCode != 202 { + if resp.StatusCode != http.StatusAccepted { body, err := io.ReadAll(resp.Body) if err != nil { return nil, fmt.Errorf("failed to read response body: %w", err) diff --git a/pkg/github/repositories_test.go b/pkg/github/repositories_test.go new file mode 100644 index 000000000..4e39b47f3 --- /dev/null +++ b/pkg/github/repositories_test.go @@ -0,0 +1,909 @@ +package github + +import ( + "context" + "encoding/json" + "net/http" + "testing" + "time" + + "github.com/google/go-github/v69/github" + "github.com/mark3labs/mcp-go/mcp" + "github.com/migueleliasweb/go-github-mock/src/mock" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_GetFileContents(t *testing.T) { + // Verify tool definition once + mockClient := github.NewClient(nil) + tool, _ := getFileContents(mockClient) + + assert.Equal(t, "get_file_contents", tool.Name) + assert.NotEmpty(t, tool.Description) + assert.Contains(t, tool.InputSchema.Properties, "owner") + assert.Contains(t, tool.InputSchema.Properties, "repo") + assert.Contains(t, tool.InputSchema.Properties, "path") + assert.Contains(t, tool.InputSchema.Properties, "branch") + assert.ElementsMatch(t, tool.InputSchema.Required, []string{"owner", "repo", "path"}) + + // Setup mock file content for success case + mockFileContent := &github.RepositoryContent{ + Type: github.Ptr("file"), + Name: github.Ptr("README.md"), + Path: github.Ptr("README.md"), + Content: github.Ptr("IyBUZXN0IFJlcG9zaXRvcnkKClRoaXMgaXMgYSB0ZXN0IHJlcG9zaXRvcnku"), // Base64 encoded "# Test Repository\n\nThis is a test repository." + SHA: github.Ptr("abc123"), + Size: github.Ptr(42), + HTMLURL: github.Ptr("https://github.com/owner/repo/blob/main/README.md"), + DownloadURL: github.Ptr("https://raw.githubusercontent.com/owner/repo/main/README.md"), + } + + // Setup mock directory content for success case + mockDirContent := []*github.RepositoryContent{ + { + Type: github.Ptr("file"), + Name: github.Ptr("README.md"), + Path: github.Ptr("README.md"), + SHA: github.Ptr("abc123"), + Size: github.Ptr(42), + HTMLURL: github.Ptr("https://github.com/owner/repo/blob/main/README.md"), + }, + { + Type: github.Ptr("dir"), + Name: github.Ptr("src"), + Path: github.Ptr("src"), + SHA: github.Ptr("def456"), + HTMLURL: github.Ptr("https://github.com/owner/repo/tree/main/src"), + }, + } + + tests := []struct { + name string + mockedClient *http.Client + requestArgs map[string]interface{} + expectError bool + expectedResult interface{} + expectedErrMsg string + }{ + { + name: "successful file content fetch", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatch( + mock.GetReposContentsByOwnerByRepoByPath, + mockFileContent, + ), + ), + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + "path": "README.md", + "branch": "main", + }, + expectError: false, + expectedResult: mockFileContent, + }, + { + name: "successful directory content fetch", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatch( + mock.GetReposContentsByOwnerByRepoByPath, + mockDirContent, + ), + ), + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + "path": "src", + }, + expectError: false, + expectedResult: mockDirContent, + }, + { + name: "content fetch fails", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.GetReposContentsByOwnerByRepoByPath, + http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNotFound) + _, _ = w.Write([]byte(`{"message": "Not Found"}`)) + }), + ), + ), + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + "path": "nonexistent.md", + "branch": "main", + }, + expectError: true, + expectedErrMsg: "failed to get file contents", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Setup client with mock + client := github.NewClient(tc.mockedClient) + _, handler := getFileContents(client) + + // Create call request + request := mcp.CallToolRequest{ + Params: struct { + Name string `json:"name"` + Arguments map[string]interface{} `json:"arguments,omitempty"` + Meta *struct { + ProgressToken mcp.ProgressToken `json:"progressToken,omitempty"` + } `json:"_meta,omitempty"` + }{ + Arguments: tc.requestArgs, + }, + } + + // Call handler + result, err := handler(context.Background(), request) + + // Verify results + if tc.expectError { + require.Error(t, err) + assert.Contains(t, err.Error(), tc.expectedErrMsg) + return + } + + require.NoError(t, err) + + // Parse the result and get the text content if no error + textContent := getTextResult(t, result) + + // Verify based on expected type + switch expected := tc.expectedResult.(type) { + case *github.RepositoryContent: + var returnedContent github.RepositoryContent + err = json.Unmarshal([]byte(textContent.Text), &returnedContent) + require.NoError(t, err) + assert.Equal(t, *expected.Name, *returnedContent.Name) + assert.Equal(t, *expected.Path, *returnedContent.Path) + assert.Equal(t, *expected.Type, *returnedContent.Type) + case []*github.RepositoryContent: + var returnedContents []*github.RepositoryContent + err = json.Unmarshal([]byte(textContent.Text), &returnedContents) + require.NoError(t, err) + assert.Len(t, returnedContents, len(expected)) + for i, content := range returnedContents { + assert.Equal(t, *expected[i].Name, *content.Name) + assert.Equal(t, *expected[i].Path, *content.Path) + assert.Equal(t, *expected[i].Type, *content.Type) + } + } + }) + } +} + +func Test_ForkRepository(t *testing.T) { + // Verify tool definition once + mockClient := github.NewClient(nil) + tool, _ := forkRepository(mockClient) + + assert.Equal(t, "fork_repository", tool.Name) + assert.NotEmpty(t, tool.Description) + assert.Contains(t, tool.InputSchema.Properties, "owner") + assert.Contains(t, tool.InputSchema.Properties, "repo") + assert.Contains(t, tool.InputSchema.Properties, "organization") + assert.ElementsMatch(t, tool.InputSchema.Required, []string{"owner", "repo"}) + + // Setup mock forked repo for success case + mockForkedRepo := &github.Repository{ + ID: github.Ptr(int64(123456)), + Name: github.Ptr("repo"), + FullName: github.Ptr("new-owner/repo"), + Owner: &github.User{ + Login: github.Ptr("new-owner"), + }, + HTMLURL: github.Ptr("https://github.com/new-owner/repo"), + DefaultBranch: github.Ptr("main"), + Fork: github.Ptr(true), + ForksCount: github.Ptr(0), + } + + tests := []struct { + name string + mockedClient *http.Client + requestArgs map[string]interface{} + expectError bool + expectedRepo *github.Repository + expectedErrMsg string + }{ + { + name: "successful repository fork", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.PostReposForksByOwnerByRepo, + mockResponse(t, http.StatusAccepted, mockForkedRepo), + ), + ), + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + }, + expectError: false, + expectedRepo: mockForkedRepo, + }, + { + name: "repository fork fails", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.PostReposForksByOwnerByRepo, + http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusForbidden) + _, _ = w.Write([]byte(`{"message": "Forbidden"}`)) + }), + ), + ), + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + }, + expectError: true, + expectedErrMsg: "failed to fork repository", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Setup client with mock + client := github.NewClient(tc.mockedClient) + _, handler := forkRepository(client) + + // Create call request + request := createMCPRequest(tc.requestArgs) + + // Call handler + result, err := handler(context.Background(), request) + + // Verify results + if tc.expectError { + require.Error(t, err) + assert.Contains(t, err.Error(), tc.expectedErrMsg) + return + } + + require.NoError(t, err) + + // Parse the result and get the text content if no error + textContent := getTextResult(t, result) + + assert.Contains(t, textContent.Text, "Fork is in progress") + }) + } +} + +func Test_CreateBranch(t *testing.T) { + // Verify tool definition once + mockClient := github.NewClient(nil) + tool, _ := createBranch(mockClient) + + assert.Equal(t, "create_branch", tool.Name) + assert.NotEmpty(t, tool.Description) + assert.Contains(t, tool.InputSchema.Properties, "owner") + assert.Contains(t, tool.InputSchema.Properties, "repo") + assert.Contains(t, tool.InputSchema.Properties, "branch") + assert.Contains(t, tool.InputSchema.Properties, "from_branch") + assert.ElementsMatch(t, tool.InputSchema.Required, []string{"owner", "repo", "branch"}) + + // Setup mock repository for default branch test + mockRepo := &github.Repository{ + DefaultBranch: github.Ptr("main"), + } + + // Setup mock reference for from_branch tests + mockSourceRef := &github.Reference{ + Ref: github.Ptr("refs/heads/main"), + Object: &github.GitObject{ + SHA: github.Ptr("abc123def456"), + }, + } + + // Setup mock created reference + mockCreatedRef := &github.Reference{ + Ref: github.Ptr("refs/heads/new-feature"), + Object: &github.GitObject{ + SHA: github.Ptr("abc123def456"), + }, + } + + tests := []struct { + name string + mockedClient *http.Client + requestArgs map[string]interface{} + expectError bool + expectedRef *github.Reference + expectedErrMsg string + }{ + { + name: "successful branch creation with from_branch", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatch( + mock.GetReposGitRefByOwnerByRepoByRef, + mockSourceRef, + ), + mock.WithRequestMatch( + mock.PostReposGitRefsByOwnerByRepo, + mockCreatedRef, + ), + ), + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + "branch": "new-feature", + "from_branch": "main", + }, + expectError: false, + expectedRef: mockCreatedRef, + }, + { + name: "successful branch creation with default branch", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatch( + mock.GetReposByOwnerByRepo, + mockRepo, + ), + mock.WithRequestMatch( + mock.GetReposGitRefByOwnerByRepoByRef, + mockSourceRef, + ), + mock.WithRequestMatch( + mock.PostReposGitRefsByOwnerByRepo, + mockCreatedRef, + ), + ), + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + "branch": "new-feature", + }, + expectError: false, + expectedRef: mockCreatedRef, + }, + { + name: "fail to get repository", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.GetReposByOwnerByRepo, + http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNotFound) + _, _ = w.Write([]byte(`{"message": "Repository not found"}`)) + }), + ), + ), + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "nonexistent-repo", + "branch": "new-feature", + }, + expectError: true, + expectedErrMsg: "failed to get repository", + }, + { + name: "fail to get reference", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.GetReposGitRefByOwnerByRepoByRef, + http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNotFound) + _, _ = w.Write([]byte(`{"message": "Reference not found"}`)) + }), + ), + ), + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + "branch": "new-feature", + "from_branch": "nonexistent-branch", + }, + expectError: true, + expectedErrMsg: "failed to get reference", + }, + { + name: "fail to create branch", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatch( + mock.GetReposGitRefByOwnerByRepoByRef, + mockSourceRef, + ), + mock.WithRequestMatchHandler( + mock.PostReposGitRefsByOwnerByRepo, + http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusUnprocessableEntity) + _, _ = w.Write([]byte(`{"message": "Reference already exists"}`)) + }), + ), + ), + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + "branch": "existing-branch", + "from_branch": "main", + }, + expectError: true, + expectedErrMsg: "failed to create branch", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Setup client with mock + client := github.NewClient(tc.mockedClient) + _, handler := createBranch(client) + + // Create call request + request := createMCPRequest(tc.requestArgs) + + // Call handler + result, err := handler(context.Background(), request) + + // Verify results + if tc.expectError { + require.Error(t, err) + assert.Contains(t, err.Error(), tc.expectedErrMsg) + return + } + + require.NoError(t, err) + + // Parse the result and get the text content if no error + textContent := getTextResult(t, result) + + // Unmarshal and verify the result + var returnedRef github.Reference + err = json.Unmarshal([]byte(textContent.Text), &returnedRef) + require.NoError(t, err) + assert.Equal(t, *tc.expectedRef.Ref, *returnedRef.Ref) + assert.Equal(t, *tc.expectedRef.Object.SHA, *returnedRef.Object.SHA) + }) + } +} + +func Test_ListCommits(t *testing.T) { + // Verify tool definition once + mockClient := github.NewClient(nil) + tool, _ := listCommits(mockClient) + + assert.Equal(t, "list_commits", tool.Name) + assert.NotEmpty(t, tool.Description) + assert.Contains(t, tool.InputSchema.Properties, "owner") + assert.Contains(t, tool.InputSchema.Properties, "repo") + assert.Contains(t, tool.InputSchema.Properties, "sha") + assert.Contains(t, tool.InputSchema.Properties, "page") + assert.Contains(t, tool.InputSchema.Properties, "per_page") + assert.ElementsMatch(t, tool.InputSchema.Required, []string{"owner", "repo"}) + + // Setup mock commits for success case + mockCommits := []*github.RepositoryCommit{ + { + SHA: github.Ptr("abc123def456"), + Commit: &github.Commit{ + Message: github.Ptr("First commit"), + Author: &github.CommitAuthor{ + Name: github.Ptr("Test User"), + Email: github.Ptr("test@example.com"), + Date: &github.Timestamp{Time: time.Now().Add(-48 * time.Hour)}, + }, + }, + Author: &github.User{ + Login: github.Ptr("testuser"), + }, + HTMLURL: github.Ptr("https://github.com/owner/repo/commit/abc123def456"), + }, + { + SHA: github.Ptr("def456abc789"), + Commit: &github.Commit{ + Message: github.Ptr("Second commit"), + Author: &github.CommitAuthor{ + Name: github.Ptr("Another User"), + Email: github.Ptr("another@example.com"), + Date: &github.Timestamp{Time: time.Now().Add(-24 * time.Hour)}, + }, + }, + Author: &github.User{ + Login: github.Ptr("anotheruser"), + }, + HTMLURL: github.Ptr("https://github.com/owner/repo/commit/def456abc789"), + }, + } + + tests := []struct { + name string + mockedClient *http.Client + requestArgs map[string]interface{} + expectError bool + expectedCommits []*github.RepositoryCommit + expectedErrMsg string + }{ + { + name: "successful commits fetch with default params", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatch( + mock.GetReposCommitsByOwnerByRepo, + mockCommits, + ), + ), + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + }, + expectError: false, + expectedCommits: mockCommits, + }, + { + name: "successful commits fetch with branch", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatch( + mock.GetReposCommitsByOwnerByRepo, + mockCommits, + ), + ), + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + "sha": "main", + }, + expectError: false, + expectedCommits: mockCommits, + }, + { + name: "successful commits fetch with pagination", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatch( + mock.GetReposCommitsByOwnerByRepo, + mockCommits, + ), + ), + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + "page": float64(2), + "per_page": float64(10), + }, + expectError: false, + expectedCommits: mockCommits, + }, + { + name: "commits fetch fails", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.GetReposCommitsByOwnerByRepo, + http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNotFound) + _, _ = w.Write([]byte(`{"message": "Not Found"}`)) + }), + ), + ), + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "nonexistent-repo", + }, + expectError: true, + expectedErrMsg: "failed to list commits", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Setup client with mock + client := github.NewClient(tc.mockedClient) + _, handler := listCommits(client) + + // Create call request + request := createMCPRequest(tc.requestArgs) + + // Call handler + result, err := handler(context.Background(), request) + + // Verify results + if tc.expectError { + require.Error(t, err) + assert.Contains(t, err.Error(), tc.expectedErrMsg) + return + } + + require.NoError(t, err) + + // Parse the result and get the text content if no error + textContent := getTextResult(t, result) + + // Unmarshal and verify the result + var returnedCommits []*github.RepositoryCommit + err = json.Unmarshal([]byte(textContent.Text), &returnedCommits) + require.NoError(t, err) + assert.Len(t, returnedCommits, len(tc.expectedCommits)) + for i, commit := range returnedCommits { + assert.Equal(t, *tc.expectedCommits[i].SHA, *commit.SHA) + assert.Equal(t, *tc.expectedCommits[i].Commit.Message, *commit.Commit.Message) + assert.Equal(t, *tc.expectedCommits[i].Author.Login, *commit.Author.Login) + assert.Equal(t, *tc.expectedCommits[i].HTMLURL, *commit.HTMLURL) + } + }) + } +} + +func Test_CreateOrUpdateFile(t *testing.T) { + // Verify tool definition once + mockClient := github.NewClient(nil) + tool, _ := createOrUpdateFile(mockClient) + + assert.Equal(t, "create_or_update_file", tool.Name) + assert.NotEmpty(t, tool.Description) + assert.Contains(t, tool.InputSchema.Properties, "owner") + assert.Contains(t, tool.InputSchema.Properties, "repo") + assert.Contains(t, tool.InputSchema.Properties, "path") + assert.Contains(t, tool.InputSchema.Properties, "content") + assert.Contains(t, tool.InputSchema.Properties, "message") + assert.Contains(t, tool.InputSchema.Properties, "branch") + assert.Contains(t, tool.InputSchema.Properties, "sha") + assert.ElementsMatch(t, tool.InputSchema.Required, []string{"owner", "repo", "path", "content", "message", "branch"}) + + // Setup mock file content response + mockFileResponse := &github.RepositoryContentResponse{ + Content: &github.RepositoryContent{ + Name: github.Ptr("example.md"), + Path: github.Ptr("docs/example.md"), + SHA: github.Ptr("abc123def456"), + Size: github.Ptr(42), + HTMLURL: github.Ptr("https://github.com/owner/repo/blob/main/docs/example.md"), + DownloadURL: github.Ptr("https://raw.githubusercontent.com/owner/repo/main/docs/example.md"), + }, + Commit: github.Commit{ + SHA: github.Ptr("def456abc789"), + Message: github.Ptr("Add example file"), + Author: &github.CommitAuthor{ + Name: github.Ptr("Test User"), + Email: github.Ptr("test@example.com"), + Date: &github.Timestamp{Time: time.Now()}, + }, + HTMLURL: github.Ptr("https://github.com/owner/repo/commit/def456abc789"), + }, + } + + tests := []struct { + name string + mockedClient *http.Client + requestArgs map[string]interface{} + expectError bool + expectedContent *github.RepositoryContentResponse + expectedErrMsg string + }{ + { + name: "successful file creation", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatch( + mock.PutReposContentsByOwnerByRepoByPath, + mockFileResponse, + ), + ), + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + "path": "docs/example.md", + "content": "# Example\n\nThis is an example file.", + "message": "Add example file", + "branch": "main", + }, + expectError: false, + expectedContent: mockFileResponse, + }, + { + name: "successful file update with SHA", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatch( + mock.PutReposContentsByOwnerByRepoByPath, + mockFileResponse, + ), + ), + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + "path": "docs/example.md", + "content": "# Updated Example\n\nThis file has been updated.", + "message": "Update example file", + "branch": "main", + "sha": "abc123def456", + }, + expectError: false, + expectedContent: mockFileResponse, + }, + { + name: "file creation fails", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.PutReposContentsByOwnerByRepoByPath, + http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusUnprocessableEntity) + _, _ = w.Write([]byte(`{"message": "Invalid request"}`)) + }), + ), + ), + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + "path": "docs/example.md", + "content": "#Invalid Content", + "message": "Invalid request", + "branch": "nonexistent-branch", + }, + expectError: true, + expectedErrMsg: "failed to create/update file", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Setup client with mock + client := github.NewClient(tc.mockedClient) + _, handler := createOrUpdateFile(client) + + // Create call request + request := createMCPRequest(tc.requestArgs) + + // Call handler + result, err := handler(context.Background(), request) + + // Verify results + if tc.expectError { + require.Error(t, err) + assert.Contains(t, err.Error(), tc.expectedErrMsg) + return + } + + require.NoError(t, err) + + // Parse the result and get the text content if no error + textContent := getTextResult(t, result) + + // Unmarshal and verify the result + var returnedContent github.RepositoryContentResponse + err = json.Unmarshal([]byte(textContent.Text), &returnedContent) + require.NoError(t, err) + + // Verify content + assert.Equal(t, *tc.expectedContent.Content.Name, *returnedContent.Content.Name) + assert.Equal(t, *tc.expectedContent.Content.Path, *returnedContent.Content.Path) + assert.Equal(t, *tc.expectedContent.Content.SHA, *returnedContent.Content.SHA) + + // Verify commit + assert.Equal(t, *tc.expectedContent.Commit.SHA, *returnedContent.Commit.SHA) + assert.Equal(t, *tc.expectedContent.Commit.Message, *returnedContent.Commit.Message) + }) + } +} + +func Test_CreateRepository(t *testing.T) { + // Verify tool definition once + mockClient := github.NewClient(nil) + tool, _ := createRepository(mockClient) + + assert.Equal(t, "create_repository", tool.Name) + assert.NotEmpty(t, tool.Description) + assert.Contains(t, tool.InputSchema.Properties, "name") + assert.Contains(t, tool.InputSchema.Properties, "description") + assert.Contains(t, tool.InputSchema.Properties, "private") + assert.Contains(t, tool.InputSchema.Properties, "auto_init") + assert.ElementsMatch(t, tool.InputSchema.Required, []string{"name"}) + + // Setup mock repository response + mockRepo := &github.Repository{ + Name: github.Ptr("test-repo"), + Description: github.Ptr("Test repository"), + Private: github.Ptr(true), + HTMLURL: github.Ptr("https://github.com/testuser/test-repo"), + CloneURL: github.Ptr("https://github.com/testuser/test-repo.git"), + CreatedAt: &github.Timestamp{Time: time.Now()}, + Owner: &github.User{ + Login: github.Ptr("testuser"), + }, + } + + tests := []struct { + name string + mockedClient *http.Client + requestArgs map[string]interface{} + expectError bool + expectedRepo *github.Repository + expectedErrMsg string + }{ + { + name: "successful repository creation with all parameters", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.EndpointPattern{ + Pattern: "/user/repos", + Method: "POST", + }, + mockResponse(t, http.StatusCreated, mockRepo), + ), + ), + requestArgs: map[string]interface{}{ + "name": "test-repo", + "description": "Test repository", + "private": true, + "auto_init": true, + }, + expectError: false, + expectedRepo: mockRepo, + }, + { + name: "successful repository creation with minimal parameters", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.EndpointPattern{ + Pattern: "/user/repos", + Method: "POST", + }, + mockResponse(t, http.StatusCreated, mockRepo), + ), + ), + requestArgs: map[string]interface{}{ + "name": "test-repo", + }, + expectError: false, + expectedRepo: mockRepo, + }, + { + name: "repository creation fails", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.EndpointPattern{ + Pattern: "/user/repos", + Method: "POST", + }, + http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusUnprocessableEntity) + _, _ = w.Write([]byte(`{"message": "Repository creation failed"}`)) + }), + ), + ), + requestArgs: map[string]interface{}{ + "name": "invalid-repo", + }, + expectError: true, + expectedErrMsg: "failed to create repository", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Setup client with mock + client := github.NewClient(tc.mockedClient) + _, handler := createRepository(client) + + // Create call request + request := createMCPRequest(tc.requestArgs) + + // Call handler + result, err := handler(context.Background(), request) + + // Verify results + if tc.expectError { + require.Error(t, err) + assert.Contains(t, err.Error(), tc.expectedErrMsg) + return + } + + require.NoError(t, err) + + // Parse the result and get the text content if no error + textContent := getTextResult(t, result) + + // Unmarshal and verify the result + var returnedRepo github.Repository + err = json.Unmarshal([]byte(textContent.Text), &returnedRepo) + assert.NoError(t, err) + + // Verify repository details + assert.Equal(t, *tc.expectedRepo.Name, *returnedRepo.Name) + assert.Equal(t, *tc.expectedRepo.Description, *returnedRepo.Description) + assert.Equal(t, *tc.expectedRepo.Private, *returnedRepo.Private) + assert.Equal(t, *tc.expectedRepo.HTMLURL, *returnedRepo.HTMLURL) + assert.Equal(t, *tc.expectedRepo.Owner.Login, *returnedRepo.Owner.Login) + }) + } +} diff --git a/pkg/github/search_test.go b/pkg/github/search_test.go new file mode 100644 index 000000000..d43fd843e --- /dev/null +++ b/pkg/github/search_test.go @@ -0,0 +1,429 @@ +package github + +import ( + "context" + "encoding/json" + "net/http" + "testing" + + "github.com/google/go-github/v69/github" + "github.com/migueleliasweb/go-github-mock/src/mock" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_SearchRepositories(t *testing.T) { + // Verify tool definition once + mockClient := github.NewClient(nil) + tool, _ := searchRepositories(mockClient) + + assert.Equal(t, "search_repositories", tool.Name) + assert.NotEmpty(t, tool.Description) + assert.Contains(t, tool.InputSchema.Properties, "query") + assert.Contains(t, tool.InputSchema.Properties, "page") + assert.Contains(t, tool.InputSchema.Properties, "per_page") + assert.ElementsMatch(t, tool.InputSchema.Required, []string{"query"}) + + // Setup mock search results + mockSearchResult := &github.RepositoriesSearchResult{ + Total: github.Ptr(2), + IncompleteResults: github.Ptr(false), + Repositories: []*github.Repository{ + { + ID: github.Ptr(int64(12345)), + Name: github.Ptr("repo-1"), + FullName: github.Ptr("owner/repo-1"), + HTMLURL: github.Ptr("https://github.com/owner/repo-1"), + Description: github.Ptr("Test repository 1"), + StargazersCount: github.Ptr(100), + }, + { + ID: github.Ptr(int64(67890)), + Name: github.Ptr("repo-2"), + FullName: github.Ptr("owner/repo-2"), + HTMLURL: github.Ptr("https://github.com/owner/repo-2"), + Description: github.Ptr("Test repository 2"), + StargazersCount: github.Ptr(50), + }, + }, + } + + tests := []struct { + name string + mockedClient *http.Client + requestArgs map[string]interface{} + expectError bool + expectedResult *github.RepositoriesSearchResult + expectedErrMsg string + }{ + { + name: "successful repository search", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatch( + mock.GetSearchRepositories, + mockSearchResult, + ), + ), + requestArgs: map[string]interface{}{ + "query": "golang test", + "page": float64(1), + "per_page": float64(30), + }, + expectError: false, + expectedResult: mockSearchResult, + }, + { + name: "repository search with default pagination", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatch( + mock.GetSearchRepositories, + mockSearchResult, + ), + ), + requestArgs: map[string]interface{}{ + "query": "golang test", + }, + expectError: false, + expectedResult: mockSearchResult, + }, + { + name: "search fails", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.GetSearchRepositories, + http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusBadRequest) + _, _ = w.Write([]byte(`{"message": "Invalid query"}`)) + }), + ), + ), + requestArgs: map[string]interface{}{ + "query": "invalid:query", + }, + expectError: true, + expectedErrMsg: "failed to search repositories", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Setup client with mock + client := github.NewClient(tc.mockedClient) + _, handler := searchRepositories(client) + + // Create call request + request := createMCPRequest(tc.requestArgs) + + // Call handler + result, err := handler(context.Background(), request) + + // Verify results + if tc.expectError { + require.Error(t, err) + assert.Contains(t, err.Error(), tc.expectedErrMsg) + return + } + + require.NoError(t, err) + + // Parse the result and get the text content if no error + textContent := getTextResult(t, result) + + // Unmarshal and verify the result + var returnedResult github.RepositoriesSearchResult + err = json.Unmarshal([]byte(textContent.Text), &returnedResult) + require.NoError(t, err) + assert.Equal(t, *tc.expectedResult.Total, *returnedResult.Total) + assert.Equal(t, *tc.expectedResult.IncompleteResults, *returnedResult.IncompleteResults) + assert.Len(t, returnedResult.Repositories, len(tc.expectedResult.Repositories)) + for i, repo := range returnedResult.Repositories { + assert.Equal(t, *tc.expectedResult.Repositories[i].ID, *repo.ID) + assert.Equal(t, *tc.expectedResult.Repositories[i].Name, *repo.Name) + assert.Equal(t, *tc.expectedResult.Repositories[i].FullName, *repo.FullName) + assert.Equal(t, *tc.expectedResult.Repositories[i].HTMLURL, *repo.HTMLURL) + } + + }) + } +} + +func Test_SearchCode(t *testing.T) { + // Verify tool definition once + mockClient := github.NewClient(nil) + tool, _ := searchCode(mockClient) + + assert.Equal(t, "search_code", tool.Name) + assert.NotEmpty(t, tool.Description) + assert.Contains(t, tool.InputSchema.Properties, "q") + assert.Contains(t, tool.InputSchema.Properties, "sort") + assert.Contains(t, tool.InputSchema.Properties, "order") + assert.Contains(t, tool.InputSchema.Properties, "per_page") + assert.Contains(t, tool.InputSchema.Properties, "page") + assert.ElementsMatch(t, tool.InputSchema.Required, []string{"q"}) + + // Setup mock search results + mockSearchResult := &github.CodeSearchResult{ + Total: github.Ptr(2), + IncompleteResults: github.Ptr(false), + CodeResults: []*github.CodeResult{ + { + Name: github.Ptr("file1.go"), + Path: github.Ptr("path/to/file1.go"), + SHA: github.Ptr("abc123def456"), + HTMLURL: github.Ptr("https://github.com/owner/repo/blob/main/path/to/file1.go"), + Repository: &github.Repository{Name: github.Ptr("repo"), FullName: github.Ptr("owner/repo")}, + }, + { + Name: github.Ptr("file2.go"), + Path: github.Ptr("path/to/file2.go"), + SHA: github.Ptr("def456abc123"), + HTMLURL: github.Ptr("https://github.com/owner/repo/blob/main/path/to/file2.go"), + Repository: &github.Repository{Name: github.Ptr("repo"), FullName: github.Ptr("owner/repo")}, + }, + }, + } + + tests := []struct { + name string + mockedClient *http.Client + requestArgs map[string]interface{} + expectError bool + expectedResult *github.CodeSearchResult + expectedErrMsg string + }{ + { + name: "successful code search with all parameters", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatch( + mock.GetSearchCode, + mockSearchResult, + ), + ), + requestArgs: map[string]interface{}{ + "q": "fmt.Println language:go", + "sort": "indexed", + "order": "desc", + "page": float64(1), + "per_page": float64(30), + }, + expectError: false, + expectedResult: mockSearchResult, + }, + { + name: "code search with minimal parameters", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatch( + mock.GetSearchCode, + mockSearchResult, + ), + ), + requestArgs: map[string]interface{}{ + "q": "fmt.Println language:go", + }, + expectError: false, + expectedResult: mockSearchResult, + }, + { + name: "search code fails", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.GetSearchCode, + http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusBadRequest) + _, _ = w.Write([]byte(`{"message": "Validation Failed"}`)) + }), + ), + ), + requestArgs: map[string]interface{}{ + "q": "invalid:query", + }, + expectError: true, + expectedErrMsg: "failed to search code", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Setup client with mock + client := github.NewClient(tc.mockedClient) + _, handler := searchCode(client) + + // Create call request + request := createMCPRequest(tc.requestArgs) + + // Call handler + result, err := handler(context.Background(), request) + + // Verify results + if tc.expectError { + require.Error(t, err) + assert.Contains(t, err.Error(), tc.expectedErrMsg) + return + } + + require.NoError(t, err) + + // Parse the result and get the text content if no error + textContent := getTextResult(t, result) + + // Unmarshal and verify the result + var returnedResult github.CodeSearchResult + err = json.Unmarshal([]byte(textContent.Text), &returnedResult) + require.NoError(t, err) + assert.Equal(t, *tc.expectedResult.Total, *returnedResult.Total) + assert.Equal(t, *tc.expectedResult.IncompleteResults, *returnedResult.IncompleteResults) + assert.Len(t, returnedResult.CodeResults, len(tc.expectedResult.CodeResults)) + for i, code := range returnedResult.CodeResults { + assert.Equal(t, *tc.expectedResult.CodeResults[i].Name, *code.Name) + assert.Equal(t, *tc.expectedResult.CodeResults[i].Path, *code.Path) + assert.Equal(t, *tc.expectedResult.CodeResults[i].SHA, *code.SHA) + assert.Equal(t, *tc.expectedResult.CodeResults[i].HTMLURL, *code.HTMLURL) + assert.Equal(t, *tc.expectedResult.CodeResults[i].Repository.FullName, *code.Repository.FullName) + } + }) + } +} + +func Test_SearchUsers(t *testing.T) { + // Verify tool definition once + mockClient := github.NewClient(nil) + tool, _ := searchUsers(mockClient) + + assert.Equal(t, "search_users", tool.Name) + assert.NotEmpty(t, tool.Description) + assert.Contains(t, tool.InputSchema.Properties, "q") + assert.Contains(t, tool.InputSchema.Properties, "sort") + assert.Contains(t, tool.InputSchema.Properties, "order") + assert.Contains(t, tool.InputSchema.Properties, "per_page") + assert.Contains(t, tool.InputSchema.Properties, "page") + assert.ElementsMatch(t, tool.InputSchema.Required, []string{"q"}) + + // Setup mock search results + mockSearchResult := &github.UsersSearchResult{ + Total: github.Ptr(2), + IncompleteResults: github.Ptr(false), + Users: []*github.User{ + { + Login: github.Ptr("user1"), + ID: github.Ptr(int64(1001)), + HTMLURL: github.Ptr("https://github.com/user1"), + AvatarURL: github.Ptr("https://avatars.githubusercontent.com/u/1001"), + Type: github.Ptr("User"), + Followers: github.Ptr(100), + Following: github.Ptr(50), + }, + { + Login: github.Ptr("user2"), + ID: github.Ptr(int64(1002)), + HTMLURL: github.Ptr("https://github.com/user2"), + AvatarURL: github.Ptr("https://avatars.githubusercontent.com/u/1002"), + Type: github.Ptr("User"), + Followers: github.Ptr(200), + Following: github.Ptr(75), + }, + }, + } + + tests := []struct { + name string + mockedClient *http.Client + requestArgs map[string]interface{} + expectError bool + expectedResult *github.UsersSearchResult + expectedErrMsg string + }{ + { + name: "successful users search with all parameters", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatch( + mock.GetSearchUsers, + mockSearchResult, + ), + ), + requestArgs: map[string]interface{}{ + "q": "location:finland language:go", + "sort": "followers", + "order": "desc", + "page": float64(1), + "per_page": float64(30), + }, + expectError: false, + expectedResult: mockSearchResult, + }, + { + name: "users search with minimal parameters", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatch( + mock.GetSearchUsers, + mockSearchResult, + ), + ), + requestArgs: map[string]interface{}{ + "q": "location:finland language:go", + }, + expectError: false, + expectedResult: mockSearchResult, + }, + { + name: "search users fails", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.GetSearchUsers, + http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusBadRequest) + _, _ = w.Write([]byte(`{"message": "Validation Failed"}`)) + }), + ), + ), + requestArgs: map[string]interface{}{ + "q": "invalid:query", + }, + expectError: true, + expectedErrMsg: "failed to search users", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Setup client with mock + client := github.NewClient(tc.mockedClient) + _, handler := searchUsers(client) + + // Create call request + request := createMCPRequest(tc.requestArgs) + + // Call handler + result, err := handler(context.Background(), request) + + // Verify results + if tc.expectError { + require.Error(t, err) + assert.Contains(t, err.Error(), tc.expectedErrMsg) + return + } + + require.NoError(t, err) + + // Parse the result and get the text content if no error + require.NotNil(t, result) + + textContent := getTextResult(t, result) + + // Unmarshal and verify the result + var returnedResult github.UsersSearchResult + err = json.Unmarshal([]byte(textContent.Text), &returnedResult) + require.NoError(t, err) + assert.Equal(t, *tc.expectedResult.Total, *returnedResult.Total) + assert.Equal(t, *tc.expectedResult.IncompleteResults, *returnedResult.IncompleteResults) + assert.Len(t, returnedResult.Users, len(tc.expectedResult.Users)) + for i, user := range returnedResult.Users { + assert.Equal(t, *tc.expectedResult.Users[i].Login, *user.Login) + assert.Equal(t, *tc.expectedResult.Users[i].ID, *user.ID) + assert.Equal(t, *tc.expectedResult.Users[i].HTMLURL, *user.HTMLURL) + assert.Equal(t, *tc.expectedResult.Users[i].AvatarURL, *user.AvatarURL) + assert.Equal(t, *tc.expectedResult.Users[i].Type, *user.Type) + assert.Equal(t, *tc.expectedResult.Users[i].Followers, *user.Followers) + } + }) + } +} diff --git a/pkg/github/server.go b/pkg/github/server.go index b3ef7016b..0a90b4d1b 100644 --- a/pkg/github/server.go +++ b/pkg/github/server.go @@ -3,8 +3,10 @@ package github import ( "context" "encoding/json" + "errors" "fmt" "io" + "net/http" "github.com/google/go-github/v69/github" "github.com/mark3labs/mcp-go/mcp" @@ -73,7 +75,7 @@ func getMe(client *github.Client) (tool mcp.Tool, handler server.ToolHandlerFunc } defer func() { _ = resp.Body.Close() }() - if resp.StatusCode != 200 { + if resp.StatusCode != http.StatusOK { body, err := io.ReadAll(resp.Body) if err != nil { return nil, fmt.Errorf("failed to read response body: %w", err) @@ -89,3 +91,9 @@ func getMe(client *github.Client) (tool mcp.Tool, handler server.ToolHandlerFunc return mcp.NewToolResultText(string(r)), nil } } + +// isAcceptedError checks if the error is an accepted error. +func isAcceptedError(err error) bool { + var acceptedError *github.AcceptedError + return errors.As(err, &acceptedError) +} diff --git a/pkg/github/server_test.go b/pkg/github/server_test.go new file mode 100644 index 000000000..d56993ded --- /dev/null +++ b/pkg/github/server_test.go @@ -0,0 +1,168 @@ +package github + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "testing" + "time" + + "github.com/google/go-github/v69/github" + "github.com/migueleliasweb/go-github-mock/src/mock" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_GetMe(t *testing.T) { + // Verify tool definition + mockClient := github.NewClient(nil) + tool, _ := getMe(mockClient) + + assert.Equal(t, "get_me", tool.Name) + assert.NotEmpty(t, tool.Description) + assert.Contains(t, tool.InputSchema.Properties, "reason") + assert.Empty(t, tool.InputSchema.Required) // No required parameters + + // Setup mock user response + mockUser := &github.User{ + Login: github.Ptr("testuser"), + Name: github.Ptr("Test User"), + Email: github.Ptr("test@example.com"), + Bio: github.Ptr("GitHub user for testing"), + Company: github.Ptr("Test Company"), + Location: github.Ptr("Test Location"), + HTMLURL: github.Ptr("https://github.com/testuser"), + CreatedAt: &github.Timestamp{Time: time.Now().Add(-365 * 24 * time.Hour)}, + Type: github.Ptr("User"), + Plan: &github.Plan{ + Name: github.Ptr("pro"), + }, + } + + tests := []struct { + name string + mockedClient *http.Client + requestArgs map[string]interface{} + expectError bool + expectedUser *github.User + expectedErrMsg string + }{ + { + name: "successful get user", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatch( + mock.GetUser, + mockUser, + ), + ), + requestArgs: map[string]interface{}{}, + expectError: false, + expectedUser: mockUser, + }, + { + name: "successful get user with reason", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatch( + mock.GetUser, + mockUser, + ), + ), + requestArgs: map[string]interface{}{ + "reason": "Testing API", + }, + expectError: false, + expectedUser: mockUser, + }, + { + name: "get user fails", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.GetUser, + http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusUnauthorized) + _, _ = w.Write([]byte(`{"message": "Unauthorized"}`)) + }), + ), + ), + requestArgs: map[string]interface{}{}, + expectError: true, + expectedErrMsg: "failed to get user", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Setup client with mock + client := github.NewClient(tc.mockedClient) + _, handler := getMe(client) + + // Create call request + request := createMCPRequest(tc.requestArgs) + + // Call handler + result, err := handler(context.Background(), request) + + // Verify results + if tc.expectError { + require.Error(t, err) + assert.Contains(t, err.Error(), tc.expectedErrMsg) + return + } + + require.NoError(t, err) + + // Parse result and get text content if no error + textContent := getTextResult(t, result) + + // Unmarshal and verify the result + var returnedUser github.User + err = json.Unmarshal([]byte(textContent.Text), &returnedUser) + require.NoError(t, err) + + // Verify user details + assert.Equal(t, *tc.expectedUser.Login, *returnedUser.Login) + assert.Equal(t, *tc.expectedUser.Name, *returnedUser.Name) + assert.Equal(t, *tc.expectedUser.Email, *returnedUser.Email) + assert.Equal(t, *tc.expectedUser.Bio, *returnedUser.Bio) + assert.Equal(t, *tc.expectedUser.HTMLURL, *returnedUser.HTMLURL) + assert.Equal(t, *tc.expectedUser.Type, *returnedUser.Type) + }) + } +} + +func Test_IsAcceptedError(t *testing.T) { + tests := []struct { + name string + err error + expectAccepted bool + }{ + { + name: "github AcceptedError", + err: &github.AcceptedError{}, + expectAccepted: true, + }, + { + name: "regular error", + err: fmt.Errorf("some other error"), + expectAccepted: false, + }, + { + name: "nil error", + err: nil, + expectAccepted: false, + }, + { + name: "wrapped AcceptedError", + err: fmt.Errorf("wrapped: %w", &github.AcceptedError{}), + expectAccepted: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + result := isAcceptedError(tc.err) + assert.Equal(t, tc.expectAccepted, result) + }) + } +} pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy