From 786c342c104d4f5ce8f9d8401e8012996b14bf37 Mon Sep 17 00:00:00 2001 From: Juan Broullon Date: Mon, 9 Jun 2025 01:24:57 +0200 Subject: [PATCH] feat: add UpdateCodeScanningAlert and ListOrgCodeScanningAlerts tools to code_security --- pkg/github/code_scanning.go | 168 ++++++++++++++++++++++++ pkg/github/code_scanning_test.go | 219 +++++++++++++++++++++++++++++++ pkg/github/tools.go | 5 + 3 files changed, 392 insertions(+) diff --git a/pkg/github/code_scanning.go b/pkg/github/code_scanning.go index 1886b6342..d4d11d22a 100644 --- a/pkg/github/code_scanning.go +++ b/pkg/github/code_scanning.go @@ -158,3 +158,171 @@ func ListCodeScanningAlerts(getClient GetClientFn, t translations.TranslationHel return mcp.NewToolResultText(string(r)), nil } } + +func ListOrgCodeScanningAlerts(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { + return mcp.NewTool("list_org_code_scanning_alerts", + mcp.WithDescription(t("TOOL_LIST_ORG_CODE_SCANNING_ALERTS_DESCRIPTION", "List code scanning alerts for a GitHub organization.")), + mcp.WithToolAnnotation(mcp.ToolAnnotation{ + Title: t("TOOL_LIST_ORG_CODE_SCANNING_ALERTS_USER_TITLE", "List org code scanning alerts"), + ReadOnlyHint: toBoolPtr(true), + }), + mcp.WithString("org", + mcp.Required(), + mcp.Description("The organization of the repository."), + ), + mcp.WithString("sort", + mcp.Description("Sort by"), + mcp.Enum("created", "updated"), + ), + mcp.WithString("severity", + mcp.Description("Filter code scanning alerts by severity"), + mcp.Enum("critical", "high", "medium", "low", "warning", "note", "error"), + ), + mcp.WithString("tool_name", + mcp.Description("The name of the tool used for code scanning."), + ), + mcp.WithString("state", + mcp.Description("Filter code scanning alerts by state. Defaults to open"), + mcp.DefaultString("open"), + mcp.Enum("open", "closed", "dismissed", "fixed"), + ), + WithPagination(), + ), + func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + org, err := requiredParam[string](request, "org") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + sort, err := OptionalParam[string](request, "sort") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + severity, err := OptionalParam[string](request, "severity") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + toolName, err := OptionalParam[string](request, "tool_name") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + state, err := OptionalParam[string](request, "state") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + + client, err := getClient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get GitHub client: %w", err) + } + alerts, resp, err := client.CodeScanning.ListAlertsForOrg(ctx, org, &github.AlertListOptions{Sort: sort, State: state, Severity: severity, ToolName: toolName}) + if err != nil { + return nil, fmt.Errorf("failed to list organization alerts: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + return mcp.NewToolResultError(fmt.Sprintf("failed to list organization alerts: %s", string(body))), nil + } + + r, err := json.Marshal(alerts) + if err != nil { + return nil, fmt.Errorf("failed to marshal alerts: %w", err) + } + + return mcp.NewToolResultText(string(r)), nil + } +} + +func UpdateCodeScanningAlert(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { + return mcp.NewTool("update_code_scanning_alert", + mcp.WithDescription(t("TOOL_UPDATE_CODE_SCANNING_ALERT_DESCRIPTION", "Update details of a specific code scanning alert in a GitHub repository.")), + mcp.WithToolAnnotation(mcp.ToolAnnotation{ + Title: t("TOOL_UPDATE_CODE_SCANNING_ALERT_USER_TITLE", "Update code scanning alert"), + ReadOnlyHint: toBoolPtr(false), + }), + mcp.WithString("owner", + mcp.Required(), + mcp.Description("The owner of the repository."), + ), + mcp.WithString("repo", + mcp.Required(), + mcp.Description("The name of the repository."), + ), + mcp.WithNumber("alertNumber", + mcp.Required(), + mcp.Description("The number of the alert."), + ), + mcp.WithString("state", + mcp.Required(), + mcp.Description("State of the alert"), + mcp.Enum("open", "dismissed"), + ), + mcp.WithString("dismissed_reason", + mcp.Description("Reason for dismissing or closing the alert"), + mcp.Enum("false positive", "won't fix", "used in tests"), + ), + mcp.WithString("dismissed_comment", + mcp.Description("Dismissal comment associated with the dismissal of the alert"), + ), + ), + func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + owner, err := requiredParam[string](request, "owner") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + repo, err := requiredParam[string](request, "repo") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + alertNumber, err := RequiredInt(request, "alertNumber") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + state, err := requiredParam[string](request, "state") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + dismissed_reason, err := OptionalParam[string](request, "dismissed_reason") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + dismissed_comment, err := OptionalParam[string](request, "dismissed_comment") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + + if state == "dismissed" && dismissed_reason == "" { + return nil, fmt.Errorf("dismissed_reason required for 'dismissed' state ") + } + + client, err := getClient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get GitHub client: %w", err) + } + + alert, resp, err := client.CodeScanning.UpdateAlert(ctx, owner, repo, int64(alertNumber), &github.CodeScanningAlertState{State: state, DismissedReason: &dismissed_reason, DismissedComment: &dismissed_comment}) + if err != nil { + return nil, fmt.Errorf("failed to update alert: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + return mcp.NewToolResultError(fmt.Sprintf("failed to update alert: %s", string(body))), nil + } + + r, err := json.Marshal(alert) + if err != nil { + return nil, fmt.Errorf("failed to marshal alert: %w", err) + } + + return mcp.NewToolResultText(string(r)), nil + } +} diff --git a/pkg/github/code_scanning_test.go b/pkg/github/code_scanning_test.go index b5facbf6b..92bb4ce6a 100644 --- a/pkg/github/code_scanning_test.go +++ b/pkg/github/code_scanning_test.go @@ -238,3 +238,222 @@ func Test_ListCodeScanningAlerts(t *testing.T) { }) } } + +func Test_UpdateCodeScanningAlert(t *testing.T) { + // Verify tool definition + mockClient := github.NewClient(nil) + tool, _ := UpdateCodeScanningAlert(stubGetClientFn(mockClient), translations.NullTranslationHelper) + + assert.Equal(t, "update_code_scanning_alert", tool.Name) + assert.NotEmpty(t, tool.Description) + assert.Contains(t, tool.InputSchema.Properties, "owner") + assert.Contains(t, tool.InputSchema.Properties, "repo") + assert.Contains(t, tool.InputSchema.Properties, "alertNumber") + assert.Contains(t, tool.InputSchema.Properties, "state") + assert.ElementsMatch(t, tool.InputSchema.Required, []string{"owner", "repo", "alertNumber", "state"}) + + // Mock alert for success + mockAlert := &github.Alert{ + Number: github.Ptr(42), + State: github.Ptr("open"), + Rule: &github.Rule{ID: github.Ptr("rule-id"), Description: github.Ptr("desc")}, + HTMLURL: github.Ptr("https://github.com/owner/repo/security/code-scanning/42"), + } + + tests := []struct { + name string + mockedClient *http.Client + requestArgs map[string]interface{} + expectError bool + expectedAlert *github.Alert + expectedErrMsg string + }{ + { + name: "successful alert update", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatch( + mock.PatchReposCodeScanningAlertsByOwnerByRepoByAlertNumber, + mockAlert, + ), + ), + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + "alertNumber": float64(42), + "state": "open", + }, + expectError: false, + expectedAlert: mockAlert, + }, + { + name: "update fails", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.PatchReposCodeScanningAlertsByOwnerByRepoByAlertNumber, + http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusBadRequest) + _, _ = w.Write([]byte(`{"message": "Invalid request"}`)) + }), + ), + ), + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + "alertNumber": float64(9999), + "state": "open", + }, + expectError: true, + expectedErrMsg: "failed to update alert", + }, + { + name: "error when dismissed_reason not provided", + mockedClient: nil, // early exit happens before any HTTP call + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + "alertNumber": float64(42), + "state": "dismissed", + "dismissed_reason": "", + }, + expectError: true, + expectedErrMsg: "dismissed_reason required for 'dismissed' state", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + client := github.NewClient(tc.mockedClient) + _, handler := UpdateCodeScanningAlert(stubGetClientFn(client), translations.NullTranslationHelper) + request := createMCPRequest(tc.requestArgs) + + result, err := handler(context.Background(), request) + if tc.expectError { + require.Error(t, err) + assert.Contains(t, err.Error(), tc.expectedErrMsg) + return + } + + require.NoError(t, err) + text := getTextResult(t, result) + var got github.Alert + require.NoError(t, json.Unmarshal([]byte(text.Text), &got)) + + assert.Equal(t, *tc.expectedAlert.Number, *got.Number) + assert.Equal(t, *tc.expectedAlert.State, *got.State) + assert.Equal(t, *tc.expectedAlert.Rule.ID, *got.Rule.ID) + assert.Equal(t, *tc.expectedAlert.HTMLURL, *got.HTMLURL) + }) + } +} + +func Test_ListOrgCodeScanningAlerts(t *testing.T) { + // Verify tool definition + mockClient := github.NewClient(nil) + tool, _ := ListOrgCodeScanningAlerts(stubGetClientFn(mockClient), translations.NullTranslationHelper) + + assert.Equal(t, "list_org_code_scanning_alerts", tool.Name) + assert.NotEmpty(t, tool.Description) + assert.Contains(t, tool.InputSchema.Properties, "org") + assert.Contains(t, tool.InputSchema.Properties, "sort") + assert.Contains(t, tool.InputSchema.Properties, "severity") + assert.Contains(t, tool.InputSchema.Properties, "tool_name") + assert.Contains(t, tool.InputSchema.Properties, "state") + assert.ElementsMatch(t, tool.InputSchema.Required, []string{"org"}) + + // Mock alerts for success + mockAlerts := []*github.Alert{ + { + Number: github.Ptr(100), + State: github.Ptr("open"), + Rule: &github.Rule{ID: github.Ptr("org-rule-1"), Description: github.Ptr("desc1")}, + HTMLURL: github.Ptr("https://github.com/org/repo/security/code-scanning/100"), + }, + { + Number: github.Ptr(101), + State: github.Ptr("dismissed"), + Rule: &github.Rule{ID: github.Ptr("org-rule-2"), Description: github.Ptr("desc2")}, + HTMLURL: github.Ptr("https://github.com/org/repo/security/code-scanning/101"), + }, + } + + tests := []struct { + name string + mockedClient *http.Client + requestArgs map[string]interface{} + expectError bool + expectedAlerts []*github.Alert + expectedErrMsg string + }{ + { + name: "successful org alerts listing", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.GetOrgsCodeScanningAlertsByOrg, + expectQueryParams(t, map[string]string{ + "state": "open", + "severity": "high", + "tool_name": "codeql", + "sort": "created", + }).andThen( + mockResponse(t, http.StatusOK, mockAlerts), + ), + ), + ), + requestArgs: map[string]interface{}{ + "org": "org", + "state": "open", + "severity": "high", + "tool_name": "codeql", + "sort": "created", + }, + expectError: false, + expectedAlerts: mockAlerts, + }, + { + name: "org alerts listing fails", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.GetOrgsCodeScanningAlertsByOrg, + http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusForbidden) + _, _ = w.Write([]byte(`{"message":"Forbidden"}`)) + }), + ), + ), + requestArgs: map[string]interface{}{ + "org": "org", + }, + expectError: true, + expectedErrMsg: "failed to list organization alerts", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + client := github.NewClient(tc.mockedClient) + _, handler := ListOrgCodeScanningAlerts(stubGetClientFn(client), translations.NullTranslationHelper) + request := createMCPRequest(tc.requestArgs) + + result, err := handler(context.Background(), request) + if tc.expectError { + require.Error(t, err) + assert.Contains(t, err.Error(), tc.expectedErrMsg) + return + } + + require.NoError(t, err) + text := getTextResult(t, result) + + var got []*github.Alert + require.NoError(t, json.Unmarshal([]byte(text.Text), &got)) + assert.Len(t, got, len(tc.expectedAlerts)) + + for i := range got { + assert.Equal(t, *tc.expectedAlerts[i].Number, *got[i].Number) + assert.Equal(t, *tc.expectedAlerts[i].State, *got[i].State) + assert.Equal(t, *tc.expectedAlerts[i].Rule.ID, *got[i].Rule.ID) + assert.Equal(t, *tc.expectedAlerts[i].HTMLURL, *got[i].HTMLURL) + } + }) + } +} diff --git a/pkg/github/tools.go b/pkg/github/tools.go index f8e05fc85..687053eb5 100644 --- a/pkg/github/tools.go +++ b/pkg/github/tools.go @@ -84,7 +84,12 @@ func DefaultToolsetGroup(readOnly bool, getClient GetClientFn, getGQLClient GetG AddReadTools( toolsets.NewServerTool(GetCodeScanningAlert(getClient, t)), toolsets.NewServerTool(ListCodeScanningAlerts(getClient, t)), + toolsets.NewServerTool(ListOrgCodeScanningAlerts(getClient, t)), + ). + AddWriteTools( + toolsets.NewServerTool(UpdateCodeScanningAlert(getClient, t)), ) + secretProtection := toolsets.NewToolset("secret_protection", "Secret protection related tools, such as GitHub Secret Scanning"). AddReadTools( toolsets.NewServerTool(GetSecretScanningAlert(getClient, t)), pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy