diff --git a/codersdk/toolsdk/bash.go b/codersdk/toolsdk/bash.go index e45ca6a49e29a..5fb15843f1bf1 100644 --- a/codersdk/toolsdk/bash.go +++ b/codersdk/toolsdk/bash.go @@ -1,11 +1,14 @@ package toolsdk import ( + "bytes" "context" "errors" "fmt" "io" "strings" + "sync" + "time" gossh "golang.org/x/crypto/ssh" "golang.org/x/xerrors" @@ -20,6 +23,7 @@ import ( type WorkspaceBashArgs struct { Workspace string `json:"workspace"` Command string `json:"command"` + TimeoutMs int `json:"timeout_ms,omitempty"` } type WorkspaceBashResult struct { @@ -43,9 +47,12 @@ The workspace parameter supports various formats: - workspace.agent (specific agent) - owner/workspace.agent +The timeout_ms parameter specifies the command timeout in milliseconds (defaults to 60000ms, maximum of 300000ms). +If the command times out, all output captured up to that point is returned with a cancellation message. + Examples: - workspace: "my-workspace", command: "ls -la" -- workspace: "john/dev-env", command: "git status" +- workspace: "john/dev-env", command: "git status", timeout_ms: 30000 - workspace: "my-workspace.main", command: "docker ps"`, Schema: aisdk.Schema{ Properties: map[string]any{ @@ -57,11 +64,17 @@ Examples: "type": "string", "description": "The bash command to execute in the workspace.", }, + "timeout_ms": map[string]any{ + "type": "integer", + "description": "Command timeout in milliseconds. Defaults to 60000ms (60 seconds) if not specified.", + "default": 60000, + "minimum": 1, + }, }, Required: []string{"workspace", "command"}, }, }, - Handler: func(ctx context.Context, deps Deps, args WorkspaceBashArgs) (WorkspaceBashResult, error) { + Handler: func(ctx context.Context, deps Deps, args WorkspaceBashArgs) (res WorkspaceBashResult, err error) { if args.Workspace == "" { return WorkspaceBashResult{}, xerrors.New("workspace name cannot be empty") } @@ -69,6 +82,9 @@ Examples: return WorkspaceBashResult{}, xerrors.New("command cannot be empty") } + ctx, cancel := context.WithTimeoutCause(ctx, 5*time.Minute, xerrors.New("MCP handler timeout after 5 min")) + defer cancel() + // Normalize workspace input to handle various formats workspaceName := NormalizeWorkspaceInput(args.Workspace) @@ -119,23 +135,42 @@ Examples: } defer session.Close() - // Execute command and capture output - output, err := session.CombinedOutput(args.Command) + // Set default timeout if not specified (60 seconds) + timeoutMs := args.TimeoutMs + if timeoutMs <= 0 { + timeoutMs = 60000 + } + + // Create context with timeout + ctx, cancel = context.WithTimeout(ctx, time.Duration(timeoutMs)*time.Millisecond) + defer cancel() + + // Execute command with timeout handling + output, err := executeCommandWithTimeout(ctx, session, args.Command) outputStr := strings.TrimSpace(string(output)) + // Handle command execution results if err != nil { - // Check if it's an SSH exit error to get the exit code - var exitErr *gossh.ExitError - if errors.As(err, &exitErr) { + // Check if the command timed out + if errors.Is(context.Cause(ctx), context.DeadlineExceeded) { + outputStr += "\nCommand canceled due to timeout" return WorkspaceBashResult{ Output: outputStr, - ExitCode: exitErr.ExitStatus(), + ExitCode: 124, }, nil } - // For other errors, return exit code 1 + + // Extract exit code from SSH error if available + exitCode := 1 + var exitErr *gossh.ExitError + if errors.As(err, &exitErr) { + exitCode = exitErr.ExitStatus() + } + + // For other errors, use standard timeout or generic error code return WorkspaceBashResult{ Output: outputStr, - ExitCode: 1, + ExitCode: exitCode, }, nil } @@ -292,3 +327,99 @@ func NormalizeWorkspaceInput(input string) string { return normalized } + +// executeCommandWithTimeout executes a command with timeout support +func executeCommandWithTimeout(ctx context.Context, session *gossh.Session, command string) ([]byte, error) { + // Set up pipes to capture output + stdoutPipe, err := session.StdoutPipe() + if err != nil { + return nil, xerrors.Errorf("failed to create stdout pipe: %w", err) + } + + stderrPipe, err := session.StderrPipe() + if err != nil { + return nil, xerrors.Errorf("failed to create stderr pipe: %w", err) + } + + // Start the command + if err := session.Start(command); err != nil { + return nil, xerrors.Errorf("failed to start command: %w", err) + } + + // Create a thread-safe buffer for combined output + var output bytes.Buffer + var mu sync.Mutex + safeWriter := &syncWriter{w: &output, mu: &mu} + + // Use io.MultiWriter to combine stdout and stderr + multiWriter := io.MultiWriter(safeWriter) + + // Channel to signal when command completes + done := make(chan error, 1) + + // Start goroutine to copy output and wait for completion + go func() { + // Copy stdout and stderr concurrently + var wg sync.WaitGroup + wg.Add(2) + + go func() { + defer wg.Done() + _, _ = io.Copy(multiWriter, stdoutPipe) + }() + + go func() { + defer wg.Done() + _, _ = io.Copy(multiWriter, stderrPipe) + }() + + // Wait for all output to be copied + wg.Wait() + + // Wait for the command to complete + done <- session.Wait() + }() + + // Wait for either completion or context cancellation + select { + case err := <-done: + // Command completed normally + return safeWriter.Bytes(), err + case <-ctx.Done(): + // Context was canceled (timeout or other cancellation) + // Close the session to stop the command + _ = session.Close() + + // Give a brief moment to collect any remaining output + timer := time.NewTimer(50 * time.Millisecond) + defer timer.Stop() + + select { + case <-timer.C: + // Timer expired, return what we have + case err := <-done: + // Command finished during grace period + return safeWriter.Bytes(), err + } + + return safeWriter.Bytes(), context.Cause(ctx) + } +} + +// syncWriter is a thread-safe writer +type syncWriter struct { + w *bytes.Buffer + mu *sync.Mutex +} + +func (sw *syncWriter) Write(p []byte) (n int, err error) { + sw.mu.Lock() + defer sw.mu.Unlock() + return sw.w.Write(p) +} + +func (sw *syncWriter) Bytes() []byte { + sw.mu.Lock() + defer sw.mu.Unlock() + return sw.w.Bytes() +} diff --git a/codersdk/toolsdk/bash_test.go b/codersdk/toolsdk/bash_test.go index 474071fc45acb..53ac480039278 100644 --- a/codersdk/toolsdk/bash_test.go +++ b/codersdk/toolsdk/bash_test.go @@ -6,6 +6,8 @@ import ( "github.com/stretchr/testify/require" + "github.com/coder/coder/v2/agent/agenttest" + "github.com/coder/coder/v2/coderd/coderdtest" "github.com/coder/coder/v2/codersdk/toolsdk" ) @@ -40,7 +42,7 @@ func TestWorkspaceBash(t *testing.T) { t.Run("ErrorScenarios", func(t *testing.T) { t.Parallel() - deps := toolsdk.Deps{} // Empty deps will cause client access to fail + deps := toolsdk.Deps{} ctx := context.Background() // Test input validation errors (these should fail before client access) @@ -159,3 +161,180 @@ func TestAllToolsIncludesBash(t *testing.T) { } require.True(t, found, "WorkspaceBash tool should be included in toolsdk.All") } + +// Note: Unit testing ExecuteCommandWithTimeout is challenging because it expects +// a concrete SSH session type. The integration tests above demonstrate the +// timeout functionality with a real SSH connection and mock clock. + +func TestWorkspaceBashTimeout(t *testing.T) { + t.Parallel() + + t.Run("TimeoutDefaultValue", func(t *testing.T) { + t.Parallel() + + // Test that the TimeoutMs field can be set and read correctly + args := toolsdk.WorkspaceBashArgs{ + Workspace: "test-workspace", + Command: "echo test", + TimeoutMs: 0, // Should default to 60000 in handler + } + + // Verify that the TimeoutMs field exists and can be set + require.Equal(t, 0, args.TimeoutMs) + + // Test setting a positive value + args.TimeoutMs = 5000 + require.Equal(t, 5000, args.TimeoutMs) + }) + + t.Run("TimeoutNegativeValue", func(t *testing.T) { + t.Parallel() + + // Test that negative values can be set and will be handled by the default logic + args := toolsdk.WorkspaceBashArgs{ + Workspace: "test-workspace", + Command: "echo test", + TimeoutMs: -100, + } + + require.Equal(t, -100, args.TimeoutMs) + + // The actual defaulting to 60000 happens inside the handler + // We can't test it without a full integration test setup + }) + + t.Run("TimeoutSchemaValidation", func(t *testing.T) { + t.Parallel() + + tool := toolsdk.WorkspaceBash + + // Check that timeout_ms is in the schema + require.Contains(t, tool.Schema.Properties, "timeout_ms") + + timeoutProperty := tool.Schema.Properties["timeout_ms"].(map[string]any) + require.Equal(t, "integer", timeoutProperty["type"]) + require.Equal(t, 60000, timeoutProperty["default"]) + require.Equal(t, 1, timeoutProperty["minimum"]) + require.Contains(t, timeoutProperty["description"], "timeout in milliseconds") + }) + + t.Run("TimeoutDescriptionUpdated", func(t *testing.T) { + t.Parallel() + + tool := toolsdk.WorkspaceBash + + // Check that description mentions timeout functionality + require.Contains(t, tool.Description, "timeout_ms parameter") + require.Contains(t, tool.Description, "defaults to 60000ms") + require.Contains(t, tool.Description, "timeout_ms: 30000") + }) + + t.Run("TimeoutCommandScenario", func(t *testing.T) { + t.Parallel() + + // Scenario: echo "123"; sleep 60; echo "456" with 5ms timeout + // In this scenario, we'd expect to see "123" in the output and a cancellation message + args := toolsdk.WorkspaceBashArgs{ + Workspace: "test-workspace", + Command: `echo "123"; sleep 60; echo "456"`, // This command would take 60+ seconds + TimeoutMs: 5, // 5ms timeout - should timeout after first echo + } + + // Verify the args are structured correctly for the intended test scenario + require.Equal(t, "test-workspace", args.Workspace) + require.Contains(t, args.Command, `echo "123"`) + require.Contains(t, args.Command, "sleep 60") + require.Contains(t, args.Command, `echo "456"`) + require.Equal(t, 5, args.TimeoutMs) + + // Note: The actual timeout behavior would need to be tested with a real workspace + // This test just verifies the structure is correct for the timeout scenario + }) +} + +func TestWorkspaceBashTimeoutIntegration(t *testing.T) { + t.Parallel() + + t.Run("ActualTimeoutBehavior", func(t *testing.T) { + t.Parallel() + + // Scenario: echo "123"; sleep 60; echo "456" with 5s timeout + // In this scenario, we'd expect to see "123" in the output and a cancellation message + + client, workspace, agentToken := setupWorkspaceForAgent(t) + + // Start the agent and wait for it to be fully ready + _ = agenttest.New(t, client.URL, agentToken) + + // Wait for workspace agents to be ready like other SSH tests do + coderdtest.NewWorkspaceAgentWaiter(t, client, workspace.ID).Wait() + + // Use real clock for integration test + deps, err := toolsdk.NewDeps(client) + require.NoError(t, err) + + args := toolsdk.WorkspaceBashArgs{ + Workspace: workspace.Name, + Command: `echo "123" && sleep 60 && echo "456"`, // This command would take 60+ seconds + TimeoutMs: 2000, // 2 seconds timeout - should timeout after first echo + } + + result, err := toolsdk.WorkspaceBash.Handler(t.Context(), deps, args) + + // Should not error (timeout is handled gracefully) + require.NoError(t, err) + + t.Logf("Test results: exitCode=%d, output=%q, error=%v", result.ExitCode, result.Output, err) + + // Should have a non-zero exit code (timeout or error) + require.NotEqual(t, 0, result.ExitCode, "Expected non-zero exit code for timeout") + + t.Logf("result.Output: %s", result.Output) + + // Should contain the first echo output + require.Contains(t, result.Output, "123") + + // Should NOT contain the second echo (it never executed due to timeout) + require.NotContains(t, result.Output, "456", "Should not contain output after sleep") + }) + + t.Run("NormalCommandExecution", func(t *testing.T) { + t.Parallel() + + // Test that normal commands still work with timeout functionality present + + client, workspace, agentToken := setupWorkspaceForAgent(t) + + // Start the agent and wait for it to be fully ready + _ = agenttest.New(t, client.URL, agentToken) + + // Wait for workspace agents to be ready + coderdtest.NewWorkspaceAgentWaiter(t, client, workspace.ID).Wait() + + deps, err := toolsdk.NewDeps(client) + require.NoError(t, err) + ctx := context.Background() + + args := toolsdk.WorkspaceBashArgs{ + Workspace: workspace.Name, + Command: `echo "normal command"`, // Quick command that should complete normally + TimeoutMs: 5000, // 5 second timeout - plenty of time + } + + result, err := toolsdk.WorkspaceBash.Handler(ctx, deps, args) + + // Should not error + require.NoError(t, err) + + t.Logf("result.Output: %s", result.Output) + + // Should have exit code 0 (success) + require.Equal(t, 0, result.ExitCode) + + // Should contain the expected output + require.Equal(t, "normal command", result.Output) + + // Should NOT contain timeout message + require.NotContains(t, result.Output, "Command canceled due to timeout") + }) +} pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy