Skip to content

Commit d5e1f48

Browse files
Add updating draft state to update_pull_request tool (#774)
* initial impl of pull request draft state update * appease linter * update README * add nosec * fixed err return type for json marshalling * add gql test
1 parent efef8ae commit d5e1f48

File tree

5 files changed

+348
-29
lines changed

5 files changed

+348
-29
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -736,6 +736,7 @@ The following sets of tools are available (all are on by default):
736736
- **update_pull_request** - Edit pull request
737737
- `base`: New base branch name (string, optional)
738738
- `body`: New description (string, optional)
739+
- `draft`: Mark pull request as draft (true) or ready for review (false) (boolean, optional)
739740
- `maintainer_can_modify`: Allow maintainer edits (boolean, optional)
740741
- `owner`: Repository owner (string, required)
741742
- `pullNumber`: Pull request number to update (number, required)

pkg/github/__toolsnaps__/update_pull_request.snap

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,10 @@
1414
"description": "New description",
1515
"type": "string"
1616
},
17+
"draft": {
18+
"description": "Mark pull request as draft (true) or ready for review (false)",
19+
"type": "boolean"
20+
},
1721
"maintainer_can_modify": {
1822
"description": "Allow maintainer edits",
1923
"type": "boolean"

pkg/github/pullrequests.go

Lines changed: 120 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,7 @@ func CreatePullRequest(getClient GetClientFn, t translations.TranslationHelperFu
203203
}
204204

205205
// UpdatePullRequest creates a tool to update an existing pull request.
206-
func UpdatePullRequest(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, server.ToolHandlerFunc) {
206+
func UpdatePullRequest(getClient GetClientFn, getGQLClient GetGQLClientFn, t translations.TranslationHelperFunc) (mcp.Tool, server.ToolHandlerFunc) {
207207
return mcp.NewTool("update_pull_request",
208208
mcp.WithDescription(t("TOOL_UPDATE_PULL_REQUEST_DESCRIPTION", "Update an existing pull request in a GitHub repository.")),
209209
mcp.WithToolAnnotation(mcp.ToolAnnotation{
@@ -232,6 +232,9 @@ func UpdatePullRequest(getClient GetClientFn, t translations.TranslationHelperFu
232232
mcp.Description("New state"),
233233
mcp.Enum("open", "closed"),
234234
),
235+
mcp.WithBoolean("draft",
236+
mcp.Description("Mark pull request as draft (true) or ready for review (false)"),
237+
),
235238
mcp.WithString("base",
236239
mcp.Description("New base branch name"),
237240
),
@@ -253,74 +256,165 @@ func UpdatePullRequest(getClient GetClientFn, t translations.TranslationHelperFu
253256
return mcp.NewToolResultError(err.Error()), nil
254257
}
255258

256-
// Build the update struct only with provided fields
259+
draftProvided := request.GetArguments()["draft"] != nil
260+
var draftValue bool
261+
if draftProvided {
262+
draftValue, err = OptionalParam[bool](request, "draft")
263+
if err != nil {
264+
return nil, err
265+
}
266+
}
267+
257268
update := &github.PullRequest{}
258-
updateNeeded := false
269+
restUpdateNeeded := false
259270

260271
if title, ok, err := OptionalParamOK[string](request, "title"); err != nil {
261272
return mcp.NewToolResultError(err.Error()), nil
262273
} else if ok {
263274
update.Title = github.Ptr(title)
264-
updateNeeded = true
275+
restUpdateNeeded = true
265276
}
266277

267278
if body, ok, err := OptionalParamOK[string](request, "body"); err != nil {
268279
return mcp.NewToolResultError(err.Error()), nil
269280
} else if ok {
270281
update.Body = github.Ptr(body)
271-
updateNeeded = true
282+
restUpdateNeeded = true
272283
}
273284

274285
if state, ok, err := OptionalParamOK[string](request, "state"); err != nil {
275286
return mcp.NewToolResultError(err.Error()), nil
276287
} else if ok {
277288
update.State = github.Ptr(state)
278-
updateNeeded = true
289+
restUpdateNeeded = true
279290
}
280291

281292
if base, ok, err := OptionalParamOK[string](request, "base"); err != nil {
282293
return mcp.NewToolResultError(err.Error()), nil
283294
} else if ok {
284295
update.Base = &github.PullRequestBranch{Ref: github.Ptr(base)}
285-
updateNeeded = true
296+
restUpdateNeeded = true
286297
}
287298

288299
if maintainerCanModify, ok, err := OptionalParamOK[bool](request, "maintainer_can_modify"); err != nil {
289300
return mcp.NewToolResultError(err.Error()), nil
290301
} else if ok {
291302
update.MaintainerCanModify = github.Ptr(maintainerCanModify)
292-
updateNeeded = true
303+
restUpdateNeeded = true
293304
}
294305

295-
if !updateNeeded {
306+
if !restUpdateNeeded && !draftProvided {
296307
return mcp.NewToolResultError("No update parameters provided."), nil
297308
}
298309

310+
if restUpdateNeeded {
311+
client, err := getClient(ctx)
312+
if err != nil {
313+
return nil, fmt.Errorf("failed to get GitHub client: %w", err)
314+
}
315+
316+
_, resp, err := client.PullRequests.Edit(ctx, owner, repo, pullNumber, update)
317+
if err != nil {
318+
return ghErrors.NewGitHubAPIErrorResponse(ctx,
319+
"failed to update pull request",
320+
resp,
321+
err,
322+
), nil
323+
}
324+
defer func() { _ = resp.Body.Close() }()
325+
326+
if resp.StatusCode != http.StatusOK {
327+
body, err := io.ReadAll(resp.Body)
328+
if err != nil {
329+
return nil, fmt.Errorf("failed to read response body: %w", err)
330+
}
331+
return mcp.NewToolResultError(fmt.Sprintf("failed to update pull request: %s", string(body))), nil
332+
}
333+
}
334+
335+
if draftProvided {
336+
gqlClient, err := getGQLClient(ctx)
337+
if err != nil {
338+
return nil, fmt.Errorf("failed to get GitHub GraphQL client: %w", err)
339+
}
340+
341+
var prQuery struct {
342+
Repository struct {
343+
PullRequest struct {
344+
ID githubv4.ID
345+
IsDraft githubv4.Boolean
346+
} `graphql:"pullRequest(number: $prNum)"`
347+
} `graphql:"repository(owner: $owner, name: $repo)"`
348+
}
349+
350+
err = gqlClient.Query(ctx, &prQuery, map[string]interface{}{
351+
"owner": githubv4.String(owner),
352+
"repo": githubv4.String(repo),
353+
"prNum": githubv4.Int(pullNumber), // #nosec G115 - pull request numbers are always small positive integers
354+
})
355+
if err != nil {
356+
return ghErrors.NewGitHubGraphQLErrorResponse(ctx, "Failed to find pull request", err), nil
357+
}
358+
359+
currentIsDraft := bool(prQuery.Repository.PullRequest.IsDraft)
360+
361+
if currentIsDraft != draftValue {
362+
if draftValue {
363+
// Convert to draft
364+
var mutation struct {
365+
ConvertPullRequestToDraft struct {
366+
PullRequest struct {
367+
ID githubv4.ID
368+
IsDraft githubv4.Boolean
369+
}
370+
} `graphql:"convertPullRequestToDraft(input: $input)"`
371+
}
372+
373+
err = gqlClient.Mutate(ctx, &mutation, githubv4.ConvertPullRequestToDraftInput{
374+
PullRequestID: prQuery.Repository.PullRequest.ID,
375+
}, nil)
376+
if err != nil {
377+
return ghErrors.NewGitHubGraphQLErrorResponse(ctx, "Failed to convert pull request to draft", err), nil
378+
}
379+
} else {
380+
// Mark as ready for review
381+
var mutation struct {
382+
MarkPullRequestReadyForReview struct {
383+
PullRequest struct {
384+
ID githubv4.ID
385+
IsDraft githubv4.Boolean
386+
}
387+
} `graphql:"markPullRequestReadyForReview(input: $input)"`
388+
}
389+
390+
err = gqlClient.Mutate(ctx, &mutation, githubv4.MarkPullRequestReadyForReviewInput{
391+
PullRequestID: prQuery.Repository.PullRequest.ID,
392+
}, nil)
393+
if err != nil {
394+
return ghErrors.NewGitHubGraphQLErrorResponse(ctx, "Failed to mark pull request ready for review", err), nil
395+
}
396+
}
397+
}
398+
}
399+
299400
client, err := getClient(ctx)
300401
if err != nil {
301-
return nil, fmt.Errorf("failed to get GitHub client: %w", err)
402+
return nil, err
302403
}
303-
pr, resp, err := client.PullRequests.Edit(ctx, owner, repo, pullNumber, update)
404+
405+
finalPR, resp, err := client.PullRequests.Get(ctx, owner, repo, pullNumber)
304406
if err != nil {
305-
return ghErrors.NewGitHubAPIErrorResponse(ctx,
306-
"failed to update pull request",
307-
resp,
308-
err,
309-
), nil
407+
return ghErrors.NewGitHubAPIErrorResponse(ctx, "Failed to get pull request", resp, err), nil
310408
}
311-
defer func() { _ = resp.Body.Close() }()
312-
313-
if resp.StatusCode != http.StatusOK {
314-
body, err := io.ReadAll(resp.Body)
315-
if err != nil {
316-
return nil, fmt.Errorf("failed to read response body: %w", err)
409+
defer func() {
410+
if resp != nil && resp.Body != nil {
411+
_ = resp.Body.Close()
317412
}
318-
return mcp.NewToolResultError(fmt.Sprintf("failed to update pull request: %s", string(body))), nil
319-
}
413+
}()
320414

321-
r, err := json.Marshal(pr)
415+
r, err := json.Marshal(finalPR)
322416
if err != nil {
323-
return nil, fmt.Errorf("failed to marshal response: %w", err)
417+
return mcp.NewToolResultError(fmt.Sprintf("Failed to marshal response: %v", err)), nil
324418
}
325419

326420
return mcp.NewToolResultText(string(r)), nil

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy