Skip to content

Commit d1a344a

Browse files
committed
add support for create_pull_request
1 parent 51ccba0 commit d1a344a

File tree

4 files changed

+273
-0
lines changed

4 files changed

+273
-0
lines changed

README.md

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,17 @@ The flag `--gh-host` and the environment variable `GH_HOST` can be used to set t
255255
- `commit_id`: SHA of commit to review (string, optional)
256256
- `comments`: Line-specific comments array of objects, each object with path (string), position (number), and body (string) (array, optional)
257257

258+
- **create_pull_request** - Create a new pull request
259+
260+
- `owner`: Repository owner (string, required)
261+
- `repo`: Repository name (string, required)
262+
- `title`: PR title (string, required)
263+
- `body`: PR description (string, optional)
264+
- `head`: Branch containing changes (string, required)
265+
- `base`: Branch to merge into (string, required)
266+
- `draft`: Create as draft PR (boolean, optional)
267+
- `maintainer_can_modify`: Allow maintainer edits (boolean, optional)
268+
258269
### Repositories
259270

260271
- **create_or_update_file** - Create or update a single file in a repository

pkg/github/pullrequests.go

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -712,3 +712,115 @@ func createPullRequestReview(client *github.Client, t translations.TranslationHe
712712
return mcp.NewToolResultText(string(r)), nil
713713
}
714714
}
715+
716+
// createPullRequest creates a tool to create a new pull request.
717+
func createPullRequest(client *github.Client, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) {
718+
return mcp.NewTool("create_pull_request",
719+
mcp.WithDescription(t("TOOL_CREATE_PULL_REQUEST_DESCRIPTION", "Create a new pull request in a GitHub repository")),
720+
mcp.WithString("owner",
721+
mcp.Required(),
722+
mcp.Description("Repository owner"),
723+
),
724+
mcp.WithString("repo",
725+
mcp.Required(),
726+
mcp.Description("Repository name"),
727+
),
728+
mcp.WithString("title",
729+
mcp.Required(),
730+
mcp.Description("PR title"),
731+
),
732+
mcp.WithString("body",
733+
mcp.Description("PR description"),
734+
),
735+
mcp.WithString("head",
736+
mcp.Required(),
737+
mcp.Description("Branch containing changes"),
738+
),
739+
mcp.WithString("base",
740+
mcp.Required(),
741+
mcp.Description("Branch to merge into"),
742+
),
743+
mcp.WithBoolean("draft",
744+
mcp.Description("Create as draft PR"),
745+
),
746+
mcp.WithBoolean("maintainer_can_modify",
747+
mcp.Description("Allow maintainer edits"),
748+
),
749+
),
750+
func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
751+
owner, err := requiredParam[string](request, "owner")
752+
if err != nil {
753+
return mcp.NewToolResultError(err.Error()), nil
754+
}
755+
repo, err := requiredParam[string](request, "repo")
756+
if err != nil {
757+
return mcp.NewToolResultError(err.Error()), nil
758+
}
759+
title, err := requiredParam[string](request, "title")
760+
if err != nil {
761+
return mcp.NewToolResultError(err.Error()), nil
762+
}
763+
head, err := requiredParam[string](request, "head")
764+
if err != nil {
765+
return mcp.NewToolResultError(err.Error()), nil
766+
}
767+
base, err := requiredParam[string](request, "base")
768+
if err != nil {
769+
return mcp.NewToolResultError(err.Error()), nil
770+
}
771+
772+
body, err := optionalParam[string](request, "body")
773+
if err != nil {
774+
return mcp.NewToolResultError(err.Error()), nil
775+
}
776+
777+
draft, err := optionalParam[bool](request, "draft")
778+
if err != nil {
779+
return mcp.NewToolResultError(err.Error()), nil
780+
}
781+
782+
maintainerCanModify, err := optionalParam[bool](request, "maintainer_can_modify")
783+
if err != nil {
784+
return mcp.NewToolResultError(err.Error()), nil
785+
}
786+
787+
newPR := &github.NewPullRequest{
788+
Title: github.Ptr(title),
789+
Head: github.Ptr(head),
790+
Base: github.Ptr(base),
791+
}
792+
793+
if body != "" {
794+
newPR.Body = github.Ptr(body)
795+
}
796+
797+
if draft {
798+
newPR.Draft = github.Ptr(draft)
799+
}
800+
801+
if maintainerCanModify {
802+
newPR.MaintainerCanModify = github.Ptr(maintainerCanModify)
803+
}
804+
805+
pr, resp, err := client.PullRequests.Create(ctx, owner, repo, newPR)
806+
if err != nil {
807+
return nil, fmt.Errorf("failed to create pull request: %w", err)
808+
}
809+
defer func() { _ = resp.Body.Close() }()
810+
811+
if resp.StatusCode != http.StatusCreated {
812+
body, err := io.ReadAll(resp.Body)
813+
if err != nil {
814+
return nil, fmt.Errorf("failed to read response body: %w", err)
815+
}
816+
return mcp.NewToolResultError(fmt.Sprintf("failed to create pull request: %s", string(body))), nil
817+
}
818+
819+
r, err := json.Marshal(pr)
820+
if err != nil {
821+
return nil, fmt.Errorf("failed to marshal response: %w", err)
822+
}
823+
824+
return mcp.NewToolResultText(string(r)), nil
825+
}
826+
}

pkg/github/pullrequests_test.go

Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1187,3 +1187,152 @@ func Test_CreatePullRequestReview(t *testing.T) {
11871187
})
11881188
}
11891189
}
1190+
1191+
func Test_CreatePullRequest(t *testing.T) {
1192+
// Verify tool definition once
1193+
mockClient := github.NewClient(nil)
1194+
tool, _ := createPullRequest(mockClient, translations.NullTranslationHelper)
1195+
1196+
assert.Equal(t, "create_pull_request", tool.Name)
1197+
assert.NotEmpty(t, tool.Description)
1198+
assert.Contains(t, tool.InputSchema.Properties, "owner")
1199+
assert.Contains(t, tool.InputSchema.Properties, "repo")
1200+
assert.Contains(t, tool.InputSchema.Properties, "title")
1201+
assert.Contains(t, tool.InputSchema.Properties, "body")
1202+
assert.Contains(t, tool.InputSchema.Properties, "head")
1203+
assert.Contains(t, tool.InputSchema.Properties, "base")
1204+
assert.Contains(t, tool.InputSchema.Properties, "draft")
1205+
assert.Contains(t, tool.InputSchema.Properties, "maintainer_can_modify")
1206+
assert.ElementsMatch(t, tool.InputSchema.Required, []string{"owner", "repo", "title", "head", "base"})
1207+
1208+
// Setup mock PR for success case
1209+
mockPR := &github.PullRequest{
1210+
Number: github.Ptr(42),
1211+
Title: github.Ptr("Test PR"),
1212+
State: github.Ptr("open"),
1213+
HTMLURL: github.Ptr("https://github.com/owner/repo/pull/42"),
1214+
Head: &github.PullRequestBranch{
1215+
SHA: github.Ptr("abcd1234"),
1216+
Ref: github.Ptr("feature-branch"),
1217+
},
1218+
Base: &github.PullRequestBranch{
1219+
SHA: github.Ptr("efgh5678"),
1220+
Ref: github.Ptr("main"),
1221+
},
1222+
Body: github.Ptr("This is a test PR"),
1223+
Draft: github.Ptr(false),
1224+
MaintainerCanModify: github.Ptr(true),
1225+
User: &github.User{
1226+
Login: github.Ptr("testuser"),
1227+
},
1228+
}
1229+
1230+
tests := []struct {
1231+
name string
1232+
mockedClient *http.Client
1233+
requestArgs map[string]interface{}
1234+
expectError bool
1235+
expectedPR *github.PullRequest
1236+
expectedErrMsg string
1237+
}{
1238+
{
1239+
name: "successful PR creation",
1240+
mockedClient: mock.NewMockedHTTPClient(
1241+
mock.WithRequestMatchHandler(
1242+
mock.PostReposPullsByOwnerByRepo,
1243+
mockResponse(t, http.StatusCreated, mockPR),
1244+
),
1245+
),
1246+
1247+
requestArgs: map[string]interface{}{
1248+
"owner": "owner",
1249+
"repo": "repo",
1250+
"title": "Test PR",
1251+
"body": "This is a test PR",
1252+
"head": "feature-branch",
1253+
"base": "main",
1254+
"draft": false,
1255+
"maintainer_can_modify": true,
1256+
},
1257+
expectError: false,
1258+
expectedPR: mockPR,
1259+
},
1260+
{
1261+
name: "missing required parameter",
1262+
mockedClient: mock.NewMockedHTTPClient(),
1263+
requestArgs: map[string]interface{}{
1264+
"owner": "owner",
1265+
"repo": "repo",
1266+
// missing title, head, base
1267+
},
1268+
expectError: true,
1269+
expectedErrMsg: "missing required parameter: title",
1270+
},
1271+
{
1272+
name: "PR creation fails",
1273+
mockedClient: mock.NewMockedHTTPClient(
1274+
mock.WithRequestMatchHandler(
1275+
mock.PostReposPullsByOwnerByRepo,
1276+
http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
1277+
w.WriteHeader(http.StatusUnprocessableEntity)
1278+
_, _ = w.Write([]byte(`{"message":"Validation failed","errors":[{"resource":"PullRequest","code":"invalid"}]}`))
1279+
}),
1280+
),
1281+
),
1282+
requestArgs: map[string]interface{}{
1283+
"owner": "owner",
1284+
"repo": "repo",
1285+
"title": "Test PR",
1286+
"head": "feature-branch",
1287+
"base": "main",
1288+
},
1289+
expectError: true,
1290+
expectedErrMsg: "failed to create pull request",
1291+
},
1292+
}
1293+
1294+
for _, tc := range tests {
1295+
t.Run(tc.name, func(t *testing.T) {
1296+
// Setup client with mock
1297+
client := github.NewClient(tc.mockedClient)
1298+
_, handler := createPullRequest(client, translations.NullTranslationHelper)
1299+
1300+
// Create call request
1301+
request := createMCPRequest(tc.requestArgs)
1302+
1303+
// Call handler
1304+
result, err := handler(context.Background(), request)
1305+
1306+
// Verify results
1307+
if tc.expectError {
1308+
if err != nil {
1309+
assert.Contains(t, err.Error(), tc.expectedErrMsg)
1310+
return
1311+
}
1312+
1313+
// If no error returned but in the result
1314+
textContent := getTextResult(t, result)
1315+
assert.Contains(t, textContent.Text, tc.expectedErrMsg)
1316+
return
1317+
}
1318+
1319+
require.NoError(t, err)
1320+
1321+
// Parse the result and get the text content if no error
1322+
textContent := getTextResult(t, result)
1323+
1324+
// Unmarshal and verify the result
1325+
var returnedPR github.PullRequest
1326+
err = json.Unmarshal([]byte(textContent.Text), &returnedPR)
1327+
require.NoError(t, err)
1328+
assert.Equal(t, *tc.expectedPR.Number, *returnedPR.Number)
1329+
assert.Equal(t, *tc.expectedPR.Title, *returnedPR.Title)
1330+
assert.Equal(t, *tc.expectedPR.State, *returnedPR.State)
1331+
assert.Equal(t, *tc.expectedPR.HTMLURL, *returnedPR.HTMLURL)
1332+
assert.Equal(t, *tc.expectedPR.Head.SHA, *returnedPR.Head.SHA)
1333+
assert.Equal(t, *tc.expectedPR.Base.Ref, *returnedPR.Base.Ref)
1334+
assert.Equal(t, *tc.expectedPR.Body, *returnedPR.Body)
1335+
assert.Equal(t, *tc.expectedPR.User.Login, *returnedPR.User.Login)
1336+
})
1337+
}
1338+
}

pkg/github/server.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ func NewServer(client *github.Client, readOnly bool, t translations.TranslationH
5555
s.AddTool(mergePullRequest(client, t))
5656
s.AddTool(updatePullRequestBranch(client, t))
5757
s.AddTool(createPullRequestReview(client, t))
58+
s.AddTool(createPullRequest(client, t))
5859
}
5960

6061
// Add GitHub tools - Repositories

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy