diff --git a/README.md b/README.md index 39aa0157b..b40974e20 100644 --- a/README.md +++ b/README.md @@ -287,6 +287,7 @@ The following sets of tools are available (all are on by default): | `dependabot` | Dependabot tools | | `discussions` | GitHub Discussions related tools | | `experiments` | Experimental features that are not considered stable yet | +| `gists` | GitHub Gist related tools | | `issues` | GitHub Issues related tools | | `notifications` | GitHub Notifications related tools | | `orgs` | GitHub Organization related tools | @@ -462,9 +463,35 @@ The following sets of tools are available (all are on by default): - **list_discussions** - List discussions - `after`: Cursor for pagination. Use the endCursor from the previous page's PageInfo for GraphQL APIs. (string, optional) - `category`: Optional filter by discussion category ID. If provided, only discussions with this category are listed. (string, optional) + - `direction`: Order direction. (string, optional) + - `orderBy`: Order discussions by field. If provided, the 'direction' also needs to be provided. (string, optional) - `owner`: Repository owner (string, required) - `perPage`: Results per page for pagination (min 1, max 100) (number, optional) - - `repo`: Repository name (string, required) + - `repo`: Repository name. If not provided, discussions will be queried at the organisation level. (string, optional) + + + +
+ +Gists + +- **create_gist** - Create Gist + - `content`: Content for simple single-file gist creation (string, required) + - `description`: Description of the gist (string, optional) + - `filename`: Filename for simple single-file gist creation (string, required) + - `public`: Whether the gist is public (boolean, optional) + +- **list_gists** - List Gists + - `page`: Page number for pagination (min 1) (number, optional) + - `perPage`: Results per page for pagination (min 1, max 100) (number, optional) + - `since`: Only gists updated after this time (ISO 8601 timestamp) (string, optional) + - `username`: GitHub username (omit for authenticated user's gists) (string, optional) + +- **update_gist** - Update Gist + - `content`: Content for the file (string, required) + - `description`: Updated description of the gist (string, optional) + - `filename`: Filename to update or create (string, required) + - `gist_id`: ID of the gist to update (string, required)
@@ -609,7 +636,7 @@ The following sets of tools are available (all are on by default): - `order`: Sort order (string, optional) - `page`: Page number for pagination (min 1) (number, optional) - `perPage`: Results per page for pagination (min 1, max 100) (number, optional) - - `query`: Search query using GitHub organizations search syntax scoped to type:org (string, required) + - `query`: Organization search query. Examples: 'microsoft', 'location:california', 'created:>=2025-01-01'. Search is automatically scoped to type:org. (string, required) - `sort`: Sort field by category (string, optional) @@ -734,10 +761,12 @@ The following sets of tools are available (all are on by default): - **update_pull_request** - Edit pull request - `base`: New base branch name (string, optional) - `body`: New description (string, optional) + - `draft`: Mark pull request as draft (true) or ready for review (false) (boolean, optional) - `maintainer_can_modify`: Allow maintainer edits (boolean, optional) - `owner`: Repository owner (string, required) - `pullNumber`: Pull request number to update (number, required) - `repo`: Repository name (string, required) + - `reviewers`: GitHub usernames to request reviews from (string[], optional) - `state`: New state (string, optional) - `title`: New title (string, optional) @@ -833,16 +862,16 @@ The following sets of tools are available (all are on by default): - `repo`: Repository name (string, required) - **search_code** - Search code - - `order`: Sort order (string, optional) + - `order`: Sort order for results (string, optional) - `page`: Page number for pagination (min 1) (number, optional) - `perPage`: Results per page for pagination (min 1, max 100) (number, optional) - - `q`: Search query using GitHub code search syntax (string, required) + - `query`: Search query using GitHub's powerful code search syntax. Examples: 'content:Skill language:Java org:github', 'NOT is:archived language:Python OR language:go', 'repo:github/github-mcp-server'. Supports exact matching, language filters, path filters, and more. (string, required) - `sort`: Sort field ('indexed' only) (string, optional) - **search_repositories** - Search repositories - `page`: Page number for pagination (min 1) (number, optional) - `perPage`: Results per page for pagination (min 1, max 100) (number, optional) - - `query`: Search query (string, required) + - `query`: Repository search query. Examples: 'machine learning in:name stars:>1000 language:python', 'topic:react', 'user:facebook'. Supports advanced search syntax for precise filtering. (string, required) @@ -872,8 +901,8 @@ The following sets of tools are available (all are on by default): - `order`: Sort order (string, optional) - `page`: Page number for pagination (min 1) (number, optional) - `perPage`: Results per page for pagination (min 1, max 100) (number, optional) - - `query`: Search query using GitHub users search syntax scoped to type:user (string, required) - - `sort`: Sort field by category (string, optional) + - `query`: User search query. Examples: 'john smith', 'location:seattle', 'followers:>100'. Search is automatically scoped to type:user. (string, required) + - `sort`: Sort users by number of followers or repositories, or when the person joined GitHub. (string, optional) @@ -1046,4 +1075,4 @@ The exported Go API of this module should currently be considered unstable, and ## License -This project is licensed under the terms of the MIT open source license. Please refer to [MIT](./LICENSE) for the full terms. +This project is licensed under the terms of the MIT open source license. Please refer to [MIT](./LICENSE) for the full terms. \ No newline at end of file diff --git a/docs/remote-server.md b/docs/remote-server.md index 49794c605..5f57f4961 100644 --- a/docs/remote-server.md +++ b/docs/remote-server.md @@ -25,6 +25,7 @@ Below is a table of available toolsets for the remote GitHub MCP Server. Each to | Dependabot | Dependabot tools | https://api.githubcopilot.com/mcp/x/dependabot | [Install](https://insiders.vscode.dev/redirect/mcp/install?name=gh-dependabot&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2Fx%2Fdependabot%22%7D) | [read-only](https://api.githubcopilot.com/mcp/x/dependabot/readonly) | [Install read-only](https://insiders.vscode.dev/redirect/mcp/install?name=gh-dependabot&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2Fx%2Fdependabot%2Freadonly%22%7D) | | Discussions | GitHub Discussions related tools | https://api.githubcopilot.com/mcp/x/discussions | [Install](https://insiders.vscode.dev/redirect/mcp/install?name=gh-discussions&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2Fx%2Fdiscussions%22%7D) | [read-only](https://api.githubcopilot.com/mcp/x/discussions/readonly) | [Install read-only](https://insiders.vscode.dev/redirect/mcp/install?name=gh-discussions&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2Fx%2Fdiscussions%2Freadonly%22%7D) | | Experiments | Experimental features that are not considered stable yet | https://api.githubcopilot.com/mcp/x/experiments | [Install](https://insiders.vscode.dev/redirect/mcp/install?name=gh-experiments&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2Fx%2Fexperiments%22%7D) | [read-only](https://api.githubcopilot.com/mcp/x/experiments/readonly) | [Install read-only](https://insiders.vscode.dev/redirect/mcp/install?name=gh-experiments&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2Fx%2Fexperiments%2Freadonly%22%7D) | +| Gists | GitHub Gist related tools | https://api.githubcopilot.com/mcp/x/gists | [Install](https://insiders.vscode.dev/redirect/mcp/install?name=gh-gists&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2Fx%2Fgists%22%7D) | [read-only](https://api.githubcopilot.com/mcp/x/gists/readonly) | [Install read-only](https://insiders.vscode.dev/redirect/mcp/install?name=gh-gists&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2Fx%2Fgists%2Freadonly%22%7D) | | Issues | GitHub Issues related tools | https://api.githubcopilot.com/mcp/x/issues | [Install](https://insiders.vscode.dev/redirect/mcp/install?name=gh-issues&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2Fx%2Fissues%22%7D) | [read-only](https://api.githubcopilot.com/mcp/x/issues/readonly) | [Install read-only](https://insiders.vscode.dev/redirect/mcp/install?name=gh-issues&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2Fx%2Fissues%2Freadonly%22%7D) | | Notifications | GitHub Notifications related tools | https://api.githubcopilot.com/mcp/x/notifications | [Install](https://insiders.vscode.dev/redirect/mcp/install?name=gh-notifications&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2Fx%2Fnotifications%22%7D) | [read-only](https://api.githubcopilot.com/mcp/x/notifications/readonly) | [Install read-only](https://insiders.vscode.dev/redirect/mcp/install?name=gh-notifications&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2Fx%2Fnotifications%2Freadonly%22%7D) | | Organizations | GitHub Organization related tools | https://api.githubcopilot.com/mcp/x/orgs | [Install](https://insiders.vscode.dev/redirect/mcp/install?name=gh-orgs&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2Fx%2Forgs%22%7D) | [read-only](https://api.githubcopilot.com/mcp/x/orgs/readonly) | [Install read-only](https://insiders.vscode.dev/redirect/mcp/install?name=gh-orgs&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2Fx%2Forgs%2Freadonly%22%7D) | diff --git a/docs/testing.md b/docs/testing.md index dbdc3e080..226660e9d 100644 --- a/docs/testing.md +++ b/docs/testing.md @@ -22,7 +22,7 @@ This project uses a combination of unit tests and end-to-end (e2e) tests to ensu ## toolsnaps: Tool Schema Snapshots - The `toolsnaps` utility ensures that the JSON schema for each tool does not change unexpectedly. -- Snapshots are stored in `__toolsnaps__/*.snap` files , where `*` represents the name of the tool +- Snapshots are stored in `__toolsnaps__/*.snap` files, where `*` represents the name of the tool - When running tests, the current tool schema is compared to the snapshot. If there is a difference, the test will fail and show a diff. - If you intentionally change a tool's schema, update the snapshots by running tests with the environment variable: `UPDATE_TOOLSNAPS=true go test ./...` - In CI (when `GITHUB_ACTIONS=true`), missing snapshots will cause a test failure to ensure snapshots are always diff --git a/pkg/github/__toolsnaps__/search_code.snap b/pkg/github/__toolsnaps__/search_code.snap index c85d6674d..4ef40c5f8 100644 --- a/pkg/github/__toolsnaps__/search_code.snap +++ b/pkg/github/__toolsnaps__/search_code.snap @@ -3,11 +3,11 @@ "title": "Search code", "readOnlyHint": true }, - "description": "Search for code across GitHub repositories", + "description": "Fast and precise code search across ALL GitHub repositories using GitHub's native search engine. Best for finding exact symbols, functions, classes, or specific code patterns.", "inputSchema": { "properties": { "order": { - "description": "Sort order", + "description": "Sort order for results", "enum": [ "asc", "desc" @@ -25,8 +25,8 @@ "minimum": 1, "type": "number" }, - "q": { - "description": "Search query using GitHub code search syntax", + "query": { + "description": "Search query using GitHub's powerful code search syntax. Examples: 'content:Skill language:Java org:github', 'NOT is:archived language:Python OR language:go', 'repo:github/github-mcp-server'. Supports exact matching, language filters, path filters, and more.", "type": "string" }, "sort": { @@ -35,7 +35,7 @@ } }, "required": [ - "q" + "query" ], "type": "object" }, diff --git a/pkg/github/__toolsnaps__/search_repositories.snap b/pkg/github/__toolsnaps__/search_repositories.snap index b6b6d1d44..d283a2cc0 100644 --- a/pkg/github/__toolsnaps__/search_repositories.snap +++ b/pkg/github/__toolsnaps__/search_repositories.snap @@ -3,7 +3,7 @@ "title": "Search repositories", "readOnlyHint": true }, - "description": "Search for GitHub repositories", + "description": "Find GitHub repositories by name, description, readme, topics, or other metadata. Perfect for discovering projects, finding examples, or locating specific repositories across GitHub.", "inputSchema": { "properties": { "page": { @@ -18,7 +18,7 @@ "type": "number" }, "query": { - "description": "Search query", + "description": "Repository search query. Examples: 'machine learning in:name stars:\u003e1000 language:python', 'topic:react', 'user:facebook'. Supports advanced search syntax for precise filtering.", "type": "string" } }, diff --git a/pkg/github/__toolsnaps__/search_users.snap b/pkg/github/__toolsnaps__/search_users.snap index 5cf9796f2..73ff7a43c 100644 --- a/pkg/github/__toolsnaps__/search_users.snap +++ b/pkg/github/__toolsnaps__/search_users.snap @@ -3,7 +3,7 @@ "title": "Search users", "readOnlyHint": true }, - "description": "Search for GitHub users exclusively", + "description": "Find GitHub users by username, real name, or other profile information. Useful for locating developers, contributors, or team members.", "inputSchema": { "properties": { "order": { @@ -26,11 +26,11 @@ "type": "number" }, "query": { - "description": "Search query using GitHub users search syntax scoped to type:user", + "description": "User search query. Examples: 'john smith', 'location:seattle', 'followers:\u003e100'. Search is automatically scoped to type:user.", "type": "string" }, "sort": { - "description": "Sort field by category", + "description": "Sort users by number of followers or repositories, or when the person joined GitHub.", "enum": [ "followers", "repositories", diff --git a/pkg/github/__toolsnaps__/update_pull_request.snap b/pkg/github/__toolsnaps__/update_pull_request.snap index 765983afd..25170ed5f 100644 --- a/pkg/github/__toolsnaps__/update_pull_request.snap +++ b/pkg/github/__toolsnaps__/update_pull_request.snap @@ -14,6 +14,10 @@ "description": "New description", "type": "string" }, + "draft": { + "description": "Mark pull request as draft (true) or ready for review (false)", + "type": "boolean" + }, "maintainer_can_modify": { "description": "Allow maintainer edits", "type": "boolean" @@ -30,6 +34,13 @@ "description": "Repository name", "type": "string" }, + "reviewers": { + "description": "GitHub usernames to request reviews from", + "items": { + "type": "string" + }, + "type": "array" + }, "state": { "description": "New state", "enum": [ diff --git a/pkg/github/discussions.go b/pkg/github/discussions.go index 2b8ccfb0b..91487f7aa 100644 --- a/pkg/github/discussions.go +++ b/pkg/github/discussions.go @@ -15,9 +15,111 @@ import ( const DefaultGraphQLPageSize = 30 +// Common interface for all discussion query types +type DiscussionQueryResult interface { + GetDiscussionFragment() DiscussionFragment +} + +// Implement the interface for all query types +func (q *BasicNoOrder) GetDiscussionFragment() DiscussionFragment { + return q.Repository.Discussions +} + +func (q *BasicWithOrder) GetDiscussionFragment() DiscussionFragment { + return q.Repository.Discussions +} + +func (q *WithCategoryAndOrder) GetDiscussionFragment() DiscussionFragment { + return q.Repository.Discussions +} + +func (q *WithCategoryNoOrder) GetDiscussionFragment() DiscussionFragment { + return q.Repository.Discussions +} + +type DiscussionFragment struct { + Nodes []NodeFragment + PageInfo PageInfoFragment + TotalCount githubv4.Int +} + +type NodeFragment struct { + Number githubv4.Int + Title githubv4.String + CreatedAt githubv4.DateTime + UpdatedAt githubv4.DateTime + Author struct { + Login githubv4.String + } + Category struct { + Name githubv4.String + } `graphql:"category"` + URL githubv4.String `graphql:"url"` +} + +type PageInfoFragment struct { + HasNextPage bool + HasPreviousPage bool + StartCursor githubv4.String + EndCursor githubv4.String +} + +type BasicNoOrder struct { + Repository struct { + Discussions DiscussionFragment `graphql:"discussions(first: $first, after: $after)"` + } `graphql:"repository(owner: $owner, name: $repo)"` +} + +type BasicWithOrder struct { + Repository struct { + Discussions DiscussionFragment `graphql:"discussions(first: $first, after: $after, orderBy: { field: $orderByField, direction: $orderByDirection })"` + } `graphql:"repository(owner: $owner, name: $repo)"` +} + +type WithCategoryAndOrder struct { + Repository struct { + Discussions DiscussionFragment `graphql:"discussions(first: $first, after: $after, categoryId: $categoryId, orderBy: { field: $orderByField, direction: $orderByDirection })"` + } `graphql:"repository(owner: $owner, name: $repo)"` +} + +type WithCategoryNoOrder struct { + Repository struct { + Discussions DiscussionFragment `graphql:"discussions(first: $first, after: $after, categoryId: $categoryId)"` + } `graphql:"repository(owner: $owner, name: $repo)"` +} + +func fragmentToDiscussion(fragment NodeFragment) *github.Discussion { + return &github.Discussion{ + Number: github.Ptr(int(fragment.Number)), + Title: github.Ptr(string(fragment.Title)), + HTMLURL: github.Ptr(string(fragment.URL)), + CreatedAt: &github.Timestamp{Time: fragment.CreatedAt.Time}, + UpdatedAt: &github.Timestamp{Time: fragment.UpdatedAt.Time}, + User: &github.User{ + Login: github.Ptr(string(fragment.Author.Login)), + }, + DiscussionCategory: &github.DiscussionCategory{ + Name: github.Ptr(string(fragment.Category.Name)), + }, + } +} + +func getQueryType(useOrdering bool, categoryID *githubv4.ID) any { + if categoryID != nil && useOrdering { + return &WithCategoryAndOrder{} + } + if categoryID != nil && !useOrdering { + return &WithCategoryNoOrder{} + } + if categoryID == nil && useOrdering { + return &BasicWithOrder{} + } + return &BasicNoOrder{} +} + func ListDiscussions(getGQLClient GetGQLClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { return mcp.NewTool("list_discussions", - mcp.WithDescription(t("TOOL_LIST_DISCUSSIONS_DESCRIPTION", "List discussions for a repository")), + mcp.WithDescription(t("TOOL_LIST_DISCUSSIONS_DESCRIPTION", "List discussions for a repository or organisation.")), mcp.WithToolAnnotation(mcp.ToolAnnotation{ Title: t("TOOL_LIST_DISCUSSIONS_USER_TITLE", "List discussions"), ReadOnlyHint: ToBoolPtr(true), @@ -27,31 +129,51 @@ func ListDiscussions(getGQLClient GetGQLClientFn, t translations.TranslationHelp mcp.Description("Repository owner"), ), mcp.WithString("repo", - mcp.Required(), - mcp.Description("Repository name"), + mcp.Description("Repository name. If not provided, discussions will be queried at the organisation level."), ), mcp.WithString("category", mcp.Description("Optional filter by discussion category ID. If provided, only discussions with this category are listed."), ), + mcp.WithString("orderBy", + mcp.Description("Order discussions by field. If provided, the 'direction' also needs to be provided."), + mcp.Enum("CREATED_AT", "UPDATED_AT"), + ), + mcp.WithString("direction", + mcp.Description("Order direction."), + mcp.Enum("ASC", "DESC"), + ), WithCursorPagination(), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - // Required params owner, err := RequiredParam[string](request, "owner") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - repo, err := RequiredParam[string](request, "repo") + repo, err := OptionalParam[string](request, "repo") if err != nil { return mcp.NewToolResultError(err.Error()), nil } + // when not provided, default to the .github repository + // this will query discussions at the organisation level + if repo == "" { + repo = ".github" + } - // Optional params category, err := OptionalParam[string](request, "category") if err != nil { return mcp.NewToolResultError(err.Error()), nil } + orderBy, err := OptionalParam[string](request, "orderBy") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + + direction, err := OptionalParam[string](request, "direction") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + // Get pagination parameters and convert to GraphQL format pagination, err := OptionalCursorPaginationParams(request) if err != nil { @@ -67,155 +189,69 @@ func ListDiscussions(getGQLClient GetGQLClientFn, t translations.TranslationHelp return mcp.NewToolResultError(fmt.Sprintf("failed to get GitHub GQL client: %v", err)), nil } - // If category filter is specified, use it as the category ID for server-side filtering var categoryID *githubv4.ID if category != "" { id := githubv4.ID(category) categoryID = &id } - var out []byte - - var discussions []*github.Discussion - if categoryID != nil { - // Query with category filter (server-side filtering) - var query struct { - Repository struct { - Discussions struct { - Nodes []struct { - Number githubv4.Int - Title githubv4.String - CreatedAt githubv4.DateTime - Category struct { - Name githubv4.String - } `graphql:"category"` - URL githubv4.String `graphql:"url"` - } - PageInfo struct { - HasNextPage bool - HasPreviousPage bool - StartCursor string - EndCursor string - } - TotalCount int - } `graphql:"discussions(first: $first, after: $after, categoryId: $categoryId)"` - } `graphql:"repository(owner: $owner, name: $repo)"` - } - vars := map[string]interface{}{ - "owner": githubv4.String(owner), - "repo": githubv4.String(repo), - "categoryId": *categoryID, - "first": githubv4.Int(*paginationParams.First), - } - if paginationParams.After != nil { - vars["after"] = githubv4.String(*paginationParams.After) - } else { - vars["after"] = (*githubv4.String)(nil) - } - if err := client.Query(ctx, &query, vars); err != nil { - return mcp.NewToolResultError(err.Error()), nil - } - - // Map nodes to GitHub Discussion objects - for _, n := range query.Repository.Discussions.Nodes { - di := &github.Discussion{ - Number: github.Ptr(int(n.Number)), - Title: github.Ptr(string(n.Title)), - HTMLURL: github.Ptr(string(n.URL)), - CreatedAt: &github.Timestamp{Time: n.CreatedAt.Time}, - DiscussionCategory: &github.DiscussionCategory{ - Name: github.Ptr(string(n.Category.Name)), - }, - } - discussions = append(discussions, di) - } + vars := map[string]interface{}{ + "owner": githubv4.String(owner), + "repo": githubv4.String(repo), + "first": githubv4.Int(*paginationParams.First), + } + if paginationParams.After != nil { + vars["after"] = githubv4.String(*paginationParams.After) + } else { + vars["after"] = (*githubv4.String)(nil) + } - // Create response with pagination info - response := map[string]interface{}{ - "discussions": discussions, - "pageInfo": map[string]interface{}{ - "hasNextPage": query.Repository.Discussions.PageInfo.HasNextPage, - "hasPreviousPage": query.Repository.Discussions.PageInfo.HasPreviousPage, - "startCursor": query.Repository.Discussions.PageInfo.StartCursor, - "endCursor": query.Repository.Discussions.PageInfo.EndCursor, - }, - "totalCount": query.Repository.Discussions.TotalCount, - } + // this is an extra check in case the tool description is misinterpreted, because + // we shouldn't use ordering unless both a 'field' and 'direction' are provided + useOrdering := orderBy != "" && direction != "" + if useOrdering { + vars["orderByField"] = githubv4.DiscussionOrderField(orderBy) + vars["orderByDirection"] = githubv4.OrderDirection(direction) + } - out, err = json.Marshal(response) - if err != nil { - return nil, fmt.Errorf("failed to marshal discussions: %w", err) - } - } else { - // Query without category filter - var query struct { - Repository struct { - Discussions struct { - Nodes []struct { - Number githubv4.Int - Title githubv4.String - CreatedAt githubv4.DateTime - Category struct { - Name githubv4.String - } `graphql:"category"` - URL githubv4.String `graphql:"url"` - } - PageInfo struct { - HasNextPage bool - HasPreviousPage bool - StartCursor string - EndCursor string - } - TotalCount int - } `graphql:"discussions(first: $first, after: $after)"` - } `graphql:"repository(owner: $owner, name: $repo)"` - } - vars := map[string]interface{}{ - "owner": githubv4.String(owner), - "repo": githubv4.String(repo), - "first": githubv4.Int(*paginationParams.First), - } - if paginationParams.After != nil { - vars["after"] = githubv4.String(*paginationParams.After) - } else { - vars["after"] = (*githubv4.String)(nil) - } - if err := client.Query(ctx, &query, vars); err != nil { - return mcp.NewToolResultError(err.Error()), nil - } + if categoryID != nil { + vars["categoryId"] = *categoryID + } - // Map nodes to GitHub Discussion objects - for _, n := range query.Repository.Discussions.Nodes { - di := &github.Discussion{ - Number: github.Ptr(int(n.Number)), - Title: github.Ptr(string(n.Title)), - HTMLURL: github.Ptr(string(n.URL)), - CreatedAt: &github.Timestamp{Time: n.CreatedAt.Time}, - DiscussionCategory: &github.DiscussionCategory{ - Name: github.Ptr(string(n.Category.Name)), - }, - } - discussions = append(discussions, di) - } + discussionQuery := getQueryType(useOrdering, categoryID) + if err := client.Query(ctx, discussionQuery, vars); err != nil { + return mcp.NewToolResultError(err.Error()), nil + } - // Create response with pagination info - response := map[string]interface{}{ - "discussions": discussions, - "pageInfo": map[string]interface{}{ - "hasNextPage": query.Repository.Discussions.PageInfo.HasNextPage, - "hasPreviousPage": query.Repository.Discussions.PageInfo.HasPreviousPage, - "startCursor": query.Repository.Discussions.PageInfo.StartCursor, - "endCursor": query.Repository.Discussions.PageInfo.EndCursor, - }, - "totalCount": query.Repository.Discussions.TotalCount, + // Extract and convert all discussion nodes using the common interface + var discussions []*github.Discussion + var pageInfo PageInfoFragment + var totalCount githubv4.Int + if queryResult, ok := discussionQuery.(DiscussionQueryResult); ok { + fragment := queryResult.GetDiscussionFragment() + for _, node := range fragment.Nodes { + discussions = append(discussions, fragmentToDiscussion(node)) } + pageInfo = fragment.PageInfo + totalCount = fragment.TotalCount + } - out, err = json.Marshal(response) - if err != nil { - return nil, fmt.Errorf("failed to marshal discussions: %w", err) - } + // Create response with pagination info + response := map[string]interface{}{ + "discussions": discussions, + "pageInfo": map[string]interface{}{ + "hasNextPage": pageInfo.HasNextPage, + "hasPreviousPage": pageInfo.HasPreviousPage, + "startCursor": string(pageInfo.StartCursor), + "endCursor": string(pageInfo.EndCursor), + }, + "totalCount": totalCount, } + out, err := json.Marshal(response) + if err != nil { + return nil, fmt.Errorf("failed to marshal discussions: %w", err) + } return mcp.NewToolResultText(string(out)), nil } } @@ -259,6 +295,7 @@ func GetDiscussion(getGQLClient GetGQLClientFn, t translations.TranslationHelper Repository struct { Discussion struct { Number githubv4.Int + Title githubv4.String Body githubv4.String CreatedAt githubv4.DateTime URL githubv4.String `graphql:"url"` @@ -279,6 +316,7 @@ func GetDiscussion(getGQLClient GetGQLClientFn, t translations.TranslationHelper d := q.Repository.Discussion discussion := &github.Discussion{ Number: github.Ptr(int(d.Number)), + Title: github.Ptr(string(d.Title)), Body: github.Ptr(string(d.Body)), HTMLURL: github.Ptr(string(d.URL)), CreatedAt: &github.Timestamp{Time: d.CreatedAt.Time}, diff --git a/pkg/github/discussions_test.go b/pkg/github/discussions_test.go index e2e3d99ed..9458dfce0 100644 --- a/pkg/github/discussions_test.go +++ b/pkg/github/discussions_test.go @@ -17,14 +17,97 @@ import ( var ( discussionsGeneral = []map[string]any{ - {"number": 1, "title": "Discussion 1 title", "createdAt": "2023-01-01T00:00:00Z", "url": "https://github.com/owner/repo/discussions/1", "category": map[string]any{"name": "General"}}, - {"number": 3, "title": "Discussion 3 title", "createdAt": "2023-03-01T00:00:00Z", "url": "https://github.com/owner/repo/discussions/3", "category": map[string]any{"name": "General"}}, + {"number": 1, "title": "Discussion 1 title", "createdAt": "2023-01-01T00:00:00Z", "updatedAt": "2023-01-01T00:00:00Z", "author": map[string]any{"login": "user1"}, "url": "https://github.com/owner/repo/discussions/1", "category": map[string]any{"name": "General"}}, + {"number": 3, "title": "Discussion 3 title", "createdAt": "2023-03-01T00:00:00Z", "updatedAt": "2023-02-01T00:00:00Z", "author": map[string]any{"login": "user1"}, "url": "https://github.com/owner/repo/discussions/3", "category": map[string]any{"name": "General"}}, } discussionsAll = []map[string]any{ - {"number": 1, "title": "Discussion 1 title", "createdAt": "2023-01-01T00:00:00Z", "url": "https://github.com/owner/repo/discussions/1", "category": map[string]any{"name": "General"}}, - {"number": 2, "title": "Discussion 2 title", "createdAt": "2023-02-01T00:00:00Z", "url": "https://github.com/owner/repo/discussions/2", "category": map[string]any{"name": "Questions"}}, - {"number": 3, "title": "Discussion 3 title", "createdAt": "2023-03-01T00:00:00Z", "url": "https://github.com/owner/repo/discussions/3", "category": map[string]any{"name": "General"}}, + { + "number": 1, + "title": "Discussion 1 title", + "createdAt": "2023-01-01T00:00:00Z", + "updatedAt": "2023-01-01T00:00:00Z", + "author": map[string]any{"login": "user1"}, + "url": "https://github.com/owner/repo/discussions/1", + "category": map[string]any{"name": "General"}, + }, + { + "number": 2, + "title": "Discussion 2 title", + "createdAt": "2023-02-01T00:00:00Z", + "updatedAt": "2023-02-01T00:00:00Z", + "author": map[string]any{"login": "user2"}, + "url": "https://github.com/owner/repo/discussions/2", + "category": map[string]any{"name": "Questions"}, + }, + { + "number": 3, + "title": "Discussion 3 title", + "createdAt": "2023-03-01T00:00:00Z", + "updatedAt": "2023-03-01T00:00:00Z", + "author": map[string]any{"login": "user3"}, + "url": "https://github.com/owner/repo/discussions/3", + "category": map[string]any{"name": "General"}, + }, + } + + discussionsOrgLevel = []map[string]any{ + { + "number": 1, + "title": "Org Discussion 1 - Community Guidelines", + "createdAt": "2023-01-15T00:00:00Z", + "updatedAt": "2023-01-15T00:00:00Z", + "author": map[string]any{"login": "org-admin"}, + "url": "https://github.com/owner/.github/discussions/1", + "category": map[string]any{"name": "Announcements"}, + }, + { + "number": 2, + "title": "Org Discussion 2 - Roadmap 2023", + "createdAt": "2023-02-20T00:00:00Z", + "updatedAt": "2023-02-20T00:00:00Z", + "author": map[string]any{"login": "org-admin"}, + "url": "https://github.com/owner/.github/discussions/2", + "category": map[string]any{"name": "General"}, + }, + { + "number": 3, + "title": "Org Discussion 3 - Roadmap 2024", + "createdAt": "2023-02-20T00:00:00Z", + "updatedAt": "2023-02-20T00:00:00Z", + "author": map[string]any{"login": "org-admin"}, + "url": "https://github.com/owner/.github/discussions/3", + "category": map[string]any{"name": "General"}, + }, + { + "number": 4, + "title": "Org Discussion 4 - Roadmap 2025", + "createdAt": "2023-02-20T00:00:00Z", + "updatedAt": "2023-02-20T00:00:00Z", + "author": map[string]any{"login": "org-admin"}, + "url": "https://github.com/owner/.github/discussions/4", + "category": map[string]any{"name": "General"}, + }, + } + + // Ordered mock responses + discussionsOrderedCreatedAsc = []map[string]any{ + discussionsAll[0], // Discussion 1 (created 2023-01-01) + discussionsAll[1], // Discussion 2 (created 2023-02-01) + discussionsAll[2], // Discussion 3 (created 2023-03-01) } + + discussionsOrderedUpdatedDesc = []map[string]any{ + discussionsAll[2], // Discussion 3 (updated 2023-03-01) + discussionsAll[1], // Discussion 2 (updated 2023-02-01) + discussionsAll[0], // Discussion 1 (updated 2023-01-01) + } + + // only 'General' category discussions ordered by created date descending + discussionsGeneralOrderedDesc = []map[string]any{ + discussionsGeneral[1], // Discussion 3 (created 2023-03-01) + discussionsGeneral[0], // Discussion 1 (created 2023-01-01) + } + mockResponseListAll = githubv4mock.DataResponse(map[string]any{ "repository": map[string]any{ "discussions": map[string]any{ @@ -53,23 +136,77 @@ var ( }, }, }) + mockResponseOrderedCreatedAsc = githubv4mock.DataResponse(map[string]any{ + "repository": map[string]any{ + "discussions": map[string]any{ + "nodes": discussionsOrderedCreatedAsc, + "pageInfo": map[string]any{ + "hasNextPage": false, + "hasPreviousPage": false, + "startCursor": "", + "endCursor": "", + }, + "totalCount": 3, + }, + }, + }) + mockResponseOrderedUpdatedDesc = githubv4mock.DataResponse(map[string]any{ + "repository": map[string]any{ + "discussions": map[string]any{ + "nodes": discussionsOrderedUpdatedDesc, + "pageInfo": map[string]any{ + "hasNextPage": false, + "hasPreviousPage": false, + "startCursor": "", + "endCursor": "", + }, + "totalCount": 3, + }, + }, + }) + mockResponseGeneralOrderedDesc = githubv4mock.DataResponse(map[string]any{ + "repository": map[string]any{ + "discussions": map[string]any{ + "nodes": discussionsGeneralOrderedDesc, + "pageInfo": map[string]any{ + "hasNextPage": false, + "hasPreviousPage": false, + "startCursor": "", + "endCursor": "", + }, + "totalCount": 2, + }, + }, + }) + + mockResponseOrgLevel = githubv4mock.DataResponse(map[string]any{ + "repository": map[string]any{ + "discussions": map[string]any{ + "nodes": discussionsOrgLevel, + "pageInfo": map[string]any{ + "hasNextPage": false, + "hasPreviousPage": false, + "startCursor": "", + "endCursor": "", + }, + "totalCount": 4, + }, + }, + }) + mockErrorRepoNotFound = githubv4mock.ErrorResponse("repository not found") ) func Test_ListDiscussions(t *testing.T) { mockClient := githubv4.NewClient(nil) - // Verify tool definition and schema toolDef, _ := ListDiscussions(stubGetGQLClientFn(mockClient), translations.NullTranslationHelper) assert.Equal(t, "list_discussions", toolDef.Name) assert.NotEmpty(t, toolDef.Description) assert.Contains(t, toolDef.InputSchema.Properties, "owner") assert.Contains(t, toolDef.InputSchema.Properties, "repo") - assert.ElementsMatch(t, toolDef.InputSchema.Required, []string{"owner", "repo"}) - - // Use exact string queries that match implementation output (from error messages) - qDiscussions := "query($after:String$first:Int!$owner:String!$repo:String!){repository(owner: $owner, name: $repo){discussions(first: $first, after: $after){nodes{number,title,createdAt,category{name},url},pageInfo{hasNextPage,hasPreviousPage,startCursor,endCursor},totalCount}}}" - - qDiscussionsFiltered := "query($after:String$categoryId:ID!$first:Int!$owner:String!$repo:String!){repository(owner: $owner, name: $repo){discussions(first: $first, after: $after, categoryId: $categoryId){nodes{number,title,createdAt,category{name},url},pageInfo{hasNextPage,hasPreviousPage,startCursor,endCursor},totalCount}}}" + assert.Contains(t, toolDef.InputSchema.Properties, "orderBy") + assert.Contains(t, toolDef.InputSchema.Properties, "direction") + assert.ElementsMatch(t, toolDef.InputSchema.Required, []string{"owner"}) // Variables matching what GraphQL receives after JSON marshaling/unmarshaling varsListAll := map[string]interface{}{ @@ -94,12 +231,48 @@ func Test_ListDiscussions(t *testing.T) { "after": (*string)(nil), } + varsOrderByCreatedAsc := map[string]interface{}{ + "owner": "owner", + "repo": "repo", + "orderByField": "CREATED_AT", + "orderByDirection": "ASC", + "first": float64(30), + "after": (*string)(nil), + } + + varsOrderByUpdatedDesc := map[string]interface{}{ + "owner": "owner", + "repo": "repo", + "orderByField": "UPDATED_AT", + "orderByDirection": "DESC", + "first": float64(30), + "after": (*string)(nil), + } + + varsCategoryWithOrder := map[string]interface{}{ + "owner": "owner", + "repo": "repo", + "categoryId": "DIC_kwDOABC123", + "orderByField": "CREATED_AT", + "orderByDirection": "DESC", + "first": float64(30), + "after": (*string)(nil), + } + + varsOrgLevel := map[string]interface{}{ + "owner": "owner", + "repo": ".github", // This is what gets set when repo is not provided + "first": float64(30), + "after": (*string)(nil), + } + tests := []struct { name string reqParams map[string]interface{} expectError bool errContains string expectedCount int + verifyOrder func(t *testing.T, discussions []*github.Discussion) }{ { name: "list all discussions without category filter", @@ -120,6 +293,80 @@ func Test_ListDiscussions(t *testing.T) { expectError: false, expectedCount: 2, // Only General discussions (matching the category ID) }, + { + name: "order by created at ascending", + reqParams: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + "orderBy": "CREATED_AT", + "direction": "ASC", + }, + expectError: false, + expectedCount: 3, + verifyOrder: func(t *testing.T, discussions []*github.Discussion) { + // Verify discussions are ordered by created date ascending + require.Len(t, discussions, 3) + assert.Equal(t, 1, *discussions[0].Number, "First should be discussion 1 (created 2023-01-01)") + assert.Equal(t, 2, *discussions[1].Number, "Second should be discussion 2 (created 2023-02-01)") + assert.Equal(t, 3, *discussions[2].Number, "Third should be discussion 3 (created 2023-03-01)") + }, + }, + { + name: "order by updated at descending", + reqParams: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + "orderBy": "UPDATED_AT", + "direction": "DESC", + }, + expectError: false, + expectedCount: 3, + verifyOrder: func(t *testing.T, discussions []*github.Discussion) { + // Verify discussions are ordered by updated date descending + require.Len(t, discussions, 3) + assert.Equal(t, 3, *discussions[0].Number, "First should be discussion 3 (updated 2023-03-01)") + assert.Equal(t, 2, *discussions[1].Number, "Second should be discussion 2 (updated 2023-02-01)") + assert.Equal(t, 1, *discussions[2].Number, "Third should be discussion 1 (updated 2023-01-01)") + }, + }, + { + name: "filter by category with order", + reqParams: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + "category": "DIC_kwDOABC123", + "orderBy": "CREATED_AT", + "direction": "DESC", + }, + expectError: false, + expectedCount: 2, + verifyOrder: func(t *testing.T, discussions []*github.Discussion) { + // Verify only General discussions, ordered by created date descending + require.Len(t, discussions, 2) + assert.Equal(t, 3, *discussions[0].Number, "First should be discussion 3 (created 2023-03-01)") + assert.Equal(t, 1, *discussions[1].Number, "Second should be discussion 1 (created 2023-01-01)") + }, + }, + { + name: "order by without direction (should not use ordering)", + reqParams: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + "orderBy": "CREATED_AT", + }, + expectError: false, + expectedCount: 3, + }, + { + name: "direction without order by (should not use ordering)", + reqParams: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + "direction": "DESC", + }, + expectError: false, + expectedCount: 3, + }, { name: "repository not found error", reqParams: map[string]interface{}{ @@ -129,23 +376,54 @@ func Test_ListDiscussions(t *testing.T) { expectError: true, errContains: "repository not found", }, + { + name: "list org-level discussions (no repo provided)", + reqParams: map[string]interface{}{ + "owner": "owner", + // repo is not provided, it will default to ".github" + }, + expectError: false, + expectedCount: 4, + }, } + // Define the actual query strings that match the implementation + qBasicNoOrder := "query($after:String$first:Int!$owner:String!$repo:String!){repository(owner: $owner, name: $repo){discussions(first: $first, after: $after){nodes{number,title,createdAt,updatedAt,author{login},category{name},url},pageInfo{hasNextPage,hasPreviousPage,startCursor,endCursor},totalCount}}}" + qWithCategoryNoOrder := "query($after:String$categoryId:ID!$first:Int!$owner:String!$repo:String!){repository(owner: $owner, name: $repo){discussions(first: $first, after: $after, categoryId: $categoryId){nodes{number,title,createdAt,updatedAt,author{login},category{name},url},pageInfo{hasNextPage,hasPreviousPage,startCursor,endCursor},totalCount}}}" + qBasicWithOrder := "query($after:String$first:Int!$orderByDirection:OrderDirection!$orderByField:DiscussionOrderField!$owner:String!$repo:String!){repository(owner: $owner, name: $repo){discussions(first: $first, after: $after, orderBy: { field: $orderByField, direction: $orderByDirection }){nodes{number,title,createdAt,updatedAt,author{login},category{name},url},pageInfo{hasNextPage,hasPreviousPage,startCursor,endCursor},totalCount}}}" + qWithCategoryAndOrder := "query($after:String$categoryId:ID!$first:Int!$orderByDirection:OrderDirection!$orderByField:DiscussionOrderField!$owner:String!$repo:String!){repository(owner: $owner, name: $repo){discussions(first: $first, after: $after, categoryId: $categoryId, orderBy: { field: $orderByField, direction: $orderByDirection }){nodes{number,title,createdAt,updatedAt,author{login},category{name},url},pageInfo{hasNextPage,hasPreviousPage,startCursor,endCursor},totalCount}}}" + for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { var httpClient *http.Client switch tc.name { case "list all discussions without category filter": - // Simple case - no category filter - matcher := githubv4mock.NewQueryMatcher(qDiscussions, varsListAll, mockResponseListAll) + matcher := githubv4mock.NewQueryMatcher(qBasicNoOrder, varsListAll, mockResponseListAll) httpClient = githubv4mock.NewMockedHTTPClient(matcher) case "filter by category ID": - // Simple case - category filter using category ID directly - matcher := githubv4mock.NewQueryMatcher(qDiscussionsFiltered, varsDiscussionsFiltered, mockResponseListGeneral) + matcher := githubv4mock.NewQueryMatcher(qWithCategoryNoOrder, varsDiscussionsFiltered, mockResponseListGeneral) + httpClient = githubv4mock.NewMockedHTTPClient(matcher) + case "order by created at ascending": + matcher := githubv4mock.NewQueryMatcher(qBasicWithOrder, varsOrderByCreatedAsc, mockResponseOrderedCreatedAsc) + httpClient = githubv4mock.NewMockedHTTPClient(matcher) + case "order by updated at descending": + matcher := githubv4mock.NewQueryMatcher(qBasicWithOrder, varsOrderByUpdatedDesc, mockResponseOrderedUpdatedDesc) + httpClient = githubv4mock.NewMockedHTTPClient(matcher) + case "filter by category with order": + matcher := githubv4mock.NewQueryMatcher(qWithCategoryAndOrder, varsCategoryWithOrder, mockResponseGeneralOrderedDesc) + httpClient = githubv4mock.NewMockedHTTPClient(matcher) + case "order by without direction (should not use ordering)": + matcher := githubv4mock.NewQueryMatcher(qBasicNoOrder, varsListAll, mockResponseListAll) + httpClient = githubv4mock.NewMockedHTTPClient(matcher) + case "direction without order by (should not use ordering)": + matcher := githubv4mock.NewQueryMatcher(qBasicNoOrder, varsListAll, mockResponseListAll) httpClient = githubv4mock.NewMockedHTTPClient(matcher) case "repository not found error": - matcher := githubv4mock.NewQueryMatcher(qDiscussions, varsRepoNotFound, mockErrorRepoNotFound) + matcher := githubv4mock.NewQueryMatcher(qBasicNoOrder, varsRepoNotFound, mockErrorRepoNotFound) + httpClient = githubv4mock.NewMockedHTTPClient(matcher) + case "list org-level discussions (no repo provided)": + matcher := githubv4mock.NewQueryMatcher(qBasicNoOrder, varsOrgLevel, mockResponseOrgLevel) httpClient = githubv4mock.NewMockedHTTPClient(matcher) } @@ -179,6 +457,11 @@ func Test_ListDiscussions(t *testing.T) { assert.Len(t, response.Discussions, tc.expectedCount, "Expected %d discussions, got %d", tc.expectedCount, len(response.Discussions)) + // Verify order if verifyOrder function is provided + if tc.verifyOrder != nil { + tc.verifyOrder(t, response.Discussions) + } + // Verify that all returned discussions have a category if filtered if _, hasCategory := tc.reqParams["category"]; hasCategory { for _, discussion := range response.Discussions { @@ -201,7 +484,7 @@ func Test_GetDiscussion(t *testing.T) { assert.ElementsMatch(t, toolDef.InputSchema.Required, []string{"owner", "repo", "discussionNumber"}) // Use exact string query that matches implementation output - qGetDiscussion := "query($discussionNumber:Int!$owner:String!$repo:String!){repository(owner: $owner, name: $repo){discussion(number: $discussionNumber){number,body,createdAt,url,category{name}}}}" + qGetDiscussion := "query($discussionNumber:Int!$owner:String!$repo:String!){repository(owner: $owner, name: $repo){discussion(number: $discussionNumber){number,title,body,createdAt,url,category{name}}}}" vars := map[string]interface{}{ "owner": "owner", @@ -220,6 +503,7 @@ func Test_GetDiscussion(t *testing.T) { response: githubv4mock.DataResponse(map[string]any{ "repository": map[string]any{"discussion": map[string]any{ "number": 1, + "title": "Test Discussion Title", "body": "This is a test discussion", "url": "https://github.com/owner/repo/discussions/1", "createdAt": "2025-04-25T12:00:00Z", @@ -230,6 +514,7 @@ func Test_GetDiscussion(t *testing.T) { expected: &github.Discussion{ HTMLURL: github.Ptr("https://github.com/owner/repo/discussions/1"), Number: github.Ptr(1), + Title: github.Ptr("Test Discussion Title"), Body: github.Ptr("This is a test discussion"), CreatedAt: &github.Timestamp{Time: time.Date(2025, 4, 25, 12, 0, 0, 0, time.UTC)}, DiscussionCategory: &github.DiscussionCategory{ @@ -266,6 +551,7 @@ func Test_GetDiscussion(t *testing.T) { require.NoError(t, json.Unmarshal([]byte(text), &out)) assert.Equal(t, *tc.expected.HTMLURL, *out.HTMLURL) assert.Equal(t, *tc.expected.Number, *out.Number) + assert.Equal(t, *tc.expected.Title, *out.Title) assert.Equal(t, *tc.expected.Body, *out.Body) // Check category label assert.Equal(t, *tc.expected.DiscussionCategory.Name, *out.DiscussionCategory.Name) diff --git a/pkg/github/gists.go b/pkg/github/gists.go new file mode 100644 index 000000000..403804cad --- /dev/null +++ b/pkg/github/gists.go @@ -0,0 +1,259 @@ +package github + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + + "github.com/github/github-mcp-server/pkg/translations" + "github.com/google/go-github/v73/github" + "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/server" +) + +// ListGists creates a tool to list gists for a user +func ListGists(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { + return mcp.NewTool("list_gists", + mcp.WithDescription(t("TOOL_LIST_GISTS_DESCRIPTION", "List gists for a user")), + mcp.WithToolAnnotation(mcp.ToolAnnotation{ + Title: t("TOOL_LIST_GISTS", "List Gists"), + ReadOnlyHint: ToBoolPtr(true), + }), + mcp.WithString("username", + mcp.Description("GitHub username (omit for authenticated user's gists)"), + ), + mcp.WithString("since", + mcp.Description("Only gists updated after this time (ISO 8601 timestamp)"), + ), + WithPagination(), + ), + func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + username, err := OptionalParam[string](request, "username") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + + since, err := OptionalParam[string](request, "since") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + + pagination, err := OptionalPaginationParams(request) + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + + opts := &github.GistListOptions{ + ListOptions: github.ListOptions{ + Page: pagination.Page, + PerPage: pagination.PerPage, + }, + } + + // Parse since timestamp if provided + if since != "" { + sinceTime, err := parseISOTimestamp(since) + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("invalid since timestamp: %v", err)), nil + } + opts.Since = sinceTime + } + + client, err := getClient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get GitHub client: %w", err) + } + + gists, resp, err := client.Gists.List(ctx, username, opts) + if err != nil { + return nil, fmt.Errorf("failed to list gists: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + return mcp.NewToolResultError(fmt.Sprintf("failed to list gists: %s", string(body))), nil + } + + r, err := json.Marshal(gists) + if err != nil { + return nil, fmt.Errorf("failed to marshal response: %w", err) + } + + return mcp.NewToolResultText(string(r)), nil + } +} + +// CreateGist creates a tool to create a new gist +func CreateGist(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { + return mcp.NewTool("create_gist", + mcp.WithDescription(t("TOOL_CREATE_GIST_DESCRIPTION", "Create a new gist")), + mcp.WithToolAnnotation(mcp.ToolAnnotation{ + Title: t("TOOL_CREATE_GIST", "Create Gist"), + ReadOnlyHint: ToBoolPtr(false), + }), + mcp.WithString("description", + mcp.Description("Description of the gist"), + ), + mcp.WithString("filename", + mcp.Required(), + mcp.Description("Filename for simple single-file gist creation"), + ), + mcp.WithString("content", + mcp.Required(), + mcp.Description("Content for simple single-file gist creation"), + ), + mcp.WithBoolean("public", + mcp.Description("Whether the gist is public"), + mcp.DefaultBool(false), + ), + ), + func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + description, err := OptionalParam[string](request, "description") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + + filename, err := RequiredParam[string](request, "filename") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + + content, err := RequiredParam[string](request, "content") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + + public, err := OptionalParam[bool](request, "public") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + + files := make(map[github.GistFilename]github.GistFile) + files[github.GistFilename(filename)] = github.GistFile{ + Filename: github.Ptr(filename), + Content: github.Ptr(content), + } + + gist := &github.Gist{ + Files: files, + Public: github.Ptr(public), + Description: github.Ptr(description), + } + + client, err := getClient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get GitHub client: %w", err) + } + + createdGist, resp, err := client.Gists.Create(ctx, gist) + if err != nil { + return nil, fmt.Errorf("failed to create gist: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusCreated { + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + return mcp.NewToolResultError(fmt.Sprintf("failed to create gist: %s", string(body))), nil + } + + r, err := json.Marshal(createdGist) + if err != nil { + return nil, fmt.Errorf("failed to marshal response: %w", err) + } + + return mcp.NewToolResultText(string(r)), nil + } +} + +// UpdateGist creates a tool to edit an existing gist +func UpdateGist(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { + return mcp.NewTool("update_gist", + mcp.WithDescription(t("TOOL_UPDATE_GIST_DESCRIPTION", "Update an existing gist")), + mcp.WithToolAnnotation(mcp.ToolAnnotation{ + Title: t("TOOL_UPDATE_GIST", "Update Gist"), + ReadOnlyHint: ToBoolPtr(false), + }), + mcp.WithString("gist_id", + mcp.Required(), + mcp.Description("ID of the gist to update"), + ), + mcp.WithString("description", + mcp.Description("Updated description of the gist"), + ), + mcp.WithString("filename", + mcp.Required(), + mcp.Description("Filename to update or create"), + ), + mcp.WithString("content", + mcp.Required(), + mcp.Description("Content for the file"), + ), + ), + func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + gistID, err := RequiredParam[string](request, "gist_id") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + + description, err := OptionalParam[string](request, "description") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + + filename, err := RequiredParam[string](request, "filename") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + + content, err := RequiredParam[string](request, "content") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + + files := make(map[github.GistFilename]github.GistFile) + files[github.GistFilename(filename)] = github.GistFile{ + Filename: github.Ptr(filename), + Content: github.Ptr(content), + } + + gist := &github.Gist{ + Files: files, + Description: github.Ptr(description), + } + + client, err := getClient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get GitHub client: %w", err) + } + + updatedGist, resp, err := client.Gists.Edit(ctx, gistID, gist) + if err != nil { + return nil, fmt.Errorf("failed to update gist: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + return mcp.NewToolResultError(fmt.Sprintf("failed to update gist: %s", string(body))), nil + } + + r, err := json.Marshal(updatedGist) + if err != nil { + return nil, fmt.Errorf("failed to marshal response: %w", err) + } + + return mcp.NewToolResultText(string(r)), nil + } +} diff --git a/pkg/github/gists_test.go b/pkg/github/gists_test.go new file mode 100644 index 000000000..423422925 --- /dev/null +++ b/pkg/github/gists_test.go @@ -0,0 +1,507 @@ +package github + +import ( + "context" + "encoding/json" + "net/http" + "testing" + "time" + + "github.com/github/github-mcp-server/pkg/translations" + "github.com/google/go-github/v73/github" + "github.com/migueleliasweb/go-github-mock/src/mock" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_ListGists(t *testing.T) { + // Verify tool definition + mockClient := github.NewClient(nil) + tool, _ := ListGists(stubGetClientFn(mockClient), translations.NullTranslationHelper) + + assert.Equal(t, "list_gists", tool.Name) + assert.NotEmpty(t, tool.Description) + assert.Contains(t, tool.InputSchema.Properties, "username") + assert.Contains(t, tool.InputSchema.Properties, "since") + assert.Contains(t, tool.InputSchema.Properties, "page") + assert.Contains(t, tool.InputSchema.Properties, "perPage") + assert.Empty(t, tool.InputSchema.Required) + + // Setup mock gists for success case + mockGists := []*github.Gist{ + { + ID: github.Ptr("gist1"), + Description: github.Ptr("First Gist"), + HTMLURL: github.Ptr("https://gist.github.com/user/gist1"), + Public: github.Ptr(true), + CreatedAt: &github.Timestamp{Time: time.Date(2023, 1, 1, 0, 0, 0, 0, time.UTC)}, + Owner: &github.User{Login: github.Ptr("user")}, + Files: map[github.GistFilename]github.GistFile{ + "file1.txt": { + Filename: github.Ptr("file1.txt"), + Content: github.Ptr("content of file 1"), + }, + }, + }, + { + ID: github.Ptr("gist2"), + Description: github.Ptr("Second Gist"), + HTMLURL: github.Ptr("https://gist.github.com/testuser/gist2"), + Public: github.Ptr(false), + CreatedAt: &github.Timestamp{Time: time.Date(2023, 2, 1, 0, 0, 0, 0, time.UTC)}, + Owner: &github.User{Login: github.Ptr("testuser")}, + Files: map[github.GistFilename]github.GistFile{ + "file2.js": { + Filename: github.Ptr("file2.js"), + Content: github.Ptr("console.log('hello');"), + }, + }, + }, + } + + tests := []struct { + name string + mockedClient *http.Client + requestArgs map[string]interface{} + expectError bool + expectedGists []*github.Gist + expectedErrMsg string + }{ + { + name: "list authenticated user's gists", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatch( + mock.GetGists, + mockGists, + ), + ), + requestArgs: map[string]interface{}{}, + expectError: false, + expectedGists: mockGists, + }, + { + name: "list specific user's gists", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.GetUsersGistsByUsername, + mockResponse(t, http.StatusOK, mockGists), + ), + ), + requestArgs: map[string]interface{}{ + "username": "testuser", + }, + expectError: false, + expectedGists: mockGists, + }, + { + name: "list gists with pagination and since parameter", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.GetGists, + expectQueryParams(t, map[string]string{ + "since": "2023-01-01T00:00:00Z", + "page": "2", + "per_page": "5", + }).andThen( + mockResponse(t, http.StatusOK, mockGists), + ), + ), + ), + requestArgs: map[string]interface{}{ + "since": "2023-01-01T00:00:00Z", + "page": float64(2), + "perPage": float64(5), + }, + expectError: false, + expectedGists: mockGists, + }, + { + name: "invalid since parameter", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatch( + mock.GetGists, + mockGists, + ), + ), + requestArgs: map[string]interface{}{ + "since": "invalid-date", + }, + expectError: true, + expectedErrMsg: "invalid since timestamp", + }, + { + name: "list gists fails with error", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.GetGists, + http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusUnauthorized) + _, _ = w.Write([]byte(`{"message": "Requires authentication"}`)) + }), + ), + ), + requestArgs: map[string]interface{}{}, + expectError: true, + expectedErrMsg: "failed to list gists", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Setup client with mock + client := github.NewClient(tc.mockedClient) + _, handler := ListGists(stubGetClientFn(client), translations.NullTranslationHelper) + + // Create call request + request := createMCPRequest(tc.requestArgs) + + // Call handler + result, err := handler(context.Background(), request) + + // Verify results + if tc.expectError { + if err != nil { + assert.Contains(t, err.Error(), tc.expectedErrMsg) + } else { + // For errors returned as part of the result, not as an error + assert.NotNil(t, result) + textContent := getTextResult(t, result) + assert.Contains(t, textContent.Text, tc.expectedErrMsg) + } + return + } + + require.NoError(t, err) + + // Parse the result and get the text content if no error + textContent := getTextResult(t, result) + + // Unmarshal and verify the result + var returnedGists []*github.Gist + err = json.Unmarshal([]byte(textContent.Text), &returnedGists) + require.NoError(t, err) + + assert.Len(t, returnedGists, len(tc.expectedGists)) + for i, gist := range returnedGists { + assert.Equal(t, *tc.expectedGists[i].ID, *gist.ID) + assert.Equal(t, *tc.expectedGists[i].Description, *gist.Description) + assert.Equal(t, *tc.expectedGists[i].HTMLURL, *gist.HTMLURL) + assert.Equal(t, *tc.expectedGists[i].Public, *gist.Public) + } + }) + } +} + +func Test_CreateGist(t *testing.T) { + // Verify tool definition + mockClient := github.NewClient(nil) + tool, _ := CreateGist(stubGetClientFn(mockClient), translations.NullTranslationHelper) + + assert.Equal(t, "create_gist", tool.Name) + assert.NotEmpty(t, tool.Description) + assert.Contains(t, tool.InputSchema.Properties, "description") + assert.Contains(t, tool.InputSchema.Properties, "filename") + assert.Contains(t, tool.InputSchema.Properties, "content") + assert.Contains(t, tool.InputSchema.Properties, "public") + + // Verify required parameters + assert.Contains(t, tool.InputSchema.Required, "filename") + assert.Contains(t, tool.InputSchema.Required, "content") + + // Setup mock data for test cases + createdGist := &github.Gist{ + ID: github.Ptr("new-gist-id"), + Description: github.Ptr("Test Gist"), + HTMLURL: github.Ptr("https://gist.github.com/user/new-gist-id"), + Public: github.Ptr(false), + CreatedAt: &github.Timestamp{Time: time.Now()}, + Owner: &github.User{Login: github.Ptr("user")}, + Files: map[github.GistFilename]github.GistFile{ + "test.go": { + Filename: github.Ptr("test.go"), + Content: github.Ptr("package main\n\nfunc main() {\n\tfmt.Println(\"Hello, Gist!\")\n}"), + }, + }, + } + + tests := []struct { + name string + mockedClient *http.Client + requestArgs map[string]interface{} + expectError bool + expectedErrMsg string + expectedGist *github.Gist + }{ + { + name: "create gist successfully", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.PostGists, + mockResponse(t, http.StatusCreated, createdGist), + ), + ), + requestArgs: map[string]interface{}{ + "filename": "test.go", + "content": "package main\n\nfunc main() {\n\tfmt.Println(\"Hello, Gist!\")\n}", + "description": "Test Gist", + "public": false, + }, + expectError: false, + expectedGist: createdGist, + }, + { + name: "missing required filename", + mockedClient: mock.NewMockedHTTPClient(), + requestArgs: map[string]interface{}{ + "content": "test content", + "description": "Test Gist", + }, + expectError: true, + expectedErrMsg: "missing required parameter: filename", + }, + { + name: "missing required content", + mockedClient: mock.NewMockedHTTPClient(), + requestArgs: map[string]interface{}{ + "filename": "test.go", + "description": "Test Gist", + }, + expectError: true, + expectedErrMsg: "missing required parameter: content", + }, + { + name: "api returns error", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.PostGists, + http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusUnauthorized) + _, _ = w.Write([]byte(`{"message": "Requires authentication"}`)) + }), + ), + ), + requestArgs: map[string]interface{}{ + "filename": "test.go", + "content": "package main", + "description": "Test Gist", + }, + expectError: true, + expectedErrMsg: "failed to create gist", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Setup client with mock + client := github.NewClient(tc.mockedClient) + _, handler := CreateGist(stubGetClientFn(client), translations.NullTranslationHelper) + + // Create call request + request := createMCPRequest(tc.requestArgs) + + // Call handler + result, err := handler(context.Background(), request) + + // Verify results + if tc.expectError { + if err != nil { + assert.Contains(t, err.Error(), tc.expectedErrMsg) + } else { + // For errors returned as part of the result, not as an error + assert.NotNil(t, result) + textContent := getTextResult(t, result) + assert.Contains(t, textContent.Text, tc.expectedErrMsg) + } + return + } + + require.NoError(t, err) + assert.NotNil(t, result) + + // Parse the result and get the text content + textContent := getTextResult(t, result) + + // Unmarshal and verify the result + var gist *github.Gist + err = json.Unmarshal([]byte(textContent.Text), &gist) + require.NoError(t, err) + + assert.Equal(t, *tc.expectedGist.ID, *gist.ID) + assert.Equal(t, *tc.expectedGist.Description, *gist.Description) + assert.Equal(t, *tc.expectedGist.HTMLURL, *gist.HTMLURL) + assert.Equal(t, *tc.expectedGist.Public, *gist.Public) + + // Verify file content + for filename, expectedFile := range tc.expectedGist.Files { + actualFile, exists := gist.Files[filename] + assert.True(t, exists) + assert.Equal(t, *expectedFile.Filename, *actualFile.Filename) + assert.Equal(t, *expectedFile.Content, *actualFile.Content) + } + }) + } +} + +func Test_UpdateGist(t *testing.T) { + // Verify tool definition + mockClient := github.NewClient(nil) + tool, _ := UpdateGist(stubGetClientFn(mockClient), translations.NullTranslationHelper) + + assert.Equal(t, "update_gist", tool.Name) + assert.NotEmpty(t, tool.Description) + assert.Contains(t, tool.InputSchema.Properties, "gist_id") + assert.Contains(t, tool.InputSchema.Properties, "description") + assert.Contains(t, tool.InputSchema.Properties, "filename") + assert.Contains(t, tool.InputSchema.Properties, "content") + + // Verify required parameters + assert.Contains(t, tool.InputSchema.Required, "gist_id") + assert.Contains(t, tool.InputSchema.Required, "filename") + assert.Contains(t, tool.InputSchema.Required, "content") + + // Setup mock data for test cases + updatedGist := &github.Gist{ + ID: github.Ptr("existing-gist-id"), + Description: github.Ptr("Updated Test Gist"), + HTMLURL: github.Ptr("https://gist.github.com/user/existing-gist-id"), + Public: github.Ptr(true), + UpdatedAt: &github.Timestamp{Time: time.Now()}, + Owner: &github.User{Login: github.Ptr("user")}, + Files: map[github.GistFilename]github.GistFile{ + "updated.go": { + Filename: github.Ptr("updated.go"), + Content: github.Ptr("package main\n\nfunc main() {\n\tfmt.Println(\"Updated Gist!\")\n}"), + }, + }, + } + + tests := []struct { + name string + mockedClient *http.Client + requestArgs map[string]interface{} + expectError bool + expectedErrMsg string + expectedGist *github.Gist + }{ + { + name: "update gist successfully", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.PatchGistsByGistId, + mockResponse(t, http.StatusOK, updatedGist), + ), + ), + requestArgs: map[string]interface{}{ + "gist_id": "existing-gist-id", + "filename": "updated.go", + "content": "package main\n\nfunc main() {\n\tfmt.Println(\"Updated Gist!\")\n}", + "description": "Updated Test Gist", + }, + expectError: false, + expectedGist: updatedGist, + }, + { + name: "missing required gist_id", + mockedClient: mock.NewMockedHTTPClient(), + requestArgs: map[string]interface{}{ + "filename": "updated.go", + "content": "updated content", + "description": "Updated Test Gist", + }, + expectError: true, + expectedErrMsg: "missing required parameter: gist_id", + }, + { + name: "missing required filename", + mockedClient: mock.NewMockedHTTPClient(), + requestArgs: map[string]interface{}{ + "gist_id": "existing-gist-id", + "content": "updated content", + "description": "Updated Test Gist", + }, + expectError: true, + expectedErrMsg: "missing required parameter: filename", + }, + { + name: "missing required content", + mockedClient: mock.NewMockedHTTPClient(), + requestArgs: map[string]interface{}{ + "gist_id": "existing-gist-id", + "filename": "updated.go", + "description": "Updated Test Gist", + }, + expectError: true, + expectedErrMsg: "missing required parameter: content", + }, + { + name: "api returns error", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.PatchGistsByGistId, + http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusNotFound) + _, _ = w.Write([]byte(`{"message": "Not Found"}`)) + }), + ), + ), + requestArgs: map[string]interface{}{ + "gist_id": "nonexistent-gist-id", + "filename": "updated.go", + "content": "package main", + "description": "Updated Test Gist", + }, + expectError: true, + expectedErrMsg: "failed to update gist", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Setup client with mock + client := github.NewClient(tc.mockedClient) + _, handler := UpdateGist(stubGetClientFn(client), translations.NullTranslationHelper) + + // Create call request + request := createMCPRequest(tc.requestArgs) + + // Call handler + result, err := handler(context.Background(), request) + + // Verify results + if tc.expectError { + if err != nil { + assert.Contains(t, err.Error(), tc.expectedErrMsg) + } else { + // For errors returned as part of the result, not as an error + assert.NotNil(t, result) + textContent := getTextResult(t, result) + assert.Contains(t, textContent.Text, tc.expectedErrMsg) + } + return + } + + require.NoError(t, err) + assert.NotNil(t, result) + + // Parse the result and get the text content + textContent := getTextResult(t, result) + + // Unmarshal and verify the result + var gist *github.Gist + err = json.Unmarshal([]byte(textContent.Text), &gist) + require.NoError(t, err) + + assert.Equal(t, *tc.expectedGist.ID, *gist.ID) + assert.Equal(t, *tc.expectedGist.Description, *gist.Description) + assert.Equal(t, *tc.expectedGist.HTMLURL, *gist.HTMLURL) + + // Verify file content + for filename, expectedFile := range tc.expectedGist.Files { + actualFile, exists := gist.Files[filename] + assert.True(t, exists) + assert.Equal(t, *expectedFile.Filename, *actualFile.Filename) + assert.Equal(t, *expectedFile.Content, *actualFile.Content) + } + }) + } +} diff --git a/pkg/github/pullrequests.go b/pkg/github/pullrequests.go index 47b7c6bd2..f82117cad 100644 --- a/pkg/github/pullrequests.go +++ b/pkg/github/pullrequests.go @@ -203,7 +203,7 @@ func CreatePullRequest(getClient GetClientFn, t translations.TranslationHelperFu } // UpdatePullRequest creates a tool to update an existing pull request. -func UpdatePullRequest(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, server.ToolHandlerFunc) { +func UpdatePullRequest(getClient GetClientFn, getGQLClient GetGQLClientFn, t translations.TranslationHelperFunc) (mcp.Tool, server.ToolHandlerFunc) { return mcp.NewTool("update_pull_request", mcp.WithDescription(t("TOOL_UPDATE_PULL_REQUEST_DESCRIPTION", "Update an existing pull request in a GitHub repository.")), mcp.WithToolAnnotation(mcp.ToolAnnotation{ @@ -232,12 +232,21 @@ func UpdatePullRequest(getClient GetClientFn, t translations.TranslationHelperFu mcp.Description("New state"), mcp.Enum("open", "closed"), ), + mcp.WithBoolean("draft", + mcp.Description("Mark pull request as draft (true) or ready for review (false)"), + ), mcp.WithString("base", mcp.Description("New base branch name"), ), mcp.WithBoolean("maintainer_can_modify", mcp.Description("Allow maintainer edits"), ), + mcp.WithArray("reviewers", + mcp.Description("GitHub usernames to request reviews from"), + mcp.Items(map[string]interface{}{ + "type": "string", + }), + ), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { owner, err := RequiredParam[string](request, "owner") @@ -253,74 +262,211 @@ func UpdatePullRequest(getClient GetClientFn, t translations.TranslationHelperFu return mcp.NewToolResultError(err.Error()), nil } + // Check if draft parameter is provided + draftProvided := request.GetArguments()["draft"] != nil + var draftValue bool + if draftProvided { + draftValue, err = OptionalParam[bool](request, "draft") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + } + // Build the update struct only with provided fields update := &github.PullRequest{} - updateNeeded := false + restUpdateNeeded := false if title, ok, err := OptionalParamOK[string](request, "title"); err != nil { return mcp.NewToolResultError(err.Error()), nil } else if ok { update.Title = github.Ptr(title) - updateNeeded = true + restUpdateNeeded = true } if body, ok, err := OptionalParamOK[string](request, "body"); err != nil { return mcp.NewToolResultError(err.Error()), nil } else if ok { update.Body = github.Ptr(body) - updateNeeded = true + restUpdateNeeded = true } if state, ok, err := OptionalParamOK[string](request, "state"); err != nil { return mcp.NewToolResultError(err.Error()), nil } else if ok { update.State = github.Ptr(state) - updateNeeded = true + restUpdateNeeded = true } if base, ok, err := OptionalParamOK[string](request, "base"); err != nil { return mcp.NewToolResultError(err.Error()), nil } else if ok { update.Base = &github.PullRequestBranch{Ref: github.Ptr(base)} - updateNeeded = true + restUpdateNeeded = true } if maintainerCanModify, ok, err := OptionalParamOK[bool](request, "maintainer_can_modify"); err != nil { return mcp.NewToolResultError(err.Error()), nil } else if ok { update.MaintainerCanModify = github.Ptr(maintainerCanModify) - updateNeeded = true + restUpdateNeeded = true + } + + // Handle reviewers separately + reviewers, err := OptionalStringArrayParam(request, "reviewers") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil } - if !updateNeeded { + // If no updates, no draft change, and no reviewers, return error early + if !restUpdateNeeded && !draftProvided && len(reviewers) == 0 { return mcp.NewToolResultError("No update parameters provided."), nil } - client, err := getClient(ctx) - if err != nil { - return nil, fmt.Errorf("failed to get GitHub client: %w", err) + // Handle REST API updates (title, body, state, base, maintainer_can_modify) + if restUpdateNeeded { + client, err := getClient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get GitHub client: %w", err) + } + + _, resp, err := client.PullRequests.Edit(ctx, owner, repo, pullNumber, update) + if err != nil { + return ghErrors.NewGitHubAPIErrorResponse(ctx, + "failed to update pull request", + resp, + err, + ), nil + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + return mcp.NewToolResultError(fmt.Sprintf("failed to update pull request: %s", string(body))), nil + } } - pr, resp, err := client.PullRequests.Edit(ctx, owner, repo, pullNumber, update) - if err != nil { - return ghErrors.NewGitHubAPIErrorResponse(ctx, - "failed to update pull request", - resp, - err, - ), nil + + // Handle draft status changes using GraphQL + if draftProvided { + gqlClient, err := getGQLClient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get GitHub GraphQL client: %w", err) + } + + var prQuery struct { + Repository struct { + PullRequest struct { + ID githubv4.ID + IsDraft githubv4.Boolean + } `graphql:"pullRequest(number: $prNum)"` + } `graphql:"repository(owner: $owner, name: $repo)"` + } + + err = gqlClient.Query(ctx, &prQuery, map[string]interface{}{ + "owner": githubv4.String(owner), + "repo": githubv4.String(repo), + "prNum": githubv4.Int(pullNumber), // #nosec G115 - pull request numbers are always small positive integers + }) + if err != nil { + return ghErrors.NewGitHubGraphQLErrorResponse(ctx, "Failed to find pull request", err), nil + } + + currentIsDraft := bool(prQuery.Repository.PullRequest.IsDraft) + + if currentIsDraft != draftValue { + if draftValue { + // Convert to draft + var mutation struct { + ConvertPullRequestToDraft struct { + PullRequest struct { + ID githubv4.ID + IsDraft githubv4.Boolean + } + } `graphql:"convertPullRequestToDraft(input: $input)"` + } + + err = gqlClient.Mutate(ctx, &mutation, githubv4.ConvertPullRequestToDraftInput{ + PullRequestID: prQuery.Repository.PullRequest.ID, + }, nil) + if err != nil { + return ghErrors.NewGitHubGraphQLErrorResponse(ctx, "Failed to convert pull request to draft", err), nil + } + } else { + // Mark as ready for review + var mutation struct { + MarkPullRequestReadyForReview struct { + PullRequest struct { + ID githubv4.ID + IsDraft githubv4.Boolean + } + } `graphql:"markPullRequestReadyForReview(input: $input)"` + } + + err = gqlClient.Mutate(ctx, &mutation, githubv4.MarkPullRequestReadyForReviewInput{ + PullRequestID: prQuery.Repository.PullRequest.ID, + }, nil) + if err != nil { + return ghErrors.NewGitHubGraphQLErrorResponse(ctx, "Failed to mark pull request ready for review", err), nil + } + } + } } - defer func() { _ = resp.Body.Close() }() - if resp.StatusCode != http.StatusOK { - body, err := io.ReadAll(resp.Body) + // Handle reviewer requests + if len(reviewers) > 0 { + client, err := getClient(ctx) if err != nil { - return nil, fmt.Errorf("failed to read response body: %w", err) + return nil, fmt.Errorf("failed to get GitHub client: %w", err) + } + + reviewersRequest := github.ReviewersRequest{ + Reviewers: reviewers, + } + + _, resp, err := client.PullRequests.RequestReviewers(ctx, owner, repo, pullNumber, reviewersRequest) + if err != nil { + return ghErrors.NewGitHubAPIErrorResponse(ctx, + "failed to request reviewers", + resp, + err, + ), nil + } + defer func() { + if resp != nil && resp.Body != nil { + _ = resp.Body.Close() + } + }() + + if resp.StatusCode != http.StatusCreated && resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + return mcp.NewToolResultError(fmt.Sprintf("failed to request reviewers: %s", string(body))), nil } - return mcp.NewToolResultError(fmt.Sprintf("failed to update pull request: %s", string(body))), nil } - r, err := json.Marshal(pr) + // Get the final state of the PR to return + client, err := getClient(ctx) if err != nil { - return nil, fmt.Errorf("failed to marshal response: %w", err) + return nil, err + } + + finalPR, resp, err := client.PullRequests.Get(ctx, owner, repo, pullNumber) + if err != nil { + return ghErrors.NewGitHubAPIErrorResponse(ctx, "Failed to get pull request", resp, err), nil + } + defer func() { + if resp != nil && resp.Body != nil { + _ = resp.Body.Close() + } + }() + + r, err := json.Marshal(finalPR) + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("Failed to marshal response: %v", err)), nil } return mcp.NewToolResultText(string(r)), nil diff --git a/pkg/github/pullrequests_test.go b/pkg/github/pullrequests_test.go index 42fd5bf03..3a99d9f46 100644 --- a/pkg/github/pullrequests_test.go +++ b/pkg/github/pullrequests_test.go @@ -137,7 +137,7 @@ func Test_GetPullRequest(t *testing.T) { func Test_UpdatePullRequest(t *testing.T) { // Verify tool definition once mockClient := github.NewClient(nil) - tool, _ := UpdatePullRequest(stubGetClientFn(mockClient), translations.NullTranslationHelper) + tool, _ := UpdatePullRequest(stubGetClientFn(mockClient), stubGetGQLClientFn(githubv4.NewClient(nil)), translations.NullTranslationHelper) require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "update_pull_request", tool.Name) @@ -145,11 +145,13 @@ func Test_UpdatePullRequest(t *testing.T) { assert.Contains(t, tool.InputSchema.Properties, "owner") assert.Contains(t, tool.InputSchema.Properties, "repo") assert.Contains(t, tool.InputSchema.Properties, "pullNumber") + assert.Contains(t, tool.InputSchema.Properties, "draft") assert.Contains(t, tool.InputSchema.Properties, "title") assert.Contains(t, tool.InputSchema.Properties, "body") assert.Contains(t, tool.InputSchema.Properties, "state") assert.Contains(t, tool.InputSchema.Properties, "base") assert.Contains(t, tool.InputSchema.Properties, "maintainer_can_modify") + assert.Contains(t, tool.InputSchema.Properties, "reviewers") assert.ElementsMatch(t, tool.InputSchema.Required, []string{"owner", "repo", "pullNumber"}) // Setup mock PR for success case @@ -160,6 +162,7 @@ func Test_UpdatePullRequest(t *testing.T) { HTMLURL: github.Ptr("https://github.com/owner/repo/pull/42"), Body: github.Ptr("Updated test PR body."), MaintainerCanModify: github.Ptr(false), + Draft: github.Ptr(false), Base: &github.PullRequestBranch{ Ref: github.Ptr("develop"), }, @@ -171,6 +174,17 @@ func Test_UpdatePullRequest(t *testing.T) { State: github.Ptr("closed"), // State updated } + // Mock PR for when there are no updates but we still need a response + mockPRWithReviewers := &github.PullRequest{ + Number: github.Ptr(42), + Title: github.Ptr("Test PR"), + State: github.Ptr("open"), + RequestedReviewers: []*github.User{ + {Login: github.Ptr("reviewer1")}, + {Login: github.Ptr("reviewer2")}, + }, + } + tests := []struct { name string mockedClient *http.Client @@ -194,6 +208,10 @@ func Test_UpdatePullRequest(t *testing.T) { mockResponse(t, http.StatusOK, mockUpdatedPR), ), ), + mock.WithRequestMatch( + mock.GetReposPullsByOwnerByRepoByPullNumber, + mockUpdatedPR, + ), ), requestArgs: map[string]interface{}{ "owner": "owner", @@ -218,6 +236,10 @@ func Test_UpdatePullRequest(t *testing.T) { mockResponse(t, http.StatusOK, mockClosedPR), ), ), + mock.WithRequestMatch( + mock.GetReposPullsByOwnerByRepoByPullNumber, + mockClosedPR, + ), ), requestArgs: map[string]interface{}{ "owner": "owner", @@ -228,6 +250,53 @@ func Test_UpdatePullRequest(t *testing.T) { expectError: false, expectedPR: mockClosedPR, }, + { + name: "successful PR update with reviewers", + mockedClient: mock.NewMockedHTTPClient( + // Mock for RequestReviewers call, returning the PR with reviewers + mock.WithRequestMatch( + mock.PostReposPullsRequestedReviewersByOwnerByRepoByPullNumber, + mockPRWithReviewers, + ), + mock.WithRequestMatch( + mock.GetReposPullsByOwnerByRepoByPullNumber, + mockPRWithReviewers, + ), + ), + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + "pullNumber": float64(42), + "reviewers": []interface{}{"reviewer1", "reviewer2"}, + }, + expectError: false, + expectedPR: mockPRWithReviewers, + }, + { + name: "successful PR update (title only)", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.PatchReposPullsByOwnerByRepoByPullNumber, + expectRequestBody(t, map[string]interface{}{ + "title": "Updated Test PR Title", + }).andThen( + mockResponse(t, http.StatusOK, mockUpdatedPR), + ), + ), + mock.WithRequestMatch( + mock.GetReposPullsByOwnerByRepoByPullNumber, + mockUpdatedPR, + ), + ), + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + "pullNumber": float64(42), + "title": "Updated Test PR Title", + }, + expectError: false, + expectedPR: mockUpdatedPR, + }, { name: "no update parameters provided", mockedClient: mock.NewMockedHTTPClient(), // No API call expected @@ -260,13 +329,34 @@ func Test_UpdatePullRequest(t *testing.T) { expectError: true, expectedErrMsg: "failed to update pull request", }, + { + name: "request reviewers fails", + mockedClient: mock.NewMockedHTTPClient( + // Then reviewer request fails + mock.WithRequestMatchHandler( + mock.PostReposPullsRequestedReviewersByOwnerByRepoByPullNumber, + http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusUnprocessableEntity) + _, _ = w.Write([]byte(`{"message": "Invalid reviewers"}`)) + }), + ), + ), + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + "pullNumber": float64(42), + "reviewers": []interface{}{"invalid-user"}, + }, + expectError: true, + expectedErrMsg: "failed to request reviewers", + }, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := UpdatePullRequest(stubGetClientFn(client), translations.NullTranslationHelper) + _, handler := UpdatePullRequest(stubGetClientFn(client), stubGetGQLClientFn(githubv4.NewClient(nil)), translations.NullTranslationHelper) // Create call request request := createMCPRequest(tc.requestArgs) @@ -312,6 +402,208 @@ func Test_UpdatePullRequest(t *testing.T) { if tc.expectedPR.MaintainerCanModify != nil { assert.Equal(t, *tc.expectedPR.MaintainerCanModify, *returnedPR.MaintainerCanModify) } + + // Check reviewers if they exist in the expected PR + if len(tc.expectedPR.RequestedReviewers) > 0 { + assert.NotNil(t, returnedPR.RequestedReviewers) + assert.Equal(t, len(tc.expectedPR.RequestedReviewers), len(returnedPR.RequestedReviewers)) + + // Create maps of reviewer logins for easy comparison + expectedReviewers := make(map[string]bool) + for _, reviewer := range tc.expectedPR.RequestedReviewers { + expectedReviewers[*reviewer.Login] = true + } + + actualReviewers := make(map[string]bool) + for _, reviewer := range returnedPR.RequestedReviewers { + actualReviewers[*reviewer.Login] = true + } + + // Compare the maps + assert.Equal(t, expectedReviewers, actualReviewers) + } + }) + } +} + +func Test_UpdatePullRequest_Draft(t *testing.T) { + // Setup mock PR for success case + mockUpdatedPR := &github.PullRequest{ + Number: github.Ptr(42), + Title: github.Ptr("Test PR Title"), + State: github.Ptr("open"), + HTMLURL: github.Ptr("https://github.com/owner/repo/pull/42"), + Body: github.Ptr("Test PR body."), + MaintainerCanModify: github.Ptr(false), + Draft: github.Ptr(false), // Updated to ready for review + Base: &github.PullRequestBranch{ + Ref: github.Ptr("main"), + }, + } + + tests := []struct { + name string + mockedClient *http.Client + requestArgs map[string]interface{} + expectError bool + expectedPR *github.PullRequest + expectedErrMsg string + }{ + { + name: "successful draft update to ready for review", + mockedClient: githubv4mock.NewMockedHTTPClient( + githubv4mock.NewQueryMatcher( + struct { + Repository struct { + PullRequest struct { + ID githubv4.ID + IsDraft githubv4.Boolean + } `graphql:"pullRequest(number: $prNum)"` + } `graphql:"repository(owner: $owner, name: $repo)"` + }{}, + map[string]any{ + "owner": githubv4.String("owner"), + "repo": githubv4.String("repo"), + "prNum": githubv4.Int(42), + }, + githubv4mock.DataResponse(map[string]any{ + "repository": map[string]any{ + "pullRequest": map[string]any{ + "id": "PR_kwDOA0xdyM50BPaO", + "isDraft": true, // Current state is draft + }, + }, + }), + ), + githubv4mock.NewMutationMatcher( + struct { + MarkPullRequestReadyForReview struct { + PullRequest struct { + ID githubv4.ID + IsDraft githubv4.Boolean + } + } `graphql:"markPullRequestReadyForReview(input: $input)"` + }{}, + githubv4.MarkPullRequestReadyForReviewInput{ + PullRequestID: "PR_kwDOA0xdyM50BPaO", + }, + nil, + githubv4mock.DataResponse(map[string]any{ + "markPullRequestReadyForReview": map[string]any{ + "pullRequest": map[string]any{ + "id": "PR_kwDOA0xdyM50BPaO", + "isDraft": false, + }, + }, + }), + ), + ), + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + "pullNumber": float64(42), + "draft": false, + }, + expectError: false, + expectedPR: mockUpdatedPR, + }, + { + name: "successful convert pull request to draft", + mockedClient: githubv4mock.NewMockedHTTPClient( + githubv4mock.NewQueryMatcher( + struct { + Repository struct { + PullRequest struct { + ID githubv4.ID + IsDraft githubv4.Boolean + } `graphql:"pullRequest(number: $prNum)"` + } `graphql:"repository(owner: $owner, name: $repo)"` + }{}, + map[string]any{ + "owner": githubv4.String("owner"), + "repo": githubv4.String("repo"), + "prNum": githubv4.Int(42), + }, + githubv4mock.DataResponse(map[string]any{ + "repository": map[string]any{ + "pullRequest": map[string]any{ + "id": "PR_kwDOA0xdyM50BPaO", + "isDraft": false, // Current state is draft + }, + }, + }), + ), + githubv4mock.NewMutationMatcher( + struct { + ConvertPullRequestToDraft struct { + PullRequest struct { + ID githubv4.ID + IsDraft githubv4.Boolean + } + } `graphql:"convertPullRequestToDraft(input: $input)"` + }{}, + githubv4.ConvertPullRequestToDraftInput{ + PullRequestID: "PR_kwDOA0xdyM50BPaO", + }, + nil, + githubv4mock.DataResponse(map[string]any{ + "convertPullRequestToDraft": map[string]any{ + "pullRequest": map[string]any{ + "id": "PR_kwDOA0xdyM50BPaO", + "isDraft": true, + }, + }, + }), + ), + ), + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + "pullNumber": float64(42), + "draft": true, + }, + expectError: false, + expectedPR: mockUpdatedPR, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // For draft-only tests, we need to mock both GraphQL and the final REST GET call + restClient := github.NewClient(mock.NewMockedHTTPClient( + mock.WithRequestMatch( + mock.GetReposPullsByOwnerByRepoByPullNumber, + mockUpdatedPR, + ), + )) + gqlClient := githubv4.NewClient(tc.mockedClient) + + _, handler := UpdatePullRequest(stubGetClientFn(restClient), stubGetGQLClientFn(gqlClient), translations.NullTranslationHelper) + + request := createMCPRequest(tc.requestArgs) + + result, err := handler(context.Background(), request) + + if tc.expectError || tc.expectedErrMsg != "" { + require.NoError(t, err) + require.True(t, result.IsError) + errorContent := getErrorResult(t, result) + if tc.expectedErrMsg != "" { + assert.Contains(t, errorContent.Text, tc.expectedErrMsg) + } + return + } + + require.NoError(t, err) + require.False(t, result.IsError) + + textContent := getTextResult(t, result) + + // Unmarshal and verify the successful result + var returnedPR github.PullRequest + err = json.Unmarshal([]byte(textContent.Text), &returnedPR) + require.NoError(t, err) + assert.Equal(t, *tc.expectedPR.Number, *returnedPR.Number) }) } } diff --git a/pkg/github/search.go b/pkg/github/search.go index b11bb3bbc..cbde0f7c6 100644 --- a/pkg/github/search.go +++ b/pkg/github/search.go @@ -16,14 +16,15 @@ import ( // SearchRepositories creates a tool to search for GitHub repositories. func SearchRepositories(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { return mcp.NewTool("search_repositories", - mcp.WithDescription(t("TOOL_SEARCH_REPOSITORIES_DESCRIPTION", "Search for GitHub repositories")), + mcp.WithDescription(t("TOOL_SEARCH_REPOSITORIES_DESCRIPTION", "Find GitHub repositories by name, description, readme, topics, or other metadata. Perfect for discovering projects, finding examples, or locating specific repositories across GitHub.")), + mcp.WithToolAnnotation(mcp.ToolAnnotation{ Title: t("TOOL_SEARCH_REPOSITORIES_USER_TITLE", "Search repositories"), ReadOnlyHint: ToBoolPtr(true), }), mcp.WithString("query", mcp.Required(), - mcp.Description("Search query"), + mcp.Description("Repository search query. Examples: 'machine learning in:name stars:>1000 language:python', 'topic:react', 'user:facebook'. Supports advanced search syntax for precise filtering."), ), WithPagination(), ), @@ -78,26 +79,26 @@ func SearchRepositories(getClient GetClientFn, t translations.TranslationHelperF // SearchCode creates a tool to search for code across GitHub repositories. func SearchCode(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { return mcp.NewTool("search_code", - mcp.WithDescription(t("TOOL_SEARCH_CODE_DESCRIPTION", "Search for code across GitHub repositories")), + mcp.WithDescription(t("TOOL_SEARCH_CODE_DESCRIPTION", "Fast and precise code search across ALL GitHub repositories using GitHub's native search engine. Best for finding exact symbols, functions, classes, or specific code patterns.")), mcp.WithToolAnnotation(mcp.ToolAnnotation{ Title: t("TOOL_SEARCH_CODE_USER_TITLE", "Search code"), ReadOnlyHint: ToBoolPtr(true), }), - mcp.WithString("q", + mcp.WithString("query", mcp.Required(), - mcp.Description("Search query using GitHub code search syntax"), + mcp.Description("Search query using GitHub's powerful code search syntax. Examples: 'content:Skill language:Java org:github', 'NOT is:archived language:Python OR language:go', 'repo:github/github-mcp-server'. Supports exact matching, language filters, path filters, and more."), ), mcp.WithString("sort", mcp.Description("Sort field ('indexed' only)"), ), mcp.WithString("order", - mcp.Description("Sort order"), + mcp.Description("Sort order for results"), mcp.Enum("asc", "desc"), ), WithPagination(), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - query, err := RequiredParam[string](request, "q") + query, err := RequiredParam[string](request, "query") if err != nil { return mcp.NewToolResultError(err.Error()), nil } @@ -258,17 +259,17 @@ func userOrOrgHandler(accountType string, getClient GetClientFn) server.ToolHand // SearchUsers creates a tool to search for GitHub users. func SearchUsers(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { return mcp.NewTool("search_users", - mcp.WithDescription(t("TOOL_SEARCH_USERS_DESCRIPTION", "Search for GitHub users exclusively")), + mcp.WithDescription(t("TOOL_SEARCH_USERS_DESCRIPTION", "Find GitHub users by username, real name, or other profile information. Useful for locating developers, contributors, or team members.")), mcp.WithToolAnnotation(mcp.ToolAnnotation{ Title: t("TOOL_SEARCH_USERS_USER_TITLE", "Search users"), ReadOnlyHint: ToBoolPtr(true), }), mcp.WithString("query", mcp.Required(), - mcp.Description("Search query using GitHub users search syntax scoped to type:user"), + mcp.Description("User search query. Examples: 'john smith', 'location:seattle', 'followers:>100'. Search is automatically scoped to type:user."), ), mcp.WithString("sort", - mcp.Description("Sort field by category"), + mcp.Description("Sort users by number of followers or repositories, or when the person joined GitHub."), mcp.Enum("followers", "repositories", "joined"), ), mcp.WithString("order", @@ -282,14 +283,15 @@ func SearchUsers(getClient GetClientFn, t translations.TranslationHelperFunc) (t // SearchOrgs creates a tool to search for GitHub organizations. func SearchOrgs(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { return mcp.NewTool("search_orgs", - mcp.WithDescription(t("TOOL_SEARCH_ORGS_DESCRIPTION", "Search for GitHub organizations exclusively")), + mcp.WithDescription(t("TOOL_SEARCH_ORGS_DESCRIPTION", "Find GitHub organizations by name, location, or other organization metadata. Ideal for discovering companies, open source foundations, or teams.")), + mcp.WithToolAnnotation(mcp.ToolAnnotation{ Title: t("TOOL_SEARCH_ORGS_USER_TITLE", "Search organizations"), ReadOnlyHint: ToBoolPtr(true), }), mcp.WithString("query", mcp.Required(), - mcp.Description("Search query using GitHub organizations search syntax scoped to type:org"), + mcp.Description("Organization search query. Examples: 'microsoft', 'location:california', 'created:>=2025-01-01'. Search is automatically scoped to type:org."), ), mcp.WithString("sort", mcp.Description("Sort field by category"), diff --git a/pkg/github/search_test.go b/pkg/github/search_test.go index 21f7a0ca2..9ea8e71ec 100644 --- a/pkg/github/search_test.go +++ b/pkg/github/search_test.go @@ -173,12 +173,12 @@ func Test_SearchCode(t *testing.T) { assert.Equal(t, "search_code", tool.Name) assert.NotEmpty(t, tool.Description) - assert.Contains(t, tool.InputSchema.Properties, "q") + assert.Contains(t, tool.InputSchema.Properties, "query") assert.Contains(t, tool.InputSchema.Properties, "sort") assert.Contains(t, tool.InputSchema.Properties, "order") assert.Contains(t, tool.InputSchema.Properties, "perPage") assert.Contains(t, tool.InputSchema.Properties, "page") - assert.ElementsMatch(t, tool.InputSchema.Required, []string{"q"}) + assert.ElementsMatch(t, tool.InputSchema.Required, []string{"query"}) // Setup mock search results mockSearchResult := &github.CodeSearchResult{ @@ -227,7 +227,7 @@ func Test_SearchCode(t *testing.T) { ), ), requestArgs: map[string]interface{}{ - "q": "fmt.Println language:go", + "query": "fmt.Println language:go", "sort": "indexed", "order": "desc", "page": float64(1), @@ -251,7 +251,7 @@ func Test_SearchCode(t *testing.T) { ), ), requestArgs: map[string]interface{}{ - "q": "fmt.Println language:go", + "query": "fmt.Println language:go", }, expectError: false, expectedResult: mockSearchResult, @@ -268,7 +268,7 @@ func Test_SearchCode(t *testing.T) { ), ), requestArgs: map[string]interface{}{ - "q": "invalid:query", + "query": "invalid:query", }, expectError: true, expectedErrMsg: "failed to search code", diff --git a/pkg/github/tools.go b/pkg/github/tools.go index e01b7cc40..7fb1d39c0 100644 --- a/pkg/github/tools.go +++ b/pkg/github/tools.go @@ -63,7 +63,10 @@ func DefaultToolsetGroup(readOnly bool, getClient GetClientFn, getGQLClient GetG toolsets.NewServerTool(AddSubIssue(getClient, t)), toolsets.NewServerTool(RemoveSubIssue(getClient, t)), toolsets.NewServerTool(ReprioritizeSubIssue(getClient, t)), - ).AddPrompts(toolsets.NewServerPrompt(AssignCodingAgentPrompt(t))) + ).AddPrompts( + toolsets.NewServerPrompt(AssignCodingAgentPrompt(t)), + toolsets.NewServerPrompt(IssueToFixWorkflowPrompt(t)), + ) users := toolsets.NewToolset("users", "GitHub User related tools"). AddReadTools( toolsets.NewServerTool(SearchUsers(getClient, t)), @@ -87,7 +90,7 @@ func DefaultToolsetGroup(readOnly bool, getClient GetClientFn, getGQLClient GetG toolsets.NewServerTool(MergePullRequest(getClient, t)), toolsets.NewServerTool(UpdatePullRequestBranch(getClient, t)), toolsets.NewServerTool(CreatePullRequest(getClient, t)), - toolsets.NewServerTool(UpdatePullRequest(getClient, t)), + toolsets.NewServerTool(UpdatePullRequest(getClient, getGQLClient, t)), toolsets.NewServerTool(RequestCopilotReview(getClient, t)), // Reviews @@ -161,6 +164,15 @@ func DefaultToolsetGroup(readOnly bool, getClient GetClientFn, getGQLClient GetG toolsets.NewServerTool(GetMe(getClient, t)), ) + gists := toolsets.NewToolset("gists", "GitHub Gist related tools"). + AddReadTools( + toolsets.NewServerTool(ListGists(getClient, t)), + ). + AddWriteTools( + toolsets.NewServerTool(CreateGist(getClient, t)), + toolsets.NewServerTool(UpdateGist(getClient, t)), + ) + // Add toolsets to the group tsg.AddToolset(contextTools) tsg.AddToolset(repos) @@ -175,6 +187,7 @@ func DefaultToolsetGroup(readOnly bool, getClient GetClientFn, getGQLClient GetG tsg.AddToolset(notifications) tsg.AddToolset(experiments) tsg.AddToolset(discussions) + tsg.AddToolset(gists) return tsg } diff --git a/pkg/github/workflow_prompts.go b/pkg/github/workflow_prompts.go new file mode 100644 index 000000000..8a9545a42 --- /dev/null +++ b/pkg/github/workflow_prompts.go @@ -0,0 +1,77 @@ +package github + +import ( + "context" + "fmt" + + "github.com/github/github-mcp-server/pkg/translations" + "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/server" +) + +// IssueToFixWorkflowPrompt provides a guided workflow for creating an issue and then generating a PR to fix it +func IssueToFixWorkflowPrompt(t translations.TranslationHelperFunc) (tool mcp.Prompt, handler server.PromptHandlerFunc) { + return mcp.NewPrompt("IssueToFixWorkflow", + mcp.WithPromptDescription(t("PROMPT_ISSUE_TO_FIX_WORKFLOW_DESCRIPTION", "Create an issue for a problem and then generate a pull request to fix it")), + mcp.WithArgument("owner", mcp.ArgumentDescription("Repository owner"), mcp.RequiredArgument()), + mcp.WithArgument("repo", mcp.ArgumentDescription("Repository name"), mcp.RequiredArgument()), + mcp.WithArgument("title", mcp.ArgumentDescription("Issue title"), mcp.RequiredArgument()), + mcp.WithArgument("description", mcp.ArgumentDescription("Issue description"), mcp.RequiredArgument()), + mcp.WithArgument("labels", mcp.ArgumentDescription("Comma-separated list of labels to apply (optional)")), + mcp.WithArgument("assignees", mcp.ArgumentDescription("Comma-separated list of assignees (optional)")), + ), func(_ context.Context, request mcp.GetPromptRequest) (*mcp.GetPromptResult, error) { + owner := request.Params.Arguments["owner"] + repo := request.Params.Arguments["repo"] + title := request.Params.Arguments["title"] + description := request.Params.Arguments["description"] + + labels := "" + if l, exists := request.Params.Arguments["labels"]; exists { + labels = fmt.Sprintf("%v", l) + } + + assignees := "" + if a, exists := request.Params.Arguments["assignees"]; exists { + assignees = fmt.Sprintf("%v", a) + } + + messages := []mcp.PromptMessage{ + { + Role: "system", + Content: mcp.NewTextContent("You are a development workflow assistant helping to create GitHub issues and generate corresponding pull requests to fix them. You should: 1) Create a well-structured issue with clear problem description, 2) Assign it to Copilot coding agent to generate a solution, and 3) Monitor the PR creation process."), + }, + { + Role: "user", + Content: mcp.NewTextContent(fmt.Sprintf("I need to create an issue titled '%s' in %s/%s and then have a PR generated to fix it. The issue description is: %s%s%s", + title, owner, repo, description, + func() string { + if labels != "" { + return fmt.Sprintf("\n\nLabels to apply: %s", labels) + } + return "" + }(), + func() string { + if assignees != "" { + return fmt.Sprintf("\nAssignees: %s", assignees) + } + return "" + }())), + }, + { + Role: "assistant", + Content: mcp.NewTextContent(fmt.Sprintf("I'll help you create the issue '%s' in %s/%s and then coordinate with Copilot to generate a fix. Let me start by creating the issue with the provided details.", title, owner, repo)), + }, + { + Role: "user", + Content: mcp.NewTextContent("Perfect! Please:\n1. Create the issue with the title, description, labels, and assignees\n2. Once created, assign it to Copilot coding agent to generate a solution\n3. Monitor the process and let me know when the PR is ready for review"), + }, + { + Role: "assistant", + Content: mcp.NewTextContent("Excellent plan! Here's what I'll do:\n\n1. ✅ Create the issue with all specified details\n2. 🤖 Assign to Copilot coding agent for automated fix\n3. 📋 Monitor progress and notify when PR is created\n4. 🔍 Provide PR details for your review\n\nLet me start by creating the issue."), + }, + } + return &mcp.GetPromptResult{ + Messages: messages, + }, nil + } +} pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy