diff --git a/pkg/github/code_scanning.go b/pkg/github/code_scanning.go index 6b15c0c45..9e160759a 100644 --- a/pkg/github/code_scanning.go +++ b/pkg/github/code_scanning.go @@ -167,3 +167,171 @@ func ListCodeScanningAlerts(getClient GetClientFn, t translations.TranslationHel return mcp.NewToolResultText(string(r)), nil } } + +func ListOrgCodeScanningAlerts(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { + return mcp.NewTool("list_org_code_scanning_alerts", + mcp.WithDescription(t("TOOL_LIST_ORG_CODE_SCANNING_ALERTS_DESCRIPTION", "List code scanning alerts for a GitHub organization.")), + mcp.WithToolAnnotation(mcp.ToolAnnotation{ + Title: t("TOOL_LIST_ORG_CODE_SCANNING_ALERTS_USER_TITLE", "List org code scanning alerts"), + ReadOnlyHint: toBoolPtr(true), + }), + mcp.WithString("org", + mcp.Required(), + mcp.Description("The organization of the repository."), + ), + mcp.WithString("sort", + mcp.Description("Sort by"), + mcp.Enum("created", "updated"), + ), + mcp.WithString("severity", + mcp.Description("Filter code scanning alerts by severity"), + mcp.Enum("critical", "high", "medium", "low", "warning", "note", "error"), + ), + mcp.WithString("tool_name", + mcp.Description("The name of the tool used for code scanning."), + ), + mcp.WithString("state", + mcp.Description("Filter code scanning alerts by state. Defaults to open"), + mcp.DefaultString("open"), + mcp.Enum("open", "closed", "dismissed", "fixed"), + ), + WithPagination(), + ), + func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + org, err := requiredParam[string](request, "org") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + sort, err := OptionalParam[string](request, "sort") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + severity, err := OptionalParam[string](request, "severity") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + toolName, err := OptionalParam[string](request, "tool_name") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + state, err := OptionalParam[string](request, "state") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + + client, err := getClient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get GitHub client: %w", err) + } + alerts, resp, err := client.CodeScanning.ListAlertsForOrg(ctx, org, &github.AlertListOptions{Sort: sort, State: state, Severity: severity, ToolName: toolName}) + if err != nil { + return nil, fmt.Errorf("failed to list organization alerts: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + return mcp.NewToolResultError(fmt.Sprintf("failed to list organization alerts: %s", string(body))), nil + } + + r, err := json.Marshal(alerts) + if err != nil { + return nil, fmt.Errorf("failed to marshal alerts: %w", err) + } + + return mcp.NewToolResultText(string(r)), nil + } +} + +func UpdateCodeScanningAlert(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { + return mcp.NewTool("update_code_scanning_alert", + mcp.WithDescription(t("TOOL_UPDATE_CODE_SCANNING_ALERT_DESCRIPTION", "Update details of a specific code scanning alert in a GitHub repository.")), + mcp.WithToolAnnotation(mcp.ToolAnnotation{ + Title: t("TOOL_UPDATE_CODE_SCANNING_ALERT_USER_TITLE", "Update code scanning alert"), + ReadOnlyHint: toBoolPtr(false), + }), + mcp.WithString("owner", + mcp.Required(), + mcp.Description("The owner of the repository."), + ), + mcp.WithString("repo", + mcp.Required(), + mcp.Description("The name of the repository."), + ), + mcp.WithNumber("alertNumber", + mcp.Required(), + mcp.Description("The number of the alert."), + ), + mcp.WithString("state", + mcp.Required(), + mcp.Description("State of the alert"), + mcp.Enum("open", "dismissed"), + ), + mcp.WithString("dismissed_reason", + mcp.Description("Reason for dismissing or closing the alert"), + mcp.Enum("false positive", "won't fix", "used in tests"), + ), + mcp.WithString("dismissed_comment", + mcp.Description("Dismissal comment associated with the dismissal of the alert"), + ), + ), + func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + owner, err := requiredParam[string](request, "owner") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + repo, err := requiredParam[string](request, "repo") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + alertNumber, err := RequiredInt(request, "alertNumber") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + state, err := requiredParam[string](request, "state") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + dismissed_reason, err := OptionalParam[string](request, "dismissed_reason") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + dismissed_comment, err := OptionalParam[string](request, "dismissed_comment") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + + if state == "dismissed" && dismissed_reason == "" { + return nil, fmt.Errorf("dismissed_reason required for 'dismissed' state ") + } + + client, err := getClient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get GitHub client: %w", err) + } + + alert, resp, err := client.CodeScanning.UpdateAlert(ctx, owner, repo, int64(alertNumber), &github.CodeScanningAlertState{State: state, DismissedReason: &dismissed_reason, DismissedComment: &dismissed_comment}) + if err != nil { + return nil, fmt.Errorf("failed to update alert: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + return mcp.NewToolResultError(fmt.Sprintf("failed to update alert: %s", string(body))), nil + } + + r, err := json.Marshal(alert) + if err != nil { + return nil, fmt.Errorf("failed to marshal alert: %w", err) + } + + return mcp.NewToolResultText(string(r)), nil + } +} diff --git a/pkg/github/code_scanning_test.go b/pkg/github/code_scanning_test.go index 66f6fd6cc..3d53619c8 100644 --- a/pkg/github/code_scanning_test.go +++ b/pkg/github/code_scanning_test.go @@ -247,3 +247,222 @@ func Test_ListCodeScanningAlerts(t *testing.T) { }) } } + +func Test_UpdateCodeScanningAlert(t *testing.T) { + // Verify tool definition + mockClient := github.NewClient(nil) + tool, _ := UpdateCodeScanningAlert(stubGetClientFn(mockClient), translations.NullTranslationHelper) + + assert.Equal(t, "update_code_scanning_alert", tool.Name) + assert.NotEmpty(t, tool.Description) + assert.Contains(t, tool.InputSchema.Properties, "owner") + assert.Contains(t, tool.InputSchema.Properties, "repo") + assert.Contains(t, tool.InputSchema.Properties, "alertNumber") + assert.Contains(t, tool.InputSchema.Properties, "state") + assert.ElementsMatch(t, tool.InputSchema.Required, []string{"owner", "repo", "alertNumber", "state"}) + + // Mock alert for success + mockAlert := &github.Alert{ + Number: github.Ptr(42), + State: github.Ptr("open"), + Rule: &github.Rule{ID: github.Ptr("rule-id"), Description: github.Ptr("desc")}, + HTMLURL: github.Ptr("https://github.com/owner/repo/security/code-scanning/42"), + } + + tests := []struct { + name string + mockedClient *http.Client + requestArgs map[string]interface{} + expectError bool + expectedAlert *github.Alert + expectedErrMsg string + }{ + { + name: "successful alert update", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatch( + mock.PatchReposCodeScanningAlertsByOwnerByRepoByAlertNumber, + mockAlert, + ), + ), + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + "alertNumber": float64(42), + "state": "open", + }, + expectError: false, + expectedAlert: mockAlert, + }, + { + name: "update fails", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.PatchReposCodeScanningAlertsByOwnerByRepoByAlertNumber, + http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusBadRequest) + _, _ = w.Write([]byte(`{"message": "Invalid request"}`)) + }), + ), + ), + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + "alertNumber": float64(9999), + "state": "open", + }, + expectError: true, + expectedErrMsg: "failed to update alert", + }, + { + name: "error when dismissed_reason not provided", + mockedClient: nil, // early exit happens before any HTTP call + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + "alertNumber": float64(42), + "state": "dismissed", + "dismissed_reason": "", + }, + expectError: true, + expectedErrMsg: "dismissed_reason required for 'dismissed' state", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + client := github.NewClient(tc.mockedClient) + _, handler := UpdateCodeScanningAlert(stubGetClientFn(client), translations.NullTranslationHelper) + request := createMCPRequest(tc.requestArgs) + + result, err := handler(context.Background(), request) + if tc.expectError { + require.Error(t, err) + assert.Contains(t, err.Error(), tc.expectedErrMsg) + return + } + + require.NoError(t, err) + text := getTextResult(t, result) + var got github.Alert + require.NoError(t, json.Unmarshal([]byte(text.Text), &got)) + + assert.Equal(t, *tc.expectedAlert.Number, *got.Number) + assert.Equal(t, *tc.expectedAlert.State, *got.State) + assert.Equal(t, *tc.expectedAlert.Rule.ID, *got.Rule.ID) + assert.Equal(t, *tc.expectedAlert.HTMLURL, *got.HTMLURL) + }) + } +} + +func Test_ListOrgCodeScanningAlerts(t *testing.T) { + // Verify tool definition + mockClient := github.NewClient(nil) + tool, _ := ListOrgCodeScanningAlerts(stubGetClientFn(mockClient), translations.NullTranslationHelper) + + assert.Equal(t, "list_org_code_scanning_alerts", tool.Name) + assert.NotEmpty(t, tool.Description) + assert.Contains(t, tool.InputSchema.Properties, "org") + assert.Contains(t, tool.InputSchema.Properties, "sort") + assert.Contains(t, tool.InputSchema.Properties, "severity") + assert.Contains(t, tool.InputSchema.Properties, "tool_name") + assert.Contains(t, tool.InputSchema.Properties, "state") + assert.ElementsMatch(t, tool.InputSchema.Required, []string{"org"}) + + // Mock alerts for success + mockAlerts := []*github.Alert{ + { + Number: github.Ptr(100), + State: github.Ptr("open"), + Rule: &github.Rule{ID: github.Ptr("org-rule-1"), Description: github.Ptr("desc1")}, + HTMLURL: github.Ptr("https://github.com/org/repo/security/code-scanning/100"), + }, + { + Number: github.Ptr(101), + State: github.Ptr("dismissed"), + Rule: &github.Rule{ID: github.Ptr("org-rule-2"), Description: github.Ptr("desc2")}, + HTMLURL: github.Ptr("https://github.com/org/repo/security/code-scanning/101"), + }, + } + + tests := []struct { + name string + mockedClient *http.Client + requestArgs map[string]interface{} + expectError bool + expectedAlerts []*github.Alert + expectedErrMsg string + }{ + { + name: "successful org alerts listing", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.GetOrgsCodeScanningAlertsByOrg, + expectQueryParams(t, map[string]string{ + "state": "open", + "severity": "high", + "tool_name": "codeql", + "sort": "created", + }).andThen( + mockResponse(t, http.StatusOK, mockAlerts), + ), + ), + ), + requestArgs: map[string]interface{}{ + "org": "org", + "state": "open", + "severity": "high", + "tool_name": "codeql", + "sort": "created", + }, + expectError: false, + expectedAlerts: mockAlerts, + }, + { + name: "org alerts listing fails", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.GetOrgsCodeScanningAlertsByOrg, + http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusForbidden) + _, _ = w.Write([]byte(`{"message":"Forbidden"}`)) + }), + ), + ), + requestArgs: map[string]interface{}{ + "org": "org", + }, + expectError: true, + expectedErrMsg: "failed to list organization alerts", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + client := github.NewClient(tc.mockedClient) + _, handler := ListOrgCodeScanningAlerts(stubGetClientFn(client), translations.NullTranslationHelper) + request := createMCPRequest(tc.requestArgs) + + result, err := handler(context.Background(), request) + if tc.expectError { + require.Error(t, err) + assert.Contains(t, err.Error(), tc.expectedErrMsg) + return + } + + require.NoError(t, err) + text := getTextResult(t, result) + + var got []*github.Alert + require.NoError(t, json.Unmarshal([]byte(text.Text), &got)) + assert.Len(t, got, len(tc.expectedAlerts)) + + for i := range got { + assert.Equal(t, *tc.expectedAlerts[i].Number, *got[i].Number) + assert.Equal(t, *tc.expectedAlerts[i].State, *got[i].State) + assert.Equal(t, *tc.expectedAlerts[i].Rule.ID, *got[i].Rule.ID) + assert.Equal(t, *tc.expectedAlerts[i].HTMLURL, *got[i].HTMLURL) + } + }) + } +} diff --git a/pkg/github/tools.go b/pkg/github/tools.go index 77a1ccd3b..52fea9dcc 100644 --- a/pkg/github/tools.go +++ b/pkg/github/tools.go @@ -97,7 +97,12 @@ func DefaultToolsetGroup(readOnly bool, getClient GetClientFn, getGQLClient GetG AddReadTools( toolsets.NewServerTool(GetCodeScanningAlert(getClient, t)), toolsets.NewServerTool(ListCodeScanningAlerts(getClient, t)), + toolsets.NewServerTool(ListOrgCodeScanningAlerts(getClient, t)), + ). + AddWriteTools( + toolsets.NewServerTool(UpdateCodeScanningAlert(getClient, t)), ) + secretProtection := toolsets.NewToolset("secret_protection", "Secret protection related tools, such as GitHub Secret Scanning"). AddReadTools( toolsets.NewServerTool(GetSecretScanningAlert(getClient, t)), pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy