Content-Length: 96179 | pFad | http://github.com/github/github-mcp-server/pull/35.patch
thub.com
From 4d7388ae1938f03348ae742bd008872e2a9a2a8c Mon Sep 17 00:00:00 2001
From: Javier Uruen Val
Date: Mon, 24 Mar 2025 07:39:01 +0100
Subject: [PATCH 1/2] validate tools params
---
pkg/github/code_scanning.go | 40 +++-
pkg/github/issues.go | 197 +++++++++++++------
pkg/github/issues_test.go | 26 ++-
pkg/github/pullrequests.go | 215 ++++++++++++++-------
pkg/github/repositories.go | 174 +++++++++++------
pkg/github/search.go | 81 ++++----
pkg/github/server.go | 115 +++++++++++
pkg/github/server_test.go | 369 ++++++++++++++++++++++++++++++++++++
8 files changed, 989 insertions(+), 228 deletions(-)
diff --git a/pkg/github/code_scanning.go b/pkg/github/code_scanning.go
index 6fc0936a4..e7c8a4e22 100644
--- a/pkg/github/code_scanning.go
+++ b/pkg/github/code_scanning.go
@@ -30,9 +30,18 @@ func getCodeScanningAlert(client *github.Client, t translations.TranslationHelpe
),
),
func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
- owner, _ := request.Params.Arguments["owner"].(string)
- repo, _ := request.Params.Arguments["repo"].(string)
- alertNumber, _ := request.Params.Arguments["alert_number"].(float64)
+ owner, err := requiredStringParam(request, "owner")
+ if err != nil {
+ return mcp.NewToolResultError(err.Error()), nil
+ }
+ repo, err := requiredStringParam(request, "repo")
+ if err != nil {
+ return mcp.NewToolResultError(err.Error()), nil
+ }
+ alertNumber, err := requiredNumberParam(request, "alert_number")
+ if err != nil {
+ return mcp.NewToolResultError(err.Error()), nil
+ }
alert, resp, err := client.CodeScanning.GetAlert(ctx, owner, repo, int64(alertNumber))
if err != nil {
@@ -80,11 +89,26 @@ func listCodeScanningAlerts(client *github.Client, t translations.TranslationHel
),
),
func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
- owner, _ := request.Params.Arguments["owner"].(string)
- repo, _ := request.Params.Arguments["repo"].(string)
- ref, _ := request.Params.Arguments["ref"].(string)
- state, _ := request.Params.Arguments["state"].(string)
- severity, _ := request.Params.Arguments["severity"].(string)
+ owner, err := requiredStringParam(request, "owner")
+ if err != nil {
+ return mcp.NewToolResultError(err.Error()), nil
+ }
+ repo, err := requiredStringParam(request, "repo")
+ if err != nil {
+ return mcp.NewToolResultError(err.Error()), nil
+ }
+ ref, err := optionalStringParam(request, "ref")
+ if err != nil {
+ return mcp.NewToolResultError(err.Error()), nil
+ }
+ state, err := optionalStringParam(request, "state")
+ if err != nil {
+ return mcp.NewToolResultError(err.Error()), nil
+ }
+ severity, err := optionalStringParam(request, "severity")
+ if err != nil {
+ return mcp.NewToolResultError(err.Error()), nil
+ }
alerts, resp, err := client.CodeScanning.ListAlertsForRepo(ctx, owner, repo, &github.AlertListOptions{Ref: ref, State: state, Severity: severity})
if err != nil {
diff --git a/pkg/github/issues.go b/pkg/github/issues.go
index 36130b985..521acfdaf 100644
--- a/pkg/github/issues.go
+++ b/pkg/github/issues.go
@@ -32,9 +32,18 @@ func getIssue(client *github.Client, t translations.TranslationHelperFunc) (tool
),
),
func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
- owner := request.Params.Arguments["owner"].(string)
- repo := request.Params.Arguments["repo"].(string)
- issueNumber := int(request.Params.Arguments["issue_number"].(float64))
+ owner, err := requiredStringParam(request, "owner")
+ if err != nil {
+ return mcp.NewToolResultError(err.Error()), nil
+ }
+ repo, err := requiredStringParam(request, "repo")
+ if err != nil {
+ return mcp.NewToolResultError(err.Error()), nil
+ }
+ issueNumber, err := requiredNumberParam(request, "issue_number")
+ if err != nil {
+ return mcp.NewToolResultError(err.Error()), nil
+ }
issue, resp, err := client.Issues.Get(ctx, owner, repo, issueNumber)
if err != nil {
@@ -81,10 +90,22 @@ func addIssueComment(client *github.Client, t translations.TranslationHelperFunc
),
),
func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
- owner := request.Params.Arguments["owner"].(string)
- repo := request.Params.Arguments["repo"].(string)
- issueNumber := int(request.Params.Arguments["issue_number"].(float64))
- body := request.Params.Arguments["body"].(string)
+ owner, err := requiredStringParam(request, "owner")
+ if err != nil {
+ return mcp.NewToolResultError(err.Error()), nil
+ }
+ repo, err := requiredStringParam(request, "repo")
+ if err != nil {
+ return mcp.NewToolResultError(err.Error()), nil
+ }
+ issueNumber, err := requiredNumberParam(request, "issue_number")
+ if err != nil {
+ return mcp.NewToolResultError(err.Error()), nil
+ }
+ body, err := requiredStringParam(request, "body")
+ if err != nil {
+ return mcp.NewToolResultError(err.Error()), nil
+ }
comment := &github.IssueComment{
Body: github.Ptr(body),
@@ -135,22 +156,25 @@ func searchIssues(client *github.Client, t translations.TranslationHelperFunc) (
),
),
func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
- query := request.Params.Arguments["q"].(string)
- sort := ""
- if s, ok := request.Params.Arguments["sort"].(string); ok {
- sort = s
+ query, err := requiredStringParam(request, "q")
+ if err != nil {
+ return mcp.NewToolResultError(err.Error()), nil
}
- order := ""
- if o, ok := request.Params.Arguments["order"].(string); ok {
- order = o
+ sort, err := optionalStringParam(request, "sort")
+ if err != nil {
+ return mcp.NewToolResultError(err.Error()), nil
}
- perPage := 30
- if pp, ok := request.Params.Arguments["per_page"].(float64); ok {
- perPage = int(pp)
+ order, err := optionalStringParam(request, "order")
+ if err != nil {
+ return mcp.NewToolResultError(err.Error()), nil
}
- page := 1
- if p, ok := request.Params.Arguments["page"].(float64); ok {
- page = int(p)
+ perPage, err := optionalNumberParamWithDefault(request, "per_page", 30)
+ if err != nil {
+ return mcp.NewToolResultError(err.Error()), nil
+ }
+ page, err := optionalNumberParamWithDefault(request, "page", 1)
+ if err != nil {
+ return mcp.NewToolResultError(err.Error()), nil
}
opts := &github.SearchOptions{
@@ -212,26 +236,34 @@ func createIssue(client *github.Client, t translations.TranslationHelperFunc) (t
),
),
func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
- owner := request.Params.Arguments["owner"].(string)
- repo := request.Params.Arguments["repo"].(string)
- title := request.Params.Arguments["title"].(string)
+ owner, err := requiredStringParam(request, "owner")
+ if err != nil {
+ return mcp.NewToolResultError(err.Error()), nil
+ }
+ repo, err := requiredStringParam(request, "repo")
+ if err != nil {
+ return mcp.NewToolResultError(err.Error()), nil
+ }
+ title, err := requiredStringParam(request, "title")
+ if err != nil {
+ return mcp.NewToolResultError(err.Error()), nil
+ }
// Optional parameters
- var body string
- if b, ok := request.Params.Arguments["body"].(string); ok {
- body = b
+ body, err := optionalStringParam(request, "body")
+ if err != nil {
+ return mcp.NewToolResultError(err.Error()), nil
}
- // Parse assignees if present
- assignees := []string{} // default to empty slice, can't be nil
- if a, ok := request.Params.Arguments["assignees"].(string); ok && a != "" {
- assignees = parseCommaSeparatedList(a)
+ // Get assignees
+ assignees, err := optionalCommaSeparatedListParam(request, "assignees")
+ if err != nil {
+ return mcp.NewToolResultError(err.Error()), nil
}
-
- // Parse labels if present
- labels := []string{} // default to empty slice, can't be nil
- if l, ok := request.Params.Arguments["labels"].(string); ok && l != "" {
- labels = parseCommaSeparatedList(l)
+ // Get labels
+ labels, err := optionalCommaSeparatedListParam(request, "labels")
+ if err != nil {
+ return mcp.NewToolResultError(err.Error()), nil
}
// Create the issue request
@@ -300,29 +332,43 @@ func listIssues(client *github.Client, t translations.TranslationHelperFunc) (to
),
),
func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
- owner := request.Params.Arguments["owner"].(string)
- repo := request.Params.Arguments["repo"].(string)
+ owner, err := requiredStringParam(request, "owner")
+ if err != nil {
+ return mcp.NewToolResultError(err.Error()), nil
+ }
+ repo, err := requiredStringParam(request, "repo")
+ if err != nil {
+ return mcp.NewToolResultError(err.Error()), nil
+ }
opts := &github.IssueListByRepoOptions{}
// Set optional parameters if provided
- if state, ok := request.Params.Arguments["state"].(string); ok && state != "" {
- opts.State = state
+ opts.State, err = optionalStringParam(request, "state")
+ if err != nil {
+ return mcp.NewToolResultError(err.Error()), nil
}
- if labels, ok := request.Params.Arguments["labels"].(string); ok && labels != "" {
- opts.Labels = parseCommaSeparatedList(labels)
+ opts.Labels, err = optionalCommaSeparatedListParam(request, "labels")
+ if err != nil {
+ return mcp.NewToolResultError(err.Error()), nil
}
- if sort, ok := request.Params.Arguments["sort"].(string); ok && sort != "" {
- opts.Sort = sort
+ opts.Sort, err = optionalStringParam(request, "sort")
+ if err != nil {
+ return mcp.NewToolResultError(err.Error()), nil
}
- if direction, ok := request.Params.Arguments["direction"].(string); ok && direction != "" {
- opts.Direction = direction
+ opts.Direction, err = optionalStringParam(request, "direction")
+ if err != nil {
+ return mcp.NewToolResultError(err.Error()), nil
}
- if since, ok := request.Params.Arguments["since"].(string); ok && since != "" {
+ since, err := optionalStringParam(request, "since")
+ if err != nil {
+ return mcp.NewToolResultError(err.Error()), nil
+ }
+ if since != "" {
timestamp, err := parseISOTimestamp(since)
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("failed to list issues: %s", err.Error())), nil
@@ -397,38 +443,69 @@ func updateIssue(client *github.Client, t translations.TranslationHelperFunc) (t
),
),
func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
- owner := request.Params.Arguments["owner"].(string)
- repo := request.Params.Arguments["repo"].(string)
- issueNumber := int(request.Params.Arguments["issue_number"].(float64))
+ owner, err := requiredStringParam(request, "owner")
+ if err != nil {
+ return mcp.NewToolResultError(err.Error()), nil
+ }
+ repo, err := requiredStringParam(request, "repo")
+ if err != nil {
+ return mcp.NewToolResultError(err.Error()), nil
+ }
+ issueNumber, err := requiredNumberParam(request, "issue_number")
+ if err != nil {
+ return mcp.NewToolResultError(err.Error()), nil
+ }
// Create the issue request with only provided fields
issueRequest := &github.IssueRequest{}
// Set optional parameters if provided
- if title, ok := request.Params.Arguments["title"].(string); ok && title != "" {
+ title, err := optionalStringParam(request, "title")
+ if err != nil {
+ return mcp.NewToolResultError(err.Error()), nil
+ }
+ if title != "" {
issueRequest.Title = github.Ptr(title)
}
- if body, ok := request.Params.Arguments["body"].(string); ok && body != "" {
+ body, err := optionalStringParam(request, "body")
+ if err != nil {
+ return mcp.NewToolResultError(err.Error()), nil
+ }
+ if body != "" {
issueRequest.Body = github.Ptr(body)
}
- if state, ok := request.Params.Arguments["state"].(string); ok && state != "" {
+ state, err := optionalStringParam(request, "state")
+ if err != nil {
+ return mcp.NewToolResultError(err.Error()), nil
+ }
+ if state != "" {
issueRequest.State = github.Ptr(state)
}
- if labels, ok := request.Params.Arguments["labels"].(string); ok && labels != "" {
- labelsList := parseCommaSeparatedList(labels)
- issueRequest.Labels = &labelsList
+ labels, err := optionalCommaSeparatedListParam(request, "labels")
+ if err != nil {
+ return mcp.NewToolResultError(err.Error()), nil
+ }
+ if len(labels) > 0 {
+ issueRequest.Labels = &labels
}
- if assignees, ok := request.Params.Arguments["assignees"].(string); ok && assignees != "" {
- assigneesList := parseCommaSeparatedList(assignees)
- issueRequest.Assignees = &assigneesList
+ assignees, err := optionalCommaSeparatedListParam(request, "assignees")
+ if err != nil {
+ return mcp.NewToolResultError(err.Error()), nil
+ }
+ if len(assignees) > 0 {
+ issueRequest.Assignees = &assignees
}
- if milestone, ok := request.Params.Arguments["milestone"].(float64); ok {
- milestoneNum := int(milestone)
+ milestone, err := optionalNumberParam(request, "milestone")
+ if err != nil {
+ return mcp.NewToolResultError(err.Error()), nil
+ }
+ if milestone != 0 {
+ milestoneNum := milestone
issueRequest.Milestone = &milestoneNum
}
diff --git a/pkg/github/issues_test.go b/pkg/github/issues_test.go
index 4e8250fd2..c2de65797 100644
--- a/pkg/github/issues_test.go
+++ b/pkg/github/issues_test.go
@@ -176,8 +176,8 @@ func Test_AddIssueComment(t *testing.T) {
"issue_number": float64(42),
"body": "",
},
- expectError: true,
- expectedErrMsg: "failed to create comment",
+ expectError: false,
+ expectedErrMsg: "missing required parameter: body",
},
}
@@ -210,6 +210,13 @@ func Test_AddIssueComment(t *testing.T) {
return
}
+ if tc.expectedErrMsg != "" {
+ require.NotNil(t, result)
+ textContent := getTextResult(t, result)
+ assert.Contains(t, textContent.Text, tc.expectedErrMsg)
+ return
+ }
+
require.NoError(t, err)
// Parse the result and get the text content if no error
@@ -419,8 +426,8 @@ func Test_CreateIssue(t *testing.T) {
"repo": "repo",
"title": "Test Issue",
"body": "This is a test issue",
- "assignees": []interface{}{"user1", "user2"},
- "labels": []interface{}{"bug", "help wanted"},
+ "assignees": "user1, user2",
+ "labels": "bug, help wanted",
},
expectError: false,
expectedIssue: mockIssue,
@@ -467,8 +474,8 @@ func Test_CreateIssue(t *testing.T) {
"repo": "repo",
"title": "",
},
- expectError: true,
- expectedErrMsg: "failed to create issue",
+ expectError: false,
+ expectedErrMsg: "missing required parameter: title",
},
}
@@ -491,6 +498,13 @@ func Test_CreateIssue(t *testing.T) {
return
}
+ if tc.expectedErrMsg != "" {
+ require.NotNil(t, result)
+ textContent := getTextResult(t, result)
+ assert.Contains(t, textContent.Text, tc.expectedErrMsg)
+ return
+ }
+
require.NoError(t, err)
textContent := getTextResult(t, result)
diff --git a/pkg/github/pullrequests.go b/pkg/github/pullrequests.go
index e0414394a..d5caab97c 100644
--- a/pkg/github/pullrequests.go
+++ b/pkg/github/pullrequests.go
@@ -31,9 +31,18 @@ func getPullRequest(client *github.Client, t translations.TranslationHelperFunc)
),
),
func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
- owner := request.Params.Arguments["owner"].(string)
- repo := request.Params.Arguments["repo"].(string)
- pullNumber := int(request.Params.Arguments["pull_number"].(float64))
+ owner, err := requiredStringParam(request, "owner")
+ if err != nil {
+ return mcp.NewToolResultError(err.Error()), nil
+ }
+ repo, err := requiredStringParam(request, "repo")
+ if err != nil {
+ return mcp.NewToolResultError(err.Error()), nil
+ }
+ pullNumber, err := requiredNumberParam(request, "pull_number")
+ if err != nil {
+ return mcp.NewToolResultError(err.Error()), nil
+ }
pr, resp, err := client.PullRequests.Get(ctx, owner, repo, pullNumber)
if err != nil {
@@ -93,35 +102,41 @@ func listPullRequests(client *github.Client, t translations.TranslationHelperFun
),
),
func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
- owner := request.Params.Arguments["owner"].(string)
- repo := request.Params.Arguments["repo"].(string)
- state := ""
- if s, ok := request.Params.Arguments["state"].(string); ok {
- state = s
+ owner, err := requiredStringParam(request, "owner")
+ if err != nil {
+ return mcp.NewToolResultError(err.Error()), nil
+ }
+ repo, err := requiredStringParam(request, "repo")
+ if err != nil {
+ return mcp.NewToolResultError(err.Error()), nil
}
- head := ""
- if h, ok := request.Params.Arguments["head"].(string); ok {
- head = h
+ state, err := optionalStringParam(request, "state")
+ if err != nil {
+ return mcp.NewToolResultError(err.Error()), nil
}
- base := ""
- if b, ok := request.Params.Arguments["base"].(string); ok {
- base = b
+ head, err := optionalStringParam(request, "head")
+ if err != nil {
+ return mcp.NewToolResultError(err.Error()), nil
}
- sort := ""
- if s, ok := request.Params.Arguments["sort"].(string); ok {
- sort = s
+ base, err := optionalStringParam(request, "base")
+ if err != nil {
+ return mcp.NewToolResultError(err.Error()), nil
}
- direction := ""
- if d, ok := request.Params.Arguments["direction"].(string); ok {
- direction = d
+ sort, err := optionalStringParam(request, "sort")
+ if err != nil {
+ return mcp.NewToolResultError(err.Error()), nil
}
- perPage := 30
- if pp, ok := request.Params.Arguments["per_page"].(float64); ok {
- perPage = int(pp)
+ direction, err := optionalStringParam(request, "direction")
+ if err != nil {
+ return mcp.NewToolResultError(err.Error()), nil
}
- page := 1
- if p, ok := request.Params.Arguments["page"].(float64); ok {
- page = int(p)
+ perPage, err := optionalNumberParamWithDefault(request, "per_page", 30)
+ if err != nil {
+ return mcp.NewToolResultError(err.Error()), nil
+ }
+ page, err := optionalNumberParamWithDefault(request, "page", 1)
+ if err != nil {
+ return mcp.NewToolResultError(err.Error()), nil
}
opts := &github.PullRequestListOptions{
@@ -186,20 +201,29 @@ func mergePullRequest(client *github.Client, t translations.TranslationHelperFun
),
),
func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
- owner := request.Params.Arguments["owner"].(string)
- repo := request.Params.Arguments["repo"].(string)
- pullNumber := int(request.Params.Arguments["pull_number"].(float64))
- commitTitle := ""
- if ct, ok := request.Params.Arguments["commit_title"].(string); ok {
- commitTitle = ct
+ owner, err := requiredStringParam(request, "owner")
+ if err != nil {
+ return mcp.NewToolResultError(err.Error()), nil
}
- commitMessage := ""
- if cm, ok := request.Params.Arguments["commit_message"].(string); ok {
- commitMessage = cm
+ repo, err := requiredStringParam(request, "repo")
+ if err != nil {
+ return mcp.NewToolResultError(err.Error()), nil
}
- mergeMethod := ""
- if mm, ok := request.Params.Arguments["merge_method"].(string); ok {
- mergeMethod = mm
+ pullNumber, err := requiredNumberParam(request, "pull_number")
+ if err != nil {
+ return mcp.NewToolResultError(err.Error()), nil
+ }
+ commitTitle, err := optionalStringParam(request, "commit_title")
+ if err != nil {
+ return mcp.NewToolResultError(err.Error()), nil
+ }
+ commitMessage, err := optionalStringParam(request, "commit_message")
+ if err != nil {
+ return mcp.NewToolResultError(err.Error()), nil
+ }
+ mergeMethod, err := optionalStringParam(request, "merge_method")
+ if err != nil {
+ return mcp.NewToolResultError(err.Error()), nil
}
options := &github.PullRequestOptions{
@@ -248,9 +272,18 @@ func getPullRequestFiles(client *github.Client, t translations.TranslationHelper
),
),
func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
- owner := request.Params.Arguments["owner"].(string)
- repo := request.Params.Arguments["repo"].(string)
- pullNumber := int(request.Params.Arguments["pull_number"].(float64))
+ owner, err := requiredStringParam(request, "owner")
+ if err != nil {
+ return mcp.NewToolResultError(err.Error()), nil
+ }
+ repo, err := requiredStringParam(request, "repo")
+ if err != nil {
+ return mcp.NewToolResultError(err.Error()), nil
+ }
+ pullNumber, err := requiredNumberParam(request, "pull_number")
+ if err != nil {
+ return mcp.NewToolResultError(err.Error()), nil
+ }
opts := &github.ListOptions{}
files, resp, err := client.PullRequests.ListFiles(ctx, owner, repo, pullNumber, opts)
@@ -294,10 +327,18 @@ func getPullRequestStatus(client *github.Client, t translations.TranslationHelpe
),
),
func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
- owner := request.Params.Arguments["owner"].(string)
- repo := request.Params.Arguments["repo"].(string)
- pullNumber := int(request.Params.Arguments["pull_number"].(float64))
-
+ owner, err := requiredStringParam(request, "owner")
+ if err != nil {
+ return mcp.NewToolResultError(err.Error()), nil
+ }
+ repo, err := requiredStringParam(request, "repo")
+ if err != nil {
+ return mcp.NewToolResultError(err.Error()), nil
+ }
+ pullNumber, err := requiredNumberParam(request, "pull_number")
+ if err != nil {
+ return mcp.NewToolResultError(err.Error()), nil
+ }
// First get the PR to find the head SHA
pr, resp, err := client.PullRequests.Get(ctx, owner, repo, pullNumber)
if err != nil {
@@ -358,14 +399,22 @@ func updatePullRequestBranch(client *github.Client, t translations.TranslationHe
),
),
func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
- owner := request.Params.Arguments["owner"].(string)
- repo := request.Params.Arguments["repo"].(string)
- pullNumber := int(request.Params.Arguments["pull_number"].(float64))
- expectedHeadSHA := ""
- if sha, ok := request.Params.Arguments["expected_head_sha"].(string); ok {
- expectedHeadSHA = sha
+ owner, err := requiredStringParam(request, "owner")
+ if err != nil {
+ return mcp.NewToolResultError(err.Error()), nil
+ }
+ repo, err := requiredStringParam(request, "repo")
+ if err != nil {
+ return mcp.NewToolResultError(err.Error()), nil
+ }
+ pullNumber, err := requiredNumberParam(request, "pull_number")
+ if err != nil {
+ return mcp.NewToolResultError(err.Error()), nil
+ }
+ expectedHeadSHA, err := optionalStringParam(request, "expected_head_sha")
+ if err != nil {
+ return mcp.NewToolResultError(err.Error()), nil
}
-
opts := &github.PullRequestBranchUpdateOptions{}
if expectedHeadSHA != "" {
opts.ExpectedHeadSHA = github.Ptr(expectedHeadSHA)
@@ -417,9 +466,18 @@ func getPullRequestComments(client *github.Client, t translations.TranslationHel
),
),
func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
- owner := request.Params.Arguments["owner"].(string)
- repo := request.Params.Arguments["repo"].(string)
- pullNumber := int(request.Params.Arguments["pull_number"].(float64))
+ owner, err := requiredStringParam(request, "owner")
+ if err != nil {
+ return mcp.NewToolResultError(err.Error()), nil
+ }
+ repo, err := requiredStringParam(request, "repo")
+ if err != nil {
+ return mcp.NewToolResultError(err.Error()), nil
+ }
+ pullNumber, err := requiredNumberParam(request, "pull_number")
+ if err != nil {
+ return mcp.NewToolResultError(err.Error()), nil
+ }
opts := &github.PullRequestListCommentsOptions{
ListOptions: github.ListOptions{
@@ -468,9 +526,18 @@ func getPullRequestReviews(client *github.Client, t translations.TranslationHelp
),
),
func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
- owner := request.Params.Arguments["owner"].(string)
- repo := request.Params.Arguments["repo"].(string)
- pullNumber := int(request.Params.Arguments["pull_number"].(float64))
+ owner, err := requiredStringParam(request, "owner")
+ if err != nil {
+ return mcp.NewToolResultError(err.Error()), nil
+ }
+ repo, err := requiredStringParam(request, "repo")
+ if err != nil {
+ return mcp.NewToolResultError(err.Error()), nil
+ }
+ pullNumber, err := requiredNumberParam(request, "pull_number")
+ if err != nil {
+ return mcp.NewToolResultError(err.Error()), nil
+ }
reviews, resp, err := client.PullRequests.ListReviews(ctx, owner, repo, pullNumber, nil)
if err != nil {
@@ -526,10 +593,22 @@ func createPullRequestReview(client *github.Client, t translations.TranslationHe
),
),
func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
- owner := request.Params.Arguments["owner"].(string)
- repo := request.Params.Arguments["repo"].(string)
- pullNumber := int(request.Params.Arguments["pull_number"].(float64))
- event := request.Params.Arguments["event"].(string)
+ owner, err := requiredStringParam(request, "owner")
+ if err != nil {
+ return mcp.NewToolResultError(err.Error()), nil
+ }
+ repo, err := requiredStringParam(request, "repo")
+ if err != nil {
+ return mcp.NewToolResultError(err.Error()), nil
+ }
+ pullNumber, err := requiredNumberParam(request, "pull_number")
+ if err != nil {
+ return mcp.NewToolResultError(err.Error()), nil
+ }
+ event, err := requiredStringParam(request, "event")
+ if err != nil {
+ return mcp.NewToolResultError(err.Error()), nil
+ }
// Create review request
reviewRequest := &github.PullRequestReviewRequest{
@@ -537,12 +616,20 @@ func createPullRequestReview(client *github.Client, t translations.TranslationHe
}
// Add body if provided
- if body, ok := request.Params.Arguments["body"].(string); ok && body != "" {
+ body, err := optionalStringParam(request, "body")
+ if err != nil {
+ return mcp.NewToolResultError(err.Error()), nil
+ }
+ if body != "" {
reviewRequest.Body = github.Ptr(body)
}
// Add commit ID if provided
- if commitID, ok := request.Params.Arguments["commit_id"].(string); ok && commitID != "" {
+ commitID, err := optionalStringParam(request, "commit_id")
+ if err != nil {
+ return mcp.NewToolResultError(err.Error()), nil
+ }
+ if commitID != "" {
reviewRequest.CommitID = github.Ptr(commitID)
}
diff --git a/pkg/github/repositories.go b/pkg/github/repositories.go
index 6e3b176df..f222b1f80 100644
--- a/pkg/github/repositories.go
+++ b/pkg/github/repositories.go
@@ -37,19 +37,25 @@ func listCommits(client *github.Client, t translations.TranslationHelperFunc) (t
),
),
func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
- owner := request.Params.Arguments["owner"].(string)
- repo := request.Params.Arguments["repo"].(string)
- sha := ""
- if s, ok := request.Params.Arguments["sha"].(string); ok {
- sha = s
+ owner, err := requiredStringParam(request, "owner")
+ if err != nil {
+ return mcp.NewToolResultError(err.Error()), nil
+ }
+ repo, err := requiredStringParam(request, "repo")
+ if err != nil {
+ return mcp.NewToolResultError(err.Error()), nil
}
- page := 1
- if p, ok := request.Params.Arguments["page"].(float64); ok {
- page = int(p)
+ sha, err := optionalStringParam(request, "sha")
+ if err != nil {
+ return mcp.NewToolResultError(err.Error()), nil
}
- perPage := 30
- if pp, ok := request.Params.Arguments["per_page"].(float64); ok {
- perPage = int(pp)
+ page, err := optionalNumberParamWithDefault(request, "page", 1)
+ if err != nil {
+ return mcp.NewToolResultError(err.Error()), nil
+ }
+ perPage, err := optionalNumberParamWithDefault(request, "per_page", 30)
+ if err != nil {
+ return mcp.NewToolResultError(err.Error()), nil
}
opts := &github.CommitsListOptions{
@@ -116,12 +122,30 @@ func createOrUpdateFile(client *github.Client, t translations.TranslationHelperF
),
),
func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
- owner := request.Params.Arguments["owner"].(string)
- repo := request.Params.Arguments["repo"].(string)
- path := request.Params.Arguments["path"].(string)
- content := request.Params.Arguments["content"].(string)
- message := request.Params.Arguments["message"].(string)
- branch := request.Params.Arguments["branch"].(string)
+ owner, err := requiredStringParam(request, "owner")
+ if err != nil {
+ return mcp.NewToolResultError(err.Error()), nil
+ }
+ repo, err := requiredStringParam(request, "repo")
+ if err != nil {
+ return mcp.NewToolResultError(err.Error()), nil
+ }
+ path, err := requiredStringParam(request, "path")
+ if err != nil {
+ return mcp.NewToolResultError(err.Error()), nil
+ }
+ content, err := requiredStringParam(request, "content")
+ if err != nil {
+ return mcp.NewToolResultError(err.Error()), nil
+ }
+ message, err := requiredStringParam(request, "message")
+ if err != nil {
+ return mcp.NewToolResultError(err.Error()), nil
+ }
+ branch, err := requiredStringParam(request, "branch")
+ if err != nil {
+ return mcp.NewToolResultError(err.Error()), nil
+ }
// Convert content to base64
contentBytes := []byte(content)
@@ -134,7 +158,11 @@ func createOrUpdateFile(client *github.Client, t translations.TranslationHelperF
}
// If SHA is provided, set it (for updates)
- if sha, ok := request.Params.Arguments["sha"].(string); ok && sha != "" {
+ sha, err := optionalStringParam(request, "sha")
+ if err != nil {
+ return mcp.NewToolResultError(err.Error()), nil
+ }
+ if sha != "" {
opts.SHA = ptr.String(sha)
}
@@ -181,25 +209,28 @@ func createRepository(client *github.Client, t translations.TranslationHelperFun
),
),
func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
- name := request.Params.Arguments["name"].(string)
- description := ""
- if desc, ok := request.Params.Arguments["description"].(string); ok {
- description = desc
+ name, err := requiredStringParam(request, "name")
+ if err != nil {
+ return mcp.NewToolResultError(err.Error()), nil
+ }
+ description, err := optionalStringParam(request, "description")
+ if err != nil {
+ return mcp.NewToolResultError(err.Error()), nil
}
- private := false
- if priv, ok := request.Params.Arguments["private"].(bool); ok {
- private = priv
+ private, err := optionalBooleanParam(request, "private")
+ if err != nil {
+ return mcp.NewToolResultError(err.Error()), nil
}
- autoInit := false
- if init, ok := request.Params.Arguments["auto_init"].(bool); ok {
- autoInit = init
+ autoInit, err := optionalBooleanParam(request, "auto_init")
+ if err != nil {
+ return mcp.NewToolResultError(err.Error()), nil
}
repo := &github.Repository{
- Name: github.String(name),
- Description: github.String(description),
- Private: github.Bool(private),
- AutoInit: github.Bool(autoInit),
+ Name: github.Ptr(name),
+ Description: github.Ptr(description),
+ Private: github.Ptr(private),
+ AutoInit: github.Ptr(autoInit),
}
createdRepo, resp, err := client.Repositories.Create(ctx, "", repo)
@@ -246,12 +277,21 @@ func getFileContents(client *github.Client, t translations.TranslationHelperFunc
),
),
func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
- owner := request.Params.Arguments["owner"].(string)
- repo := request.Params.Arguments["repo"].(string)
- path := request.Params.Arguments["path"].(string)
- branch := ""
- if b, ok := request.Params.Arguments["branch"].(string); ok {
- branch = b
+ owner, err := requiredStringParam(request, "owner")
+ if err != nil {
+ return mcp.NewToolResultError(err.Error()), nil
+ }
+ repo, err := requiredStringParam(request, "repo")
+ if err != nil {
+ return mcp.NewToolResultError(err.Error()), nil
+ }
+ path, err := requiredStringParam(request, "path")
+ if err != nil {
+ return mcp.NewToolResultError(err.Error()), nil
+ }
+ branch, err := optionalStringParam(request, "branch")
+ if err != nil {
+ return mcp.NewToolResultError(err.Error()), nil
}
opts := &github.RepositoryContentGetOptions{Ref: branch}
@@ -302,11 +342,17 @@ func forkRepository(client *github.Client, t translations.TranslationHelperFunc)
),
),
func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
- owner := request.Params.Arguments["owner"].(string)
- repo := request.Params.Arguments["repo"].(string)
- org := ""
- if o, ok := request.Params.Arguments["organization"].(string); ok {
- org = o
+ owner, err := requiredStringParam(request, "owner")
+ if err != nil {
+ return mcp.NewToolResultError(err.Error()), nil
+ }
+ repo, err := requiredStringParam(request, "repo")
+ if err != nil {
+ return mcp.NewToolResultError(err.Error()), nil
+ }
+ org, err := optionalStringParam(request, "organization")
+ if err != nil {
+ return mcp.NewToolResultError(err.Error()), nil
}
opts := &github.RepositoryCreateForkOptions{}
@@ -363,17 +409,25 @@ func createBranch(client *github.Client, t translations.TranslationHelperFunc) (
),
),
func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
- owner := request.Params.Arguments["owner"].(string)
- repo := request.Params.Arguments["repo"].(string)
- branch := request.Params.Arguments["branch"].(string)
- fromBranch := ""
- if fb, ok := request.Params.Arguments["from_branch"].(string); ok {
- fromBranch = fb
+ owner, err := requiredStringParam(request, "owner")
+ if err != nil {
+ return mcp.NewToolResultError(err.Error()), nil
+ }
+ repo, err := requiredStringParam(request, "repo")
+ if err != nil {
+ return mcp.NewToolResultError(err.Error()), nil
+ }
+ branch, err := requiredStringParam(request, "branch")
+ if err != nil {
+ return mcp.NewToolResultError(err.Error()), nil
+ }
+ fromBranch, err := optionalStringParam(request, "from_branch")
+ if err != nil {
+ return mcp.NewToolResultError(err.Error()), nil
}
// Get the source branch SHA
var ref *github.Reference
- var err error
if fromBranch == "" {
// Get default branch if from_branch not specified
@@ -440,10 +494,22 @@ func pushFiles(client *github.Client, t translations.TranslationHelperFunc) (too
),
),
func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
- owner := request.Params.Arguments["owner"].(string)
- repo := request.Params.Arguments["repo"].(string)
- branch := request.Params.Arguments["branch"].(string)
- message := request.Params.Arguments["message"].(string)
+ owner, err := requiredStringParam(request, "owner")
+ if err != nil {
+ return mcp.NewToolResultError(err.Error()), nil
+ }
+ repo, err := requiredStringParam(request, "repo")
+ if err != nil {
+ return mcp.NewToolResultError(err.Error()), nil
+ }
+ branch, err := requiredStringParam(request, "branch")
+ if err != nil {
+ return mcp.NewToolResultError(err.Error()), nil
+ }
+ message, err := requiredStringParam(request, "message")
+ if err != nil {
+ return mcp.NewToolResultError(err.Error()), nil
+ }
// Parse files parameter - this should be an array of objects with path and content
filesObj, ok := request.Params.Arguments["files"].([]interface{})
diff --git a/pkg/github/search.go b/pkg/github/search.go
index 353c6fb21..d7ea49049 100644
--- a/pkg/github/search.go
+++ b/pkg/github/search.go
@@ -28,14 +28,17 @@ func searchRepositories(client *github.Client, t translations.TranslationHelperF
),
),
func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
- query := request.Params.Arguments["query"].(string)
- page := 1
- if p, ok := request.Params.Arguments["page"].(float64); ok {
- page = int(p)
+ query, err := requiredStringParam(request, "query")
+ if err != nil {
+ return mcp.NewToolResultError(err.Error()), nil
}
- perPage := 30
- if pp, ok := request.Params.Arguments["per_page"].(float64); ok {
- perPage = int(pp)
+ page, err := optionalNumberParamWithDefault(request, "page", 1)
+ if err != nil {
+ return mcp.NewToolResultError(err.Error()), nil
+ }
+ perPage, err := optionalNumberParamWithDefault(request, "per_page", 30)
+ if err != nil {
+ return mcp.NewToolResultError(err.Error()), nil
}
opts := &github.SearchOptions{
@@ -90,22 +93,25 @@ func searchCode(client *github.Client, t translations.TranslationHelperFunc) (to
),
),
func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
- query := request.Params.Arguments["q"].(string)
- sort := ""
- if s, ok := request.Params.Arguments["sort"].(string); ok {
- sort = s
+ query, err := requiredStringParam(request, "q")
+ if err != nil {
+ return mcp.NewToolResultError(err.Error()), nil
}
- order := ""
- if o, ok := request.Params.Arguments["order"].(string); ok {
- order = o
+ sort, err := optionalStringParam(request, "sort")
+ if err != nil {
+ return mcp.NewToolResultError(err.Error()), nil
}
- perPage := 30
- if pp, ok := request.Params.Arguments["per_page"].(float64); ok {
- perPage = int(pp)
+ order, err := optionalStringParam(request, "order")
+ if err != nil {
+ return mcp.NewToolResultError(err.Error()), nil
}
- page := 1
- if p, ok := request.Params.Arguments["page"].(float64); ok {
- page = int(p)
+ perPage, err := optionalNumberParamWithDefault(request, "per_page", 30)
+ if err != nil {
+ return mcp.NewToolResultError(err.Error()), nil
+ }
+ page, err := optionalNumberParamWithDefault(request, "page", 1)
+ if err != nil {
+ return mcp.NewToolResultError(err.Error()), nil
}
opts := &github.SearchOptions{
@@ -162,22 +168,25 @@ func searchUsers(client *github.Client, t translations.TranslationHelperFunc) (t
),
),
func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
- query := request.Params.Arguments["q"].(string)
- sort := ""
- if s, ok := request.Params.Arguments["sort"].(string); ok {
- sort = s
- }
- order := ""
- if o, ok := request.Params.Arguments["order"].(string); ok {
- order = o
- }
- perPage := 30
- if pp, ok := request.Params.Arguments["per_page"].(float64); ok {
- perPage = int(pp)
- }
- page := 1
- if p, ok := request.Params.Arguments["page"].(float64); ok {
- page = int(p)
+ query, err := requiredStringParam(request, "q")
+ if err != nil {
+ return mcp.NewToolResultError(err.Error()), nil
+ }
+ sort, err := optionalStringParam(request, "sort")
+ if err != nil {
+ return mcp.NewToolResultError(err.Error()), nil
+ }
+ order, err := optionalStringParam(request, "order")
+ if err != nil {
+ return mcp.NewToolResultError(err.Error()), nil
+ }
+ perPage, err := optionalNumberParamWithDefault(request, "per_page", 30)
+ if err != nil {
+ return mcp.NewToolResultError(err.Error()), nil
+ }
+ page, err := optionalNumberParamWithDefault(request, "page", 1)
+ if err != nil {
+ return mcp.NewToolResultError(err.Error()), nil
}
opts := &github.SearchOptions{
diff --git a/pkg/github/server.go b/pkg/github/server.go
index a0993e2f3..42e230833 100644
--- a/pkg/github/server.go
+++ b/pkg/github/server.go
@@ -138,3 +138,118 @@ func parseCommaSeparatedList(input string) []string {
return result
}
+
+// requiredStringParam checks if the parameter is present in the request and is of type string.
+func requiredStringParam(r mcp.CallToolRequest, p string) (string, error) {
+ // Check if the parameter is present in the request
+ if _, ok := r.Params.Arguments[p]; !ok {
+ return "", fmt.Errorf("missing required parameter: %s", p)
+ }
+
+ // Check if the parameter is of the expected type
+ if _, ok := r.Params.Arguments[p].(string); !ok {
+ return "", fmt.Errorf("parameter %s is not of type string", p)
+ }
+
+ // Check if the parameter is not the zero value
+ v := r.Params.Arguments[p].(string)
+ if v == "" {
+ return v, fmt.Errorf("missing required parameter: %s", p)
+ }
+
+ return v, nil
+}
+
+// requiredNumberParam checks if the parameter is present in the request and is of type number.
+func requiredNumberParam(r mcp.CallToolRequest, p string) (int, error) {
+ // Check if the parameter is present in the request
+ if _, ok := r.Params.Arguments[p]; !ok {
+ return 0, fmt.Errorf("missing required parameter: %s", p)
+ }
+
+ // Check if the parameter is of the expected type
+ if _, ok := r.Params.Arguments[p].(float64); !ok {
+ return 0, fmt.Errorf("parameter %s is not of type number", p)
+ }
+
+ return int(r.Params.Arguments[p].(float64)), nil
+}
+
+// optionalStringParam checks if an optional parameter is present in the request and is of type string.
+func optionalStringParam(r mcp.CallToolRequest, p string) (value string, err error) {
+ // Check if the parameter is present in the request
+ if _, ok := r.Params.Arguments[p]; !ok {
+ return "", nil
+ }
+
+ // Check if the parameter is of the expected type
+ if _, ok := r.Params.Arguments[p].(string); !ok {
+ return "", fmt.Errorf("parameter %s is not of type string", p)
+ }
+
+ return r.Params.Arguments[p].(string), nil
+}
+
+// optionalNumberParam checks if an optional parameter is present in the request and is of type number.
+func optionalNumberParam(r mcp.CallToolRequest, p string) (int, error) {
+ // Check if the parameter is present in the request
+ if _, ok := r.Params.Arguments[p]; !ok {
+ return 0, nil
+ }
+
+ // Check if the parameter is of the expected type
+ if _, ok := r.Params.Arguments[p].(float64); !ok {
+ return 0, fmt.Errorf("parameter %s is not of type number", p)
+ }
+
+ return int(r.Params.Arguments[p].(float64)), nil
+}
+
+// optionalNumberParamWithDefault checks if an optional parameter is present in the request and is of type number.
+// If the parameter is not present or is zero, it returns the default value.
+func optionalNumberParamWithDefault(r mcp.CallToolRequest, p string, d int) (int, error) {
+ v, err := optionalNumberParam(r, p)
+ if err != nil {
+ return 0, err
+ }
+ if v == 0 {
+ return d, nil
+ }
+ return v, nil
+}
+
+// optionalCommaSeparatedListParam checks if an optional parameter is present in the request and is of type string.
+// If the parameter is presents, it uses parseCommaSeparatedList to parse the string into a list of strings.
+// If the parameter is not present or is empty, it returns an empty list.
+func optionalCommaSeparatedListParam(r mcp.CallToolRequest, p string) ([]string, error) {
+ // Check if the parameter is present in the request
+ if _, ok := r.Params.Arguments[p]; !ok {
+ return []string{}, nil //default to empty list, not nil
+ }
+
+ // Check if the parameter is of the expected type
+ if _, ok := r.Params.Arguments[p].(string); !ok {
+ return nil, fmt.Errorf("parameter %s is not of type string", p)
+ }
+
+ l := parseCommaSeparatedList(r.Params.Arguments[p].(string))
+ if len(l) == 0 {
+ return []string{}, nil // default to empty list, not nil
+ }
+ return l, nil
+}
+
+// optionalBooleanParam checks if an optional parameter is present in the request and is of type boolean.
+func optionalBooleanParam(r mcp.CallToolRequest, p string) (bool, error) {
+ // Check if the parameter is present in the request
+ if _, ok := r.Params.Arguments[p]; !ok {
+ return false, nil
+ }
+
+ // Check if the parameter is of the expected type
+ if _, ok := r.Params.Arguments[p].(bool); !ok {
+ return false, fmt.Errorf("parameter %s is not of type bool", p)
+ }
+
+ return r.Params.Arguments[p].(bool), nil
+}
diff --git a/pkg/github/server_test.go b/pkg/github/server_test.go
index 316a0efa2..a081d31de 100644
--- a/pkg/github/server_test.go
+++ b/pkg/github/server_test.go
@@ -228,3 +228,372 @@ func Test_ParseCommaSeparatedList(t *testing.T) {
})
}
}
+
+func Test_RequiredStringParam(t *testing.T) {
+ tests := []struct {
+ name string
+ params map[string]interface{}
+ paramName string
+ expected string
+ expectError bool
+ }{
+ {
+ name: "valid string parameter",
+ params: map[string]interface{}{"name": "test-value"},
+ paramName: "name",
+ expected: "test-value",
+ expectError: false,
+ },
+ {
+ name: "missing parameter",
+ params: map[string]interface{}{},
+ paramName: "name",
+ expected: "",
+ expectError: true,
+ },
+ {
+ name: "empty string parameter",
+ params: map[string]interface{}{"name": ""},
+ paramName: "name",
+ expected: "",
+ expectError: true,
+ },
+ {
+ name: "wrong type parameter",
+ params: map[string]interface{}{"name": 123},
+ paramName: "name",
+ expected: "",
+ expectError: true,
+ },
+ }
+
+ for _, tc := range tests {
+ t.Run(tc.name, func(t *testing.T) {
+ request := createMCPRequest(tc.params)
+ result, err := requiredStringParam(request, tc.paramName)
+
+ if tc.expectError {
+ assert.Error(t, err)
+ } else {
+ assert.NoError(t, err)
+ assert.Equal(t, tc.expected, result)
+ }
+ })
+ }
+}
+
+func Test_OptionalStringParam(t *testing.T) {
+ tests := []struct {
+ name string
+ params map[string]interface{}
+ paramName string
+ expected string
+ expectError bool
+ }{
+ {
+ name: "valid string parameter",
+ params: map[string]interface{}{"name": "test-value"},
+ paramName: "name",
+ expected: "test-value",
+ expectError: false,
+ },
+ {
+ name: "missing parameter",
+ params: map[string]interface{}{},
+ paramName: "name",
+ expected: "",
+ expectError: false,
+ },
+ {
+ name: "empty string parameter",
+ params: map[string]interface{}{"name": ""},
+ paramName: "name",
+ expected: "",
+ expectError: false,
+ },
+ {
+ name: "wrong type parameter",
+ params: map[string]interface{}{"name": 123},
+ paramName: "name",
+ expected: "",
+ expectError: true,
+ },
+ }
+
+ for _, tc := range tests {
+ t.Run(tc.name, func(t *testing.T) {
+ request := createMCPRequest(tc.params)
+ result, err := optionalStringParam(request, tc.paramName)
+
+ if tc.expectError {
+ assert.Error(t, err)
+ } else {
+ assert.NoError(t, err)
+ assert.Equal(t, tc.expected, result)
+ }
+ })
+ }
+}
+
+func Test_RequiredNumberParam(t *testing.T) {
+ tests := []struct {
+ name string
+ params map[string]interface{}
+ paramName string
+ expected int
+ expectError bool
+ }{
+ {
+ name: "valid number parameter",
+ params: map[string]interface{}{"count": float64(42)},
+ paramName: "count",
+ expected: 42,
+ expectError: false,
+ },
+ {
+ name: "missing parameter",
+ params: map[string]interface{}{},
+ paramName: "count",
+ expected: 0,
+ expectError: true,
+ },
+ {
+ name: "wrong type parameter",
+ params: map[string]interface{}{"count": "not-a-number"},
+ paramName: "count",
+ expected: 0,
+ expectError: true,
+ },
+ }
+
+ for _, tc := range tests {
+ t.Run(tc.name, func(t *testing.T) {
+ request := createMCPRequest(tc.params)
+ result, err := requiredNumberParam(request, tc.paramName)
+
+ if tc.expectError {
+ assert.Error(t, err)
+ } else {
+ assert.NoError(t, err)
+ assert.Equal(t, tc.expected, result)
+ }
+ })
+ }
+}
+
+func Test_OptionalNumberParam(t *testing.T) {
+ tests := []struct {
+ name string
+ params map[string]interface{}
+ paramName string
+ expected int
+ expectError bool
+ }{
+ {
+ name: "valid number parameter",
+ params: map[string]interface{}{"count": float64(42)},
+ paramName: "count",
+ expected: 42,
+ expectError: false,
+ },
+ {
+ name: "missing parameter",
+ params: map[string]interface{}{},
+ paramName: "count",
+ expected: 0,
+ expectError: false,
+ },
+ {
+ name: "zero value",
+ params: map[string]interface{}{"count": float64(0)},
+ paramName: "count",
+ expected: 0,
+ expectError: false,
+ },
+ {
+ name: "wrong type parameter",
+ params: map[string]interface{}{"count": "not-a-number"},
+ paramName: "count",
+ expected: 0,
+ expectError: true,
+ },
+ }
+
+ for _, tc := range tests {
+ t.Run(tc.name, func(t *testing.T) {
+ request := createMCPRequest(tc.params)
+ result, err := optionalNumberParam(request, tc.paramName)
+
+ if tc.expectError {
+ assert.Error(t, err)
+ } else {
+ assert.NoError(t, err)
+ assert.Equal(t, tc.expected, result)
+ }
+ })
+ }
+}
+
+func Test_OptionalNumberParamWithDefault(t *testing.T) {
+ tests := []struct {
+ name string
+ params map[string]interface{}
+ paramName string
+ defaultVal int
+ expected int
+ expectError bool
+ }{
+ {
+ name: "valid number parameter",
+ params: map[string]interface{}{"count": float64(42)},
+ paramName: "count",
+ defaultVal: 10,
+ expected: 42,
+ expectError: false,
+ },
+ {
+ name: "missing parameter",
+ params: map[string]interface{}{},
+ paramName: "count",
+ defaultVal: 10,
+ expected: 10,
+ expectError: false,
+ },
+ {
+ name: "zero value",
+ params: map[string]interface{}{"count": float64(0)},
+ paramName: "count",
+ defaultVal: 10,
+ expected: 10,
+ expectError: false,
+ },
+ {
+ name: "wrong type parameter",
+ params: map[string]interface{}{"count": "not-a-number"},
+ paramName: "count",
+ defaultVal: 10,
+ expected: 0,
+ expectError: true,
+ },
+ }
+
+ for _, tc := range tests {
+ t.Run(tc.name, func(t *testing.T) {
+ request := createMCPRequest(tc.params)
+ result, err := optionalNumberParamWithDefault(request, tc.paramName, tc.defaultVal)
+
+ if tc.expectError {
+ assert.Error(t, err)
+ } else {
+ assert.NoError(t, err)
+ assert.Equal(t, tc.expected, result)
+ }
+ })
+ }
+}
+
+func Test_OptionalCommaSeparatedListParam(t *testing.T) {
+ tests := []struct {
+ name string
+ params map[string]interface{}
+ paramName string
+ expected []string
+ expectError bool
+ }{
+ {
+ name: "valid comma-separated list",
+ params: map[string]interface{}{"tags": "one,two,three"},
+ paramName: "tags",
+ expected: []string{"one", "two", "three"},
+ expectError: false,
+ },
+ {
+ name: "empty list",
+ params: map[string]interface{}{"tags": ""},
+ paramName: "tags",
+ expected: []string{},
+ expectError: false,
+ },
+ {
+ name: "missing parameter",
+ params: map[string]interface{}{},
+ paramName: "tags",
+ expected: []string{},
+ expectError: false,
+ },
+ {
+ name: "wrong type parameter",
+ params: map[string]interface{}{"tags": 123},
+ paramName: "tags",
+ expected: nil,
+ expectError: true,
+ },
+ }
+
+ for _, tc := range tests {
+ t.Run(tc.name, func(t *testing.T) {
+ request := createMCPRequest(tc.params)
+ result, err := optionalCommaSeparatedListParam(request, tc.paramName)
+
+ if tc.expectError {
+ assert.Error(t, err)
+ } else {
+ assert.NoError(t, err)
+ assert.Equal(t, tc.expected, result)
+ }
+ })
+ }
+}
+
+func Test_OptionalBooleanParam(t *testing.T) {
+ tests := []struct {
+ name string
+ params map[string]interface{}
+ paramName string
+ expected bool
+ expectError bool
+ }{
+ {
+ name: "true value",
+ params: map[string]interface{}{"flag": true},
+ paramName: "flag",
+ expected: true,
+ expectError: false,
+ },
+ {
+ name: "false value",
+ params: map[string]interface{}{"flag": false},
+ paramName: "flag",
+ expected: false,
+ expectError: false,
+ },
+ {
+ name: "missing parameter",
+ params: map[string]interface{}{},
+ paramName: "flag",
+ expected: false,
+ expectError: false,
+ },
+ {
+ name: "wrong type parameter",
+ params: map[string]interface{}{"flag": "not-a-boolean"},
+ paramName: "flag",
+ expected: false,
+ expectError: true,
+ },
+ }
+
+ for _, tc := range tests {
+ t.Run(tc.name, func(t *testing.T) {
+ request := createMCPRequest(tc.params)
+ result, err := optionalBooleanParam(request, tc.paramName)
+
+ if tc.expectError {
+ assert.Error(t, err)
+ } else {
+ assert.NoError(t, err)
+ assert.Equal(t, tc.expected, result)
+ }
+ })
+ }
+}
From 3c6467ffef6b2493d69533d34776d5cdd97b65ce Mon Sep 17 00:00:00 2001
From: Javier Uruen Val
Date: Mon, 24 Mar 2025 14:47:54 +0100
Subject: [PATCH 2/2] use generic for helper functions
---
pkg/github/code_scanning.go | 16 ++---
pkg/github/issues.go | 58 ++++++++--------
pkg/github/pullrequests.go | 80 +++++++++++-----------
pkg/github/repositories.go | 62 ++++++++---------
pkg/github/search.go | 26 +++----
pkg/github/server.go | 132 ++++++++++++++++--------------------
pkg/github/server_test.go | 12 ++--
7 files changed, 186 insertions(+), 200 deletions(-)
diff --git a/pkg/github/code_scanning.go b/pkg/github/code_scanning.go
index e7c8a4e22..380dc02cf 100644
--- a/pkg/github/code_scanning.go
+++ b/pkg/github/code_scanning.go
@@ -30,15 +30,15 @@ func getCodeScanningAlert(client *github.Client, t translations.TranslationHelpe
),
),
func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
- owner, err := requiredStringParam(request, "owner")
+ owner, err := requiredParam[string](request, "owner")
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
- repo, err := requiredStringParam(request, "repo")
+ repo, err := requiredParam[string](request, "repo")
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
- alertNumber, err := requiredNumberParam(request, "alert_number")
+ alertNumber, err := requiredInt(request, "alert_number")
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
@@ -89,23 +89,23 @@ func listCodeScanningAlerts(client *github.Client, t translations.TranslationHel
),
),
func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
- owner, err := requiredStringParam(request, "owner")
+ owner, err := requiredParam[string](request, "owner")
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
- repo, err := requiredStringParam(request, "repo")
+ repo, err := requiredParam[string](request, "repo")
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
- ref, err := optionalStringParam(request, "ref")
+ ref, err := optionalParam[string](request, "ref")
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
- state, err := optionalStringParam(request, "state")
+ state, err := optionalParam[string](request, "state")
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
- severity, err := optionalStringParam(request, "severity")
+ severity, err := optionalParam[string](request, "severity")
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
diff --git a/pkg/github/issues.go b/pkg/github/issues.go
index 521acfdaf..a62213ea6 100644
--- a/pkg/github/issues.go
+++ b/pkg/github/issues.go
@@ -32,15 +32,15 @@ func getIssue(client *github.Client, t translations.TranslationHelperFunc) (tool
),
),
func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
- owner, err := requiredStringParam(request, "owner")
+ owner, err := requiredParam[string](request, "owner")
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
- repo, err := requiredStringParam(request, "repo")
+ repo, err := requiredParam[string](request, "repo")
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
- issueNumber, err := requiredNumberParam(request, "issue_number")
+ issueNumber, err := requiredInt(request, "issue_number")
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
@@ -90,19 +90,19 @@ func addIssueComment(client *github.Client, t translations.TranslationHelperFunc
),
),
func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
- owner, err := requiredStringParam(request, "owner")
+ owner, err := requiredParam[string](request, "owner")
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
- repo, err := requiredStringParam(request, "repo")
+ repo, err := requiredParam[string](request, "repo")
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
- issueNumber, err := requiredNumberParam(request, "issue_number")
+ issueNumber, err := requiredInt(request, "issue_number")
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
- body, err := requiredStringParam(request, "body")
+ body, err := requiredParam[string](request, "body")
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
@@ -156,23 +156,23 @@ func searchIssues(client *github.Client, t translations.TranslationHelperFunc) (
),
),
func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
- query, err := requiredStringParam(request, "q")
+ query, err := requiredParam[string](request, "q")
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
- sort, err := optionalStringParam(request, "sort")
+ sort, err := optionalParam[string](request, "sort")
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
- order, err := optionalStringParam(request, "order")
+ order, err := optionalParam[string](request, "order")
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
- perPage, err := optionalNumberParamWithDefault(request, "per_page", 30)
+ perPage, err := optionalIntParamWithDefault(request, "per_page", 30)
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
- page, err := optionalNumberParamWithDefault(request, "page", 1)
+ page, err := optionalIntParamWithDefault(request, "page", 1)
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
@@ -236,21 +236,21 @@ func createIssue(client *github.Client, t translations.TranslationHelperFunc) (t
),
),
func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
- owner, err := requiredStringParam(request, "owner")
+ owner, err := requiredParam[string](request, "owner")
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
- repo, err := requiredStringParam(request, "repo")
+ repo, err := requiredParam[string](request, "repo")
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
- title, err := requiredStringParam(request, "title")
+ title, err := requiredParam[string](request, "title")
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
// Optional parameters
- body, err := optionalStringParam(request, "body")
+ body, err := optionalParam[string](request, "body")
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
@@ -332,11 +332,11 @@ func listIssues(client *github.Client, t translations.TranslationHelperFunc) (to
),
),
func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
- owner, err := requiredStringParam(request, "owner")
+ owner, err := requiredParam[string](request, "owner")
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
- repo, err := requiredStringParam(request, "repo")
+ repo, err := requiredParam[string](request, "repo")
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
@@ -344,7 +344,7 @@ func listIssues(client *github.Client, t translations.TranslationHelperFunc) (to
opts := &github.IssueListByRepoOptions{}
// Set optional parameters if provided
- opts.State, err = optionalStringParam(request, "state")
+ opts.State, err = optionalParam[string](request, "state")
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
@@ -354,17 +354,17 @@ func listIssues(client *github.Client, t translations.TranslationHelperFunc) (to
return mcp.NewToolResultError(err.Error()), nil
}
- opts.Sort, err = optionalStringParam(request, "sort")
+ opts.Sort, err = optionalParam[string](request, "sort")
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
- opts.Direction, err = optionalStringParam(request, "direction")
+ opts.Direction, err = optionalParam[string](request, "direction")
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
- since, err := optionalStringParam(request, "since")
+ since, err := optionalParam[string](request, "since")
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
@@ -443,15 +443,15 @@ func updateIssue(client *github.Client, t translations.TranslationHelperFunc) (t
),
),
func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
- owner, err := requiredStringParam(request, "owner")
+ owner, err := requiredParam[string](request, "owner")
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
- repo, err := requiredStringParam(request, "repo")
+ repo, err := requiredParam[string](request, "repo")
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
- issueNumber, err := requiredNumberParam(request, "issue_number")
+ issueNumber, err := requiredInt(request, "issue_number")
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
@@ -460,7 +460,7 @@ func updateIssue(client *github.Client, t translations.TranslationHelperFunc) (t
issueRequest := &github.IssueRequest{}
// Set optional parameters if provided
- title, err := optionalStringParam(request, "title")
+ title, err := optionalParam[string](request, "title")
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
@@ -468,7 +468,7 @@ func updateIssue(client *github.Client, t translations.TranslationHelperFunc) (t
issueRequest.Title = github.Ptr(title)
}
- body, err := optionalStringParam(request, "body")
+ body, err := optionalParam[string](request, "body")
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
@@ -476,7 +476,7 @@ func updateIssue(client *github.Client, t translations.TranslationHelperFunc) (t
issueRequest.Body = github.Ptr(body)
}
- state, err := optionalStringParam(request, "state")
+ state, err := optionalParam[string](request, "state")
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
@@ -500,7 +500,7 @@ func updateIssue(client *github.Client, t translations.TranslationHelperFunc) (t
issueRequest.Assignees = &assignees
}
- milestone, err := optionalNumberParam(request, "milestone")
+ milestone, err := optionalIntParam(request, "milestone")
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
diff --git a/pkg/github/pullrequests.go b/pkg/github/pullrequests.go
index d5caab97c..dc8b64819 100644
--- a/pkg/github/pullrequests.go
+++ b/pkg/github/pullrequests.go
@@ -31,15 +31,15 @@ func getPullRequest(client *github.Client, t translations.TranslationHelperFunc)
),
),
func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
- owner, err := requiredStringParam(request, "owner")
+ owner, err := requiredParam[string](request, "owner")
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
- repo, err := requiredStringParam(request, "repo")
+ repo, err := requiredParam[string](request, "repo")
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
- pullNumber, err := requiredNumberParam(request, "pull_number")
+ pullNumber, err := requiredInt(request, "pull_number")
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
@@ -102,39 +102,39 @@ func listPullRequests(client *github.Client, t translations.TranslationHelperFun
),
),
func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
- owner, err := requiredStringParam(request, "owner")
+ owner, err := requiredParam[string](request, "owner")
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
- repo, err := requiredStringParam(request, "repo")
+ repo, err := requiredParam[string](request, "repo")
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
- state, err := optionalStringParam(request, "state")
+ state, err := optionalParam[string](request, "state")
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
- head, err := optionalStringParam(request, "head")
+ head, err := optionalParam[string](request, "head")
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
- base, err := optionalStringParam(request, "base")
+ base, err := optionalParam[string](request, "base")
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
- sort, err := optionalStringParam(request, "sort")
+ sort, err := optionalParam[string](request, "sort")
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
- direction, err := optionalStringParam(request, "direction")
+ direction, err := optionalParam[string](request, "direction")
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
- perPage, err := optionalNumberParamWithDefault(request, "per_page", 30)
+ perPage, err := optionalIntParamWithDefault(request, "per_page", 30)
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
- page, err := optionalNumberParamWithDefault(request, "page", 1)
+ page, err := optionalIntParamWithDefault(request, "page", 1)
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
@@ -201,27 +201,27 @@ func mergePullRequest(client *github.Client, t translations.TranslationHelperFun
),
),
func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
- owner, err := requiredStringParam(request, "owner")
+ owner, err := requiredParam[string](request, "owner")
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
- repo, err := requiredStringParam(request, "repo")
+ repo, err := requiredParam[string](request, "repo")
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
- pullNumber, err := requiredNumberParam(request, "pull_number")
+ pullNumber, err := requiredInt(request, "pull_number")
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
- commitTitle, err := optionalStringParam(request, "commit_title")
+ commitTitle, err := optionalParam[string](request, "commit_title")
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
- commitMessage, err := optionalStringParam(request, "commit_message")
+ commitMessage, err := optionalParam[string](request, "commit_message")
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
- mergeMethod, err := optionalStringParam(request, "merge_method")
+ mergeMethod, err := optionalParam[string](request, "merge_method")
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
@@ -272,15 +272,15 @@ func getPullRequestFiles(client *github.Client, t translations.TranslationHelper
),
),
func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
- owner, err := requiredStringParam(request, "owner")
+ owner, err := requiredParam[string](request, "owner")
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
- repo, err := requiredStringParam(request, "repo")
+ repo, err := requiredParam[string](request, "repo")
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
- pullNumber, err := requiredNumberParam(request, "pull_number")
+ pullNumber, err := requiredInt(request, "pull_number")
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
@@ -327,15 +327,15 @@ func getPullRequestStatus(client *github.Client, t translations.TranslationHelpe
),
),
func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
- owner, err := requiredStringParam(request, "owner")
+ owner, err := requiredParam[string](request, "owner")
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
- repo, err := requiredStringParam(request, "repo")
+ repo, err := requiredParam[string](request, "repo")
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
- pullNumber, err := requiredNumberParam(request, "pull_number")
+ pullNumber, err := requiredInt(request, "pull_number")
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
@@ -399,19 +399,19 @@ func updatePullRequestBranch(client *github.Client, t translations.TranslationHe
),
),
func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
- owner, err := requiredStringParam(request, "owner")
+ owner, err := requiredParam[string](request, "owner")
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
- repo, err := requiredStringParam(request, "repo")
+ repo, err := requiredParam[string](request, "repo")
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
- pullNumber, err := requiredNumberParam(request, "pull_number")
+ pullNumber, err := requiredInt(request, "pull_number")
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
- expectedHeadSHA, err := optionalStringParam(request, "expected_head_sha")
+ expectedHeadSHA, err := optionalParam[string](request, "expected_head_sha")
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
@@ -466,15 +466,15 @@ func getPullRequestComments(client *github.Client, t translations.TranslationHel
),
),
func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
- owner, err := requiredStringParam(request, "owner")
+ owner, err := requiredParam[string](request, "owner")
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
- repo, err := requiredStringParam(request, "repo")
+ repo, err := requiredParam[string](request, "repo")
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
- pullNumber, err := requiredNumberParam(request, "pull_number")
+ pullNumber, err := requiredInt(request, "pull_number")
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
@@ -526,15 +526,15 @@ func getPullRequestReviews(client *github.Client, t translations.TranslationHelp
),
),
func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
- owner, err := requiredStringParam(request, "owner")
+ owner, err := requiredParam[string](request, "owner")
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
- repo, err := requiredStringParam(request, "repo")
+ repo, err := requiredParam[string](request, "repo")
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
- pullNumber, err := requiredNumberParam(request, "pull_number")
+ pullNumber, err := requiredInt(request, "pull_number")
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
@@ -593,19 +593,19 @@ func createPullRequestReview(client *github.Client, t translations.TranslationHe
),
),
func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
- owner, err := requiredStringParam(request, "owner")
+ owner, err := requiredParam[string](request, "owner")
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
- repo, err := requiredStringParam(request, "repo")
+ repo, err := requiredParam[string](request, "repo")
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
- pullNumber, err := requiredNumberParam(request, "pull_number")
+ pullNumber, err := requiredInt(request, "pull_number")
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
- event, err := requiredStringParam(request, "event")
+ event, err := requiredParam[string](request, "event")
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
@@ -616,7 +616,7 @@ func createPullRequestReview(client *github.Client, t translations.TranslationHe
}
// Add body if provided
- body, err := optionalStringParam(request, "body")
+ body, err := optionalParam[string](request, "body")
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
@@ -625,7 +625,7 @@ func createPullRequestReview(client *github.Client, t translations.TranslationHe
}
// Add commit ID if provided
- commitID, err := optionalStringParam(request, "commit_id")
+ commitID, err := optionalParam[string](request, "commit_id")
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
diff --git a/pkg/github/repositories.go b/pkg/github/repositories.go
index f222b1f80..f507b8973 100644
--- a/pkg/github/repositories.go
+++ b/pkg/github/repositories.go
@@ -37,23 +37,23 @@ func listCommits(client *github.Client, t translations.TranslationHelperFunc) (t
),
),
func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
- owner, err := requiredStringParam(request, "owner")
+ owner, err := requiredParam[string](request, "owner")
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
- repo, err := requiredStringParam(request, "repo")
+ repo, err := requiredParam[string](request, "repo")
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
- sha, err := optionalStringParam(request, "sha")
+ sha, err := optionalParam[string](request, "sha")
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
- page, err := optionalNumberParamWithDefault(request, "page", 1)
+ page, err := optionalIntParamWithDefault(request, "page", 1)
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
- perPage, err := optionalNumberParamWithDefault(request, "per_page", 30)
+ perPage, err := optionalIntParamWithDefault(request, "per_page", 30)
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
@@ -122,27 +122,27 @@ func createOrUpdateFile(client *github.Client, t translations.TranslationHelperF
),
),
func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
- owner, err := requiredStringParam(request, "owner")
+ owner, err := requiredParam[string](request, "owner")
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
- repo, err := requiredStringParam(request, "repo")
+ repo, err := requiredParam[string](request, "repo")
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
- path, err := requiredStringParam(request, "path")
+ path, err := requiredParam[string](request, "path")
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
- content, err := requiredStringParam(request, "content")
+ content, err := requiredParam[string](request, "content")
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
- message, err := requiredStringParam(request, "message")
+ message, err := requiredParam[string](request, "message")
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
- branch, err := requiredStringParam(request, "branch")
+ branch, err := requiredParam[string](request, "branch")
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
@@ -158,7 +158,7 @@ func createOrUpdateFile(client *github.Client, t translations.TranslationHelperF
}
// If SHA is provided, set it (for updates)
- sha, err := optionalStringParam(request, "sha")
+ sha, err := optionalParam[string](request, "sha")
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
@@ -209,19 +209,19 @@ func createRepository(client *github.Client, t translations.TranslationHelperFun
),
),
func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
- name, err := requiredStringParam(request, "name")
+ name, err := requiredParam[string](request, "name")
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
- description, err := optionalStringParam(request, "description")
+ description, err := optionalParam[string](request, "description")
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
- private, err := optionalBooleanParam(request, "private")
+ private, err := optionalParam[bool](request, "private")
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
- autoInit, err := optionalBooleanParam(request, "auto_init")
+ autoInit, err := optionalParam[bool](request, "auto_init")
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
@@ -277,19 +277,19 @@ func getFileContents(client *github.Client, t translations.TranslationHelperFunc
),
),
func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
- owner, err := requiredStringParam(request, "owner")
+ owner, err := requiredParam[string](request, "owner")
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
- repo, err := requiredStringParam(request, "repo")
+ repo, err := requiredParam[string](request, "repo")
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
- path, err := requiredStringParam(request, "path")
+ path, err := requiredParam[string](request, "path")
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
- branch, err := optionalStringParam(request, "branch")
+ branch, err := optionalParam[string](request, "branch")
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
@@ -342,15 +342,15 @@ func forkRepository(client *github.Client, t translations.TranslationHelperFunc)
),
),
func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
- owner, err := requiredStringParam(request, "owner")
+ owner, err := requiredParam[string](request, "owner")
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
- repo, err := requiredStringParam(request, "repo")
+ repo, err := requiredParam[string](request, "repo")
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
- org, err := optionalStringParam(request, "organization")
+ org, err := optionalParam[string](request, "organization")
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
@@ -409,19 +409,19 @@ func createBranch(client *github.Client, t translations.TranslationHelperFunc) (
),
),
func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
- owner, err := requiredStringParam(request, "owner")
+ owner, err := requiredParam[string](request, "owner")
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
- repo, err := requiredStringParam(request, "repo")
+ repo, err := requiredParam[string](request, "repo")
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
- branch, err := requiredStringParam(request, "branch")
+ branch, err := requiredParam[string](request, "branch")
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
- fromBranch, err := optionalStringParam(request, "from_branch")
+ fromBranch, err := optionalParam[string](request, "from_branch")
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
@@ -494,19 +494,19 @@ func pushFiles(client *github.Client, t translations.TranslationHelperFunc) (too
),
),
func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
- owner, err := requiredStringParam(request, "owner")
+ owner, err := requiredParam[string](request, "owner")
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
- repo, err := requiredStringParam(request, "repo")
+ repo, err := requiredParam[string](request, "repo")
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
- branch, err := requiredStringParam(request, "branch")
+ branch, err := requiredParam[string](request, "branch")
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
- message, err := requiredStringParam(request, "message")
+ message, err := requiredParam[string](request, "message")
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
diff --git a/pkg/github/search.go b/pkg/github/search.go
index d7ea49049..904dc7372 100644
--- a/pkg/github/search.go
+++ b/pkg/github/search.go
@@ -28,15 +28,15 @@ func searchRepositories(client *github.Client, t translations.TranslationHelperF
),
),
func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
- query, err := requiredStringParam(request, "query")
+ query, err := requiredParam[string](request, "query")
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
- page, err := optionalNumberParamWithDefault(request, "page", 1)
+ page, err := optionalIntParamWithDefault(request, "page", 1)
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
- perPage, err := optionalNumberParamWithDefault(request, "per_page", 30)
+ perPage, err := optionalIntParamWithDefault(request, "per_page", 30)
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
@@ -93,23 +93,23 @@ func searchCode(client *github.Client, t translations.TranslationHelperFunc) (to
),
),
func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
- query, err := requiredStringParam(request, "q")
+ query, err := requiredParam[string](request, "q")
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
- sort, err := optionalStringParam(request, "sort")
+ sort, err := optionalParam[string](request, "sort")
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
- order, err := optionalStringParam(request, "order")
+ order, err := optionalParam[string](request, "order")
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
- perPage, err := optionalNumberParamWithDefault(request, "per_page", 30)
+ perPage, err := optionalIntParamWithDefault(request, "per_page", 30)
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
- page, err := optionalNumberParamWithDefault(request, "page", 1)
+ page, err := optionalIntParamWithDefault(request, "page", 1)
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
@@ -168,23 +168,23 @@ func searchUsers(client *github.Client, t translations.TranslationHelperFunc) (t
),
),
func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
- query, err := requiredStringParam(request, "q")
+ query, err := requiredParam[string](request, "q")
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
- sort, err := optionalStringParam(request, "sort")
+ sort, err := optionalParam[string](request, "sort")
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
- order, err := optionalStringParam(request, "order")
+ order, err := optionalParam[string](request, "order")
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
- perPage, err := optionalNumberParamWithDefault(request, "per_page", 30)
+ perPage, err := optionalIntParamWithDefault(request, "per_page", 30)
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
- page, err := optionalNumberParamWithDefault(request, "page", 1)
+ page, err := optionalIntParamWithDefault(request, "page", 1)
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
diff --git a/pkg/github/server.go b/pkg/github/server.go
index 42e230833..829994f19 100644
--- a/pkg/github/server.go
+++ b/pkg/github/server.go
@@ -139,76 +139,81 @@ func parseCommaSeparatedList(input string) []string {
return result
}
-// requiredStringParam checks if the parameter is present in the request and is of type string.
-func requiredStringParam(r mcp.CallToolRequest, p string) (string, error) {
+// requiredParam is a helper function that can be used to fetch a requested parameter from the request.
+// It does the following checks:
+// 1. Checks if the parameter is present in the request.
+// 2. Checks if the parameter is of the expected type.
+// 3. Checks if the parameter is not empty, i.e: non-zero value
+func requiredParam[T comparable](r mcp.CallToolRequest, p string) (T, error) {
+ var zero T
+
// Check if the parameter is present in the request
if _, ok := r.Params.Arguments[p]; !ok {
- return "", fmt.Errorf("missing required parameter: %s", p)
+ return zero, fmt.Errorf("missing required parameter: %s", p)
}
// Check if the parameter is of the expected type
- if _, ok := r.Params.Arguments[p].(string); !ok {
- return "", fmt.Errorf("parameter %s is not of type string", p)
+ if _, ok := r.Params.Arguments[p].(T); !ok {
+ return zero, fmt.Errorf("parameter %s is not of type %T", p, zero)
}
- // Check if the parameter is not the zero value
- v := r.Params.Arguments[p].(string)
- if v == "" {
- return v, fmt.Errorf("missing required parameter: %s", p)
+ if r.Params.Arguments[p].(T) == zero {
+ return zero, fmt.Errorf("missing required parameter: %s", p)
+
}
- return v, nil
+ return r.Params.Arguments[p].(T), nil
}
-// requiredNumberParam checks if the parameter is present in the request and is of type number.
-func requiredNumberParam(r mcp.CallToolRequest, p string) (int, error) {
- // Check if the parameter is present in the request
- if _, ok := r.Params.Arguments[p]; !ok {
- return 0, fmt.Errorf("missing required parameter: %s", p)
- }
-
- // Check if the parameter is of the expected type
- if _, ok := r.Params.Arguments[p].(float64); !ok {
- return 0, fmt.Errorf("parameter %s is not of type number", p)
+// requiredInt is a helper function that can be used to fetch a requested parameter from the request.
+// It does the following checks:
+// 1. Checks if the parameter is present in the request.
+// 2. Checks if the parameter is of the expected type.
+// 3. Checks if the parameter is not empty, i.e: non-zero value
+func requiredInt(r mcp.CallToolRequest, p string) (int, error) {
+ v, err := requiredParam[float64](r, p)
+ if err != nil {
+ return 0, err
}
-
- return int(r.Params.Arguments[p].(float64)), nil
+ return int(v), nil
}
-// optionalStringParam checks if an optional parameter is present in the request and is of type string.
-func optionalStringParam(r mcp.CallToolRequest, p string) (value string, err error) {
+// optionalParam is a helper function that can be used to fetch a requested parameter from the request.
+// It does the following checks:
+// 1. Checks if the parameter is present in the request, if not, it returns its zero-value
+// 2. If it is present, it checks if the parameter is of the expected type and returns it
+func optionalParam[T any](r mcp.CallToolRequest, p string) (T, error) {
+ var zero T
+
// Check if the parameter is present in the request
if _, ok := r.Params.Arguments[p]; !ok {
- return "", nil
+ return zero, nil
}
// Check if the parameter is of the expected type
- if _, ok := r.Params.Arguments[p].(string); !ok {
- return "", fmt.Errorf("parameter %s is not of type string", p)
+ if _, ok := r.Params.Arguments[p].(T); !ok {
+ return zero, fmt.Errorf("parameter %s is not of type %T", p, zero)
}
- return r.Params.Arguments[p].(string), nil
+ return r.Params.Arguments[p].(T), nil
}
-// optionalNumberParam checks if an optional parameter is present in the request and is of type number.
-func optionalNumberParam(r mcp.CallToolRequest, p string) (int, error) {
- // Check if the parameter is present in the request
- if _, ok := r.Params.Arguments[p]; !ok {
- return 0, nil
- }
-
- // Check if the parameter is of the expected type
- if _, ok := r.Params.Arguments[p].(float64); !ok {
- return 0, fmt.Errorf("parameter %s is not of type number", p)
+// optionalIntParam is a helper function that can be used to fetch a requested parameter from the request.
+// It does the following checks:
+// 1. Checks if the parameter is present in the request, if not, it returns its zero-value
+// 2. If it is present, it checks if the parameter is of the expected type and returns it
+func optionalIntParam(r mcp.CallToolRequest, p string) (int, error) {
+ v, err := optionalParam[float64](r, p)
+ if err != nil {
+ return 0, err
}
-
- return int(r.Params.Arguments[p].(float64)), nil
+ return int(v), nil
}
-// optionalNumberParamWithDefault checks if an optional parameter is present in the request and is of type number.
-// If the parameter is not present or is zero, it returns the default value.
-func optionalNumberParamWithDefault(r mcp.CallToolRequest, p string, d int) (int, error) {
- v, err := optionalNumberParam(r, p)
+// optionalIntParamWithDefault is a helper function that can be used to fetch a requested parameter from the request
+// similar to optionalIntParam, but it also takes a default value.
+func optionalIntParamWithDefault(r mcp.CallToolRequest, p string, d int) (int, error) {
+ v, err := optionalIntParam(r, p)
if err != nil {
return 0, err
}
@@ -218,38 +223,19 @@ func optionalNumberParamWithDefault(r mcp.CallToolRequest, p string, d int) (int
return v, nil
}
-// optionalCommaSeparatedListParam checks if an optional parameter is present in the request and is of type string.
-// If the parameter is presents, it uses parseCommaSeparatedList to parse the string into a list of strings.
-// If the parameter is not present or is empty, it returns an empty list.
+// optionalCommaSeparatedListParam is a helper function that can be used to fetch a requested parameter from the request.
+// It does the following:
+// 1. Checks if the parameter is present in the request, if not, it returns an empty list
+// 2. If it is present, it checks if the parameter is of the expected type and uses parseCommaSeparatedList to parse it
+// and return the list of strings
func optionalCommaSeparatedListParam(r mcp.CallToolRequest, p string) ([]string, error) {
- // Check if the parameter is present in the request
- if _, ok := r.Params.Arguments[p]; !ok {
- return []string{}, nil //default to empty list, not nil
- }
-
- // Check if the parameter is of the expected type
- if _, ok := r.Params.Arguments[p].(string); !ok {
- return nil, fmt.Errorf("parameter %s is not of type string", p)
+ v, err := optionalParam[string](r, p)
+ if err != nil {
+ return []string{}, err
}
-
- l := parseCommaSeparatedList(r.Params.Arguments[p].(string))
+ l := parseCommaSeparatedList(v)
if len(l) == 0 {
- return []string{}, nil // default to empty list, not nil
+ return []string{}, nil
}
return l, nil
}
-
-// optionalBooleanParam checks if an optional parameter is present in the request and is of type boolean.
-func optionalBooleanParam(r mcp.CallToolRequest, p string) (bool, error) {
- // Check if the parameter is present in the request
- if _, ok := r.Params.Arguments[p]; !ok {
- return false, nil
- }
-
- // Check if the parameter is of the expected type
- if _, ok := r.Params.Arguments[p].(bool); !ok {
- return false, fmt.Errorf("parameter %s is not of type bool", p)
- }
-
- return r.Params.Arguments[p].(bool), nil
-}
diff --git a/pkg/github/server_test.go b/pkg/github/server_test.go
index a081d31de..5e7ac9d4d 100644
--- a/pkg/github/server_test.go
+++ b/pkg/github/server_test.go
@@ -270,7 +270,7 @@ func Test_RequiredStringParam(t *testing.T) {
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
request := createMCPRequest(tc.params)
- result, err := requiredStringParam(request, tc.paramName)
+ result, err := requiredParam[string](request, tc.paramName)
if tc.expectError {
assert.Error(t, err)
@@ -323,7 +323,7 @@ func Test_OptionalStringParam(t *testing.T) {
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
request := createMCPRequest(tc.params)
- result, err := optionalStringParam(request, tc.paramName)
+ result, err := optionalParam[string](request, tc.paramName)
if tc.expectError {
assert.Error(t, err)
@@ -369,7 +369,7 @@ func Test_RequiredNumberParam(t *testing.T) {
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
request := createMCPRequest(tc.params)
- result, err := requiredNumberParam(request, tc.paramName)
+ result, err := requiredInt(request, tc.paramName)
if tc.expectError {
assert.Error(t, err)
@@ -422,7 +422,7 @@ func Test_OptionalNumberParam(t *testing.T) {
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
request := createMCPRequest(tc.params)
- result, err := optionalNumberParam(request, tc.paramName)
+ result, err := optionalIntParam(request, tc.paramName)
if tc.expectError {
assert.Error(t, err)
@@ -480,7 +480,7 @@ func Test_OptionalNumberParamWithDefault(t *testing.T) {
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
request := createMCPRequest(tc.params)
- result, err := optionalNumberParamWithDefault(request, tc.paramName, tc.defaultVal)
+ result, err := optionalIntParamWithDefault(request, tc.paramName, tc.defaultVal)
if tc.expectError {
assert.Error(t, err)
@@ -586,7 +586,7 @@ func Test_OptionalBooleanParam(t *testing.T) {
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
request := createMCPRequest(tc.params)
- result, err := optionalBooleanParam(request, tc.paramName)
+ result, err := optionalParam[bool](request, tc.paramName)
if tc.expectError {
assert.Error(t, err)
--- a PPN by Garber Painting Akron. With Image Size Reduction included!Fetched URL: http://github.com/github/github-mcp-server/pull/35.patch
Alternative Proxies:
Alternative Proxy
pFad Proxy
pFad v3 Proxy
pFad v4 Proxy