Skip to content

Commit 9edd5f7

Browse files
committed
fix(codersdk/toolsdk): address type incompatibility issues
1 parent c1057d9 commit 9edd5f7

File tree

4 files changed

+162
-119
lines changed

4 files changed

+162
-119
lines changed

cli/exp_mcp.go

Lines changed: 8 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package cli
22

33
import (
4+
"bytes"
45
"context"
56
"encoding/json"
67
"errors"
@@ -697,7 +698,7 @@ func getAgentToken(fs afero.Fs) (string, error) {
697698

698699
// mcpFromSDK adapts a toolsdk.Tool to go-mcp's server.ServerTool.
699700
// It assumes that the tool responds with a valid JSON object.
700-
func mcpFromSDK(sdkTool toolsdk.Tool[any, any], tb toolsdk.Deps) server.ServerTool {
701+
func mcpFromSDK(sdkTool toolsdk.GenericTool, tb toolsdk.Deps) server.ServerTool {
701702
// NOTE: some clients will silently refuse to use tools if there is an issue
702703
// with the tool's schema or configuration.
703704
if sdkTool.Schema.Properties == nil {
@@ -714,27 +715,17 @@ func mcpFromSDK(sdkTool toolsdk.Tool[any, any], tb toolsdk.Deps) server.ServerTo
714715
},
715716
},
716717
Handler: func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
717-
result, err := sdkTool.Handler(ctx, tb, request.Params.Arguments)
718+
var buf bytes.Buffer
719+
if err := json.NewEncoder(&buf).Encode(request.Params.Arguments); err != nil {
720+
return nil, xerrors.Errorf("failed to encode request arguments: %w", err)
721+
}
722+
result, err := sdkTool.Handler(ctx, tb, buf.Bytes())
718723
if err != nil {
719724
return nil, err
720725
}
721-
var sb strings.Builder
722-
if err := json.NewEncoder(&sb).Encode(result); err == nil {
723-
return &mcp.CallToolResult{
724-
Content: []mcp.Content{
725-
mcp.NewTextContent(sb.String()),
726-
},
727-
}, nil
728-
}
729-
// If the result is not JSON, return it as a string.
730-
// This is a fallback for tools that return non-JSON data.
731-
resultStr, ok := result.(string)
732-
if !ok {
733-
return nil, xerrors.Errorf("tool call result is neither valid JSON or a string, got: %T", result)
734-
}
735726
return &mcp.CallToolResult{
736727
Content: []mcp.Content{
737-
mcp.NewTextContent(resultStr),
728+
mcp.NewTextContent(string(result)),
738729
},
739730
}, nil
740731
},

cli/exp_mcp_test.go

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,16 @@ func TestExpMcpServer(t *testing.T) {
3131
t.Parallel()
3232

3333
ctx := testutil.Context(t, testutil.WaitShort)
34+
cmdDone := make(chan struct{})
3435
cancelCtx, cancel := context.WithCancel(ctx)
35-
t.Cleanup(cancel)
36+
t.Cleanup(func() {
37+
cancel()
38+
<-cmdDone
39+
})
3640

3741
// Given: a running coder deployment
3842
client := coderdtest.New(t, nil)
39-
_ = coderdtest.CreateFirstUser(t, client)
43+
owner := coderdtest.CreateFirstUser(t, client)
4044

4145
// Given: we run the exp mcp command with allowed tools set
4246
inv, root := clitest.New(t, "exp", "mcp", "server", "--allowed-tools=coder_get_authenticated_user")
@@ -48,7 +52,6 @@ func TestExpMcpServer(t *testing.T) {
4852
// nolint: gocritic // not the focus of this test
4953
clitest.SetupConfig(t, client, root)
5054

51-
cmdDone := make(chan struct{})
5255
go func() {
5356
defer close(cmdDone)
5457
err := inv.Run()
@@ -61,9 +64,6 @@ func TestExpMcpServer(t *testing.T) {
6164
_ = pty.ReadLine(ctx) // ignore echoed output
6265
output := pty.ReadLine(ctx)
6366

64-
cancel()
65-
<-cmdDone
66-
6767
// Then: we should only see the allowed tools in the response
6868
var toolsResponse struct {
6969
Result struct {
@@ -81,6 +81,18 @@ func TestExpMcpServer(t *testing.T) {
8181
}
8282
slices.Sort(foundTools)
8383
require.Equal(t, []string{"coder_get_authenticated_user"}, foundTools)
84+
85+
// Call the tool and ensure it works.
86+
toolPayload := `{"jsonrpc":"2.0","id":3,"method":"tools/call", "params": {"name": "coder_get_authenticated_user", "arguments": {}}}`
87+
pty.WriteLine(toolPayload)
88+
_ = pty.ReadLine(ctx) // ignore echoed output
89+
output = pty.ReadLine(ctx)
90+
require.NotEmpty(t, output, "should have received a response from the tool")
91+
// Ensure it's valid JSON
92+
_, err = json.Marshal(output)
93+
require.NoError(t, err, "should have received a valid JSON response from the tool")
94+
// Ensure the tool returns the expected user
95+
require.Contains(t, output, owner.UserID.String(), "should have received the expected user ID")
8496
})
8597

8698
t.Run("OK", func(t *testing.T) {

codersdk/toolsdk/toolsdk.go

Lines changed: 66 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@ package toolsdk
22

33
import (
44
"archive/tar"
5+
"bytes"
56
"context"
7+
"encoding/json"
68
"io"
79

810
"github.com/google/uuid"
@@ -20,28 +22,49 @@ type Deps struct {
2022
AppStatusSlug string
2123
}
2224

23-
// HandlerFunc is a function that handles a tool call.
25+
// HandlerFunc is a typed function that handles a tool call.
2426
type HandlerFunc[Arg, Ret any] func(context.Context, Deps, Arg) (Ret, error)
2527

28+
// Tool consists of an aisdk.Tool and a corresponding typed handler function.
2629
type Tool[Arg, Ret any] struct {
2730
aisdk.Tool
2831
Handler HandlerFunc[Arg, Ret]
2932
}
3033

31-
// Generic returns a type-erased version of the Tool.
32-
func (t Tool[Arg, Ret]) Generic() Tool[any, any] {
33-
return Tool[any, any]{
34+
// Generic returns a type-erased version of a TypedTool where the arguments and
35+
// return values are converted to/from json.RawMessage.
36+
// This allows the tool to be referenced without knowing the concrete arguments
37+
// or return values. The original TypedHandlerFunc is wrapped to handle type
38+
// conversion.
39+
func (t Tool[Arg, Ret]) Generic() GenericTool {
40+
return GenericTool{
3441
Tool: t.Tool,
35-
Handler: func(ctx context.Context, tb Deps, args any) (any, error) {
36-
typedArg, ok := args.(Arg)
37-
if !ok {
38-
return nil, xerrors.Errorf("developer error: invalid argument type for tool %s", t.Tool.Name)
42+
Handler: wrap(func(ctx context.Context, tb Deps, args json.RawMessage) (json.RawMessage, error) {
43+
var typedArgs Arg
44+
if err := json.Unmarshal(args, &typedArgs); err != nil {
45+
return nil, xerrors.Errorf("failed to unmarshal args: %w", err)
3946
}
40-
return t.Handler(ctx, tb, typedArg)
41-
},
47+
ret, err := t.Handler(ctx, tb, typedArgs)
48+
var buf bytes.Buffer
49+
if err := json.NewEncoder(&buf).Encode(ret); err != nil {
50+
return json.RawMessage{}, err
51+
}
52+
return buf.Bytes(), err
53+
}, WithCleanContext, WithRecover),
4254
}
4355
}
4456

57+
// GenericTool is a type-erased wrapper for GenericTool.
58+
// This allows referencing the tool without knowing the concrete argument or
59+
// return type. The Handler function allows calling the tool with known types.
60+
type GenericTool struct {
61+
aisdk.Tool
62+
Handler GenericHandlerFunc
63+
}
64+
65+
// GenericHandlerFunc is a function that handles a tool call.
66+
type GenericHandlerFunc func(context.Context, Deps, json.RawMessage) (json.RawMessage, error)
67+
4568
type NoArgs struct{}
4669

4770
type ReportTaskArgs struct {
@@ -114,8 +137,8 @@ type UploadTarFileArgs struct {
114137
}
115138

116139
// WithRecover wraps a HandlerFunc to recover from panics and return an error.
117-
func WithRecover[Arg, Ret any](h HandlerFunc[Arg, Ret]) HandlerFunc[Arg, Ret] {
118-
return func(ctx context.Context, tb Deps, args Arg) (ret Ret, err error) {
140+
func WithRecover(h GenericHandlerFunc) GenericHandlerFunc {
141+
return func(ctx context.Context, tb Deps, args json.RawMessage) (ret json.RawMessage, err error) {
119142
defer func() {
120143
if r := recover(); r != nil {
121144
err = xerrors.Errorf("tool handler panic: %v", r)
@@ -129,8 +152,8 @@ func WithRecover[Arg, Ret any](h HandlerFunc[Arg, Ret]) HandlerFunc[Arg, Ret] {
129152
// This ensures that no data is passed using context.Value.
130153
// If a deadline is set on the parent context, it will be passed to the child
131154
// context.
132-
func WithCleanContext[Arg, Ret any](h HandlerFunc[Arg, Ret]) HandlerFunc[Arg, Ret] {
133-
return func(parent context.Context, tb Deps, args Arg) (ret Ret, err error) {
155+
func WithCleanContext(h GenericHandlerFunc) GenericHandlerFunc {
156+
return func(parent context.Context, tb Deps, args json.RawMessage) (ret json.RawMessage, err error) {
134157
child, childCancel := context.WithCancel(context.Background())
135158
defer childCancel()
136159
// Ensure that cancellation propagates from the parent context to the child context.
@@ -153,19 +176,18 @@ func WithCleanContext[Arg, Ret any](h HandlerFunc[Arg, Ret]) HandlerFunc[Arg, Re
153176
}
154177
}
155178

156-
// wrapAll wraps all provided tools with the given middleware function.
157-
func wrapAll(mw func(HandlerFunc[any, any]) HandlerFunc[any, any], tools ...Tool[any, any]) []Tool[any, any] {
158-
for i, t := range tools {
159-
t.Handler = mw(t.Handler)
160-
tools[i] = t
179+
// wrap wraps the provided GenericHandlerFunc with the provided middleware functions.
180+
func wrap(hf GenericHandlerFunc, mw ...func(GenericHandlerFunc) GenericHandlerFunc) GenericHandlerFunc {
181+
for _, m := range mw {
182+
hf = m(hf)
161183
}
162-
return tools
184+
return hf
163185
}
164186

165187
var (
166188
// All is a list of all tools that can be used in the Coder CLI.
167189
// When you add a new tool, be sure to include it here!
168-
All = wrapAll(WithCleanContext, wrapAll(WithRecover,
190+
All = []GenericTool{
169191
CreateTemplate.Generic(),
170192
CreateTemplateVersion.Generic(),
171193
CreateWorkspace.Generic(),
@@ -182,9 +204,9 @@ var (
182204
ReportTask.Generic(),
183205
UploadTarFile.Generic(),
184206
UpdateTemplateActiveVersion.Generic(),
185-
)...)
207+
}
186208

187-
ReportTask = Tool[ReportTaskArgs, string]{
209+
ReportTask = Tool[ReportTaskArgs, codersdk.Response]{
188210
Tool: aisdk.Tool{
189211
Name: "coder_report_task",
190212
Description: "Report progress on a user task in Coder.",
@@ -211,22 +233,24 @@ var (
211233
Required: []string{"summary", "link", "state"},
212234
},
213235
},
214-
Handler: func(ctx context.Context, tb Deps, args ReportTaskArgs) (string, error) {
236+
Handler: func(ctx context.Context, tb Deps, args ReportTaskArgs) (codersdk.Response, error) {
215237
if tb.AgentClient == nil {
216-
return "", xerrors.New("tool unavailable as CODER_AGENT_TOKEN or CODER_AGENT_TOKEN_FILE not set")
238+
return codersdk.Response{}, xerrors.New("tool unavailable as CODER_AGENT_TOKEN or CODER_AGENT_TOKEN_FILE not set")
217239
}
218240
if tb.AppStatusSlug == "" {
219-
return "", xerrors.New("workspace app status slug not found in toolbox")
241+
return codersdk.Response{}, xerrors.New("workspace app status slug not found in toolbox")
220242
}
221243
if err := tb.AgentClient.PatchAppStatus(ctx, agentsdk.PatchAppStatus{
222244
AppSlug: tb.AppStatusSlug,
223245
Message: args.Summary,
224246
URI: args.Link,
225247
State: codersdk.WorkspaceAppStatusState(args.State),
226248
}); err != nil {
227-
return "", err
249+
return codersdk.Response{}, err
228250
}
229-
return "Thanks for reporting!", nil
251+
return codersdk.Response{
252+
Message: "Thanks for reporting!",
253+
}, nil
230254
},
231255
}
232256

@@ -934,9 +958,13 @@ The file_id provided is a reference to a tar file you have uploaded containing t
934958
if err != nil {
935959
return codersdk.TemplateVersion{}, xerrors.Errorf("file_id must be a valid UUID: %w", err)
936960
}
937-
templateID, err := uuid.Parse(args.TemplateID)
938-
if err != nil {
939-
return codersdk.TemplateVersion{}, xerrors.Errorf("template_id must be a valid UUID: %w", err)
961+
var templateID uuid.UUID
962+
if args.TemplateID != "" {
963+
tid, err := uuid.Parse(args.TemplateID)
964+
if err != nil {
965+
return codersdk.TemplateVersion{}, xerrors.Errorf("template_id must be a valid UUID: %w", err)
966+
}
967+
templateID = tid
940968
}
941969
templateVersion, err := tb.CoderClient.CreateTemplateVersion(ctx, me.OrganizationIDs[0], codersdk.CreateTemplateVersionRequest{
942970
Message: "Created by AI",
@@ -1183,7 +1211,7 @@ The file_id provided is a reference to a tar file you have uploaded containing t
11831211
},
11841212
}
11851213

1186-
DeleteTemplate = Tool[DeleteTemplateArgs, string]{
1214+
DeleteTemplate = Tool[DeleteTemplateArgs, codersdk.Response]{
11871215
Tool: aisdk.Tool{
11881216
Name: "coder_delete_template",
11891217
Description: "Delete a template. This is irreversible.",
@@ -1195,16 +1223,18 @@ The file_id provided is a reference to a tar file you have uploaded containing t
11951223
},
11961224
},
11971225
},
1198-
Handler: func(ctx context.Context, tb Deps, args DeleteTemplateArgs) (string, error) {
1226+
Handler: func(ctx context.Context, tb Deps, args DeleteTemplateArgs) (codersdk.Response, error) {
11991227
templateID, err := uuid.Parse(args.TemplateID)
12001228
if err != nil {
1201-
return "", xerrors.Errorf("template_id must be a valid UUID: %w", err)
1229+
return codersdk.Response{}, xerrors.Errorf("template_id must be a valid UUID: %w", err)
12021230
}
12031231
err = tb.CoderClient.DeleteTemplate(ctx, templateID)
12041232
if err != nil {
1205-
return "", err
1233+
return codersdk.Response{}, err
12061234
}
1207-
return "Successfully deleted template!", nil
1235+
return codersdk.Response{
1236+
Message: "Template deleted successfully.",
1237+
}, nil
12081238
},
12091239
}
12101240
)

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy