Skip to content

Commit 786c342

Browse files
committed
feat: add UpdateCodeScanningAlert and ListOrgCodeScanningAlerts tools to code_security
1 parent c17ebfe commit 786c342

File tree

3 files changed

+392
-0
lines changed

3 files changed

+392
-0
lines changed

pkg/github/code_scanning.go

Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,3 +158,171 @@ func ListCodeScanningAlerts(getClient GetClientFn, t translations.TranslationHel
158158
return mcp.NewToolResultText(string(r)), nil
159159
}
160160
}
161+
162+
func ListOrgCodeScanningAlerts(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) {
163+
return mcp.NewTool("list_org_code_scanning_alerts",
164+
mcp.WithDescription(t("TOOL_LIST_ORG_CODE_SCANNING_ALERTS_DESCRIPTION", "List code scanning alerts for a GitHub organization.")),
165+
mcp.WithToolAnnotation(mcp.ToolAnnotation{
166+
Title: t("TOOL_LIST_ORG_CODE_SCANNING_ALERTS_USER_TITLE", "List org code scanning alerts"),
167+
ReadOnlyHint: toBoolPtr(true),
168+
}),
169+
mcp.WithString("org",
170+
mcp.Required(),
171+
mcp.Description("The organization of the repository."),
172+
),
173+
mcp.WithString("sort",
174+
mcp.Description("Sort by"),
175+
mcp.Enum("created", "updated"),
176+
),
177+
mcp.WithString("severity",
178+
mcp.Description("Filter code scanning alerts by severity"),
179+
mcp.Enum("critical", "high", "medium", "low", "warning", "note", "error"),
180+
),
181+
mcp.WithString("tool_name",
182+
mcp.Description("The name of the tool used for code scanning."),
183+
),
184+
mcp.WithString("state",
185+
mcp.Description("Filter code scanning alerts by state. Defaults to open"),
186+
mcp.DefaultString("open"),
187+
mcp.Enum("open", "closed", "dismissed", "fixed"),
188+
),
189+
WithPagination(),
190+
),
191+
func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
192+
org, err := requiredParam[string](request, "org")
193+
if err != nil {
194+
return mcp.NewToolResultError(err.Error()), nil
195+
}
196+
sort, err := OptionalParam[string](request, "sort")
197+
if err != nil {
198+
return mcp.NewToolResultError(err.Error()), nil
199+
}
200+
severity, err := OptionalParam[string](request, "severity")
201+
if err != nil {
202+
return mcp.NewToolResultError(err.Error()), nil
203+
}
204+
toolName, err := OptionalParam[string](request, "tool_name")
205+
if err != nil {
206+
return mcp.NewToolResultError(err.Error()), nil
207+
}
208+
state, err := OptionalParam[string](request, "state")
209+
if err != nil {
210+
return mcp.NewToolResultError(err.Error()), nil
211+
}
212+
213+
client, err := getClient(ctx)
214+
if err != nil {
215+
return nil, fmt.Errorf("failed to get GitHub client: %w", err)
216+
}
217+
alerts, resp, err := client.CodeScanning.ListAlertsForOrg(ctx, org, &github.AlertListOptions{Sort: sort, State: state, Severity: severity, ToolName: toolName})
218+
if err != nil {
219+
return nil, fmt.Errorf("failed to list organization alerts: %w", err)
220+
}
221+
defer func() { _ = resp.Body.Close() }()
222+
223+
if resp.StatusCode != http.StatusOK {
224+
body, err := io.ReadAll(resp.Body)
225+
if err != nil {
226+
return nil, fmt.Errorf("failed to read response body: %w", err)
227+
}
228+
return mcp.NewToolResultError(fmt.Sprintf("failed to list organization alerts: %s", string(body))), nil
229+
}
230+
231+
r, err := json.Marshal(alerts)
232+
if err != nil {
233+
return nil, fmt.Errorf("failed to marshal alerts: %w", err)
234+
}
235+
236+
return mcp.NewToolResultText(string(r)), nil
237+
}
238+
}
239+
240+
func UpdateCodeScanningAlert(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) {
241+
return mcp.NewTool("update_code_scanning_alert",
242+
mcp.WithDescription(t("TOOL_UPDATE_CODE_SCANNING_ALERT_DESCRIPTION", "Update details of a specific code scanning alert in a GitHub repository.")),
243+
mcp.WithToolAnnotation(mcp.ToolAnnotation{
244+
Title: t("TOOL_UPDATE_CODE_SCANNING_ALERT_USER_TITLE", "Update code scanning alert"),
245+
ReadOnlyHint: toBoolPtr(false),
246+
}),
247+
mcp.WithString("owner",
248+
mcp.Required(),
249+
mcp.Description("The owner of the repository."),
250+
),
251+
mcp.WithString("repo",
252+
mcp.Required(),
253+
mcp.Description("The name of the repository."),
254+
),
255+
mcp.WithNumber("alertNumber",
256+
mcp.Required(),
257+
mcp.Description("The number of the alert."),
258+
),
259+
mcp.WithString("state",
260+
mcp.Required(),
261+
mcp.Description("State of the alert"),
262+
mcp.Enum("open", "dismissed"),
263+
),
264+
mcp.WithString("dismissed_reason",
265+
mcp.Description("Reason for dismissing or closing the alert"),
266+
mcp.Enum("false positive", "won't fix", "used in tests"),
267+
),
268+
mcp.WithString("dismissed_comment",
269+
mcp.Description("Dismissal comment associated with the dismissal of the alert"),
270+
),
271+
),
272+
func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
273+
owner, err := requiredParam[string](request, "owner")
274+
if err != nil {
275+
return mcp.NewToolResultError(err.Error()), nil
276+
}
277+
repo, err := requiredParam[string](request, "repo")
278+
if err != nil {
279+
return mcp.NewToolResultError(err.Error()), nil
280+
}
281+
alertNumber, err := RequiredInt(request, "alertNumber")
282+
if err != nil {
283+
return mcp.NewToolResultError(err.Error()), nil
284+
}
285+
state, err := requiredParam[string](request, "state")
286+
if err != nil {
287+
return mcp.NewToolResultError(err.Error()), nil
288+
}
289+
dismissed_reason, err := OptionalParam[string](request, "dismissed_reason")
290+
if err != nil {
291+
return mcp.NewToolResultError(err.Error()), nil
292+
}
293+
dismissed_comment, err := OptionalParam[string](request, "dismissed_comment")
294+
if err != nil {
295+
return mcp.NewToolResultError(err.Error()), nil
296+
}
297+
298+
if state == "dismissed" && dismissed_reason == "" {
299+
return nil, fmt.Errorf("dismissed_reason required for 'dismissed' state ")
300+
}
301+
302+
client, err := getClient(ctx)
303+
if err != nil {
304+
return nil, fmt.Errorf("failed to get GitHub client: %w", err)
305+
}
306+
307+
alert, resp, err := client.CodeScanning.UpdateAlert(ctx, owner, repo, int64(alertNumber), &github.CodeScanningAlertState{State: state, DismissedReason: &dismissed_reason, DismissedComment: &dismissed_comment})
308+
if err != nil {
309+
return nil, fmt.Errorf("failed to update alert: %w", err)
310+
}
311+
defer func() { _ = resp.Body.Close() }()
312+
313+
if resp.StatusCode != http.StatusOK {
314+
body, err := io.ReadAll(resp.Body)
315+
if err != nil {
316+
return nil, fmt.Errorf("failed to read response body: %w", err)
317+
}
318+
return mcp.NewToolResultError(fmt.Sprintf("failed to update alert: %s", string(body))), nil
319+
}
320+
321+
r, err := json.Marshal(alert)
322+
if err != nil {
323+
return nil, fmt.Errorf("failed to marshal alert: %w", err)
324+
}
325+
326+
return mcp.NewToolResultText(string(r)), nil
327+
}
328+
}

pkg/github/code_scanning_test.go

Lines changed: 219 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,3 +238,222 @@ func Test_ListCodeScanningAlerts(t *testing.T) {
238238
})
239239
}
240240
}
241+
242+
func Test_UpdateCodeScanningAlert(t *testing.T) {
243+
// Verify tool definition
244+
mockClient := github.NewClient(nil)
245+
tool, _ := UpdateCodeScanningAlert(stubGetClientFn(mockClient), translations.NullTranslationHelper)
246+
247+
assert.Equal(t, "update_code_scanning_alert", tool.Name)
248+
assert.NotEmpty(t, tool.Description)
249+
assert.Contains(t, tool.InputSchema.Properties, "owner")
250+
assert.Contains(t, tool.InputSchema.Properties, "repo")
251+
assert.Contains(t, tool.InputSchema.Properties, "alertNumber")
252+
assert.Contains(t, tool.InputSchema.Properties, "state")
253+
assert.ElementsMatch(t, tool.InputSchema.Required, []string{"owner", "repo", "alertNumber", "state"})
254+
255+
// Mock alert for success
256+
mockAlert := &github.Alert{
257+
Number: github.Ptr(42),
258+
State: github.Ptr("open"),
259+
Rule: &github.Rule{ID: github.Ptr("rule-id"), Description: github.Ptr("desc")},
260+
HTMLURL: github.Ptr("https://github.com/owner/repo/security/code-scanning/42"),
261+
}
262+
263+
tests := []struct {
264+
name string
265+
mockedClient *http.Client
266+
requestArgs map[string]interface{}
267+
expectError bool
268+
expectedAlert *github.Alert
269+
expectedErrMsg string
270+
}{
271+
{
272+
name: "successful alert update",
273+
mockedClient: mock.NewMockedHTTPClient(
274+
mock.WithRequestMatch(
275+
mock.PatchReposCodeScanningAlertsByOwnerByRepoByAlertNumber,
276+
mockAlert,
277+
),
278+
),
279+
requestArgs: map[string]interface{}{
280+
"owner": "owner",
281+
"repo": "repo",
282+
"alertNumber": float64(42),
283+
"state": "open",
284+
},
285+
expectError: false,
286+
expectedAlert: mockAlert,
287+
},
288+
{
289+
name: "update fails",
290+
mockedClient: mock.NewMockedHTTPClient(
291+
mock.WithRequestMatchHandler(
292+
mock.PatchReposCodeScanningAlertsByOwnerByRepoByAlertNumber,
293+
http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
294+
w.WriteHeader(http.StatusBadRequest)
295+
_, _ = w.Write([]byte(`{"message": "Invalid request"}`))
296+
}),
297+
),
298+
),
299+
requestArgs: map[string]interface{}{
300+
"owner": "owner",
301+
"repo": "repo",
302+
"alertNumber": float64(9999),
303+
"state": "open",
304+
},
305+
expectError: true,
306+
expectedErrMsg: "failed to update alert",
307+
},
308+
{
309+
name: "error when dismissed_reason not provided",
310+
mockedClient: nil, // early exit happens before any HTTP call
311+
requestArgs: map[string]interface{}{
312+
"owner": "owner",
313+
"repo": "repo",
314+
"alertNumber": float64(42),
315+
"state": "dismissed",
316+
"dismissed_reason": "",
317+
},
318+
expectError: true,
319+
expectedErrMsg: "dismissed_reason required for 'dismissed' state",
320+
},
321+
}
322+
323+
for _, tc := range tests {
324+
t.Run(tc.name, func(t *testing.T) {
325+
client := github.NewClient(tc.mockedClient)
326+
_, handler := UpdateCodeScanningAlert(stubGetClientFn(client), translations.NullTranslationHelper)
327+
request := createMCPRequest(tc.requestArgs)
328+
329+
result, err := handler(context.Background(), request)
330+
if tc.expectError {
331+
require.Error(t, err)
332+
assert.Contains(t, err.Error(), tc.expectedErrMsg)
333+
return
334+
}
335+
336+
require.NoError(t, err)
337+
text := getTextResult(t, result)
338+
var got github.Alert
339+
require.NoError(t, json.Unmarshal([]byte(text.Text), &got))
340+
341+
assert.Equal(t, *tc.expectedAlert.Number, *got.Number)
342+
assert.Equal(t, *tc.expectedAlert.State, *got.State)
343+
assert.Equal(t, *tc.expectedAlert.Rule.ID, *got.Rule.ID)
344+
assert.Equal(t, *tc.expectedAlert.HTMLURL, *got.HTMLURL)
345+
})
346+
}
347+
}
348+
349+
func Test_ListOrgCodeScanningAlerts(t *testing.T) {
350+
// Verify tool definition
351+
mockClient := github.NewClient(nil)
352+
tool, _ := ListOrgCodeScanningAlerts(stubGetClientFn(mockClient), translations.NullTranslationHelper)
353+
354+
assert.Equal(t, "list_org_code_scanning_alerts", tool.Name)
355+
assert.NotEmpty(t, tool.Description)
356+
assert.Contains(t, tool.InputSchema.Properties, "org")
357+
assert.Contains(t, tool.InputSchema.Properties, "sort")
358+
assert.Contains(t, tool.InputSchema.Properties, "severity")
359+
assert.Contains(t, tool.InputSchema.Properties, "tool_name")
360+
assert.Contains(t, tool.InputSchema.Properties, "state")
361+
assert.ElementsMatch(t, tool.InputSchema.Required, []string{"org"})
362+
363+
// Mock alerts for success
364+
mockAlerts := []*github.Alert{
365+
{
366+
Number: github.Ptr(100),
367+
State: github.Ptr("open"),
368+
Rule: &github.Rule{ID: github.Ptr("org-rule-1"), Description: github.Ptr("desc1")},
369+
HTMLURL: github.Ptr("https://github.com/org/repo/security/code-scanning/100"),
370+
},
371+
{
372+
Number: github.Ptr(101),
373+
State: github.Ptr("dismissed"),
374+
Rule: &github.Rule{ID: github.Ptr("org-rule-2"), Description: github.Ptr("desc2")},
375+
HTMLURL: github.Ptr("https://github.com/org/repo/security/code-scanning/101"),
376+
},
377+
}
378+
379+
tests := []struct {
380+
name string
381+
mockedClient *http.Client
382+
requestArgs map[string]interface{}
383+
expectError bool
384+
expectedAlerts []*github.Alert
385+
expectedErrMsg string
386+
}{
387+
{
388+
name: "successful org alerts listing",
389+
mockedClient: mock.NewMockedHTTPClient(
390+
mock.WithRequestMatchHandler(
391+
mock.GetOrgsCodeScanningAlertsByOrg,
392+
expectQueryParams(t, map[string]string{
393+
"state": "open",
394+
"severity": "high",
395+
"tool_name": "codeql",
396+
"sort": "created",
397+
}).andThen(
398+
mockResponse(t, http.StatusOK, mockAlerts),
399+
),
400+
),
401+
),
402+
requestArgs: map[string]interface{}{
403+
"org": "org",
404+
"state": "open",
405+
"severity": "high",
406+
"tool_name": "codeql",
407+
"sort": "created",
408+
},
409+
expectError: false,
410+
expectedAlerts: mockAlerts,
411+
},
412+
{
413+
name: "org alerts listing fails",
414+
mockedClient: mock.NewMockedHTTPClient(
415+
mock.WithRequestMatchHandler(
416+
mock.GetOrgsCodeScanningAlertsByOrg,
417+
http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
418+
w.WriteHeader(http.StatusForbidden)
419+
_, _ = w.Write([]byte(`{"message":"Forbidden"}`))
420+
}),
421+
),
422+
),
423+
requestArgs: map[string]interface{}{
424+
"org": "org",
425+
},
426+
expectError: true,
427+
expectedErrMsg: "failed to list organization alerts",
428+
},
429+
}
430+
431+
for _, tc := range tests {
432+
t.Run(tc.name, func(t *testing.T) {
433+
client := github.NewClient(tc.mockedClient)
434+
_, handler := ListOrgCodeScanningAlerts(stubGetClientFn(client), translations.NullTranslationHelper)
435+
request := createMCPRequest(tc.requestArgs)
436+
437+
result, err := handler(context.Background(), request)
438+
if tc.expectError {
439+
require.Error(t, err)
440+
assert.Contains(t, err.Error(), tc.expectedErrMsg)
441+
return
442+
}
443+
444+
require.NoError(t, err)
445+
text := getTextResult(t, result)
446+
447+
var got []*github.Alert
448+
require.NoError(t, json.Unmarshal([]byte(text.Text), &got))
449+
assert.Len(t, got, len(tc.expectedAlerts))
450+
451+
for i := range got {
452+
assert.Equal(t, *tc.expectedAlerts[i].Number, *got[i].Number)
453+
assert.Equal(t, *tc.expectedAlerts[i].State, *got[i].State)
454+
assert.Equal(t, *tc.expectedAlerts[i].Rule.ID, *got[i].Rule.ID)
455+
assert.Equal(t, *tc.expectedAlerts[i].HTMLURL, *got[i].HTMLURL)
456+
}
457+
})
458+
}
459+
}

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy