From 4d7388ae1938f03348ae742bd008872e2a9a2a8c Mon Sep 17 00:00:00 2001 From: Javier Uruen Val Date: Mon, 24 Mar 2025 07:39:01 +0100 Subject: [PATCH 1/2] validate tools params --- pkg/github/code_scanning.go | 40 +++- pkg/github/issues.go | 197 +++++++++++++------ pkg/github/issues_test.go | 26 ++- pkg/github/pullrequests.go | 215 ++++++++++++++------- pkg/github/repositories.go | 174 +++++++++++------ pkg/github/search.go | 81 ++++---- pkg/github/server.go | 115 +++++++++++ pkg/github/server_test.go | 369 ++++++++++++++++++++++++++++++++++++ 8 files changed, 989 insertions(+), 228 deletions(-) diff --git a/pkg/github/code_scanning.go b/pkg/github/code_scanning.go index 6fc0936a4..e7c8a4e22 100644 --- a/pkg/github/code_scanning.go +++ b/pkg/github/code_scanning.go @@ -30,9 +30,18 @@ func getCodeScanningAlert(client *github.Client, t translations.TranslationHelpe ), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - owner, _ := request.Params.Arguments["owner"].(string) - repo, _ := request.Params.Arguments["repo"].(string) - alertNumber, _ := request.Params.Arguments["alert_number"].(float64) + owner, err := requiredStringParam(request, "owner") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + repo, err := requiredStringParam(request, "repo") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + alertNumber, err := requiredNumberParam(request, "alert_number") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } alert, resp, err := client.CodeScanning.GetAlert(ctx, owner, repo, int64(alertNumber)) if err != nil { @@ -80,11 +89,26 @@ func listCodeScanningAlerts(client *github.Client, t translations.TranslationHel ), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - owner, _ := request.Params.Arguments["owner"].(string) - repo, _ := request.Params.Arguments["repo"].(string) - ref, _ := request.Params.Arguments["ref"].(string) - state, _ := request.Params.Arguments["state"].(string) - severity, _ := request.Params.Arguments["severity"].(string) + owner, err := requiredStringParam(request, "owner") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + repo, err := requiredStringParam(request, "repo") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + ref, err := optionalStringParam(request, "ref") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + state, err := optionalStringParam(request, "state") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + severity, err := optionalStringParam(request, "severity") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } alerts, resp, err := client.CodeScanning.ListAlertsForRepo(ctx, owner, repo, &github.AlertListOptions{Ref: ref, State: state, Severity: severity}) if err != nil { diff --git a/pkg/github/issues.go b/pkg/github/issues.go index 36130b985..521acfdaf 100644 --- a/pkg/github/issues.go +++ b/pkg/github/issues.go @@ -32,9 +32,18 @@ func getIssue(client *github.Client, t translations.TranslationHelperFunc) (tool ), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - owner := request.Params.Arguments["owner"].(string) - repo := request.Params.Arguments["repo"].(string) - issueNumber := int(request.Params.Arguments["issue_number"].(float64)) + owner, err := requiredStringParam(request, "owner") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + repo, err := requiredStringParam(request, "repo") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + issueNumber, err := requiredNumberParam(request, "issue_number") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } issue, resp, err := client.Issues.Get(ctx, owner, repo, issueNumber) if err != nil { @@ -81,10 +90,22 @@ func addIssueComment(client *github.Client, t translations.TranslationHelperFunc ), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - owner := request.Params.Arguments["owner"].(string) - repo := request.Params.Arguments["repo"].(string) - issueNumber := int(request.Params.Arguments["issue_number"].(float64)) - body := request.Params.Arguments["body"].(string) + owner, err := requiredStringParam(request, "owner") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + repo, err := requiredStringParam(request, "repo") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + issueNumber, err := requiredNumberParam(request, "issue_number") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + body, err := requiredStringParam(request, "body") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } comment := &github.IssueComment{ Body: github.Ptr(body), @@ -135,22 +156,25 @@ func searchIssues(client *github.Client, t translations.TranslationHelperFunc) ( ), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - query := request.Params.Arguments["q"].(string) - sort := "" - if s, ok := request.Params.Arguments["sort"].(string); ok { - sort = s + query, err := requiredStringParam(request, "q") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil } - order := "" - if o, ok := request.Params.Arguments["order"].(string); ok { - order = o + sort, err := optionalStringParam(request, "sort") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil } - perPage := 30 - if pp, ok := request.Params.Arguments["per_page"].(float64); ok { - perPage = int(pp) + order, err := optionalStringParam(request, "order") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil } - page := 1 - if p, ok := request.Params.Arguments["page"].(float64); ok { - page = int(p) + perPage, err := optionalNumberParamWithDefault(request, "per_page", 30) + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + page, err := optionalNumberParamWithDefault(request, "page", 1) + if err != nil { + return mcp.NewToolResultError(err.Error()), nil } opts := &github.SearchOptions{ @@ -212,26 +236,34 @@ func createIssue(client *github.Client, t translations.TranslationHelperFunc) (t ), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - owner := request.Params.Arguments["owner"].(string) - repo := request.Params.Arguments["repo"].(string) - title := request.Params.Arguments["title"].(string) + owner, err := requiredStringParam(request, "owner") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + repo, err := requiredStringParam(request, "repo") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + title, err := requiredStringParam(request, "title") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } // Optional parameters - var body string - if b, ok := request.Params.Arguments["body"].(string); ok { - body = b + body, err := optionalStringParam(request, "body") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil } - // Parse assignees if present - assignees := []string{} // default to empty slice, can't be nil - if a, ok := request.Params.Arguments["assignees"].(string); ok && a != "" { - assignees = parseCommaSeparatedList(a) + // Get assignees + assignees, err := optionalCommaSeparatedListParam(request, "assignees") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil } - - // Parse labels if present - labels := []string{} // default to empty slice, can't be nil - if l, ok := request.Params.Arguments["labels"].(string); ok && l != "" { - labels = parseCommaSeparatedList(l) + // Get labels + labels, err := optionalCommaSeparatedListParam(request, "labels") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil } // Create the issue request @@ -300,29 +332,43 @@ func listIssues(client *github.Client, t translations.TranslationHelperFunc) (to ), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - owner := request.Params.Arguments["owner"].(string) - repo := request.Params.Arguments["repo"].(string) + owner, err := requiredStringParam(request, "owner") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + repo, err := requiredStringParam(request, "repo") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } opts := &github.IssueListByRepoOptions{} // Set optional parameters if provided - if state, ok := request.Params.Arguments["state"].(string); ok && state != "" { - opts.State = state + opts.State, err = optionalStringParam(request, "state") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil } - if labels, ok := request.Params.Arguments["labels"].(string); ok && labels != "" { - opts.Labels = parseCommaSeparatedList(labels) + opts.Labels, err = optionalCommaSeparatedListParam(request, "labels") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil } - if sort, ok := request.Params.Arguments["sort"].(string); ok && sort != "" { - opts.Sort = sort + opts.Sort, err = optionalStringParam(request, "sort") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil } - if direction, ok := request.Params.Arguments["direction"].(string); ok && direction != "" { - opts.Direction = direction + opts.Direction, err = optionalStringParam(request, "direction") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil } - if since, ok := request.Params.Arguments["since"].(string); ok && since != "" { + since, err := optionalStringParam(request, "since") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + if since != "" { timestamp, err := parseISOTimestamp(since) if err != nil { return mcp.NewToolResultError(fmt.Sprintf("failed to list issues: %s", err.Error())), nil @@ -397,38 +443,69 @@ func updateIssue(client *github.Client, t translations.TranslationHelperFunc) (t ), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - owner := request.Params.Arguments["owner"].(string) - repo := request.Params.Arguments["repo"].(string) - issueNumber := int(request.Params.Arguments["issue_number"].(float64)) + owner, err := requiredStringParam(request, "owner") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + repo, err := requiredStringParam(request, "repo") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + issueNumber, err := requiredNumberParam(request, "issue_number") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } // Create the issue request with only provided fields issueRequest := &github.IssueRequest{} // Set optional parameters if provided - if title, ok := request.Params.Arguments["title"].(string); ok && title != "" { + title, err := optionalStringParam(request, "title") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + if title != "" { issueRequest.Title = github.Ptr(title) } - if body, ok := request.Params.Arguments["body"].(string); ok && body != "" { + body, err := optionalStringParam(request, "body") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + if body != "" { issueRequest.Body = github.Ptr(body) } - if state, ok := request.Params.Arguments["state"].(string); ok && state != "" { + state, err := optionalStringParam(request, "state") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + if state != "" { issueRequest.State = github.Ptr(state) } - if labels, ok := request.Params.Arguments["labels"].(string); ok && labels != "" { - labelsList := parseCommaSeparatedList(labels) - issueRequest.Labels = &labelsList + labels, err := optionalCommaSeparatedListParam(request, "labels") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + if len(labels) > 0 { + issueRequest.Labels = &labels } - if assignees, ok := request.Params.Arguments["assignees"].(string); ok && assignees != "" { - assigneesList := parseCommaSeparatedList(assignees) - issueRequest.Assignees = &assigneesList + assignees, err := optionalCommaSeparatedListParam(request, "assignees") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + if len(assignees) > 0 { + issueRequest.Assignees = &assignees } - if milestone, ok := request.Params.Arguments["milestone"].(float64); ok { - milestoneNum := int(milestone) + milestone, err := optionalNumberParam(request, "milestone") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + if milestone != 0 { + milestoneNum := milestone issueRequest.Milestone = &milestoneNum } diff --git a/pkg/github/issues_test.go b/pkg/github/issues_test.go index 4e8250fd2..c2de65797 100644 --- a/pkg/github/issues_test.go +++ b/pkg/github/issues_test.go @@ -176,8 +176,8 @@ func Test_AddIssueComment(t *testing.T) { "issue_number": float64(42), "body": "", }, - expectError: true, - expectedErrMsg: "failed to create comment", + expectError: false, + expectedErrMsg: "missing required parameter: body", }, } @@ -210,6 +210,13 @@ func Test_AddIssueComment(t *testing.T) { return } + if tc.expectedErrMsg != "" { + require.NotNil(t, result) + textContent := getTextResult(t, result) + assert.Contains(t, textContent.Text, tc.expectedErrMsg) + return + } + require.NoError(t, err) // Parse the result and get the text content if no error @@ -419,8 +426,8 @@ func Test_CreateIssue(t *testing.T) { "repo": "repo", "title": "Test Issue", "body": "This is a test issue", - "assignees": []interface{}{"user1", "user2"}, - "labels": []interface{}{"bug", "help wanted"}, + "assignees": "user1, user2", + "labels": "bug, help wanted", }, expectError: false, expectedIssue: mockIssue, @@ -467,8 +474,8 @@ func Test_CreateIssue(t *testing.T) { "repo": "repo", "title": "", }, - expectError: true, - expectedErrMsg: "failed to create issue", + expectError: false, + expectedErrMsg: "missing required parameter: title", }, } @@ -491,6 +498,13 @@ func Test_CreateIssue(t *testing.T) { return } + if tc.expectedErrMsg != "" { + require.NotNil(t, result) + textContent := getTextResult(t, result) + assert.Contains(t, textContent.Text, tc.expectedErrMsg) + return + } + require.NoError(t, err) textContent := getTextResult(t, result) diff --git a/pkg/github/pullrequests.go b/pkg/github/pullrequests.go index e0414394a..d5caab97c 100644 --- a/pkg/github/pullrequests.go +++ b/pkg/github/pullrequests.go @@ -31,9 +31,18 @@ func getPullRequest(client *github.Client, t translations.TranslationHelperFunc) ), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - owner := request.Params.Arguments["owner"].(string) - repo := request.Params.Arguments["repo"].(string) - pullNumber := int(request.Params.Arguments["pull_number"].(float64)) + owner, err := requiredStringParam(request, "owner") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + repo, err := requiredStringParam(request, "repo") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + pullNumber, err := requiredNumberParam(request, "pull_number") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } pr, resp, err := client.PullRequests.Get(ctx, owner, repo, pullNumber) if err != nil { @@ -93,35 +102,41 @@ func listPullRequests(client *github.Client, t translations.TranslationHelperFun ), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - owner := request.Params.Arguments["owner"].(string) - repo := request.Params.Arguments["repo"].(string) - state := "" - if s, ok := request.Params.Arguments["state"].(string); ok { - state = s + owner, err := requiredStringParam(request, "owner") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + repo, err := requiredStringParam(request, "repo") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil } - head := "" - if h, ok := request.Params.Arguments["head"].(string); ok { - head = h + state, err := optionalStringParam(request, "state") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil } - base := "" - if b, ok := request.Params.Arguments["base"].(string); ok { - base = b + head, err := optionalStringParam(request, "head") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil } - sort := "" - if s, ok := request.Params.Arguments["sort"].(string); ok { - sort = s + base, err := optionalStringParam(request, "base") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil } - direction := "" - if d, ok := request.Params.Arguments["direction"].(string); ok { - direction = d + sort, err := optionalStringParam(request, "sort") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil } - perPage := 30 - if pp, ok := request.Params.Arguments["per_page"].(float64); ok { - perPage = int(pp) + direction, err := optionalStringParam(request, "direction") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil } - page := 1 - if p, ok := request.Params.Arguments["page"].(float64); ok { - page = int(p) + perPage, err := optionalNumberParamWithDefault(request, "per_page", 30) + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + page, err := optionalNumberParamWithDefault(request, "page", 1) + if err != nil { + return mcp.NewToolResultError(err.Error()), nil } opts := &github.PullRequestListOptions{ @@ -186,20 +201,29 @@ func mergePullRequest(client *github.Client, t translations.TranslationHelperFun ), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - owner := request.Params.Arguments["owner"].(string) - repo := request.Params.Arguments["repo"].(string) - pullNumber := int(request.Params.Arguments["pull_number"].(float64)) - commitTitle := "" - if ct, ok := request.Params.Arguments["commit_title"].(string); ok { - commitTitle = ct + owner, err := requiredStringParam(request, "owner") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil } - commitMessage := "" - if cm, ok := request.Params.Arguments["commit_message"].(string); ok { - commitMessage = cm + repo, err := requiredStringParam(request, "repo") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil } - mergeMethod := "" - if mm, ok := request.Params.Arguments["merge_method"].(string); ok { - mergeMethod = mm + pullNumber, err := requiredNumberParam(request, "pull_number") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + commitTitle, err := optionalStringParam(request, "commit_title") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + commitMessage, err := optionalStringParam(request, "commit_message") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + mergeMethod, err := optionalStringParam(request, "merge_method") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil } options := &github.PullRequestOptions{ @@ -248,9 +272,18 @@ func getPullRequestFiles(client *github.Client, t translations.TranslationHelper ), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - owner := request.Params.Arguments["owner"].(string) - repo := request.Params.Arguments["repo"].(string) - pullNumber := int(request.Params.Arguments["pull_number"].(float64)) + owner, err := requiredStringParam(request, "owner") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + repo, err := requiredStringParam(request, "repo") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + pullNumber, err := requiredNumberParam(request, "pull_number") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } opts := &github.ListOptions{} files, resp, err := client.PullRequests.ListFiles(ctx, owner, repo, pullNumber, opts) @@ -294,10 +327,18 @@ func getPullRequestStatus(client *github.Client, t translations.TranslationHelpe ), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - owner := request.Params.Arguments["owner"].(string) - repo := request.Params.Arguments["repo"].(string) - pullNumber := int(request.Params.Arguments["pull_number"].(float64)) - + owner, err := requiredStringParam(request, "owner") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + repo, err := requiredStringParam(request, "repo") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + pullNumber, err := requiredNumberParam(request, "pull_number") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } // First get the PR to find the head SHA pr, resp, err := client.PullRequests.Get(ctx, owner, repo, pullNumber) if err != nil { @@ -358,14 +399,22 @@ func updatePullRequestBranch(client *github.Client, t translations.TranslationHe ), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - owner := request.Params.Arguments["owner"].(string) - repo := request.Params.Arguments["repo"].(string) - pullNumber := int(request.Params.Arguments["pull_number"].(float64)) - expectedHeadSHA := "" - if sha, ok := request.Params.Arguments["expected_head_sha"].(string); ok { - expectedHeadSHA = sha + owner, err := requiredStringParam(request, "owner") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + repo, err := requiredStringParam(request, "repo") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + pullNumber, err := requiredNumberParam(request, "pull_number") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + expectedHeadSHA, err := optionalStringParam(request, "expected_head_sha") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil } - opts := &github.PullRequestBranchUpdateOptions{} if expectedHeadSHA != "" { opts.ExpectedHeadSHA = github.Ptr(expectedHeadSHA) @@ -417,9 +466,18 @@ func getPullRequestComments(client *github.Client, t translations.TranslationHel ), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - owner := request.Params.Arguments["owner"].(string) - repo := request.Params.Arguments["repo"].(string) - pullNumber := int(request.Params.Arguments["pull_number"].(float64)) + owner, err := requiredStringParam(request, "owner") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + repo, err := requiredStringParam(request, "repo") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + pullNumber, err := requiredNumberParam(request, "pull_number") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } opts := &github.PullRequestListCommentsOptions{ ListOptions: github.ListOptions{ @@ -468,9 +526,18 @@ func getPullRequestReviews(client *github.Client, t translations.TranslationHelp ), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - owner := request.Params.Arguments["owner"].(string) - repo := request.Params.Arguments["repo"].(string) - pullNumber := int(request.Params.Arguments["pull_number"].(float64)) + owner, err := requiredStringParam(request, "owner") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + repo, err := requiredStringParam(request, "repo") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + pullNumber, err := requiredNumberParam(request, "pull_number") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } reviews, resp, err := client.PullRequests.ListReviews(ctx, owner, repo, pullNumber, nil) if err != nil { @@ -526,10 +593,22 @@ func createPullRequestReview(client *github.Client, t translations.TranslationHe ), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - owner := request.Params.Arguments["owner"].(string) - repo := request.Params.Arguments["repo"].(string) - pullNumber := int(request.Params.Arguments["pull_number"].(float64)) - event := request.Params.Arguments["event"].(string) + owner, err := requiredStringParam(request, "owner") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + repo, err := requiredStringParam(request, "repo") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + pullNumber, err := requiredNumberParam(request, "pull_number") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + event, err := requiredStringParam(request, "event") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } // Create review request reviewRequest := &github.PullRequestReviewRequest{ @@ -537,12 +616,20 @@ func createPullRequestReview(client *github.Client, t translations.TranslationHe } // Add body if provided - if body, ok := request.Params.Arguments["body"].(string); ok && body != "" { + body, err := optionalStringParam(request, "body") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + if body != "" { reviewRequest.Body = github.Ptr(body) } // Add commit ID if provided - if commitID, ok := request.Params.Arguments["commit_id"].(string); ok && commitID != "" { + commitID, err := optionalStringParam(request, "commit_id") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + if commitID != "" { reviewRequest.CommitID = github.Ptr(commitID) } diff --git a/pkg/github/repositories.go b/pkg/github/repositories.go index 6e3b176df..f222b1f80 100644 --- a/pkg/github/repositories.go +++ b/pkg/github/repositories.go @@ -37,19 +37,25 @@ func listCommits(client *github.Client, t translations.TranslationHelperFunc) (t ), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - owner := request.Params.Arguments["owner"].(string) - repo := request.Params.Arguments["repo"].(string) - sha := "" - if s, ok := request.Params.Arguments["sha"].(string); ok { - sha = s + owner, err := requiredStringParam(request, "owner") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + repo, err := requiredStringParam(request, "repo") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil } - page := 1 - if p, ok := request.Params.Arguments["page"].(float64); ok { - page = int(p) + sha, err := optionalStringParam(request, "sha") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil } - perPage := 30 - if pp, ok := request.Params.Arguments["per_page"].(float64); ok { - perPage = int(pp) + page, err := optionalNumberParamWithDefault(request, "page", 1) + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + perPage, err := optionalNumberParamWithDefault(request, "per_page", 30) + if err != nil { + return mcp.NewToolResultError(err.Error()), nil } opts := &github.CommitsListOptions{ @@ -116,12 +122,30 @@ func createOrUpdateFile(client *github.Client, t translations.TranslationHelperF ), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - owner := request.Params.Arguments["owner"].(string) - repo := request.Params.Arguments["repo"].(string) - path := request.Params.Arguments["path"].(string) - content := request.Params.Arguments["content"].(string) - message := request.Params.Arguments["message"].(string) - branch := request.Params.Arguments["branch"].(string) + owner, err := requiredStringParam(request, "owner") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + repo, err := requiredStringParam(request, "repo") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + path, err := requiredStringParam(request, "path") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + content, err := requiredStringParam(request, "content") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + message, err := requiredStringParam(request, "message") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + branch, err := requiredStringParam(request, "branch") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } // Convert content to base64 contentBytes := []byte(content) @@ -134,7 +158,11 @@ func createOrUpdateFile(client *github.Client, t translations.TranslationHelperF } // If SHA is provided, set it (for updates) - if sha, ok := request.Params.Arguments["sha"].(string); ok && sha != "" { + sha, err := optionalStringParam(request, "sha") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + if sha != "" { opts.SHA = ptr.String(sha) } @@ -181,25 +209,28 @@ func createRepository(client *github.Client, t translations.TranslationHelperFun ), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - name := request.Params.Arguments["name"].(string) - description := "" - if desc, ok := request.Params.Arguments["description"].(string); ok { - description = desc + name, err := requiredStringParam(request, "name") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + description, err := optionalStringParam(request, "description") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil } - private := false - if priv, ok := request.Params.Arguments["private"].(bool); ok { - private = priv + private, err := optionalBooleanParam(request, "private") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil } - autoInit := false - if init, ok := request.Params.Arguments["auto_init"].(bool); ok { - autoInit = init + autoInit, err := optionalBooleanParam(request, "auto_init") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil } repo := &github.Repository{ - Name: github.String(name), - Description: github.String(description), - Private: github.Bool(private), - AutoInit: github.Bool(autoInit), + Name: github.Ptr(name), + Description: github.Ptr(description), + Private: github.Ptr(private), + AutoInit: github.Ptr(autoInit), } createdRepo, resp, err := client.Repositories.Create(ctx, "", repo) @@ -246,12 +277,21 @@ func getFileContents(client *github.Client, t translations.TranslationHelperFunc ), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - owner := request.Params.Arguments["owner"].(string) - repo := request.Params.Arguments["repo"].(string) - path := request.Params.Arguments["path"].(string) - branch := "" - if b, ok := request.Params.Arguments["branch"].(string); ok { - branch = b + owner, err := requiredStringParam(request, "owner") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + repo, err := requiredStringParam(request, "repo") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + path, err := requiredStringParam(request, "path") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + branch, err := optionalStringParam(request, "branch") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil } opts := &github.RepositoryContentGetOptions{Ref: branch} @@ -302,11 +342,17 @@ func forkRepository(client *github.Client, t translations.TranslationHelperFunc) ), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - owner := request.Params.Arguments["owner"].(string) - repo := request.Params.Arguments["repo"].(string) - org := "" - if o, ok := request.Params.Arguments["organization"].(string); ok { - org = o + owner, err := requiredStringParam(request, "owner") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + repo, err := requiredStringParam(request, "repo") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + org, err := optionalStringParam(request, "organization") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil } opts := &github.RepositoryCreateForkOptions{} @@ -363,17 +409,25 @@ func createBranch(client *github.Client, t translations.TranslationHelperFunc) ( ), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - owner := request.Params.Arguments["owner"].(string) - repo := request.Params.Arguments["repo"].(string) - branch := request.Params.Arguments["branch"].(string) - fromBranch := "" - if fb, ok := request.Params.Arguments["from_branch"].(string); ok { - fromBranch = fb + owner, err := requiredStringParam(request, "owner") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + repo, err := requiredStringParam(request, "repo") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + branch, err := requiredStringParam(request, "branch") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + fromBranch, err := optionalStringParam(request, "from_branch") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil } // Get the source branch SHA var ref *github.Reference - var err error if fromBranch == "" { // Get default branch if from_branch not specified @@ -440,10 +494,22 @@ func pushFiles(client *github.Client, t translations.TranslationHelperFunc) (too ), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - owner := request.Params.Arguments["owner"].(string) - repo := request.Params.Arguments["repo"].(string) - branch := request.Params.Arguments["branch"].(string) - message := request.Params.Arguments["message"].(string) + owner, err := requiredStringParam(request, "owner") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + repo, err := requiredStringParam(request, "repo") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + branch, err := requiredStringParam(request, "branch") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + message, err := requiredStringParam(request, "message") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } // Parse files parameter - this should be an array of objects with path and content filesObj, ok := request.Params.Arguments["files"].([]interface{}) diff --git a/pkg/github/search.go b/pkg/github/search.go index 353c6fb21..d7ea49049 100644 --- a/pkg/github/search.go +++ b/pkg/github/search.go @@ -28,14 +28,17 @@ func searchRepositories(client *github.Client, t translations.TranslationHelperF ), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - query := request.Params.Arguments["query"].(string) - page := 1 - if p, ok := request.Params.Arguments["page"].(float64); ok { - page = int(p) + query, err := requiredStringParam(request, "query") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil } - perPage := 30 - if pp, ok := request.Params.Arguments["per_page"].(float64); ok { - perPage = int(pp) + page, err := optionalNumberParamWithDefault(request, "page", 1) + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + perPage, err := optionalNumberParamWithDefault(request, "per_page", 30) + if err != nil { + return mcp.NewToolResultError(err.Error()), nil } opts := &github.SearchOptions{ @@ -90,22 +93,25 @@ func searchCode(client *github.Client, t translations.TranslationHelperFunc) (to ), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - query := request.Params.Arguments["q"].(string) - sort := "" - if s, ok := request.Params.Arguments["sort"].(string); ok { - sort = s + query, err := requiredStringParam(request, "q") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil } - order := "" - if o, ok := request.Params.Arguments["order"].(string); ok { - order = o + sort, err := optionalStringParam(request, "sort") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil } - perPage := 30 - if pp, ok := request.Params.Arguments["per_page"].(float64); ok { - perPage = int(pp) + order, err := optionalStringParam(request, "order") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil } - page := 1 - if p, ok := request.Params.Arguments["page"].(float64); ok { - page = int(p) + perPage, err := optionalNumberParamWithDefault(request, "per_page", 30) + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + page, err := optionalNumberParamWithDefault(request, "page", 1) + if err != nil { + return mcp.NewToolResultError(err.Error()), nil } opts := &github.SearchOptions{ @@ -162,22 +168,25 @@ func searchUsers(client *github.Client, t translations.TranslationHelperFunc) (t ), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - query := request.Params.Arguments["q"].(string) - sort := "" - if s, ok := request.Params.Arguments["sort"].(string); ok { - sort = s - } - order := "" - if o, ok := request.Params.Arguments["order"].(string); ok { - order = o - } - perPage := 30 - if pp, ok := request.Params.Arguments["per_page"].(float64); ok { - perPage = int(pp) - } - page := 1 - if p, ok := request.Params.Arguments["page"].(float64); ok { - page = int(p) + query, err := requiredStringParam(request, "q") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + sort, err := optionalStringParam(request, "sort") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + order, err := optionalStringParam(request, "order") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + perPage, err := optionalNumberParamWithDefault(request, "per_page", 30) + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + page, err := optionalNumberParamWithDefault(request, "page", 1) + if err != nil { + return mcp.NewToolResultError(err.Error()), nil } opts := &github.SearchOptions{ diff --git a/pkg/github/server.go b/pkg/github/server.go index a0993e2f3..42e230833 100644 --- a/pkg/github/server.go +++ b/pkg/github/server.go @@ -138,3 +138,118 @@ func parseCommaSeparatedList(input string) []string { return result } + +// requiredStringParam checks if the parameter is present in the request and is of type string. +func requiredStringParam(r mcp.CallToolRequest, p string) (string, error) { + // Check if the parameter is present in the request + if _, ok := r.Params.Arguments[p]; !ok { + return "", fmt.Errorf("missing required parameter: %s", p) + } + + // Check if the parameter is of the expected type + if _, ok := r.Params.Arguments[p].(string); !ok { + return "", fmt.Errorf("parameter %s is not of type string", p) + } + + // Check if the parameter is not the zero value + v := r.Params.Arguments[p].(string) + if v == "" { + return v, fmt.Errorf("missing required parameter: %s", p) + } + + return v, nil +} + +// requiredNumberParam checks if the parameter is present in the request and is of type number. +func requiredNumberParam(r mcp.CallToolRequest, p string) (int, error) { + // Check if the parameter is present in the request + if _, ok := r.Params.Arguments[p]; !ok { + return 0, fmt.Errorf("missing required parameter: %s", p) + } + + // Check if the parameter is of the expected type + if _, ok := r.Params.Arguments[p].(float64); !ok { + return 0, fmt.Errorf("parameter %s is not of type number", p) + } + + return int(r.Params.Arguments[p].(float64)), nil +} + +// optionalStringParam checks if an optional parameter is present in the request and is of type string. +func optionalStringParam(r mcp.CallToolRequest, p string) (value string, err error) { + // Check if the parameter is present in the request + if _, ok := r.Params.Arguments[p]; !ok { + return "", nil + } + + // Check if the parameter is of the expected type + if _, ok := r.Params.Arguments[p].(string); !ok { + return "", fmt.Errorf("parameter %s is not of type string", p) + } + + return r.Params.Arguments[p].(string), nil +} + +// optionalNumberParam checks if an optional parameter is present in the request and is of type number. +func optionalNumberParam(r mcp.CallToolRequest, p string) (int, error) { + // Check if the parameter is present in the request + if _, ok := r.Params.Arguments[p]; !ok { + return 0, nil + } + + // Check if the parameter is of the expected type + if _, ok := r.Params.Arguments[p].(float64); !ok { + return 0, fmt.Errorf("parameter %s is not of type number", p) + } + + return int(r.Params.Arguments[p].(float64)), nil +} + +// optionalNumberParamWithDefault checks if an optional parameter is present in the request and is of type number. +// If the parameter is not present or is zero, it returns the default value. +func optionalNumberParamWithDefault(r mcp.CallToolRequest, p string, d int) (int, error) { + v, err := optionalNumberParam(r, p) + if err != nil { + return 0, err + } + if v == 0 { + return d, nil + } + return v, nil +} + +// optionalCommaSeparatedListParam checks if an optional parameter is present in the request and is of type string. +// If the parameter is presents, it uses parseCommaSeparatedList to parse the string into a list of strings. +// If the parameter is not present or is empty, it returns an empty list. +func optionalCommaSeparatedListParam(r mcp.CallToolRequest, p string) ([]string, error) { + // Check if the parameter is present in the request + if _, ok := r.Params.Arguments[p]; !ok { + return []string{}, nil //default to empty list, not nil + } + + // Check if the parameter is of the expected type + if _, ok := r.Params.Arguments[p].(string); !ok { + return nil, fmt.Errorf("parameter %s is not of type string", p) + } + + l := parseCommaSeparatedList(r.Params.Arguments[p].(string)) + if len(l) == 0 { + return []string{}, nil // default to empty list, not nil + } + return l, nil +} + +// optionalBooleanParam checks if an optional parameter is present in the request and is of type boolean. +func optionalBooleanParam(r mcp.CallToolRequest, p string) (bool, error) { + // Check if the parameter is present in the request + if _, ok := r.Params.Arguments[p]; !ok { + return false, nil + } + + // Check if the parameter is of the expected type + if _, ok := r.Params.Arguments[p].(bool); !ok { + return false, fmt.Errorf("parameter %s is not of type bool", p) + } + + return r.Params.Arguments[p].(bool), nil +} diff --git a/pkg/github/server_test.go b/pkg/github/server_test.go index 316a0efa2..a081d31de 100644 --- a/pkg/github/server_test.go +++ b/pkg/github/server_test.go @@ -228,3 +228,372 @@ func Test_ParseCommaSeparatedList(t *testing.T) { }) } } + +func Test_RequiredStringParam(t *testing.T) { + tests := []struct { + name string + params map[string]interface{} + paramName string + expected string + expectError bool + }{ + { + name: "valid string parameter", + params: map[string]interface{}{"name": "test-value"}, + paramName: "name", + expected: "test-value", + expectError: false, + }, + { + name: "missing parameter", + params: map[string]interface{}{}, + paramName: "name", + expected: "", + expectError: true, + }, + { + name: "empty string parameter", + params: map[string]interface{}{"name": ""}, + paramName: "name", + expected: "", + expectError: true, + }, + { + name: "wrong type parameter", + params: map[string]interface{}{"name": 123}, + paramName: "name", + expected: "", + expectError: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + request := createMCPRequest(tc.params) + result, err := requiredStringParam(request, tc.paramName) + + if tc.expectError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tc.expected, result) + } + }) + } +} + +func Test_OptionalStringParam(t *testing.T) { + tests := []struct { + name string + params map[string]interface{} + paramName string + expected string + expectError bool + }{ + { + name: "valid string parameter", + params: map[string]interface{}{"name": "test-value"}, + paramName: "name", + expected: "test-value", + expectError: false, + }, + { + name: "missing parameter", + params: map[string]interface{}{}, + paramName: "name", + expected: "", + expectError: false, + }, + { + name: "empty string parameter", + params: map[string]interface{}{"name": ""}, + paramName: "name", + expected: "", + expectError: false, + }, + { + name: "wrong type parameter", + params: map[string]interface{}{"name": 123}, + paramName: "name", + expected: "", + expectError: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + request := createMCPRequest(tc.params) + result, err := optionalStringParam(request, tc.paramName) + + if tc.expectError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tc.expected, result) + } + }) + } +} + +func Test_RequiredNumberParam(t *testing.T) { + tests := []struct { + name string + params map[string]interface{} + paramName string + expected int + expectError bool + }{ + { + name: "valid number parameter", + params: map[string]interface{}{"count": float64(42)}, + paramName: "count", + expected: 42, + expectError: false, + }, + { + name: "missing parameter", + params: map[string]interface{}{}, + paramName: "count", + expected: 0, + expectError: true, + }, + { + name: "wrong type parameter", + params: map[string]interface{}{"count": "not-a-number"}, + paramName: "count", + expected: 0, + expectError: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + request := createMCPRequest(tc.params) + result, err := requiredNumberParam(request, tc.paramName) + + if tc.expectError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tc.expected, result) + } + }) + } +} + +func Test_OptionalNumberParam(t *testing.T) { + tests := []struct { + name string + params map[string]interface{} + paramName string + expected int + expectError bool + }{ + { + name: "valid number parameter", + params: map[string]interface{}{"count": float64(42)}, + paramName: "count", + expected: 42, + expectError: false, + }, + { + name: "missing parameter", + params: map[string]interface{}{}, + paramName: "count", + expected: 0, + expectError: false, + }, + { + name: "zero value", + params: map[string]interface{}{"count": float64(0)}, + paramName: "count", + expected: 0, + expectError: false, + }, + { + name: "wrong type parameter", + params: map[string]interface{}{"count": "not-a-number"}, + paramName: "count", + expected: 0, + expectError: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + request := createMCPRequest(tc.params) + result, err := optionalNumberParam(request, tc.paramName) + + if tc.expectError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tc.expected, result) + } + }) + } +} + +func Test_OptionalNumberParamWithDefault(t *testing.T) { + tests := []struct { + name string + params map[string]interface{} + paramName string + defaultVal int + expected int + expectError bool + }{ + { + name: "valid number parameter", + params: map[string]interface{}{"count": float64(42)}, + paramName: "count", + defaultVal: 10, + expected: 42, + expectError: false, + }, + { + name: "missing parameter", + params: map[string]interface{}{}, + paramName: "count", + defaultVal: 10, + expected: 10, + expectError: false, + }, + { + name: "zero value", + params: map[string]interface{}{"count": float64(0)}, + paramName: "count", + defaultVal: 10, + expected: 10, + expectError: false, + }, + { + name: "wrong type parameter", + params: map[string]interface{}{"count": "not-a-number"}, + paramName: "count", + defaultVal: 10, + expected: 0, + expectError: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + request := createMCPRequest(tc.params) + result, err := optionalNumberParamWithDefault(request, tc.paramName, tc.defaultVal) + + if tc.expectError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tc.expected, result) + } + }) + } +} + +func Test_OptionalCommaSeparatedListParam(t *testing.T) { + tests := []struct { + name string + params map[string]interface{} + paramName string + expected []string + expectError bool + }{ + { + name: "valid comma-separated list", + params: map[string]interface{}{"tags": "one,two,three"}, + paramName: "tags", + expected: []string{"one", "two", "three"}, + expectError: false, + }, + { + name: "empty list", + params: map[string]interface{}{"tags": ""}, + paramName: "tags", + expected: []string{}, + expectError: false, + }, + { + name: "missing parameter", + params: map[string]interface{}{}, + paramName: "tags", + expected: []string{}, + expectError: false, + }, + { + name: "wrong type parameter", + params: map[string]interface{}{"tags": 123}, + paramName: "tags", + expected: nil, + expectError: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + request := createMCPRequest(tc.params) + result, err := optionalCommaSeparatedListParam(request, tc.paramName) + + if tc.expectError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tc.expected, result) + } + }) + } +} + +func Test_OptionalBooleanParam(t *testing.T) { + tests := []struct { + name string + params map[string]interface{} + paramName string + expected bool + expectError bool + }{ + { + name: "true value", + params: map[string]interface{}{"flag": true}, + paramName: "flag", + expected: true, + expectError: false, + }, + { + name: "false value", + params: map[string]interface{}{"flag": false}, + paramName: "flag", + expected: false, + expectError: false, + }, + { + name: "missing parameter", + params: map[string]interface{}{}, + paramName: "flag", + expected: false, + expectError: false, + }, + { + name: "wrong type parameter", + params: map[string]interface{}{"flag": "not-a-boolean"}, + paramName: "flag", + expected: false, + expectError: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + request := createMCPRequest(tc.params) + result, err := optionalBooleanParam(request, tc.paramName) + + if tc.expectError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tc.expected, result) + } + }) + } +} From 3c6467ffef6b2493d69533d34776d5cdd97b65ce Mon Sep 17 00:00:00 2001 From: Javier Uruen Val Date: Mon, 24 Mar 2025 14:47:54 +0100 Subject: [PATCH 2/2] use generic for helper functions --- pkg/github/code_scanning.go | 16 ++--- pkg/github/issues.go | 58 ++++++++-------- pkg/github/pullrequests.go | 80 +++++++++++----------- pkg/github/repositories.go | 62 ++++++++--------- pkg/github/search.go | 26 +++---- pkg/github/server.go | 132 ++++++++++++++++-------------------- pkg/github/server_test.go | 12 ++-- 7 files changed, 186 insertions(+), 200 deletions(-) diff --git a/pkg/github/code_scanning.go b/pkg/github/code_scanning.go index e7c8a4e22..380dc02cf 100644 --- a/pkg/github/code_scanning.go +++ b/pkg/github/code_scanning.go @@ -30,15 +30,15 @@ func getCodeScanningAlert(client *github.Client, t translations.TranslationHelpe ), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - owner, err := requiredStringParam(request, "owner") + owner, err := requiredParam[string](request, "owner") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - repo, err := requiredStringParam(request, "repo") + repo, err := requiredParam[string](request, "repo") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - alertNumber, err := requiredNumberParam(request, "alert_number") + alertNumber, err := requiredInt(request, "alert_number") if err != nil { return mcp.NewToolResultError(err.Error()), nil } @@ -89,23 +89,23 @@ func listCodeScanningAlerts(client *github.Client, t translations.TranslationHel ), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - owner, err := requiredStringParam(request, "owner") + owner, err := requiredParam[string](request, "owner") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - repo, err := requiredStringParam(request, "repo") + repo, err := requiredParam[string](request, "repo") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - ref, err := optionalStringParam(request, "ref") + ref, err := optionalParam[string](request, "ref") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - state, err := optionalStringParam(request, "state") + state, err := optionalParam[string](request, "state") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - severity, err := optionalStringParam(request, "severity") + severity, err := optionalParam[string](request, "severity") if err != nil { return mcp.NewToolResultError(err.Error()), nil } diff --git a/pkg/github/issues.go b/pkg/github/issues.go index 521acfdaf..a62213ea6 100644 --- a/pkg/github/issues.go +++ b/pkg/github/issues.go @@ -32,15 +32,15 @@ func getIssue(client *github.Client, t translations.TranslationHelperFunc) (tool ), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - owner, err := requiredStringParam(request, "owner") + owner, err := requiredParam[string](request, "owner") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - repo, err := requiredStringParam(request, "repo") + repo, err := requiredParam[string](request, "repo") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - issueNumber, err := requiredNumberParam(request, "issue_number") + issueNumber, err := requiredInt(request, "issue_number") if err != nil { return mcp.NewToolResultError(err.Error()), nil } @@ -90,19 +90,19 @@ func addIssueComment(client *github.Client, t translations.TranslationHelperFunc ), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - owner, err := requiredStringParam(request, "owner") + owner, err := requiredParam[string](request, "owner") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - repo, err := requiredStringParam(request, "repo") + repo, err := requiredParam[string](request, "repo") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - issueNumber, err := requiredNumberParam(request, "issue_number") + issueNumber, err := requiredInt(request, "issue_number") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - body, err := requiredStringParam(request, "body") + body, err := requiredParam[string](request, "body") if err != nil { return mcp.NewToolResultError(err.Error()), nil } @@ -156,23 +156,23 @@ func searchIssues(client *github.Client, t translations.TranslationHelperFunc) ( ), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - query, err := requiredStringParam(request, "q") + query, err := requiredParam[string](request, "q") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - sort, err := optionalStringParam(request, "sort") + sort, err := optionalParam[string](request, "sort") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - order, err := optionalStringParam(request, "order") + order, err := optionalParam[string](request, "order") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - perPage, err := optionalNumberParamWithDefault(request, "per_page", 30) + perPage, err := optionalIntParamWithDefault(request, "per_page", 30) if err != nil { return mcp.NewToolResultError(err.Error()), nil } - page, err := optionalNumberParamWithDefault(request, "page", 1) + page, err := optionalIntParamWithDefault(request, "page", 1) if err != nil { return mcp.NewToolResultError(err.Error()), nil } @@ -236,21 +236,21 @@ func createIssue(client *github.Client, t translations.TranslationHelperFunc) (t ), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - owner, err := requiredStringParam(request, "owner") + owner, err := requiredParam[string](request, "owner") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - repo, err := requiredStringParam(request, "repo") + repo, err := requiredParam[string](request, "repo") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - title, err := requiredStringParam(request, "title") + title, err := requiredParam[string](request, "title") if err != nil { return mcp.NewToolResultError(err.Error()), nil } // Optional parameters - body, err := optionalStringParam(request, "body") + body, err := optionalParam[string](request, "body") if err != nil { return mcp.NewToolResultError(err.Error()), nil } @@ -332,11 +332,11 @@ func listIssues(client *github.Client, t translations.TranslationHelperFunc) (to ), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - owner, err := requiredStringParam(request, "owner") + owner, err := requiredParam[string](request, "owner") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - repo, err := requiredStringParam(request, "repo") + repo, err := requiredParam[string](request, "repo") if err != nil { return mcp.NewToolResultError(err.Error()), nil } @@ -344,7 +344,7 @@ func listIssues(client *github.Client, t translations.TranslationHelperFunc) (to opts := &github.IssueListByRepoOptions{} // Set optional parameters if provided - opts.State, err = optionalStringParam(request, "state") + opts.State, err = optionalParam[string](request, "state") if err != nil { return mcp.NewToolResultError(err.Error()), nil } @@ -354,17 +354,17 @@ func listIssues(client *github.Client, t translations.TranslationHelperFunc) (to return mcp.NewToolResultError(err.Error()), nil } - opts.Sort, err = optionalStringParam(request, "sort") + opts.Sort, err = optionalParam[string](request, "sort") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - opts.Direction, err = optionalStringParam(request, "direction") + opts.Direction, err = optionalParam[string](request, "direction") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - since, err := optionalStringParam(request, "since") + since, err := optionalParam[string](request, "since") if err != nil { return mcp.NewToolResultError(err.Error()), nil } @@ -443,15 +443,15 @@ func updateIssue(client *github.Client, t translations.TranslationHelperFunc) (t ), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - owner, err := requiredStringParam(request, "owner") + owner, err := requiredParam[string](request, "owner") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - repo, err := requiredStringParam(request, "repo") + repo, err := requiredParam[string](request, "repo") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - issueNumber, err := requiredNumberParam(request, "issue_number") + issueNumber, err := requiredInt(request, "issue_number") if err != nil { return mcp.NewToolResultError(err.Error()), nil } @@ -460,7 +460,7 @@ func updateIssue(client *github.Client, t translations.TranslationHelperFunc) (t issueRequest := &github.IssueRequest{} // Set optional parameters if provided - title, err := optionalStringParam(request, "title") + title, err := optionalParam[string](request, "title") if err != nil { return mcp.NewToolResultError(err.Error()), nil } @@ -468,7 +468,7 @@ func updateIssue(client *github.Client, t translations.TranslationHelperFunc) (t issueRequest.Title = github.Ptr(title) } - body, err := optionalStringParam(request, "body") + body, err := optionalParam[string](request, "body") if err != nil { return mcp.NewToolResultError(err.Error()), nil } @@ -476,7 +476,7 @@ func updateIssue(client *github.Client, t translations.TranslationHelperFunc) (t issueRequest.Body = github.Ptr(body) } - state, err := optionalStringParam(request, "state") + state, err := optionalParam[string](request, "state") if err != nil { return mcp.NewToolResultError(err.Error()), nil } @@ -500,7 +500,7 @@ func updateIssue(client *github.Client, t translations.TranslationHelperFunc) (t issueRequest.Assignees = &assignees } - milestone, err := optionalNumberParam(request, "milestone") + milestone, err := optionalIntParam(request, "milestone") if err != nil { return mcp.NewToolResultError(err.Error()), nil } diff --git a/pkg/github/pullrequests.go b/pkg/github/pullrequests.go index d5caab97c..dc8b64819 100644 --- a/pkg/github/pullrequests.go +++ b/pkg/github/pullrequests.go @@ -31,15 +31,15 @@ func getPullRequest(client *github.Client, t translations.TranslationHelperFunc) ), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - owner, err := requiredStringParam(request, "owner") + owner, err := requiredParam[string](request, "owner") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - repo, err := requiredStringParam(request, "repo") + repo, err := requiredParam[string](request, "repo") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - pullNumber, err := requiredNumberParam(request, "pull_number") + pullNumber, err := requiredInt(request, "pull_number") if err != nil { return mcp.NewToolResultError(err.Error()), nil } @@ -102,39 +102,39 @@ func listPullRequests(client *github.Client, t translations.TranslationHelperFun ), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - owner, err := requiredStringParam(request, "owner") + owner, err := requiredParam[string](request, "owner") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - repo, err := requiredStringParam(request, "repo") + repo, err := requiredParam[string](request, "repo") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - state, err := optionalStringParam(request, "state") + state, err := optionalParam[string](request, "state") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - head, err := optionalStringParam(request, "head") + head, err := optionalParam[string](request, "head") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - base, err := optionalStringParam(request, "base") + base, err := optionalParam[string](request, "base") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - sort, err := optionalStringParam(request, "sort") + sort, err := optionalParam[string](request, "sort") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - direction, err := optionalStringParam(request, "direction") + direction, err := optionalParam[string](request, "direction") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - perPage, err := optionalNumberParamWithDefault(request, "per_page", 30) + perPage, err := optionalIntParamWithDefault(request, "per_page", 30) if err != nil { return mcp.NewToolResultError(err.Error()), nil } - page, err := optionalNumberParamWithDefault(request, "page", 1) + page, err := optionalIntParamWithDefault(request, "page", 1) if err != nil { return mcp.NewToolResultError(err.Error()), nil } @@ -201,27 +201,27 @@ func mergePullRequest(client *github.Client, t translations.TranslationHelperFun ), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - owner, err := requiredStringParam(request, "owner") + owner, err := requiredParam[string](request, "owner") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - repo, err := requiredStringParam(request, "repo") + repo, err := requiredParam[string](request, "repo") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - pullNumber, err := requiredNumberParam(request, "pull_number") + pullNumber, err := requiredInt(request, "pull_number") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - commitTitle, err := optionalStringParam(request, "commit_title") + commitTitle, err := optionalParam[string](request, "commit_title") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - commitMessage, err := optionalStringParam(request, "commit_message") + commitMessage, err := optionalParam[string](request, "commit_message") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - mergeMethod, err := optionalStringParam(request, "merge_method") + mergeMethod, err := optionalParam[string](request, "merge_method") if err != nil { return mcp.NewToolResultError(err.Error()), nil } @@ -272,15 +272,15 @@ func getPullRequestFiles(client *github.Client, t translations.TranslationHelper ), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - owner, err := requiredStringParam(request, "owner") + owner, err := requiredParam[string](request, "owner") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - repo, err := requiredStringParam(request, "repo") + repo, err := requiredParam[string](request, "repo") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - pullNumber, err := requiredNumberParam(request, "pull_number") + pullNumber, err := requiredInt(request, "pull_number") if err != nil { return mcp.NewToolResultError(err.Error()), nil } @@ -327,15 +327,15 @@ func getPullRequestStatus(client *github.Client, t translations.TranslationHelpe ), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - owner, err := requiredStringParam(request, "owner") + owner, err := requiredParam[string](request, "owner") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - repo, err := requiredStringParam(request, "repo") + repo, err := requiredParam[string](request, "repo") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - pullNumber, err := requiredNumberParam(request, "pull_number") + pullNumber, err := requiredInt(request, "pull_number") if err != nil { return mcp.NewToolResultError(err.Error()), nil } @@ -399,19 +399,19 @@ func updatePullRequestBranch(client *github.Client, t translations.TranslationHe ), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - owner, err := requiredStringParam(request, "owner") + owner, err := requiredParam[string](request, "owner") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - repo, err := requiredStringParam(request, "repo") + repo, err := requiredParam[string](request, "repo") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - pullNumber, err := requiredNumberParam(request, "pull_number") + pullNumber, err := requiredInt(request, "pull_number") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - expectedHeadSHA, err := optionalStringParam(request, "expected_head_sha") + expectedHeadSHA, err := optionalParam[string](request, "expected_head_sha") if err != nil { return mcp.NewToolResultError(err.Error()), nil } @@ -466,15 +466,15 @@ func getPullRequestComments(client *github.Client, t translations.TranslationHel ), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - owner, err := requiredStringParam(request, "owner") + owner, err := requiredParam[string](request, "owner") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - repo, err := requiredStringParam(request, "repo") + repo, err := requiredParam[string](request, "repo") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - pullNumber, err := requiredNumberParam(request, "pull_number") + pullNumber, err := requiredInt(request, "pull_number") if err != nil { return mcp.NewToolResultError(err.Error()), nil } @@ -526,15 +526,15 @@ func getPullRequestReviews(client *github.Client, t translations.TranslationHelp ), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - owner, err := requiredStringParam(request, "owner") + owner, err := requiredParam[string](request, "owner") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - repo, err := requiredStringParam(request, "repo") + repo, err := requiredParam[string](request, "repo") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - pullNumber, err := requiredNumberParam(request, "pull_number") + pullNumber, err := requiredInt(request, "pull_number") if err != nil { return mcp.NewToolResultError(err.Error()), nil } @@ -593,19 +593,19 @@ func createPullRequestReview(client *github.Client, t translations.TranslationHe ), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - owner, err := requiredStringParam(request, "owner") + owner, err := requiredParam[string](request, "owner") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - repo, err := requiredStringParam(request, "repo") + repo, err := requiredParam[string](request, "repo") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - pullNumber, err := requiredNumberParam(request, "pull_number") + pullNumber, err := requiredInt(request, "pull_number") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - event, err := requiredStringParam(request, "event") + event, err := requiredParam[string](request, "event") if err != nil { return mcp.NewToolResultError(err.Error()), nil } @@ -616,7 +616,7 @@ func createPullRequestReview(client *github.Client, t translations.TranslationHe } // Add body if provided - body, err := optionalStringParam(request, "body") + body, err := optionalParam[string](request, "body") if err != nil { return mcp.NewToolResultError(err.Error()), nil } @@ -625,7 +625,7 @@ func createPullRequestReview(client *github.Client, t translations.TranslationHe } // Add commit ID if provided - commitID, err := optionalStringParam(request, "commit_id") + commitID, err := optionalParam[string](request, "commit_id") if err != nil { return mcp.NewToolResultError(err.Error()), nil } diff --git a/pkg/github/repositories.go b/pkg/github/repositories.go index f222b1f80..f507b8973 100644 --- a/pkg/github/repositories.go +++ b/pkg/github/repositories.go @@ -37,23 +37,23 @@ func listCommits(client *github.Client, t translations.TranslationHelperFunc) (t ), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - owner, err := requiredStringParam(request, "owner") + owner, err := requiredParam[string](request, "owner") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - repo, err := requiredStringParam(request, "repo") + repo, err := requiredParam[string](request, "repo") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - sha, err := optionalStringParam(request, "sha") + sha, err := optionalParam[string](request, "sha") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - page, err := optionalNumberParamWithDefault(request, "page", 1) + page, err := optionalIntParamWithDefault(request, "page", 1) if err != nil { return mcp.NewToolResultError(err.Error()), nil } - perPage, err := optionalNumberParamWithDefault(request, "per_page", 30) + perPage, err := optionalIntParamWithDefault(request, "per_page", 30) if err != nil { return mcp.NewToolResultError(err.Error()), nil } @@ -122,27 +122,27 @@ func createOrUpdateFile(client *github.Client, t translations.TranslationHelperF ), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - owner, err := requiredStringParam(request, "owner") + owner, err := requiredParam[string](request, "owner") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - repo, err := requiredStringParam(request, "repo") + repo, err := requiredParam[string](request, "repo") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - path, err := requiredStringParam(request, "path") + path, err := requiredParam[string](request, "path") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - content, err := requiredStringParam(request, "content") + content, err := requiredParam[string](request, "content") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - message, err := requiredStringParam(request, "message") + message, err := requiredParam[string](request, "message") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - branch, err := requiredStringParam(request, "branch") + branch, err := requiredParam[string](request, "branch") if err != nil { return mcp.NewToolResultError(err.Error()), nil } @@ -158,7 +158,7 @@ func createOrUpdateFile(client *github.Client, t translations.TranslationHelperF } // If SHA is provided, set it (for updates) - sha, err := optionalStringParam(request, "sha") + sha, err := optionalParam[string](request, "sha") if err != nil { return mcp.NewToolResultError(err.Error()), nil } @@ -209,19 +209,19 @@ func createRepository(client *github.Client, t translations.TranslationHelperFun ), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - name, err := requiredStringParam(request, "name") + name, err := requiredParam[string](request, "name") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - description, err := optionalStringParam(request, "description") + description, err := optionalParam[string](request, "description") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - private, err := optionalBooleanParam(request, "private") + private, err := optionalParam[bool](request, "private") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - autoInit, err := optionalBooleanParam(request, "auto_init") + autoInit, err := optionalParam[bool](request, "auto_init") if err != nil { return mcp.NewToolResultError(err.Error()), nil } @@ -277,19 +277,19 @@ func getFileContents(client *github.Client, t translations.TranslationHelperFunc ), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - owner, err := requiredStringParam(request, "owner") + owner, err := requiredParam[string](request, "owner") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - repo, err := requiredStringParam(request, "repo") + repo, err := requiredParam[string](request, "repo") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - path, err := requiredStringParam(request, "path") + path, err := requiredParam[string](request, "path") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - branch, err := optionalStringParam(request, "branch") + branch, err := optionalParam[string](request, "branch") if err != nil { return mcp.NewToolResultError(err.Error()), nil } @@ -342,15 +342,15 @@ func forkRepository(client *github.Client, t translations.TranslationHelperFunc) ), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - owner, err := requiredStringParam(request, "owner") + owner, err := requiredParam[string](request, "owner") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - repo, err := requiredStringParam(request, "repo") + repo, err := requiredParam[string](request, "repo") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - org, err := optionalStringParam(request, "organization") + org, err := optionalParam[string](request, "organization") if err != nil { return mcp.NewToolResultError(err.Error()), nil } @@ -409,19 +409,19 @@ func createBranch(client *github.Client, t translations.TranslationHelperFunc) ( ), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - owner, err := requiredStringParam(request, "owner") + owner, err := requiredParam[string](request, "owner") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - repo, err := requiredStringParam(request, "repo") + repo, err := requiredParam[string](request, "repo") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - branch, err := requiredStringParam(request, "branch") + branch, err := requiredParam[string](request, "branch") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - fromBranch, err := optionalStringParam(request, "from_branch") + fromBranch, err := optionalParam[string](request, "from_branch") if err != nil { return mcp.NewToolResultError(err.Error()), nil } @@ -494,19 +494,19 @@ func pushFiles(client *github.Client, t translations.TranslationHelperFunc) (too ), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - owner, err := requiredStringParam(request, "owner") + owner, err := requiredParam[string](request, "owner") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - repo, err := requiredStringParam(request, "repo") + repo, err := requiredParam[string](request, "repo") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - branch, err := requiredStringParam(request, "branch") + branch, err := requiredParam[string](request, "branch") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - message, err := requiredStringParam(request, "message") + message, err := requiredParam[string](request, "message") if err != nil { return mcp.NewToolResultError(err.Error()), nil } diff --git a/pkg/github/search.go b/pkg/github/search.go index d7ea49049..904dc7372 100644 --- a/pkg/github/search.go +++ b/pkg/github/search.go @@ -28,15 +28,15 @@ func searchRepositories(client *github.Client, t translations.TranslationHelperF ), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - query, err := requiredStringParam(request, "query") + query, err := requiredParam[string](request, "query") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - page, err := optionalNumberParamWithDefault(request, "page", 1) + page, err := optionalIntParamWithDefault(request, "page", 1) if err != nil { return mcp.NewToolResultError(err.Error()), nil } - perPage, err := optionalNumberParamWithDefault(request, "per_page", 30) + perPage, err := optionalIntParamWithDefault(request, "per_page", 30) if err != nil { return mcp.NewToolResultError(err.Error()), nil } @@ -93,23 +93,23 @@ func searchCode(client *github.Client, t translations.TranslationHelperFunc) (to ), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - query, err := requiredStringParam(request, "q") + query, err := requiredParam[string](request, "q") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - sort, err := optionalStringParam(request, "sort") + sort, err := optionalParam[string](request, "sort") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - order, err := optionalStringParam(request, "order") + order, err := optionalParam[string](request, "order") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - perPage, err := optionalNumberParamWithDefault(request, "per_page", 30) + perPage, err := optionalIntParamWithDefault(request, "per_page", 30) if err != nil { return mcp.NewToolResultError(err.Error()), nil } - page, err := optionalNumberParamWithDefault(request, "page", 1) + page, err := optionalIntParamWithDefault(request, "page", 1) if err != nil { return mcp.NewToolResultError(err.Error()), nil } @@ -168,23 +168,23 @@ func searchUsers(client *github.Client, t translations.TranslationHelperFunc) (t ), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - query, err := requiredStringParam(request, "q") + query, err := requiredParam[string](request, "q") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - sort, err := optionalStringParam(request, "sort") + sort, err := optionalParam[string](request, "sort") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - order, err := optionalStringParam(request, "order") + order, err := optionalParam[string](request, "order") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - perPage, err := optionalNumberParamWithDefault(request, "per_page", 30) + perPage, err := optionalIntParamWithDefault(request, "per_page", 30) if err != nil { return mcp.NewToolResultError(err.Error()), nil } - page, err := optionalNumberParamWithDefault(request, "page", 1) + page, err := optionalIntParamWithDefault(request, "page", 1) if err != nil { return mcp.NewToolResultError(err.Error()), nil } diff --git a/pkg/github/server.go b/pkg/github/server.go index 42e230833..829994f19 100644 --- a/pkg/github/server.go +++ b/pkg/github/server.go @@ -139,76 +139,81 @@ func parseCommaSeparatedList(input string) []string { return result } -// requiredStringParam checks if the parameter is present in the request and is of type string. -func requiredStringParam(r mcp.CallToolRequest, p string) (string, error) { +// requiredParam is a helper function that can be used to fetch a requested parameter from the request. +// It does the following checks: +// 1. Checks if the parameter is present in the request. +// 2. Checks if the parameter is of the expected type. +// 3. Checks if the parameter is not empty, i.e: non-zero value +func requiredParam[T comparable](r mcp.CallToolRequest, p string) (T, error) { + var zero T + // Check if the parameter is present in the request if _, ok := r.Params.Arguments[p]; !ok { - return "", fmt.Errorf("missing required parameter: %s", p) + return zero, fmt.Errorf("missing required parameter: %s", p) } // Check if the parameter is of the expected type - if _, ok := r.Params.Arguments[p].(string); !ok { - return "", fmt.Errorf("parameter %s is not of type string", p) + if _, ok := r.Params.Arguments[p].(T); !ok { + return zero, fmt.Errorf("parameter %s is not of type %T", p, zero) } - // Check if the parameter is not the zero value - v := r.Params.Arguments[p].(string) - if v == "" { - return v, fmt.Errorf("missing required parameter: %s", p) + if r.Params.Arguments[p].(T) == zero { + return zero, fmt.Errorf("missing required parameter: %s", p) + } - return v, nil + return r.Params.Arguments[p].(T), nil } -// requiredNumberParam checks if the parameter is present in the request and is of type number. -func requiredNumberParam(r mcp.CallToolRequest, p string) (int, error) { - // Check if the parameter is present in the request - if _, ok := r.Params.Arguments[p]; !ok { - return 0, fmt.Errorf("missing required parameter: %s", p) - } - - // Check if the parameter is of the expected type - if _, ok := r.Params.Arguments[p].(float64); !ok { - return 0, fmt.Errorf("parameter %s is not of type number", p) +// requiredInt is a helper function that can be used to fetch a requested parameter from the request. +// It does the following checks: +// 1. Checks if the parameter is present in the request. +// 2. Checks if the parameter is of the expected type. +// 3. Checks if the parameter is not empty, i.e: non-zero value +func requiredInt(r mcp.CallToolRequest, p string) (int, error) { + v, err := requiredParam[float64](r, p) + if err != nil { + return 0, err } - - return int(r.Params.Arguments[p].(float64)), nil + return int(v), nil } -// optionalStringParam checks if an optional parameter is present in the request and is of type string. -func optionalStringParam(r mcp.CallToolRequest, p string) (value string, err error) { +// optionalParam is a helper function that can be used to fetch a requested parameter from the request. +// It does the following checks: +// 1. Checks if the parameter is present in the request, if not, it returns its zero-value +// 2. If it is present, it checks if the parameter is of the expected type and returns it +func optionalParam[T any](r mcp.CallToolRequest, p string) (T, error) { + var zero T + // Check if the parameter is present in the request if _, ok := r.Params.Arguments[p]; !ok { - return "", nil + return zero, nil } // Check if the parameter is of the expected type - if _, ok := r.Params.Arguments[p].(string); !ok { - return "", fmt.Errorf("parameter %s is not of type string", p) + if _, ok := r.Params.Arguments[p].(T); !ok { + return zero, fmt.Errorf("parameter %s is not of type %T", p, zero) } - return r.Params.Arguments[p].(string), nil + return r.Params.Arguments[p].(T), nil } -// optionalNumberParam checks if an optional parameter is present in the request and is of type number. -func optionalNumberParam(r mcp.CallToolRequest, p string) (int, error) { - // Check if the parameter is present in the request - if _, ok := r.Params.Arguments[p]; !ok { - return 0, nil - } - - // Check if the parameter is of the expected type - if _, ok := r.Params.Arguments[p].(float64); !ok { - return 0, fmt.Errorf("parameter %s is not of type number", p) +// optionalIntParam is a helper function that can be used to fetch a requested parameter from the request. +// It does the following checks: +// 1. Checks if the parameter is present in the request, if not, it returns its zero-value +// 2. If it is present, it checks if the parameter is of the expected type and returns it +func optionalIntParam(r mcp.CallToolRequest, p string) (int, error) { + v, err := optionalParam[float64](r, p) + if err != nil { + return 0, err } - - return int(r.Params.Arguments[p].(float64)), nil + return int(v), nil } -// optionalNumberParamWithDefault checks if an optional parameter is present in the request and is of type number. -// If the parameter is not present or is zero, it returns the default value. -func optionalNumberParamWithDefault(r mcp.CallToolRequest, p string, d int) (int, error) { - v, err := optionalNumberParam(r, p) +// optionalIntParamWithDefault is a helper function that can be used to fetch a requested parameter from the request +// similar to optionalIntParam, but it also takes a default value. +func optionalIntParamWithDefault(r mcp.CallToolRequest, p string, d int) (int, error) { + v, err := optionalIntParam(r, p) if err != nil { return 0, err } @@ -218,38 +223,19 @@ func optionalNumberParamWithDefault(r mcp.CallToolRequest, p string, d int) (int return v, nil } -// optionalCommaSeparatedListParam checks if an optional parameter is present in the request and is of type string. -// If the parameter is presents, it uses parseCommaSeparatedList to parse the string into a list of strings. -// If the parameter is not present or is empty, it returns an empty list. +// optionalCommaSeparatedListParam is a helper function that can be used to fetch a requested parameter from the request. +// It does the following: +// 1. Checks if the parameter is present in the request, if not, it returns an empty list +// 2. If it is present, it checks if the parameter is of the expected type and uses parseCommaSeparatedList to parse it +// and return the list of strings func optionalCommaSeparatedListParam(r mcp.CallToolRequest, p string) ([]string, error) { - // Check if the parameter is present in the request - if _, ok := r.Params.Arguments[p]; !ok { - return []string{}, nil //default to empty list, not nil - } - - // Check if the parameter is of the expected type - if _, ok := r.Params.Arguments[p].(string); !ok { - return nil, fmt.Errorf("parameter %s is not of type string", p) + v, err := optionalParam[string](r, p) + if err != nil { + return []string{}, err } - - l := parseCommaSeparatedList(r.Params.Arguments[p].(string)) + l := parseCommaSeparatedList(v) if len(l) == 0 { - return []string{}, nil // default to empty list, not nil + return []string{}, nil } return l, nil } - -// optionalBooleanParam checks if an optional parameter is present in the request and is of type boolean. -func optionalBooleanParam(r mcp.CallToolRequest, p string) (bool, error) { - // Check if the parameter is present in the request - if _, ok := r.Params.Arguments[p]; !ok { - return false, nil - } - - // Check if the parameter is of the expected type - if _, ok := r.Params.Arguments[p].(bool); !ok { - return false, fmt.Errorf("parameter %s is not of type bool", p) - } - - return r.Params.Arguments[p].(bool), nil -} diff --git a/pkg/github/server_test.go b/pkg/github/server_test.go index a081d31de..5e7ac9d4d 100644 --- a/pkg/github/server_test.go +++ b/pkg/github/server_test.go @@ -270,7 +270,7 @@ func Test_RequiredStringParam(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { request := createMCPRequest(tc.params) - result, err := requiredStringParam(request, tc.paramName) + result, err := requiredParam[string](request, tc.paramName) if tc.expectError { assert.Error(t, err) @@ -323,7 +323,7 @@ func Test_OptionalStringParam(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { request := createMCPRequest(tc.params) - result, err := optionalStringParam(request, tc.paramName) + result, err := optionalParam[string](request, tc.paramName) if tc.expectError { assert.Error(t, err) @@ -369,7 +369,7 @@ func Test_RequiredNumberParam(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { request := createMCPRequest(tc.params) - result, err := requiredNumberParam(request, tc.paramName) + result, err := requiredInt(request, tc.paramName) if tc.expectError { assert.Error(t, err) @@ -422,7 +422,7 @@ func Test_OptionalNumberParam(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { request := createMCPRequest(tc.params) - result, err := optionalNumberParam(request, tc.paramName) + result, err := optionalIntParam(request, tc.paramName) if tc.expectError { assert.Error(t, err) @@ -480,7 +480,7 @@ func Test_OptionalNumberParamWithDefault(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { request := createMCPRequest(tc.params) - result, err := optionalNumberParamWithDefault(request, tc.paramName, tc.defaultVal) + result, err := optionalIntParamWithDefault(request, tc.paramName, tc.defaultVal) if tc.expectError { assert.Error(t, err) @@ -586,7 +586,7 @@ func Test_OptionalBooleanParam(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { request := createMCPRequest(tc.params) - result, err := optionalBooleanParam(request, tc.paramName) + result, err := optionalParam[bool](request, tc.paramName) if tc.expectError { assert.Error(t, err) pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy