@@ -203,7 +203,7 @@ func CreatePullRequest(getClient GetClientFn, t translations.TranslationHelperFu
203
203
}
204
204
205
205
// UpdatePullRequest creates a tool to update an existing pull request.
206
- func UpdatePullRequest (getClient GetClientFn , t translations.TranslationHelperFunc ) (mcp.Tool , server.ToolHandlerFunc ) {
206
+ func UpdatePullRequest (getClient GetClientFn , getGQLClient GetGQLClientFn , t translations.TranslationHelperFunc ) (mcp.Tool , server.ToolHandlerFunc ) {
207
207
return mcp .NewTool ("update_pull_request" ,
208
208
mcp .WithDescription (t ("TOOL_UPDATE_PULL_REQUEST_DESCRIPTION" , "Update an existing pull request in a GitHub repository." )),
209
209
mcp .WithToolAnnotation (mcp.ToolAnnotation {
@@ -232,6 +232,9 @@ func UpdatePullRequest(getClient GetClientFn, t translations.TranslationHelperFu
232
232
mcp .Description ("New state" ),
233
233
mcp .Enum ("open" , "closed" ),
234
234
),
235
+ mcp .WithBoolean ("draft" ,
236
+ mcp .Description ("Mark pull request as draft (true) or ready for review (false)" ),
237
+ ),
235
238
mcp .WithString ("base" ,
236
239
mcp .Description ("New base branch name" ),
237
240
),
@@ -253,74 +256,165 @@ func UpdatePullRequest(getClient GetClientFn, t translations.TranslationHelperFu
253
256
return mcp .NewToolResultError (err .Error ()), nil
254
257
}
255
258
256
- // Build the update struct only with provided fields
259
+ draftProvided := request .GetArguments ()["draft" ] != nil
260
+ var draftValue bool
261
+ if draftProvided {
262
+ draftValue , err = OptionalParam [bool ](request , "draft" )
263
+ if err != nil {
264
+ return nil , err
265
+ }
266
+ }
267
+
257
268
update := & github.PullRequest {}
258
- updateNeeded := false
269
+ restUpdateNeeded := false
259
270
260
271
if title , ok , err := OptionalParamOK [string ](request , "title" ); err != nil {
261
272
return mcp .NewToolResultError (err .Error ()), nil
262
273
} else if ok {
263
274
update .Title = github .Ptr (title )
264
- updateNeeded = true
275
+ restUpdateNeeded = true
265
276
}
266
277
267
278
if body , ok , err := OptionalParamOK [string ](request , "body" ); err != nil {
268
279
return mcp .NewToolResultError (err .Error ()), nil
269
280
} else if ok {
270
281
update .Body = github .Ptr (body )
271
- updateNeeded = true
282
+ restUpdateNeeded = true
272
283
}
273
284
274
285
if state , ok , err := OptionalParamOK [string ](request , "state" ); err != nil {
275
286
return mcp .NewToolResultError (err .Error ()), nil
276
287
} else if ok {
277
288
update .State = github .Ptr (state )
278
- updateNeeded = true
289
+ restUpdateNeeded = true
279
290
}
280
291
281
292
if base , ok , err := OptionalParamOK [string ](request , "base" ); err != nil {
282
293
return mcp .NewToolResultError (err .Error ()), nil
283
294
} else if ok {
284
295
update .Base = & github.PullRequestBranch {Ref : github .Ptr (base )}
285
- updateNeeded = true
296
+ restUpdateNeeded = true
286
297
}
287
298
288
299
if maintainerCanModify , ok , err := OptionalParamOK [bool ](request , "maintainer_can_modify" ); err != nil {
289
300
return mcp .NewToolResultError (err .Error ()), nil
290
301
} else if ok {
291
302
update .MaintainerCanModify = github .Ptr (maintainerCanModify )
292
- updateNeeded = true
303
+ restUpdateNeeded = true
293
304
}
294
305
295
- if ! updateNeeded {
306
+ if ! restUpdateNeeded && ! draftProvided {
296
307
return mcp .NewToolResultError ("No update parameters provided." ), nil
297
308
}
298
309
310
+ if restUpdateNeeded {
311
+ client , err := getClient (ctx )
312
+ if err != nil {
313
+ return nil , fmt .Errorf ("failed to get GitHub client: %w" , err )
314
+ }
315
+
316
+ _ , resp , err := client .PullRequests .Edit (ctx , owner , repo , pullNumber , update )
317
+ if err != nil {
318
+ return ghErrors .NewGitHubAPIErrorResponse (ctx ,
319
+ "failed to update pull request" ,
320
+ resp ,
321
+ err ,
322
+ ), nil
323
+ }
324
+ defer func () { _ = resp .Body .Close () }()
325
+
326
+ if resp .StatusCode != http .StatusOK {
327
+ body , err := io .ReadAll (resp .Body )
328
+ if err != nil {
329
+ return nil , fmt .Errorf ("failed to read response body: %w" , err )
330
+ }
331
+ return mcp .NewToolResultError (fmt .Sprintf ("failed to update pull request: %s" , string (body ))), nil
332
+ }
333
+ }
334
+
335
+ if draftProvided {
336
+ gqlClient , err := getGQLClient (ctx )
337
+ if err != nil {
338
+ return nil , fmt .Errorf ("failed to get GitHub GraphQL client: %w" , err )
339
+ }
340
+
341
+ var prQuery struct {
342
+ Repository struct {
343
+ PullRequest struct {
344
+ ID githubv4.ID
345
+ IsDraft githubv4.Boolean
346
+ } `graphql:"pullRequest(number: $prNum)"`
347
+ } `graphql:"repository(owner: $owner, name: $repo)"`
348
+ }
349
+
350
+ err = gqlClient .Query (ctx , & prQuery , map [string ]interface {}{
351
+ "owner" : githubv4 .String (owner ),
352
+ "repo" : githubv4 .String (repo ),
353
+ "prNum" : githubv4 .Int (pullNumber ), // #nosec G115 - pull request numbers are always small positive integers
354
+ })
355
+ if err != nil {
356
+ return ghErrors .NewGitHubGraphQLErrorResponse (ctx , "Failed to find pull request" , err ), nil
357
+ }
358
+
359
+ currentIsDraft := bool (prQuery .Repository .PullRequest .IsDraft )
360
+
361
+ if currentIsDraft != draftValue {
362
+ if draftValue {
363
+ // Convert to draft
364
+ var mutation struct {
365
+ ConvertPullRequestToDraft struct {
366
+ PullRequest struct {
367
+ ID githubv4.ID
368
+ IsDraft githubv4.Boolean
369
+ }
370
+ } `graphql:"convertPullRequestToDraft(input: $input)"`
371
+ }
372
+
373
+ err = gqlClient .Mutate (ctx , & mutation , githubv4.ConvertPullRequestToDraftInput {
374
+ PullRequestID : prQuery .Repository .PullRequest .ID ,
375
+ }, nil )
376
+ if err != nil {
377
+ return ghErrors .NewGitHubGraphQLErrorResponse (ctx , "Failed to convert pull request to draft" , err ), nil
378
+ }
379
+ } else {
380
+ // Mark as ready for review
381
+ var mutation struct {
382
+ MarkPullRequestReadyForReview struct {
383
+ PullRequest struct {
384
+ ID githubv4.ID
385
+ IsDraft githubv4.Boolean
386
+ }
387
+ } `graphql:"markPullRequestReadyForReview(input: $input)"`
388
+ }
389
+
390
+ err = gqlClient .Mutate (ctx , & mutation , githubv4.MarkPullRequestReadyForReviewInput {
391
+ PullRequestID : prQuery .Repository .PullRequest .ID ,
392
+ }, nil )
393
+ if err != nil {
394
+ return ghErrors .NewGitHubGraphQLErrorResponse (ctx , "Failed to mark pull request ready for review" , err ), nil
395
+ }
396
+ }
397
+ }
398
+ }
399
+
299
400
client , err := getClient (ctx )
300
401
if err != nil {
301
- return nil , fmt . Errorf ( "failed to get GitHub client: %w" , err )
402
+ return nil , err
302
403
}
303
- pr , resp , err := client .PullRequests .Edit (ctx , owner , repo , pullNumber , update )
404
+
405
+ finalPR , resp , err := client .PullRequests .Get (ctx , owner , repo , pullNumber )
304
406
if err != nil {
305
- return ghErrors .NewGitHubAPIErrorResponse (ctx ,
306
- "failed to update pull request" ,
307
- resp ,
308
- err ,
309
- ), nil
407
+ return ghErrors .NewGitHubAPIErrorResponse (ctx , "Failed to get pull request" , resp , err ), nil
310
408
}
311
- defer func () { _ = resp .Body .Close () }()
312
-
313
- if resp .StatusCode != http .StatusOK {
314
- body , err := io .ReadAll (resp .Body )
315
- if err != nil {
316
- return nil , fmt .Errorf ("failed to read response body: %w" , err )
409
+ defer func () {
410
+ if resp != nil && resp .Body != nil {
411
+ _ = resp .Body .Close ()
317
412
}
318
- return mcp .NewToolResultError (fmt .Sprintf ("failed to update pull request: %s" , string (body ))), nil
319
- }
413
+ }()
320
414
321
- r , err := json .Marshal (pr )
415
+ r , err := json .Marshal (finalPR )
322
416
if err != nil {
323
- return nil , fmt .Errorf ( "failed to marshal response: %w " , err )
417
+ return mcp . NewToolResultError ( fmt .Sprintf ( "Failed to marshal response: %v " , err )), nil
324
418
}
325
419
326
420
return mcp .NewToolResultText (string (r )), nil
0 commit comments