Skip to content

Commit 023275c

Browse files
committed
feat: add timeout support to workspace bash tool
Change-Id: I996cbde4a50debb54a0a95ca5a067781719fa25a Signed-off-by: Thomas Kosiewski <tk@coder.com>
1 parent 070178c commit 023275c

File tree

2 files changed

+320
-11
lines changed

2 files changed

+320
-11
lines changed

codersdk/toolsdk/bash.go

Lines changed: 140 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
package toolsdk
22

33
import (
4+
"bytes"
45
"context"
56
"errors"
67
"fmt"
78
"io"
89
"strings"
10+
"sync"
11+
"time"
912

1013
gossh "golang.org/x/crypto/ssh"
1114
"golang.org/x/xerrors"
@@ -20,6 +23,7 @@ import (
2023
type WorkspaceBashArgs struct {
2124
Workspace string `json:"workspace"`
2225
Command string `json:"command"`
26+
TimeoutMs int `json:"timeout_ms,omitempty"`
2327
}
2428

2529
type WorkspaceBashResult struct {
@@ -43,9 +47,12 @@ The workspace parameter supports various formats:
4347
- workspace.agent (specific agent)
4448
- owner/workspace.agent
4549
50+
The timeout_ms parameter specifies the command timeout in milliseconds (defaults to 60000ms, maximum of 300000ms).
51+
If the command times out, all output captured up to that point is returned with a cancellation message.
52+
4653
Examples:
4754
- workspace: "my-workspace", command: "ls -la"
48-
- workspace: "john/dev-env", command: "git status"
55+
- workspace: "john/dev-env", command: "git status", timeout_ms: 30000
4956
- workspace: "my-workspace.main", command: "docker ps"`,
5057
Schema: aisdk.Schema{
5158
Properties: map[string]any{
@@ -57,18 +64,27 @@ Examples:
5764
"type": "string",
5865
"description": "The bash command to execute in the workspace.",
5966
},
67+
"timeout_ms": map[string]any{
68+
"type": "integer",
69+
"description": "Command timeout in milliseconds. Defaults to 60000ms (60 seconds) if not specified.",
70+
"default": 60000,
71+
"minimum": 1,
72+
},
6073
},
6174
Required: []string{"workspace", "command"},
6275
},
6376
},
64-
Handler: func(ctx context.Context, deps Deps, args WorkspaceBashArgs) (WorkspaceBashResult, error) {
77+
Handler: func(ctx context.Context, deps Deps, args WorkspaceBashArgs) (res WorkspaceBashResult, err error) {
6578
if args.Workspace == "" {
6679
return WorkspaceBashResult{}, xerrors.New("workspace name cannot be empty")
6780
}
6881
if args.Command == "" {
6982
return WorkspaceBashResult{}, xerrors.New("command cannot be empty")
7083
}
7184

85+
ctx, cancel := context.WithTimeoutCause(ctx, 5*time.Minute, errors.New("MCP handler timeout after 5 min"))
86+
defer cancel()
87+
7288
// Normalize workspace input to handle various formats
7389
workspaceName := NormalizeWorkspaceInput(args.Workspace)
7490

@@ -119,23 +135,41 @@ Examples:
119135
}
120136
defer session.Close()
121137

122-
// Execute command and capture output
123-
output, err := session.CombinedOutput(args.Command)
138+
// Set default timeout if not specified (60 seconds)
139+
timeoutMs := args.TimeoutMs
140+
if timeoutMs <= 0 {
141+
timeoutMs = 60000
142+
}
143+
144+
// Create context with timeout
145+
ctx, cancel = context.WithTimeout(ctx, time.Duration(timeoutMs)*time.Millisecond)
146+
defer cancel()
147+
148+
// Execute command with timeout handling
149+
output, err := executeCommandWithTimeout(ctx, session, args.Command)
124150
outputStr := strings.TrimSpace(string(output))
125151

152+
// Handle command execution results
126153
if err != nil {
127-
// Check if it's an SSH exit error to get the exit code
128-
var exitErr *gossh.ExitError
129-
if errors.As(err, &exitErr) {
154+
// Check if the command timed out
155+
if context.Cause(ctx) == context.DeadlineExceeded {
156+
outputStr += "\nCommand cancelled due to timeout"
130157
return WorkspaceBashResult{
131158
Output: outputStr,
132-
ExitCode: exitErr.ExitStatus(),
159+
ExitCode: 124,
133160
}, nil
134161
}
135-
// For other errors, return exit code 1
162+
163+
// Extract exit code from SSH error if available
164+
exitCode := 1
165+
if exitErr := (&gossh.ExitError{}); errors.As(err, &exitErr) {
166+
exitCode = exitErr.ExitStatus()
167+
}
168+
169+
// For other errors, use standard timeout or generic error code
136170
return WorkspaceBashResult{
137171
Output: outputStr,
138-
ExitCode: 1,
172+
ExitCode: exitCode,
139173
}, nil
140174
}
141175

@@ -292,3 +326,99 @@ func NormalizeWorkspaceInput(input string) string {
292326

293327
return normalized
294328
}
329+
330+
// executeCommandWithTimeout executes a command with timeout support
331+
func executeCommandWithTimeout(ctx context.Context, session *gossh.Session, command string) ([]byte, error) {
332+
// Set up pipes to capture output
333+
stdoutPipe, err := session.StdoutPipe()
334+
if err != nil {
335+
return nil, xerrors.Errorf("failed to create stdout pipe: %w", err)
336+
}
337+
338+
stderrPipe, err := session.StderrPipe()
339+
if err != nil {
340+
return nil, xerrors.Errorf("failed to create stderr pipe: %w", err)
341+
}
342+
343+
// Start the command
344+
if err := session.Start(command); err != nil {
345+
return nil, xerrors.Errorf("failed to start command: %w", err)
346+
}
347+
348+
// Create a thread-safe buffer for combined output
349+
var output bytes.Buffer
350+
var mu sync.Mutex
351+
safeWriter := &syncWriter{w: &output, mu: &mu}
352+
353+
// Use io.MultiWriter to combine stdout and stderr
354+
multiWriter := io.MultiWriter(safeWriter)
355+
356+
// Channel to signal when command completes
357+
done := make(chan error, 1)
358+
359+
// Start goroutine to copy output and wait for completion
360+
go func() {
361+
// Copy stdout and stderr concurrently
362+
var wg sync.WaitGroup
363+
wg.Add(2)
364+
365+
go func() {
366+
defer wg.Done()
367+
_, _ = io.Copy(multiWriter, stdoutPipe)
368+
}()
369+
370+
go func() {
371+
defer wg.Done()
372+
_, _ = io.Copy(multiWriter, stderrPipe)
373+
}()
374+
375+
// Wait for all output to be copied
376+
wg.Wait()
377+
378+
// Wait for the command to complete
379+
done <- session.Wait()
380+
}()
381+
382+
// Wait for either completion or context cancellation
383+
select {
384+
case err := <-done:
385+
// Command completed normally
386+
return safeWriter.Bytes(), err
387+
case <-ctx.Done():
388+
// Context was cancelled (timeout or other cancellation)
389+
// Close the session to stop the command
390+
_ = session.Close()
391+
392+
// Give a brief moment to collect any remaining output
393+
timer := time.NewTimer(50 * time.Millisecond)
394+
defer timer.Stop()
395+
396+
select {
397+
case <-timer.C:
398+
// Timer expired, return what we have
399+
case err := <-done:
400+
// Command finished during grace period
401+
return safeWriter.Bytes(), err
402+
}
403+
404+
return safeWriter.Bytes(), context.Cause(ctx)
405+
}
406+
}
407+
408+
// syncWriter is a thread-safe writer
409+
type syncWriter struct {
410+
w *bytes.Buffer
411+
mu *sync.Mutex
412+
}
413+
414+
func (sw *syncWriter) Write(p []byte) (n int, err error) {
415+
sw.mu.Lock()
416+
defer sw.mu.Unlock()
417+
return sw.w.Write(p)
418+
}
419+
420+
func (sw *syncWriter) Bytes() []byte {
421+
sw.mu.Lock()
422+
defer sw.mu.Unlock()
423+
return sw.w.Bytes()
424+
}

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy