diff --git a/codersdk/toolsdk/bash.go b/codersdk/toolsdk/bash.go index 5fb15843f1bf1..037227337bfc9 100644 --- a/codersdk/toolsdk/bash.go +++ b/codersdk/toolsdk/bash.go @@ -21,9 +21,10 @@ import ( ) type WorkspaceBashArgs struct { - Workspace string `json:"workspace"` - Command string `json:"command"` - TimeoutMs int `json:"timeout_ms,omitempty"` + Workspace string `json:"workspace"` + Command string `json:"command"` + TimeoutMs int `json:"timeout_ms,omitempty"` + Background bool `json:"background,omitempty"` } type WorkspaceBashResult struct { @@ -50,9 +51,13 @@ The workspace parameter supports various formats: The timeout_ms parameter specifies the command timeout in milliseconds (defaults to 60000ms, maximum of 300000ms). If the command times out, all output captured up to that point is returned with a cancellation message. +For background commands (background: true), output is captured until the timeout is reached, then the command +continues running in the background. The captured output is returned as the result. + Examples: - workspace: "my-workspace", command: "ls -la" - workspace: "john/dev-env", command: "git status", timeout_ms: 30000 +- workspace: "my-workspace", command: "npm run dev", background: true, timeout_ms: 10000 - workspace: "my-workspace.main", command: "docker ps"`, Schema: aisdk.Schema{ Properties: map[string]any{ @@ -70,6 +75,10 @@ Examples: "default": 60000, "minimum": 1, }, + "background": map[string]any{ + "type": "boolean", + "description": "Whether to run the command in the background. Output is captured until timeout, then the command continues running in the background.", + }, }, Required: []string{"workspace", "command"}, }, @@ -137,23 +146,35 @@ Examples: // Set default timeout if not specified (60 seconds) timeoutMs := args.TimeoutMs + defaultTimeoutMs := 60000 if timeoutMs <= 0 { - timeoutMs = 60000 + timeoutMs = defaultTimeoutMs + } + command := args.Command + if args.Background { + // For background commands, use nohup directly to ensure they survive SSH session + // termination. This captures output normally but allows the process to continue + // running even after the SSH connection closes. + command = fmt.Sprintf("nohup %s &1", args.Command) } - // Create context with timeout - ctx, cancel = context.WithTimeout(ctx, time.Duration(timeoutMs)*time.Millisecond) - defer cancel() + // Create context with command timeout (replace the broader MCP timeout) + commandCtx, commandCancel := context.WithTimeout(ctx, time.Duration(timeoutMs)*time.Millisecond) + defer commandCancel() // Execute command with timeout handling - output, err := executeCommandWithTimeout(ctx, session, args.Command) + output, err := executeCommandWithTimeout(commandCtx, session, command) outputStr := strings.TrimSpace(string(output)) // Handle command execution results if err != nil { // Check if the command timed out - if errors.Is(context.Cause(ctx), context.DeadlineExceeded) { - outputStr += "\nCommand canceled due to timeout" + if errors.Is(context.Cause(commandCtx), context.DeadlineExceeded) { + if args.Background { + outputStr += "\nCommand continues running in background" + } else { + outputStr += "\nCommand canceled due to timeout" + } return WorkspaceBashResult{ Output: outputStr, ExitCode: 124, @@ -387,21 +408,27 @@ func executeCommandWithTimeout(ctx context.Context, session *gossh.Session, comm return safeWriter.Bytes(), err case <-ctx.Done(): // Context was canceled (timeout or other cancellation) - // Close the session to stop the command - _ = session.Close() + // Close the session to stop the command, but handle errors gracefully + closeErr := session.Close() - // Give a brief moment to collect any remaining output - timer := time.NewTimer(50 * time.Millisecond) + // Give a brief moment to collect any remaining output and for goroutines to finish + timer := time.NewTimer(100 * time.Millisecond) defer timer.Stop() select { case <-timer.C: // Timer expired, return what we have + break case err := <-done: // Command finished during grace period - return safeWriter.Bytes(), err + if closeErr == nil { + return safeWriter.Bytes(), err + } + // If session close failed, prioritize the context error + break } + // Return the collected output with the context error return safeWriter.Bytes(), context.Cause(ctx) } } @@ -421,5 +448,9 @@ func (sw *syncWriter) Write(p []byte) (n int, err error) { func (sw *syncWriter) Bytes() []byte { sw.mu.Lock() defer sw.mu.Unlock() - return sw.w.Bytes() + // Return a copy to prevent race conditions with the underlying buffer + b := sw.w.Bytes() + result := make([]byte, len(b)) + copy(result, b) + return result } diff --git a/codersdk/toolsdk/bash_test.go b/codersdk/toolsdk/bash_test.go index 53ac480039278..0656b2d8786e6 100644 --- a/codersdk/toolsdk/bash_test.go +++ b/codersdk/toolsdk/bash_test.go @@ -9,6 +9,7 @@ import ( "github.com/coder/coder/v2/agent/agenttest" "github.com/coder/coder/v2/coderd/coderdtest" "github.com/coder/coder/v2/codersdk/toolsdk" + "github.com/coder/coder/v2/testutil" ) func TestWorkspaceBash(t *testing.T) { @@ -174,8 +175,6 @@ func TestWorkspaceBashTimeout(t *testing.T) { // Test that the TimeoutMs field can be set and read correctly args := toolsdk.WorkspaceBashArgs{ - Workspace: "test-workspace", - Command: "echo test", TimeoutMs: 0, // Should default to 60000 in handler } @@ -192,8 +191,6 @@ func TestWorkspaceBashTimeout(t *testing.T) { // Test that negative values can be set and will be handled by the default logic args := toolsdk.WorkspaceBashArgs{ - Workspace: "test-workspace", - Command: "echo test", TimeoutMs: -100, } @@ -279,7 +276,7 @@ func TestWorkspaceBashTimeoutIntegration(t *testing.T) { TimeoutMs: 2000, // 2 seconds timeout - should timeout after first echo } - result, err := toolsdk.WorkspaceBash.Handler(t.Context(), deps, args) + result, err := testTool(t, toolsdk.WorkspaceBash, deps, args) // Should not error (timeout is handled gracefully) require.NoError(t, err) @@ -313,7 +310,6 @@ func TestWorkspaceBashTimeoutIntegration(t *testing.T) { deps, err := toolsdk.NewDeps(client) require.NoError(t, err) - ctx := context.Background() args := toolsdk.WorkspaceBashArgs{ Workspace: workspace.Name, @@ -321,7 +317,8 @@ func TestWorkspaceBashTimeoutIntegration(t *testing.T) { TimeoutMs: 5000, // 5 second timeout - plenty of time } - result, err := toolsdk.WorkspaceBash.Handler(ctx, deps, args) + // Use testTool to register the tool as tested and satisfy coverage validation + result, err := testTool(t, toolsdk.WorkspaceBash, deps, args) // Should not error require.NoError(t, err) @@ -338,3 +335,142 @@ func TestWorkspaceBashTimeoutIntegration(t *testing.T) { require.NotContains(t, result.Output, "Command canceled due to timeout") }) } + +func TestWorkspaceBashBackgroundIntegration(t *testing.T) { + t.Parallel() + + t.Run("BackgroundCommandCapturesOutput", func(t *testing.T) { + t.Parallel() + + client, workspace, agentToken := setupWorkspaceForAgent(t) + + // Start the agent and wait for it to be fully ready + _ = agenttest.New(t, client.URL, agentToken) + + // Wait for workspace agents to be ready + coderdtest.NewWorkspaceAgentWaiter(t, client, workspace.ID).Wait() + + deps, err := toolsdk.NewDeps(client) + require.NoError(t, err) + + args := toolsdk.WorkspaceBashArgs{ + Workspace: workspace.Name, + Command: `echo "started" && sleep 60 && echo "completed"`, // Command that would take 60+ seconds + Background: true, // Run in background + TimeoutMs: 2000, // 2 second timeout + } + + result, err := testTool(t, toolsdk.WorkspaceBash, deps, args) + + // Should not error + require.NoError(t, err) + + t.Logf("Background result: exitCode=%d, output=%q", result.ExitCode, result.Output) + + // Should have exit code 124 (timeout) since command times out + require.Equal(t, 124, result.ExitCode) + + // Should capture output up to timeout point + require.Contains(t, result.Output, "started", "Should contain output captured before timeout") + + // Should NOT contain the second echo (it never executed due to timeout) + require.NotContains(t, result.Output, "completed", "Should not contain output after timeout") + + // Should contain background continuation message + require.Contains(t, result.Output, "Command continues running in background") + }) + + t.Run("BackgroundVsNormalExecution", func(t *testing.T) { + t.Parallel() + + client, workspace, agentToken := setupWorkspaceForAgent(t) + + // Start the agent and wait for it to be fully ready + _ = agenttest.New(t, client.URL, agentToken) + + // Wait for workspace agents to be ready + coderdtest.NewWorkspaceAgentWaiter(t, client, workspace.ID).Wait() + + deps, err := toolsdk.NewDeps(client) + require.NoError(t, err) + + // First run the same command in normal mode + normalArgs := toolsdk.WorkspaceBashArgs{ + Workspace: workspace.Name, + Command: `echo "hello world"`, + Background: false, + } + + normalResult, err := toolsdk.WorkspaceBash.Handler(t.Context(), deps, normalArgs) + require.NoError(t, err) + + // Normal mode should return the actual output + require.Equal(t, 0, normalResult.ExitCode) + require.Equal(t, "hello world", normalResult.Output) + + // Now run the same command in background mode + backgroundArgs := toolsdk.WorkspaceBashArgs{ + Workspace: workspace.Name, + Command: `echo "hello world"`, + Background: true, + } + + backgroundResult, err := testTool(t, toolsdk.WorkspaceBash, deps, backgroundArgs) + require.NoError(t, err) + + t.Logf("Normal result: %q", normalResult.Output) + t.Logf("Background result: %q", backgroundResult.Output) + + // Background mode should also return the actual output since command completes quickly + require.Equal(t, 0, backgroundResult.ExitCode) + require.Equal(t, "hello world", backgroundResult.Output) + }) + + t.Run("BackgroundCommandContinuesAfterTimeout", func(t *testing.T) { + t.Parallel() + + client, workspace, agentToken := setupWorkspaceForAgent(t) + + // Start the agent and wait for it to be fully ready + _ = agenttest.New(t, client.URL, agentToken) + + // Wait for workspace agents to be ready + coderdtest.NewWorkspaceAgentWaiter(t, client, workspace.ID).Wait() + + deps, err := toolsdk.NewDeps(client) + require.NoError(t, err) + + args := toolsdk.WorkspaceBashArgs{ + Workspace: workspace.Name, + Command: `echo "started" && sleep 4 && echo "done" > /tmp/bg-test-done`, // Command that will timeout but continue + TimeoutMs: 2000, // 2000ms timeout (shorter than command duration) + Background: true, // Run in background + } + + result, err := testTool(t, toolsdk.WorkspaceBash, deps, args) + + // Should not error but should timeout + require.NoError(t, err) + + t.Logf("Background with timeout result: exitCode=%d, output=%q", result.ExitCode, result.Output) + + // Should have timeout exit code + require.Equal(t, 124, result.ExitCode) + + // Should capture output before timeout + require.Contains(t, result.Output, "started", "Should contain output captured before timeout") + + // Should contain background continuation message + require.Contains(t, result.Output, "Command continues running in background") + + // Wait for the background command to complete (even though SSH session timed out) + require.Eventually(t, func() bool { + checkArgs := toolsdk.WorkspaceBashArgs{ + Workspace: workspace.Name, + Command: `cat /tmp/bg-test-done 2>/dev/null || echo "not found"`, + } + checkResult, err := toolsdk.WorkspaceBash.Handler(t.Context(), deps, checkArgs) + return err == nil && checkResult.Output == "done" + }, testutil.WaitMedium, testutil.IntervalMedium, "Background command should continue running and complete after timeout") + }) +} diff --git a/codersdk/toolsdk/toolsdk_test.go b/codersdk/toolsdk/toolsdk_test.go index c201190bd3456..13e475c80609a 100644 --- a/codersdk/toolsdk/toolsdk_test.go +++ b/codersdk/toolsdk/toolsdk_test.go @@ -456,7 +456,7 @@ var testedTools sync.Map // This is to mimic how we expect external callers to use the tool. func testTool[Arg, Ret any](t *testing.T, tool toolsdk.Tool[Arg, Ret], tb toolsdk.Deps, args Arg) (Ret, error) { t.Helper() - defer func() { testedTools.Store(tool.Tool.Name, true) }() + defer func() { testedTools.Store(tool.Name, true) }() toolArgs, err := json.Marshal(args) require.NoError(t, err, "failed to marshal args") result, err := tool.Generic().Handler(t.Context(), tb, toolArgs) @@ -625,23 +625,23 @@ func TestToolSchemaFields(t *testing.T) { // Test that all tools have the required Schema fields (Properties and Required) for _, tool := range toolsdk.All { - t.Run(tool.Tool.Name, func(t *testing.T) { + t.Run(tool.Name, func(t *testing.T) { t.Parallel() // Check that Properties is not nil - require.NotNil(t, tool.Tool.Schema.Properties, - "Tool %q missing Schema.Properties", tool.Tool.Name) + require.NotNil(t, tool.Schema.Properties, + "Tool %q missing Schema.Properties", tool.Name) // Check that Required is not nil - require.NotNil(t, tool.Tool.Schema.Required, - "Tool %q missing Schema.Required", tool.Tool.Name) + require.NotNil(t, tool.Schema.Required, + "Tool %q missing Schema.Required", tool.Name) // Ensure Properties has entries for all required fields - for _, requiredField := range tool.Tool.Schema.Required { - _, exists := tool.Tool.Schema.Properties[requiredField] + for _, requiredField := range tool.Schema.Required { + _, exists := tool.Schema.Properties[requiredField] require.True(t, exists, "Tool %q requires field %q but it is not defined in Properties", - tool.Tool.Name, requiredField) + tool.Name, requiredField) } }) } @@ -652,7 +652,7 @@ func TestToolSchemaFields(t *testing.T) { func TestMain(m *testing.M) { // Initialize testedTools for _, tool := range toolsdk.All { - testedTools.Store(tool.Tool.Name, false) + testedTools.Store(tool.Name, false) } code := m.Run() @@ -660,8 +660,8 @@ func TestMain(m *testing.M) { // Ensure all tools have been tested var untested []string for _, tool := range toolsdk.All { - if tested, ok := testedTools.Load(tool.Tool.Name); !ok || !tested.(bool) { - untested = append(untested, tool.Tool.Name) + if tested, ok := testedTools.Load(tool.Name); !ok || !tested.(bool) { + untested = append(untested, tool.Name) } }
Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.
Alternative Proxies: