diff --git a/pkg/github/repositories.go b/pkg/github/repositories.go index 8a7a8af4a..732f20ab1 100644 --- a/pkg/github/repositories.go +++ b/pkg/github/repositories.go @@ -8,7 +8,6 @@ import ( "io" "net/http" "net/url" - "strconv" "strings" ghErrors "github.com/github/github-mcp-server/pkg/errors" @@ -495,33 +494,18 @@ func GetFileContents(getClient GetClientFn, getRawClient raw.GetRawClientFn, t t return mcp.NewToolResultError(err.Error()), nil } - rawOpts := &raw.ContentOpts{} - - if strings.HasPrefix(ref, "refs/pull/") { - prNumber := strings.TrimSuffix(strings.TrimPrefix(ref, "refs/pull/"), "/head") - if len(prNumber) > 0 { - // fetch the PR from the API to get the latest commit and use SHA - githubClient, err := getClient(ctx) - if err != nil { - return nil, fmt.Errorf("failed to get GitHub client: %w", err) - } - prNum, err := strconv.Atoi(prNumber) - if err != nil { - return nil, fmt.Errorf("invalid pull request number: %w", err) - } - pr, _, err := githubClient.PullRequests.Get(ctx, owner, repo, prNum) - if err != nil { - return nil, fmt.Errorf("failed to get pull request: %w", err) - } - sha = pr.GetHead().GetSHA() - ref = "" - } + client, err := getClient(ctx) + if err != nil { + return mcp.NewToolResultError("failed to get GitHub client"), nil } - rawOpts.SHA = sha - rawOpts.Ref = ref + rawOpts, err := resolveGitReference(ctx, client, owner, repo, ref, sha) + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("failed to resolve git reference: %s", err)), nil + } - // If the path is (most likely) not to be a directory, we will first try to get the raw content from the GitHub raw content API. + // If the path is (most likely) not to be a directory, we will + // first try to get the raw content from the GitHub raw content API. if path != "" && !strings.HasSuffix(path, "/") { rawClient, err := getRawClient(ctx) @@ -580,36 +564,51 @@ func GetFileContents(getClient GetClientFn, getRawClient raw.GetRawClientFn, t t } } - client, err := getClient(ctx) - if err != nil { - return mcp.NewToolResultError("failed to get GitHub client"), nil - } - - if sha != "" { - ref = sha + if rawOpts.SHA != "" { + ref = rawOpts.SHA } if strings.HasSuffix(path, "/") { opts := &github.RepositoryContentGetOptions{Ref: ref} _, dirContent, resp, err := client.Repositories.GetContents(ctx, owner, repo, path, opts) - if err != nil { - return mcp.NewToolResultError("failed to get file contents"), nil - } - defer func() { _ = resp.Body.Close() }() - - if resp.StatusCode != 200 { - body, err := io.ReadAll(resp.Body) + if err == nil && resp.StatusCode == http.StatusOK { + defer func() { _ = resp.Body.Close() }() + r, err := json.Marshal(dirContent) if err != nil { - return mcp.NewToolResultError("failed to read response body"), nil + return mcp.NewToolResultError("failed to marshal response"), nil } - return mcp.NewToolResultError(fmt.Sprintf("failed to get file contents: %s", string(body))), nil + return mcp.NewToolResultText(string(r)), nil } + } + + // The path does not point to a file or directory. + // Instead let's try to find it in the Git Tree by matching the end of the path. + + // Step 1: Get Git Tree recursively + tree, resp, err := client.Git.GetTree(ctx, owner, repo, ref, true) + if err != nil { + return ghErrors.NewGitHubAPIErrorResponse(ctx, + "failed to get git tree", + resp, + err, + ), nil + } + defer func() { _ = resp.Body.Close() }() - r, err := json.Marshal(dirContent) + // Step 2: Filter tree for matching paths + const maxMatchingFiles = 3 + matchingFiles := filterPaths(tree.Entries, path, maxMatchingFiles) + if len(matchingFiles) > 0 { + matchingFilesJSON, err := json.Marshal(matchingFiles) + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("failed to marshal matching files: %s", err)), nil + } + resolvedRefs, err := json.Marshal(rawOpts) if err != nil { - return mcp.NewToolResultError("failed to marshal response"), nil + return mcp.NewToolResultError(fmt.Sprintf("failed to marshal resolved refs: %s", err)), nil } - return mcp.NewToolResultText(string(r)), nil + return mcp.NewToolResultText(fmt.Sprintf("Path did not point to a file or directory, but resolved git ref to %s with possible path matches: %s", resolvedRefs, matchingFilesJSON)), nil } + return mcp.NewToolResultError("Failed to get file contents. The path does not point to a file or directory, or the file does not exist in the repository."), nil } } @@ -1293,3 +1292,74 @@ func GetTag(getClient GetClientFn, t translations.TranslationHelperFunc) (tool m return mcp.NewToolResultText(string(r)), nil } } + +// filterPaths filters the entries in a GitHub tree to find paths that +// match the given suffix. +// maxResults limits the number of results returned to first maxResults entries, +// a maxResults of -1 means no limit. +// It returns a slice of strings containing the matching paths. +// Directories are returned with a trailing slash. +func filterPaths(entries []*github.TreeEntry, path string, maxResults int) []string { + // Remove trailing slash for matching purposes, but flag whether we + // only want directories. + dirOnly := false + if strings.HasSuffix(path, "/") { + dirOnly = true + path = strings.TrimSuffix(path, "/") + } + + matchedPaths := []string{} + for _, entry := range entries { + if len(matchedPaths) == maxResults { + break // Limit the number of results to maxResults + } + if dirOnly && entry.GetType() != "tree" { + continue // Skip non-directory entries if dirOnly is true + } + entryPath := entry.GetPath() + if entryPath == "" { + continue // Skip empty paths + } + if strings.HasSuffix(entryPath, path) { + if entry.GetType() == "tree" { + entryPath += "/" // Return directories with a trailing slash + } + matchedPaths = append(matchedPaths, entryPath) + } + } + return matchedPaths +} + +// resolveGitReference resolves git references with the following logic: +// 1. If SHA is provided, it takes precedence +// 2. If neither is provided, use the default branch as ref +// 3. Get commit SHA from the ref +// Refs can look like `refs/tags/{tag}`, `refs/heads/{branch}` or `refs/pull/{pr_number}/head` +// The function returns the resolved ref, commit SHA and any error. +func resolveGitReference(ctx context.Context, githubClient *github.Client, owner, repo, ref, sha string) (*raw.ContentOpts, error) { + // 1. If SHA is provided, use it directly + if sha != "" { + return &raw.ContentOpts{Ref: "", SHA: sha}, nil + } + + // 2. If neither provided, use the default branch as ref + if ref == "" { + repoInfo, resp, err := githubClient.Repositories.Get(ctx, owner, repo) + if err != nil { + _, _ = ghErrors.NewGitHubAPIErrorToCtx(ctx, "failed to get repository info", resp, err) + return nil, fmt.Errorf("failed to get repository info: %w", err) + } + ref = fmt.Sprintf("refs/heads/%s", repoInfo.GetDefaultBranch()) + } + + // 3. Get the SHA from the ref + reference, resp, err := githubClient.Git.GetRef(ctx, owner, repo, ref) + if err != nil { + _, _ = ghErrors.NewGitHubAPIErrorToCtx(ctx, "failed to get reference", resp, err) + return nil, fmt.Errorf("failed to get reference: %w", err) + } + sha = reference.GetObject().GetSHA() + + // Use provided ref, or it will be empty which defaults to the default branch + return &raw.ContentOpts{Ref: ref, SHA: sha}, nil +} diff --git a/pkg/github/repositories_test.go b/pkg/github/repositories_test.go index b621cec43..0b9c5d9f9 100644 --- a/pkg/github/repositories_test.go +++ b/pkg/github/repositories_test.go @@ -69,6 +69,13 @@ func Test_GetFileContents(t *testing.T) { { name: "successful text content fetch", mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.GetReposGitRefByOwnerByRepoByRef, + http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"ref": "refs/heads/main", "object": {"sha": ""}}`)) + }), + ), mock.WithRequestMatchHandler( raw.GetRawReposContentsByOwnerByRepoByBranchByPath, http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { @@ -93,6 +100,13 @@ func Test_GetFileContents(t *testing.T) { { name: "successful file blob content fetch", mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.GetReposGitRefByOwnerByRepoByRef, + http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"ref": "refs/heads/main", "object": {"sha": ""}}`)) + }), + ), mock.WithRequestMatchHandler( raw.GetRawReposContentsByOwnerByRepoByBranchByPath, http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { @@ -117,6 +131,20 @@ func Test_GetFileContents(t *testing.T) { { name: "successful directory content fetch", mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.GetReposByOwnerByRepo, + http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"name": "repo", "default_branch": "main"}`)) + }), + ), + mock.WithRequestMatchHandler( + mock.GetReposGitRefByOwnerByRepoByRef, + http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"ref": "refs/heads/main", "object": {"sha": ""}}`)) + }), + ), mock.WithRequestMatchHandler( mock.GetReposContentsByOwnerByRepoByPath, expectQueryParams(t, map[string]string{}).andThen( @@ -143,6 +171,13 @@ func Test_GetFileContents(t *testing.T) { { name: "content fetch fails", mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.GetReposGitRefByOwnerByRepoByRef, + http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"ref": "refs/heads/main", "object": {"sha": ""}}`)) + }), + ), mock.WithRequestMatchHandler( mock.GetReposContentsByOwnerByRepoByPath, http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { @@ -203,7 +238,7 @@ func Test_GetFileContents(t *testing.T) { textContent := getTextResult(t, result) var returnedContents []*github.RepositoryContent err = json.Unmarshal([]byte(textContent.Text), &returnedContents) - require.NoError(t, err) + require.NoError(t, err, "Failed to unmarshal directory content result: %v", textContent.Text) assert.Len(t, returnedContents, len(expected)) for i, content := range returnedContents { assert.Equal(t, *expected[i].Name, *content.Name) @@ -2049,3 +2084,170 @@ func Test_GetTag(t *testing.T) { }) } } + +func Test_filterPaths(t *testing.T) { + tests := []struct { + name string + tree []*github.TreeEntry + path string + maxResults int + expected []string + }{ + { + name: "file name", + tree: []*github.TreeEntry{ + {Path: github.Ptr("folder/foo.txt"), Type: github.Ptr("blob")}, + {Path: github.Ptr("bar.txt"), Type: github.Ptr("blob")}, + {Path: github.Ptr("nested/folder/foo.txt"), Type: github.Ptr("blob")}, + {Path: github.Ptr("nested/folder/baz.txt"), Type: github.Ptr("blob")}, + }, + path: "foo.txt", + maxResults: -1, + expected: []string{"folder/foo.txt", "nested/folder/foo.txt"}, + }, + { + name: "dir name", + tree: []*github.TreeEntry{ + {Path: github.Ptr("folder"), Type: github.Ptr("tree")}, + {Path: github.Ptr("bar.txt"), Type: github.Ptr("blob")}, + {Path: github.Ptr("nested/folder"), Type: github.Ptr("tree")}, + {Path: github.Ptr("nested/folder/baz.txt"), Type: github.Ptr("blob")}, + }, + path: "folder/", + maxResults: -1, + expected: []string{"folder/", "nested/folder/"}, + }, + { + name: "dir and file match", + tree: []*github.TreeEntry{ + {Path: github.Ptr("name"), Type: github.Ptr("tree")}, + {Path: github.Ptr("name"), Type: github.Ptr("blob")}, + }, + path: "name", // No trailing slash can match both files and directories + maxResults: -1, + expected: []string{"name/", "name"}, + }, + { + name: "dir only match", + tree: []*github.TreeEntry{ + {Path: github.Ptr("name"), Type: github.Ptr("tree")}, + {Path: github.Ptr("name"), Type: github.Ptr("blob")}, + }, + path: "name/", // Trialing slash ensures only directories are matched + maxResults: -1, + expected: []string{"name/"}, + }, + { + name: "max results limit 2", + tree: []*github.TreeEntry{ + {Path: github.Ptr("folder"), Type: github.Ptr("tree")}, + {Path: github.Ptr("nested/folder"), Type: github.Ptr("tree")}, + {Path: github.Ptr("nested/nested/folder"), Type: github.Ptr("tree")}, + }, + path: "folder/", + maxResults: 2, + expected: []string{"folder/", "nested/folder/"}, + }, + { + name: "max results limit 1", + tree: []*github.TreeEntry{ + {Path: github.Ptr("folder"), Type: github.Ptr("tree")}, + {Path: github.Ptr("nested/folder"), Type: github.Ptr("tree")}, + {Path: github.Ptr("nested/nested/folder"), Type: github.Ptr("tree")}, + }, + path: "folder/", + maxResults: 1, + expected: []string{"folder/"}, + }, + { + name: "max results limit 0", + tree: []*github.TreeEntry{ + {Path: github.Ptr("folder"), Type: github.Ptr("tree")}, + {Path: github.Ptr("nested/folder"), Type: github.Ptr("tree")}, + {Path: github.Ptr("nested/nested/folder"), Type: github.Ptr("tree")}, + }, + path: "folder/", + maxResults: 0, + expected: []string{}, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + result := filterPaths(tc.tree, tc.path, tc.maxResults) + assert.Equal(t, tc.expected, result) + }) + } +} + +func Test_resolveGitReference(t *testing.T) { + ctx := context.Background() + owner := "owner" + repo := "repo" + mockedClient := mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.GetReposByOwnerByRepo, + http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"name": "repo", "default_branch": "main"}`)) + }), + ), + mock.WithRequestMatchHandler( + mock.GetReposGitRefByOwnerByRepoByRef, + http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"ref": "refs/heads/main", "object": {"sha": "123sha456"}}`)) + }), + ), + ) + + tests := []struct { + name string + ref string + sha string + expectedOutput *raw.ContentOpts + }{ + { + name: "sha takes precedence over ref", + ref: "refs/heads/main", + sha: "123sha456", + expectedOutput: &raw.ContentOpts{ + SHA: "123sha456", + }, + }, + { + name: "use default branch if ref and sha both empty", + ref: "", + sha: "", + expectedOutput: &raw.ContentOpts{ + Ref: "refs/heads/main", + SHA: "123sha456", + }, + }, + { + name: "get SHA from ref", + ref: "refs/heads/main", + sha: "", + expectedOutput: &raw.ContentOpts{ + Ref: "refs/heads/main", + SHA: "123sha456", + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Setup client with mock + client := github.NewClient(mockedClient) + opts, err := resolveGitReference(ctx, client, owner, repo, tc.ref, tc.sha) + require.NoError(t, err) + + if tc.expectedOutput.SHA != "" { + assert.Equal(t, tc.expectedOutput.SHA, opts.SHA) + } + if tc.expectedOutput.Ref != "" { + assert.Equal(t, tc.expectedOutput.Ref, opts.Ref) + } + }) + } +}
Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.
Alternative Proxies: