Skip to content

Commit c1057d9

Browse files
committed
add WithCleanContext middleware func
1 parent 5647b8b commit c1057d9

File tree

3 files changed

+201
-82
lines changed

3 files changed

+201
-82
lines changed

cli/exp_mcp.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -713,8 +713,8 @@ func mcpFromSDK(sdkTool toolsdk.Tool[any, any], tb toolsdk.Deps) server.ServerTo
713713
Required: sdkTool.Schema.Required,
714714
},
715715
},
716-
Handler: func(_ context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
717-
result, err := sdkTool.Handler(tb, request.Params.Arguments)
716+
Handler: func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
717+
result, err := sdkTool.Handler(ctx, tb, request.Params.Arguments)
718718
if err != nil {
719719
return nil, err
720720
}

codersdk/toolsdk/toolsdk.go

Lines changed: 69 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ type Deps struct {
2121
}
2222

2323
// HandlerFunc is a function that handles a tool call.
24-
type HandlerFunc[Arg, Ret any] func(tb Deps, args Arg) (Ret, error)
24+
type HandlerFunc[Arg, Ret any] func(context.Context, Deps, Arg) (Ret, error)
2525

2626
type Tool[Arg, Ret any] struct {
2727
aisdk.Tool
@@ -32,12 +32,12 @@ type Tool[Arg, Ret any] struct {
3232
func (t Tool[Arg, Ret]) Generic() Tool[any, any] {
3333
return Tool[any, any]{
3434
Tool: t.Tool,
35-
Handler: func(tb Deps, args any) (any, error) {
35+
Handler: func(ctx context.Context, tb Deps, args any) (any, error) {
3636
typedArg, ok := args.(Arg)
3737
if !ok {
3838
return nil, xerrors.Errorf("developer error: invalid argument type for tool %s", t.Tool.Name)
3939
}
40-
return t.Handler(tb, typedArg)
40+
return t.Handler(ctx, tb, typedArg)
4141
},
4242
}
4343
}
@@ -115,13 +115,41 @@ type UploadTarFileArgs struct {
115115

116116
// WithRecover wraps a HandlerFunc to recover from panics and return an error.
117117
func WithRecover[Arg, Ret any](h HandlerFunc[Arg, Ret]) HandlerFunc[Arg, Ret] {
118-
return func(tb Deps, args Arg) (ret Ret, err error) {
118+
return func(ctx context.Context, tb Deps, args Arg) (ret Ret, err error) {
119119
defer func() {
120120
if r := recover(); r != nil {
121121
err = xerrors.Errorf("tool handler panic: %v", r)
122122
}
123123
}()
124-
return h(tb, args)
124+
return h(ctx, tb, args)
125+
}
126+
}
127+
128+
// WithCleanContext wraps a HandlerFunc to provide it with a new context.
129+
// This ensures that no data is passed using context.Value.
130+
// If a deadline is set on the parent context, it will be passed to the child
131+
// context.
132+
func WithCleanContext[Arg, Ret any](h HandlerFunc[Arg, Ret]) HandlerFunc[Arg, Ret] {
133+
return func(parent context.Context, tb Deps, args Arg) (ret Ret, err error) {
134+
child, childCancel := context.WithCancel(context.Background())
135+
defer childCancel()
136+
// Ensure that cancellation propagates from the parent context to the child context.
137+
go func() {
138+
select {
139+
case <-child.Done():
140+
return
141+
case <-parent.Done():
142+
childCancel()
143+
}
144+
}()
145+
// Also ensure that the child context has the same deadline as the parent
146+
// context.
147+
if deadline, ok := parent.Deadline(); ok {
148+
deadlineCtx, deadlineCancel := context.WithDeadline(child, deadline)
149+
defer deadlineCancel()
150+
child = deadlineCtx
151+
}
152+
return h(child, tb, args)
125153
}
126154
}
127155

@@ -137,7 +165,7 @@ func wrapAll(mw func(HandlerFunc[any, any]) HandlerFunc[any, any], tools ...Tool
137165
var (
138166
// All is a list of all tools that can be used in the Coder CLI.
139167
// When you add a new tool, be sure to include it here!
140-
All = wrapAll(WithRecover,
168+
All = wrapAll(WithCleanContext, wrapAll(WithRecover,
141169
CreateTemplate.Generic(),
142170
CreateTemplateVersion.Generic(),
143171
CreateWorkspace.Generic(),
@@ -154,7 +182,7 @@ var (
154182
ReportTask.Generic(),
155183
UploadTarFile.Generic(),
156184
UpdateTemplateActiveVersion.Generic(),
157-
)
185+
)...)
158186

159187
ReportTask = Tool[ReportTaskArgs, string]{
160188
Tool: aisdk.Tool{
@@ -183,14 +211,14 @@ var (
183211
Required: []string{"summary", "link", "state"},
184212
},
185213
},
186-
Handler: func(tb Deps, args ReportTaskArgs) (string, error) {
214+
Handler: func(ctx context.Context, tb Deps, args ReportTaskArgs) (string, error) {
187215
if tb.AgentClient == nil {
188216
return "", xerrors.New("tool unavailable as CODER_AGENT_TOKEN or CODER_AGENT_TOKEN_FILE not set")
189217
}
190218
if tb.AppStatusSlug == "" {
191219
return "", xerrors.New("workspace app status slug not found in toolbox")
192220
}
193-
if err := tb.AgentClient.PatchAppStatus(context.TODO(), agentsdk.PatchAppStatus{
221+
if err := tb.AgentClient.PatchAppStatus(ctx, agentsdk.PatchAppStatus{
194222
AppSlug: tb.AppStatusSlug,
195223
Message: args.Summary,
196224
URI: args.Link,
@@ -217,12 +245,12 @@ This returns more data than list_workspaces to reduce token usage.`,
217245
Required: []string{"workspace_id"},
218246
},
219247
},
220-
Handler: func(tb Deps, args GetWorkspaceArgs) (codersdk.Workspace, error) {
248+
Handler: func(ctx context.Context, tb Deps, args GetWorkspaceArgs) (codersdk.Workspace, error) {
221249
wsID, err := uuid.Parse(args.WorkspaceID)
222250
if err != nil {
223251
return codersdk.Workspace{}, xerrors.New("workspace_id must be a valid UUID")
224252
}
225-
return tb.CoderClient.Workspace(context.TODO(), wsID)
253+
return tb.CoderClient.Workspace(ctx, wsID)
226254
},
227255
}
228256

@@ -257,7 +285,7 @@ is provisioned correctly and the agent can connect to the control plane.
257285
Required: []string{"user", "template_version_id", "name", "rich_parameters"},
258286
},
259287
},
260-
Handler: func(tb Deps, args CreateWorkspaceArgs) (codersdk.Workspace, error) {
288+
Handler: func(ctx context.Context, tb Deps, args CreateWorkspaceArgs) (codersdk.Workspace, error) {
261289
tvID, err := uuid.Parse(args.TemplateVersionID)
262290
if err != nil {
263291
return codersdk.Workspace{}, xerrors.New("template_version_id must be a valid UUID")
@@ -272,7 +300,7 @@ is provisioned correctly and the agent can connect to the control plane.
272300
Value: v,
273301
})
274302
}
275-
workspace, err := tb.CoderClient.CreateUserWorkspace(context.TODO(), args.User, codersdk.CreateWorkspaceRequest{
303+
workspace, err := tb.CoderClient.CreateUserWorkspace(ctx, args.User, codersdk.CreateWorkspaceRequest{
276304
TemplateVersionID: tvID,
277305
Name: args.Name,
278306
RichParameterValues: buildParams,
@@ -297,12 +325,12 @@ is provisioned correctly and the agent can connect to the control plane.
297325
},
298326
},
299327
},
300-
Handler: func(tb Deps, args ListWorkspacesArgs) ([]MinimalWorkspace, error) {
328+
Handler: func(ctx context.Context, tb Deps, args ListWorkspacesArgs) ([]MinimalWorkspace, error) {
301329
owner := args.Owner
302330
if owner == "" {
303331
owner = codersdk.Me
304332
}
305-
workspaces, err := tb.CoderClient.Workspaces(context.TODO(), codersdk.WorkspaceFilter{
333+
workspaces, err := tb.CoderClient.Workspaces(ctx, codersdk.WorkspaceFilter{
306334
Owner: owner,
307335
})
308336
if err != nil {
@@ -334,8 +362,8 @@ is provisioned correctly and the agent can connect to the control plane.
334362
Required: []string{},
335363
},
336364
},
337-
Handler: func(tb Deps, _ NoArgs) ([]MinimalTemplate, error) {
338-
templates, err := tb.CoderClient.Templates(context.TODO(), codersdk.TemplateFilter{})
365+
Handler: func(ctx context.Context, tb Deps, _ NoArgs) ([]MinimalTemplate, error) {
366+
templates, err := tb.CoderClient.Templates(ctx, codersdk.TemplateFilter{})
339367
if err != nil {
340368
return nil, err
341369
}
@@ -367,12 +395,12 @@ is provisioned correctly and the agent can connect to the control plane.
367395
Required: []string{"template_version_id"},
368396
},
369397
},
370-
Handler: func(tb Deps, args ListTemplateVersionParametersArgs) ([]codersdk.TemplateVersionParameter, error) {
398+
Handler: func(ctx context.Context, tb Deps, args ListTemplateVersionParametersArgs) ([]codersdk.TemplateVersionParameter, error) {
371399
templateVersionID, err := uuid.Parse(args.TemplateVersionID)
372400
if err != nil {
373401
return nil, xerrors.Errorf("template_version_id must be a valid UUID: %w", err)
374402
}
375-
parameters, err := tb.CoderClient.TemplateVersionRichParameters(context.TODO(), templateVersionID)
403+
parameters, err := tb.CoderClient.TemplateVersionRichParameters(ctx, templateVersionID)
376404
if err != nil {
377405
return nil, err
378406
}
@@ -389,8 +417,8 @@ is provisioned correctly and the agent can connect to the control plane.
389417
Required: []string{},
390418
},
391419
},
392-
Handler: func(tb Deps, _ NoArgs) (codersdk.User, error) {
393-
return tb.CoderClient.User(context.TODO(), "me")
420+
Handler: func(ctx context.Context, tb Deps, _ NoArgs) (codersdk.User, error) {
421+
return tb.CoderClient.User(ctx, "me")
394422
},
395423
}
396424

@@ -416,7 +444,7 @@ is provisioned correctly and the agent can connect to the control plane.
416444
Required: []string{"workspace_id", "transition"},
417445
},
418446
},
419-
Handler: func(tb Deps, args CreateWorkspaceBuildArgs) (codersdk.WorkspaceBuild, error) {
447+
Handler: func(ctx context.Context, tb Deps, args CreateWorkspaceBuildArgs) (codersdk.WorkspaceBuild, error) {
420448
workspaceID, err := uuid.Parse(args.WorkspaceID)
421449
if err != nil {
422450
return codersdk.WorkspaceBuild{}, xerrors.Errorf("workspace_id must be a valid UUID: %w", err)
@@ -435,7 +463,7 @@ is provisioned correctly and the agent can connect to the control plane.
435463
if templateVersionID != uuid.Nil {
436464
cbr.TemplateVersionID = templateVersionID
437465
}
438-
return tb.CoderClient.CreateWorkspaceBuild(context.TODO(), workspaceID, cbr)
466+
return tb.CoderClient.CreateWorkspaceBuild(ctx, workspaceID, cbr)
439467
},
440468
}
441469

@@ -897,8 +925,8 @@ The file_id provided is a reference to a tar file you have uploaded containing t
897925
Required: []string{"file_id"},
898926
},
899927
},
900-
Handler: func(tb Deps, args CreateTemplateVersionArgs) (codersdk.TemplateVersion, error) {
901-
me, err := tb.CoderClient.User(context.TODO(), "me")
928+
Handler: func(ctx context.Context, tb Deps, args CreateTemplateVersionArgs) (codersdk.TemplateVersion, error) {
929+
me, err := tb.CoderClient.User(ctx, "me")
902930
if err != nil {
903931
return codersdk.TemplateVersion{}, err
904932
}
@@ -910,7 +938,7 @@ The file_id provided is a reference to a tar file you have uploaded containing t
910938
if err != nil {
911939
return codersdk.TemplateVersion{}, xerrors.Errorf("template_id must be a valid UUID: %w", err)
912940
}
913-
templateVersion, err := tb.CoderClient.CreateTemplateVersion(context.TODO(), me.OrganizationIDs[0], codersdk.CreateTemplateVersionRequest{
941+
templateVersion, err := tb.CoderClient.CreateTemplateVersion(ctx, me.OrganizationIDs[0], codersdk.CreateTemplateVersionRequest{
914942
Message: "Created by AI",
915943
StorageMethod: codersdk.ProvisionerStorageMethodFile,
916944
FileID: fileID,
@@ -939,12 +967,12 @@ The file_id provided is a reference to a tar file you have uploaded containing t
939967
Required: []string{"workspace_agent_id"},
940968
},
941969
},
942-
Handler: func(tb Deps, args GetWorkspaceAgentLogsArgs) ([]string, error) {
970+
Handler: func(ctx context.Context, tb Deps, args GetWorkspaceAgentLogsArgs) ([]string, error) {
943971
workspaceAgentID, err := uuid.Parse(args.WorkspaceAgentID)
944972
if err != nil {
945973
return nil, xerrors.Errorf("workspace_agent_id must be a valid UUID: %w", err)
946974
}
947-
logs, closer, err := tb.CoderClient.WorkspaceAgentLogsAfter(context.TODO(), workspaceAgentID, 0, false)
975+
logs, closer, err := tb.CoderClient.WorkspaceAgentLogsAfter(ctx, workspaceAgentID, 0, false)
948976
if err != nil {
949977
return nil, err
950978
}
@@ -974,12 +1002,12 @@ The file_id provided is a reference to a tar file you have uploaded containing t
9741002
Required: []string{"workspace_build_id"},
9751003
},
9761004
},
977-
Handler: func(tb Deps, args GetWorkspaceBuildLogsArgs) ([]string, error) {
1005+
Handler: func(ctx context.Context, tb Deps, args GetWorkspaceBuildLogsArgs) ([]string, error) {
9781006
workspaceBuildID, err := uuid.Parse(args.WorkspaceBuildID)
9791007
if err != nil {
9801008
return nil, xerrors.Errorf("workspace_build_id must be a valid UUID: %w", err)
9811009
}
982-
logs, closer, err := tb.CoderClient.WorkspaceBuildLogsAfter(context.TODO(), workspaceBuildID, 0)
1010+
logs, closer, err := tb.CoderClient.WorkspaceBuildLogsAfter(ctx, workspaceBuildID, 0)
9831011
if err != nil {
9841012
return nil, err
9851013
}
@@ -1005,13 +1033,13 @@ The file_id provided is a reference to a tar file you have uploaded containing t
10051033
Required: []string{"template_version_id"},
10061034
},
10071035
},
1008-
Handler: func(tb Deps, args GetTemplateVersionLogsArgs) ([]string, error) {
1036+
Handler: func(ctx context.Context, tb Deps, args GetTemplateVersionLogsArgs) ([]string, error) {
10091037
templateVersionID, err := uuid.Parse(args.TemplateVersionID)
10101038
if err != nil {
10111039
return nil, xerrors.Errorf("template_version_id must be a valid UUID: %w", err)
10121040
}
10131041

1014-
logs, closer, err := tb.CoderClient.TemplateVersionLogsAfter(context.TODO(), templateVersionID, 0)
1042+
logs, closer, err := tb.CoderClient.TemplateVersionLogsAfter(ctx, templateVersionID, 0)
10151043
if err != nil {
10161044
return nil, err
10171045
}
@@ -1040,7 +1068,7 @@ The file_id provided is a reference to a tar file you have uploaded containing t
10401068
Required: []string{"template_id", "template_version_id"},
10411069
},
10421070
},
1043-
Handler: func(tb Deps, args UpdateTemplateActiveVersionArgs) (string, error) {
1071+
Handler: func(ctx context.Context, tb Deps, args UpdateTemplateActiveVersionArgs) (string, error) {
10441072
templateID, err := uuid.Parse(args.TemplateID)
10451073
if err != nil {
10461074
return "", xerrors.Errorf("template_id must be a valid UUID: %w", err)
@@ -1049,7 +1077,7 @@ The file_id provided is a reference to a tar file you have uploaded containing t
10491077
if err != nil {
10501078
return "", xerrors.Errorf("template_version_id must be a valid UUID: %w", err)
10511079
}
1052-
err = tb.CoderClient.UpdateActiveTemplateVersion(context.TODO(), templateID, codersdk.UpdateActiveTemplateVersion{
1080+
err = tb.CoderClient.UpdateActiveTemplateVersion(ctx, templateID, codersdk.UpdateActiveTemplateVersion{
10531081
ID: templateVersionID,
10541082
})
10551083
if err != nil {
@@ -1073,7 +1101,7 @@ The file_id provided is a reference to a tar file you have uploaded containing t
10731101
Required: []string{"mime_type", "files"},
10741102
},
10751103
},
1076-
Handler: func(tb Deps, args UploadTarFileArgs) (codersdk.UploadResponse, error) {
1104+
Handler: func(ctx context.Context, tb Deps, args UploadTarFileArgs) (codersdk.UploadResponse, error) {
10771105
pipeReader, pipeWriter := io.Pipe()
10781106
go func() {
10791107
defer pipeWriter.Close()
@@ -1098,7 +1126,7 @@ The file_id provided is a reference to a tar file you have uploaded containing t
10981126
}
10991127
}()
11001128

1101-
resp, err := tb.CoderClient.Upload(context.TODO(), codersdk.ContentTypeTar, pipeReader)
1129+
resp, err := tb.CoderClient.Upload(ctx, codersdk.ContentTypeTar, pipeReader)
11021130
if err != nil {
11031131
return codersdk.UploadResponse{}, err
11041132
}
@@ -1133,16 +1161,16 @@ The file_id provided is a reference to a tar file you have uploaded containing t
11331161
Required: []string{"name", "display_name", "description", "version_id"},
11341162
},
11351163
},
1136-
Handler: func(tb Deps, args CreateTemplateArgs) (codersdk.Template, error) {
1137-
me, err := tb.CoderClient.User(context.TODO(), "me")
1164+
Handler: func(ctx context.Context, tb Deps, args CreateTemplateArgs) (codersdk.Template, error) {
1165+
me, err := tb.CoderClient.User(ctx, "me")
11381166
if err != nil {
11391167
return codersdk.Template{}, err
11401168
}
11411169
versionID, err := uuid.Parse(args.VersionID)
11421170
if err != nil {
11431171
return codersdk.Template{}, xerrors.Errorf("version_id must be a valid UUID: %w", err)
11441172
}
1145-
template, err := tb.CoderClient.CreateTemplate(context.TODO(), me.OrganizationIDs[0], codersdk.CreateTemplateRequest{
1173+
template, err := tb.CoderClient.CreateTemplate(ctx, me.OrganizationIDs[0], codersdk.CreateTemplateRequest{
11461174
Name: args.Name,
11471175
DisplayName: args.DisplayName,
11481176
Description: args.Description,
@@ -1167,12 +1195,12 @@ The file_id provided is a reference to a tar file you have uploaded containing t
11671195
},
11681196
},
11691197
},
1170-
Handler: func(tb Deps, args DeleteTemplateArgs) (string, error) {
1198+
Handler: func(ctx context.Context, tb Deps, args DeleteTemplateArgs) (string, error) {
11711199
templateID, err := uuid.Parse(args.TemplateID)
11721200
if err != nil {
11731201
return "", xerrors.Errorf("template_id must be a valid UUID: %w", err)
11741202
}
1175-
err = tb.CoderClient.DeleteTemplate(context.TODO(), templateID)
1203+
err = tb.CoderClient.DeleteTemplate(ctx, templateID)
11761204
if err != nil {
11771205
return "", err
11781206
}

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy