Skip to content

Commit 1235550

Browse files
johnstcnkylecarbs
andauthored
feat(codersdk): add toolsdk and replace existing mcp server tool impl (#17343)
- Refactors existing `mcp` package to use `kylecarbs/aisdk-go` and moves to `codersdk/toolsdk` package. - Updates existing MCP server implementation to use `codersdk/toolsdk` Co-authored-by: Kyle Carberry <kyle@coder.com>
1 parent 2c573dc commit 1235550

File tree

9 files changed

+1774
-1096
lines changed

9 files changed

+1774
-1096
lines changed

cli/exp_mcp.go

Lines changed: 65 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -6,19 +6,19 @@ import (
66
"errors"
77
"os"
88
"path/filepath"
9+
"slices"
910
"strings"
1011

12+
"github.com/mark3labs/mcp-go/mcp"
1113
"github.com/mark3labs/mcp-go/server"
1214
"github.com/spf13/afero"
1315
"golang.org/x/xerrors"
1416

15-
"cdr.dev/slog"
16-
"cdr.dev/slog/sloggers/sloghuman"
1717
"github.com/coder/coder/v2/buildinfo"
1818
"github.com/coder/coder/v2/cli/cliui"
1919
"github.com/coder/coder/v2/codersdk"
2020
"github.com/coder/coder/v2/codersdk/agentsdk"
21-
codermcp "github.com/coder/coder/v2/mcp"
21+
"github.com/coder/coder/v2/codersdk/toolsdk"
2222
"github.com/coder/serpent"
2323
)
2424

@@ -365,6 +365,8 @@ func mcpServerHandler(inv *serpent.Invocation, client *codersdk.Client, instruct
365365
ctx, cancel := context.WithCancel(inv.Context())
366366
defer cancel()
367367

368+
fs := afero.NewOsFs()
369+
368370
me, err := client.User(ctx, codersdk.Me)
369371
if err != nil {
370372
cliui.Errorf(inv.Stderr, "Failed to log in to the Coder deployment.")
@@ -397,40 +399,36 @@ func mcpServerHandler(inv *serpent.Invocation, client *codersdk.Client, instruct
397399
server.WithInstructions(instructions),
398400
)
399401

400-
// Create a separate logger for the tools.
401-
toolLogger := slog.Make(sloghuman.Sink(invStderr))
402-
403-
toolDeps := codermcp.ToolDeps{
404-
Client: client,
405-
Logger: &toolLogger,
406-
AppStatusSlug: appStatusSlug,
407-
AgentClient: agentsdk.New(client.URL),
408-
}
409-
402+
// Create a new context for the tools with all relevant information.
403+
clientCtx := toolsdk.WithClient(ctx, client)
410404
// Get the workspace agent token from the environment.
411-
agentToken, ok := os.LookupEnv("CODER_AGENT_TOKEN")
412-
if ok && agentToken != "" {
413-
toolDeps.AgentClient.SetSessionToken(agentToken)
405+
if agentToken, err := getAgentToken(fs); err == nil && agentToken != "" {
406+
agentClient := agentsdk.New(client.URL)
407+
agentClient.SetSessionToken(agentToken)
408+
clientCtx = toolsdk.WithAgentClient(clientCtx, agentClient)
414409
} else {
415410
cliui.Warnf(inv.Stderr, "CODER_AGENT_TOKEN is not set, task reporting will not be available")
416411
}
417-
if appStatusSlug == "" {
412+
if appStatusSlug != "" {
418413
cliui.Warnf(inv.Stderr, "CODER_MCP_APP_STATUS_SLUG is not set, task reporting will not be available.")
414+
} else {
415+
clientCtx = toolsdk.WithWorkspaceAppStatusSlug(clientCtx, appStatusSlug)
419416
}
420417

421418
// Register tools based on the allowlist (if specified)
422-
reg := codermcp.AllTools()
423-
if len(allowedTools) > 0 {
424-
reg = reg.WithOnlyAllowed(allowedTools...)
419+
for _, tool := range toolsdk.All {
420+
if len(allowedTools) == 0 || slices.ContainsFunc(allowedTools, func(t string) bool {
421+
return t == tool.Tool.Name
422+
}) {
423+
mcpSrv.AddTools(mcpFromSDK(tool))
424+
}
425425
}
426426

427-
reg.Register(mcpSrv, toolDeps)
428-
429427
srv := server.NewStdioServer(mcpSrv)
430428
done := make(chan error)
431429
go func() {
432430
defer close(done)
433-
srvErr := srv.Listen(ctx, invStdin, invStdout)
431+
srvErr := srv.Listen(clientCtx, invStdin, invStdout)
434432
done <- srvErr
435433
}()
436434

@@ -527,8 +525,8 @@ func configureClaude(fs afero.Fs, cfg ClaudeConfig) error {
527525
if !ok {
528526
mcpServers = make(map[string]any)
529527
}
530-
for name, mcp := range cfg.MCPServers {
531-
mcpServers[name] = mcp
528+
for name, cfgmcp := range cfg.MCPServers {
529+
mcpServers[name] = cfgmcp
532530
}
533531
project["mcpServers"] = mcpServers
534532
// Prevents Claude from asking the user to complete the project onboarding.
@@ -674,7 +672,7 @@ func indexOf(s, substr string) int {
674672

675673
func getAgentToken(fs afero.Fs) (string, error) {
676674
token, ok := os.LookupEnv("CODER_AGENT_TOKEN")
677-
if ok {
675+
if ok && token != "" {
678676
return token, nil
679677
}
680678
tokenFile, ok := os.LookupEnv("CODER_AGENT_TOKEN_FILE")
@@ -687,3 +685,44 @@ func getAgentToken(fs afero.Fs) (string, error) {
687685
}
688686
return string(bs), nil
689687
}
688+
689+
// mcpFromSDK adapts a toolsdk.Tool to go-mcp's server.ServerTool.
690+
// It assumes that the tool responds with a valid JSON object.
691+
func mcpFromSDK(sdkTool toolsdk.Tool[any]) server.ServerTool {
692+
return server.ServerTool{
693+
Tool: mcp.Tool{
694+
Name: sdkTool.Tool.Name,
695+
Description: sdkTool.Description,
696+
InputSchema: mcp.ToolInputSchema{
697+
Type: "object", // Default of mcp.NewTool()
698+
Properties: sdkTool.Schema.Properties,
699+
Required: sdkTool.Schema.Required,
700+
},
701+
},
702+
Handler: func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
703+
result, err := sdkTool.Handler(ctx, request.Params.Arguments)
704+
if err != nil {
705+
return nil, err
706+
}
707+
var sb strings.Builder
708+
if err := json.NewEncoder(&sb).Encode(result); err == nil {
709+
return &mcp.CallToolResult{
710+
Content: []mcp.Content{
711+
mcp.NewTextContent(sb.String()),
712+
},
713+
}, nil
714+
}
715+
// If the result is not JSON, return it as a string.
716+
// This is a fallback for tools that return non-JSON data.
717+
resultStr, ok := result.(string)
718+
if !ok {
719+
return nil, xerrors.Errorf("tool call result is neither valid JSON or a string, got: %T", result)
720+
}
721+
return &mcp.CallToolResult{
722+
Content: []mcp.Content{
723+
mcp.NewTextContent(resultStr),
724+
},
725+
}, nil
726+
},
727+
}
728+
}

cli/exp_mcp_test.go

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,12 +39,13 @@ func TestExpMcpServer(t *testing.T) {
3939
_ = coderdtest.CreateFirstUser(t, client)
4040

4141
// Given: we run the exp mcp command with allowed tools set
42-
inv, root := clitest.New(t, "exp", "mcp", "server", "--allowed-tools=coder_whoami,coder_list_templates")
42+
inv, root := clitest.New(t, "exp", "mcp", "server", "--allowed-tools=coder_get_authenticated_user")
4343
inv = inv.WithContext(cancelCtx)
4444

4545
pty := ptytest.New(t)
4646
inv.Stdin = pty.Input()
4747
inv.Stdout = pty.Output()
48+
// nolint: gocritic // not the focus of this test
4849
clitest.SetupConfig(t, client, root)
4950

5051
cmdDone := make(chan struct{})
@@ -73,13 +74,13 @@ func TestExpMcpServer(t *testing.T) {
7374
}
7475
err := json.Unmarshal([]byte(output), &toolsResponse)
7576
require.NoError(t, err)
76-
require.Len(t, toolsResponse.Result.Tools, 2, "should have exactly 2 tools")
77+
require.Len(t, toolsResponse.Result.Tools, 1, "should have exactly 1 tool")
7778
foundTools := make([]string, 0, 2)
7879
for _, tool := range toolsResponse.Result.Tools {
7980
foundTools = append(foundTools, tool.Name)
8081
}
8182
slices.Sort(foundTools)
82-
require.Equal(t, []string{"coder_list_templates", "coder_whoami"}, foundTools)
83+
require.Equal(t, []string{"coder_get_authenticated_user"}, foundTools)
8384
})
8485

8586
t.Run("OK", func(t *testing.T) {

coderd/database/dbfake/dbfake.go

Lines changed: 29 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -287,23 +287,25 @@ type TemplateVersionResponse struct {
287287
}
288288

289289
type TemplateVersionBuilder struct {
290-
t testing.TB
291-
db database.Store
292-
seed database.TemplateVersion
293-
fileID uuid.UUID
294-
ps pubsub.Pubsub
295-
resources []*sdkproto.Resource
296-
params []database.TemplateVersionParameter
297-
promote bool
290+
t testing.TB
291+
db database.Store
292+
seed database.TemplateVersion
293+
fileID uuid.UUID
294+
ps pubsub.Pubsub
295+
resources []*sdkproto.Resource
296+
params []database.TemplateVersionParameter
297+
promote bool
298+
autoCreateTemplate bool
298299
}
299300

300301
// TemplateVersion generates a template version and optionally a parent
301302
// template if no template ID is set on the seed.
302303
func TemplateVersion(t testing.TB, db database.Store) TemplateVersionBuilder {
303304
return TemplateVersionBuilder{
304-
t: t,
305-
db: db,
306-
promote: true,
305+
t: t,
306+
db: db,
307+
promote: true,
308+
autoCreateTemplate: true,
307309
}
308310
}
309311

@@ -337,6 +339,13 @@ func (t TemplateVersionBuilder) Params(ps ...database.TemplateVersionParameter)
337339
return t
338340
}
339341

342+
func (t TemplateVersionBuilder) SkipCreateTemplate() TemplateVersionBuilder {
343+
// nolint: revive // returns modified struct
344+
t.autoCreateTemplate = false
345+
t.promote = false
346+
return t
347+
}
348+
340349
func (t TemplateVersionBuilder) Do() TemplateVersionResponse {
341350
t.t.Helper()
342351

@@ -347,7 +356,7 @@ func (t TemplateVersionBuilder) Do() TemplateVersionResponse {
347356
t.fileID = takeFirst(t.fileID, uuid.New())
348357

349358
var resp TemplateVersionResponse
350-
if t.seed.TemplateID.UUID == uuid.Nil {
359+
if t.seed.TemplateID.UUID == uuid.Nil && t.autoCreateTemplate {
351360
resp.Template = dbgen.Template(t.t, t.db, database.Template{
352361
ActiveVersionID: t.seed.ID,
353362
OrganizationID: t.seed.OrganizationID,
@@ -360,16 +369,14 @@ func (t TemplateVersionBuilder) Do() TemplateVersionResponse {
360369
}
361370

362371
version := dbgen.TemplateVersion(t.t, t.db, t.seed)
363-
364-
// Always make this version the active version. We can easily
365-
// add a conditional to the builder to opt out of this when
366-
// necessary.
367-
err := t.db.UpdateTemplateActiveVersionByID(ownerCtx, database.UpdateTemplateActiveVersionByIDParams{
368-
ID: t.seed.TemplateID.UUID,
369-
ActiveVersionID: t.seed.ID,
370-
UpdatedAt: dbtime.Now(),
371-
})
372-
require.NoError(t.t, err)
372+
if t.promote {
373+
err := t.db.UpdateTemplateActiveVersionByID(ownerCtx, database.UpdateTemplateActiveVersionByIDParams{
374+
ID: t.seed.TemplateID.UUID,
375+
ActiveVersionID: t.seed.ID,
376+
UpdatedAt: dbtime.Now(),
377+
})
378+
require.NoError(t.t, err)
379+
}
373380

374381
payload, err := json.Marshal(provisionerdserver.TemplateVersionImportJob{
375382
TemplateVersionID: t.seed.ID,

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy