Skip to content

Commit 65a4558

Browse files
committed
feat: add update_pull_request tool
1 parent 270bbf7 commit 65a4558

File tree

3 files changed

+316
-0
lines changed

3 files changed

+316
-0
lines changed

pkg/github/pullrequests.go

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,119 @@ func getPullRequest(client *github.Client, t translations.TranslationHelperFunc)
6767
}
6868
}
6969

70+
// updatePullRequest creates a tool to update an existing pull request.
71+
func updatePullRequest(client *github.Client, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) {
72+
return mcp.NewTool("update_pull_request",
73+
mcp.WithDescription(t("TOOL_UPDATE_PULL_REQUEST_DESCRIPTION", "Update an existing pull request in a GitHub repository")),
74+
mcp.WithString("owner",
75+
mcp.Required(),
76+
mcp.Description("Repository owner"),
77+
),
78+
mcp.WithString("repo",
79+
mcp.Required(),
80+
mcp.Description("Repository name"),
81+
),
82+
mcp.WithNumber("pullNumber",
83+
mcp.Required(),
84+
mcp.Description("Pull request number to update"),
85+
),
86+
mcp.WithString("title",
87+
mcp.Description("New title"),
88+
),
89+
mcp.WithString("body",
90+
mcp.Description("New description"),
91+
),
92+
mcp.WithString("state",
93+
mcp.Description("New state ('open' or 'closed')"),
94+
mcp.Enum("open", "closed"),
95+
),
96+
mcp.WithString("base",
97+
mcp.Description("New base branch name"),
98+
),
99+
mcp.WithBoolean("maintainer_can_modify",
100+
mcp.Description("Allow maintainer edits"),
101+
),
102+
),
103+
func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
104+
owner, err := requiredParam[string](request, "owner")
105+
if err != nil {
106+
return mcp.NewToolResultError(err.Error()), nil
107+
}
108+
repo, err := requiredParam[string](request, "repo")
109+
if err != nil {
110+
return mcp.NewToolResultError(err.Error()), nil
111+
}
112+
pullNumber, err := requiredInt(request, "pullNumber")
113+
if err != nil {
114+
return mcp.NewToolResultError(err.Error()), nil
115+
}
116+
117+
// Build the update struct only with provided fields
118+
update := &github.PullRequest{}
119+
updateNeeded := false
120+
121+
if title, ok, err := optionalParamOk[string](request, "title"); err != nil {
122+
return mcp.NewToolResultError(err.Error()), nil
123+
} else if ok {
124+
update.Title = github.Ptr(title)
125+
updateNeeded = true
126+
}
127+
128+
if body, ok, err := optionalParamOk[string](request, "body"); err != nil {
129+
return mcp.NewToolResultError(err.Error()), nil
130+
} else if ok {
131+
update.Body = github.Ptr(body)
132+
updateNeeded = true
133+
}
134+
135+
if state, ok, err := optionalParamOk[string](request, "state"); err != nil {
136+
return mcp.NewToolResultError(err.Error()), nil
137+
} else if ok {
138+
update.State = github.Ptr(state)
139+
updateNeeded = true
140+
}
141+
142+
if base, ok, err := optionalParamOk[string](request, "base"); err != nil {
143+
return mcp.NewToolResultError(err.Error()), nil
144+
} else if ok {
145+
update.Base = &github.PullRequestBranch{Ref: github.Ptr(base)}
146+
updateNeeded = true
147+
}
148+
149+
if maintainerCanModify, ok, err := optionalParamOk[bool](request, "maintainer_can_modify"); err != nil {
150+
return mcp.NewToolResultError(err.Error()), nil
151+
} else if ok {
152+
update.MaintainerCanModify = github.Ptr(maintainerCanModify)
153+
updateNeeded = true
154+
}
155+
156+
if !updateNeeded {
157+
return mcp.NewToolResultError("No update parameters provided."), nil
158+
}
159+
160+
pr, resp, err := client.PullRequests.Edit(ctx, owner, repo, pullNumber, update)
161+
if err != nil {
162+
return nil, fmt.Errorf("failed to update pull request: %w", err)
163+
}
164+
defer func() { _ = resp.Body.Close() }()
165+
166+
if resp.StatusCode != http.StatusOK {
167+
body, err := io.ReadAll(resp.Body)
168+
if err != nil {
169+
return nil, fmt.Errorf("failed to read response body: %w", err)
170+
}
171+
return mcp.NewToolResultError(fmt.Sprintf("failed to update pull request: %s", string(body))), nil
172+
}
173+
174+
r, err := json.Marshal(pr)
175+
if err != nil {
176+
return nil, fmt.Errorf("failed to marshal response: %w", err)
177+
}
178+
179+
return mcp.NewToolResultText(string(r)), nil
180+
}
181+
}
182+
70183
// listPullRequests creates a tool to list and filter repository pull requests.
71184
func listPullRequests(client *github.Client, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) {
72185
return mcp.NewTool("list_pull_requests",

pkg/github/pullrequests_test.go

Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,188 @@ func Test_GetPullRequest(t *testing.T) {
126126
}
127127
}
128128

129+
func Test_UpdatePullRequest(t *testing.T) {
130+
// Verify tool definition once
131+
mockClient := github.NewClient(nil)
132+
tool, _ := updatePullRequest(mockClient, translations.NullTranslationHelper)
133+
134+
assert.Equal(t, "update_pull_request", tool.Name)
135+
assert.NotEmpty(t, tool.Description)
136+
assert.Contains(t, tool.InputSchema.Properties, "owner")
137+
assert.Contains(t, tool.InputSchema.Properties, "repo")
138+
assert.Contains(t, tool.InputSchema.Properties, "pullNumber")
139+
assert.Contains(t, tool.InputSchema.Properties, "title")
140+
assert.Contains(t, tool.InputSchema.Properties, "body")
141+
assert.Contains(t, tool.InputSchema.Properties, "state")
142+
assert.Contains(t, tool.InputSchema.Properties, "base")
143+
assert.Contains(t, tool.InputSchema.Properties, "maintainer_can_modify")
144+
assert.ElementsMatch(t, tool.InputSchema.Required, []string{"owner", "repo", "pullNumber"})
145+
146+
// Setup mock PR for success case
147+
mockUpdatedPR := &github.PullRequest{
148+
Number: github.Ptr(42),
149+
Title: github.Ptr("Updated Test PR Title"),
150+
State: github.Ptr("open"),
151+
HTMLURL: github.Ptr("https://github.com/owner/repo/pull/42"),
152+
Body: github.Ptr("Updated test PR body."),
153+
MaintainerCanModify: github.Ptr(false),
154+
Base: &github.PullRequestBranch{
155+
Ref: github.Ptr("develop"),
156+
},
157+
}
158+
159+
mockClosedPR := &github.PullRequest{
160+
Number: github.Ptr(42),
161+
Title: github.Ptr("Test PR"),
162+
State: github.Ptr("closed"), // State updated
163+
}
164+
165+
tests := []struct {
166+
name string
167+
mockedClient *http.Client
168+
requestArgs map[string]interface{}
169+
expectError bool
170+
expectedPR *github.PullRequest
171+
expectedErrMsg string
172+
}{
173+
{
174+
name: "successful PR update (title, body, base, maintainer_can_modify)",
175+
mockedClient: mock.NewMockedHTTPClient(
176+
mock.WithRequestMatchHandler(
177+
mock.PatchReposPullsByOwnerByRepoByPullNumber,
178+
// Expect the flat string based on previous test failure output and API docs
179+
expectRequestBody(t, map[string]interface{}{
180+
"title": "Updated Test PR Title",
181+
"body": "Updated test PR body.",
182+
"base": "develop",
183+
"maintainer_can_modify": false,
184+
}).andThen(
185+
mockResponse(t, http.StatusOK, mockUpdatedPR),
186+
),
187+
),
188+
),
189+
requestArgs: map[string]interface{}{
190+
"owner": "owner",
191+
"repo": "repo",
192+
"pullNumber": float64(42),
193+
"title": "Updated Test PR Title",
194+
"body": "Updated test PR body.",
195+
"base": "develop",
196+
"maintainer_can_modify": false,
197+
},
198+
expectError: false,
199+
expectedPR: mockUpdatedPR,
200+
},
201+
{
202+
name: "successful PR update (state)",
203+
mockedClient: mock.NewMockedHTTPClient(
204+
mock.WithRequestMatchHandler(
205+
mock.PatchReposPullsByOwnerByRepoByPullNumber,
206+
expectRequestBody(t, map[string]interface{}{
207+
"state": "closed",
208+
}).andThen(
209+
mockResponse(t, http.StatusOK, mockClosedPR),
210+
),
211+
),
212+
),
213+
requestArgs: map[string]interface{}{
214+
"owner": "owner",
215+
"repo": "repo",
216+
"pullNumber": float64(42),
217+
"state": "closed",
218+
},
219+
expectError: false,
220+
expectedPR: mockClosedPR,
221+
},
222+
{
223+
name: "no update parameters provided",
224+
mockedClient: mock.NewMockedHTTPClient(), // No API call expected
225+
requestArgs: map[string]interface{}{
226+
"owner": "owner",
227+
"repo": "repo",
228+
"pullNumber": float64(42),
229+
// No update fields
230+
},
231+
expectError: false, // Error is returned in the result, not as Go error
232+
expectedErrMsg: "No update parameters provided",
233+
},
234+
{
235+
name: "PR update fails (API error)",
236+
mockedClient: mock.NewMockedHTTPClient(
237+
mock.WithRequestMatchHandler(
238+
mock.PatchReposPullsByOwnerByRepoByPullNumber,
239+
http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
240+
w.WriteHeader(http.StatusUnprocessableEntity)
241+
_, _ = w.Write([]byte(`{"message": "Validation Failed"}`))
242+
}),
243+
),
244+
),
245+
requestArgs: map[string]interface{}{
246+
"owner": "owner",
247+
"repo": "repo",
248+
"pullNumber": float64(42),
249+
"title": "Invalid Title Causing Error",
250+
},
251+
expectError: true,
252+
expectedErrMsg: "failed to update pull request",
253+
},
254+
}
255+
256+
for _, tc := range tests {
257+
t.Run(tc.name, func(t *testing.T) {
258+
// Setup client with mock
259+
client := github.NewClient(tc.mockedClient)
260+
_, handler := updatePullRequest(client, translations.NullTranslationHelper)
261+
262+
// Create call request
263+
request := createMCPRequest(tc.requestArgs)
264+
265+
// Call handler
266+
result, err := handler(context.Background(), request)
267+
268+
// Verify results
269+
if tc.expectError {
270+
require.Error(t, err)
271+
assert.Contains(t, err.Error(), tc.expectedErrMsg)
272+
return
273+
}
274+
275+
require.NoError(t, err)
276+
277+
// Parse the result and get the text content
278+
textContent := getTextResult(t, result)
279+
280+
// Check for expected error message within the result text
281+
if tc.expectedErrMsg != "" {
282+
assert.Contains(t, textContent.Text, tc.expectedErrMsg)
283+
return
284+
}
285+
286+
// Unmarshal and verify the successful result
287+
var returnedPR github.PullRequest
288+
err = json.Unmarshal([]byte(textContent.Text), &returnedPR)
289+
require.NoError(t, err)
290+
assert.Equal(t, *tc.expectedPR.Number, *returnedPR.Number)
291+
if tc.expectedPR.Title != nil {
292+
assert.Equal(t, *tc.expectedPR.Title, *returnedPR.Title)
293+
}
294+
if tc.expectedPR.Body != nil {
295+
assert.Equal(t, *tc.expectedPR.Body, *returnedPR.Body)
296+
}
297+
if tc.expectedPR.State != nil {
298+
assert.Equal(t, *tc.expectedPR.State, *returnedPR.State)
299+
}
300+
if tc.expectedPR.Base != nil && tc.expectedPR.Base.Ref != nil {
301+
assert.NotNil(t, returnedPR.Base)
302+
assert.Equal(t, *tc.expectedPR.Base.Ref, *returnedPR.Base.Ref)
303+
}
304+
if tc.expectedPR.MaintainerCanModify != nil {
305+
assert.Equal(t, *tc.expectedPR.MaintainerCanModify, *returnedPR.MaintainerCanModify)
306+
}
307+
})
308+
}
309+
}
310+
129311
func Test_ListPullRequests(t *testing.T) {
130312
// Verify tool definition once
131313
mockClient := github.NewClient(nil)

pkg/github/server.go

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ func NewServer(client *github.Client, readOnly bool, t translations.TranslationH
5353
s.AddTool(updatePullRequestBranch(client, t))
5454
s.AddTool(createPullRequestReview(client, t))
5555
s.AddTool(createPullRequest(client, t))
56+
s.AddTool(updatePullRequest(client, t))
5657
}
5758

5859
// Add GitHub tools - Repositories
@@ -112,6 +113,26 @@ func getMe(client *github.Client, t translations.TranslationHelperFunc) (tool mc
112113
}
113114
}
114115

116+
// optionalParamOk is a helper function that can be used to fetch a requested parameter from the request.
117+
// It returns the value, a boolean indicating if the parameter was present, and an error if the type is wrong.
118+
func optionalParamOk[T any](r mcp.CallToolRequest, p string) (T, bool, error) {
119+
var zero T
120+
121+
// Check if the parameter is present in the request
122+
val, ok := r.Params.Arguments[p]
123+
if !ok {
124+
return zero, false, nil // Not present, return zero value, false, no error
125+
}
126+
127+
// Check if the parameter is of the expected type
128+
typedVal, ok := val.(T)
129+
if !ok {
130+
return zero, true, fmt.Errorf("parameter %s is not of type %T, is %T", p, zero, val) // Present but wrong type
131+
}
132+
133+
return typedVal, true, nil // Present and correct type
134+
}
135+
115136
// isAcceptedError checks if the error is an accepted error.
116137
func isAcceptedError(err error) bool {
117138
var acceptedError *github.AcceptedError

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy