diff --git a/cli/agent.go b/cli/agent.go index 5465aeedd9302..073581bd950cb 100644 --- a/cli/agent.go +++ b/cli/agent.go @@ -50,6 +50,8 @@ func (r *RootCmd) workspaceAgent() *serpent.Command { slogJSONPath string slogStackdriverPath string blockFileTransfer bool + agentHeaderCommand string + agentHeader []string ) cmd := &serpent.Command{ Use: "agent", @@ -176,6 +178,14 @@ func (r *RootCmd) workspaceAgent() *serpent.Command { // with large payloads can take a bit. e.g. startup scripts // may take a while to insert. client.SDK.HTTPClient.Timeout = 30 * time.Second + // Attach header transport so we process --agent-header and + // --agent-header-command flags + headerTransport, err := headerTransport(ctx, r.agentURL, agentHeader, agentHeaderCommand) + if err != nil { + return xerrors.Errorf("configure header transport: %w", err) + } + headerTransport.Transport = client.SDK.HTTPClient.Transport + client.SDK.HTTPClient.Transport = headerTransport // Enable pprof handler // This prevents the pprof import from being accidentally deleted. @@ -361,6 +371,18 @@ func (r *RootCmd) workspaceAgent() *serpent.Command { Value: serpent.StringOf(&pprofAddress), Description: "The address to serve pprof.", }, + { + Flag: "agent-header-command", + Env: "CODER_AGENT_HEADER_COMMAND", + Value: serpent.StringOf(&agentHeaderCommand), + Description: "An external command that outputs additional HTTP headers added to all requests. The command must output each header as `key=value` on its own line.", + }, + { + Flag: "agent-header", + Env: "CODER_AGENT_HEADER", + Value: serpent.StringArrayOf(&agentHeader), + Description: "Additional HTTP headers added to all requests. Provide as " + `key=value` + ". Can be specified multiple times.", + }, { Flag: "no-reap", diff --git a/cli/agent_test.go b/cli/agent_test.go index 9571bf03e1a09..f30d12b012d88 100644 --- a/cli/agent_test.go +++ b/cli/agent_test.go @@ -3,10 +3,13 @@ package cli_test import ( "context" "fmt" + "net/http" + "net/http/httptest" "os" "path/filepath" "runtime" "strings" + "sync/atomic" "testing" "github.com/google/uuid" @@ -229,6 +232,43 @@ func TestWorkspaceAgent(t *testing.T) { require.Equal(t, codersdk.AgentSubsystemEnvbox, resources[0].Agents[0].Subsystems[0]) require.Equal(t, codersdk.AgentSubsystemExectrace, resources[0].Agents[0].Subsystems[1]) }) + t.Run("Header", func(t *testing.T) { + t.Parallel() + + var url string + var called int64 + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "wow", r.Header.Get("X-Testing")) + assert.Equal(t, "Ethan was Here!", r.Header.Get("Cool-Header")) + assert.Equal(t, "very-wow-"+url, r.Header.Get("X-Process-Testing")) + assert.Equal(t, "more-wow", r.Header.Get("X-Process-Testing2")) + atomic.AddInt64(&called, 1) + w.WriteHeader(http.StatusGone) + })) + defer srv.Close() + url = srv.URL + coderURLEnv := "$CODER_URL" + if runtime.GOOS == "windows" { + coderURLEnv = "%CODER_URL%" + } + + logDir := t.TempDir() + inv, _ := clitest.New(t, + "agent", + "--auth", "token", + "--agent-token", "fake-token", + "--agent-url", srv.URL, + "--log-dir", logDir, + "--agent-header", "X-Testing=wow", + "--agent-header", "Cool-Header=Ethan was Here!", + "--agent-header-command", "printf X-Process-Testing=very-wow-"+coderURLEnv+"'\\r\\n'X-Process-Testing2=more-wow", + ) + + clitest.Start(t, inv) + require.Eventually(t, func() bool { + return atomic.LoadInt64(&called) > 0 + }, testutil.WaitShort, testutil.IntervalFast) + }) } func matchAgentWithVersion(rs []codersdk.WorkspaceResource) bool { diff --git a/cli/root.go b/cli/root.go index fdebdc74bedde..da7e48f2feae4 100644 --- a/cli/root.go +++ b/cli/root.go @@ -550,44 +550,7 @@ func (r *RootCmd) InitClient(client *codersdk.Client) serpent.MiddlewareFunc { // HeaderTransport creates a new transport that executes `--header-command` // if it is set to add headers for all outbound requests. func (r *RootCmd) HeaderTransport(ctx context.Context, serverURL *url.URL) (*codersdk.HeaderTransport, error) { - transport := &codersdk.HeaderTransport{ - Transport: http.DefaultTransport, - Header: http.Header{}, - } - headers := r.header - if r.headerCommand != "" { - shell := "sh" - caller := "-c" - if runtime.GOOS == "windows" { - shell = "cmd.exe" - caller = "/c" - } - var outBuf bytes.Buffer - // #nosec - cmd := exec.CommandContext(ctx, shell, caller, r.headerCommand) - cmd.Env = append(os.Environ(), "CODER_URL="+serverURL.String()) - cmd.Stdout = &outBuf - cmd.Stderr = io.Discard - err := cmd.Run() - if err != nil { - return nil, xerrors.Errorf("failed to run %v: %w", cmd.Args, err) - } - scanner := bufio.NewScanner(&outBuf) - for scanner.Scan() { - headers = append(headers, scanner.Text()) - } - if err := scanner.Err(); err != nil { - return nil, xerrors.Errorf("scan %v: %w", cmd.Args, err) - } - } - for _, header := range headers { - parts := strings.SplitN(header, "=", 2) - if len(parts) < 2 { - return nil, xerrors.Errorf("split header %q had less than two parts", header) - } - transport.Header.Add(parts[0], parts[1]) - } - return transport, nil + return headerTransport(ctx, serverURL, r.header, r.headerCommand) } func (r *RootCmd) configureClient(ctx context.Context, client *codersdk.Client, serverURL *url.URL, inv *serpent.Invocation) error { @@ -1273,3 +1236,46 @@ type roundTripper func(req *http.Request) (*http.Response, error) func (r roundTripper) RoundTrip(req *http.Request) (*http.Response, error) { return r(req) } + +// HeaderTransport creates a new transport that executes `--header-command` +// if it is set to add headers for all outbound requests. +func headerTransport(ctx context.Context, serverURL *url.URL, header []string, headerCommand string) (*codersdk.HeaderTransport, error) { + transport := &codersdk.HeaderTransport{ + Transport: http.DefaultTransport, + Header: http.Header{}, + } + headers := header + if headerCommand != "" { + shell := "sh" + caller := "-c" + if runtime.GOOS == "windows" { + shell = "cmd.exe" + caller = "/c" + } + var outBuf bytes.Buffer + // #nosec + cmd := exec.CommandContext(ctx, shell, caller, headerCommand) + cmd.Env = append(os.Environ(), "CODER_URL="+serverURL.String()) + cmd.Stdout = &outBuf + cmd.Stderr = io.Discard + err := cmd.Run() + if err != nil { + return nil, xerrors.Errorf("failed to run %v: %w", cmd.Args, err) + } + scanner := bufio.NewScanner(&outBuf) + for scanner.Scan() { + headers = append(headers, scanner.Text()) + } + if err := scanner.Err(); err != nil { + return nil, xerrors.Errorf("scan %v: %w", cmd.Args, err) + } + } + for _, header := range headers { + parts := strings.SplitN(header, "=", 2) + if len(parts) < 2 { + return nil, xerrors.Errorf("split header %q had less than two parts", header) + } + transport.Header.Add(parts[0], parts[1]) + } + return transport, nil +} diff --git a/cli/testdata/coder_agent_--help.golden b/cli/testdata/coder_agent_--help.golden index d6982fda18e7c..3394b43a9e900 100644 --- a/cli/testdata/coder_agent_--help.golden +++ b/cli/testdata/coder_agent_--help.golden @@ -15,6 +15,15 @@ OPTIONS: --log-stackdriver string, $CODER_AGENT_LOGGING_STACKDRIVER Output Stackdriver compatible logs to a given file. + --agent-header string-array, $CODER_AGENT_HEADER + Additional HTTP headers added to all requests. Provide as key=value. + Can be specified multiple times. + + --agent-header-command string, $CODER_AGENT_HEADER_COMMAND + An external command that outputs additional HTTP headers added to all + requests. The command must output each header as `key=value` on its + own line. + --auth string, $CODER_AGENT_AUTH (default: token) Specify the authentication type to use for the agent.
Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.
Alternative Proxies: