Skip to content

Commit 4edd77b

Browse files
authored
chore(agent/agentssh): extract CreateCommandDeps (#16603)
Extracts environment-level dependencies of `agentssh.Server.CreateCommand()` to an interface to allow alternative implementations to be passed in.
1 parent 52cc0ce commit 4edd77b

File tree

5 files changed

+91
-11
lines changed

5 files changed

+91
-11
lines changed

agent/agent.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -340,7 +340,7 @@ func (a *agent) collectMetadata(ctx context.Context, md codersdk.WorkspaceAgentM
340340
// if it can guarantee the clocks are synchronized.
341341
CollectedAt: now,
342342
}
343-
cmdPty, err := a.sshServer.CreateCommand(ctx, md.Script, nil)
343+
cmdPty, err := a.sshServer.CreateCommand(ctx, md.Script, nil, nil)
344344
if err != nil {
345345
result.Error = fmt.Sprintf("create cmd: %+v", err)
346346
return result

agent/agentscripts/agentscripts.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -283,7 +283,7 @@ func (r *Runner) run(ctx context.Context, script codersdk.WorkspaceAgentScript,
283283
cmdCtx, ctxCancel = context.WithTimeout(ctx, script.Timeout)
284284
defer ctxCancel()
285285
}
286-
cmdPty, err := r.SSHServer.CreateCommand(cmdCtx, script.Script, nil)
286+
cmdPty, err := r.SSHServer.CreateCommand(cmdCtx, script.Script, nil, nil)
287287
if err != nil {
288288
return xerrors.Errorf("%s script: create command: %w", logPath, err)
289289
}

agent/agentssh/agentssh.go

Lines changed: 52 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -409,7 +409,7 @@ func (s *Server) sessionStart(logger slog.Logger, session ssh.Session, extraEnv
409409
magicTypeLabel := magicTypeMetricLabel(magicType)
410410
sshPty, windowSize, isPty := session.Pty()
411411

412-
cmd, err := s.CreateCommand(ctx, session.RawCommand(), env)
412+
cmd, err := s.CreateCommand(ctx, session.RawCommand(), env, nil)
413413
if err != nil {
414414
ptyLabel := "no"
415415
if isPty {
@@ -670,17 +670,63 @@ func (s *Server) sftpHandler(logger slog.Logger, session ssh.Session) {
670670
_ = session.Exit(1)
671671
}
672672

673+
// EnvInfoer encapsulates external information required by CreateCommand.
674+
type EnvInfoer interface {
675+
// CurrentUser returns the current user.
676+
CurrentUser() (*user.User, error)
677+
// Environ returns the environment variables of the current process.
678+
Environ() []string
679+
// UserHomeDir returns the home directory of the current user.
680+
UserHomeDir() (string, error)
681+
// UserShell returns the shell of the given user.
682+
UserShell(username string) (string, error)
683+
}
684+
685+
type systemEnvInfoer struct{}
686+
687+
var defaultEnvInfoer EnvInfoer = &systemEnvInfoer{}
688+
689+
// DefaultEnvInfoer returns a default implementation of
690+
// EnvInfoer. This reads information using the default Go
691+
// implementations.
692+
func DefaultEnvInfoer() EnvInfoer {
693+
return defaultEnvInfoer
694+
}
695+
696+
func (systemEnvInfoer) CurrentUser() (*user.User, error) {
697+
return user.Current()
698+
}
699+
700+
func (systemEnvInfoer) Environ() []string {
701+
return os.Environ()
702+
}
703+
704+
func (systemEnvInfoer) UserHomeDir() (string, error) {
705+
return userHomeDir()
706+
}
707+
708+
func (systemEnvInfoer) UserShell(username string) (string, error) {
709+
return usershell.Get(username)
710+
}
711+
673712
// CreateCommand processes raw command input with OpenSSH-like behavior.
674713
// If the script provided is empty, it will default to the users shell.
675714
// This injects environment variables specified by the user at launch too.
676-
func (s *Server) CreateCommand(ctx context.Context, script string, env []string) (*pty.Cmd, error) {
677-
currentUser, err := user.Current()
715+
// The final argument is an interface that allows the caller to provide
716+
// alternative implementations for the dependencies of CreateCommand.
717+
// This is useful when creating a command to be run in a separate environment
718+
// (for example, a Docker container). Pass in nil to use the default.
719+
func (s *Server) CreateCommand(ctx context.Context, script string, env []string, deps EnvInfoer) (*pty.Cmd, error) {
720+
if deps == nil {
721+
deps = DefaultEnvInfoer()
722+
}
723+
currentUser, err := deps.CurrentUser()
678724
if err != nil {
679725
return nil, xerrors.Errorf("get current user: %w", err)
680726
}
681727
username := currentUser.Username
682728

683-
shell, err := usershell.Get(username)
729+
shell, err := deps.UserShell(username)
684730
if err != nil {
685731
return nil, xerrors.Errorf("get user shell: %w", err)
686732
}
@@ -736,13 +782,13 @@ func (s *Server) CreateCommand(ctx context.Context, script string, env []string)
736782
_, err = os.Stat(cmd.Dir)
737783
if cmd.Dir == "" || err != nil {
738784
// Default to user home if a directory is not set.
739-
homedir, err := userHomeDir()
785+
homedir, err := deps.UserHomeDir()
740786
if err != nil {
741787
return nil, xerrors.Errorf("get home dir: %w", err)
742788
}
743789
cmd.Dir = homedir
744790
}
745-
cmd.Env = append(os.Environ(), env...)
791+
cmd.Env = append(deps.Environ(), env...)
746792
cmd.Env = append(cmd.Env, fmt.Sprintf("USER=%s", username))
747793

748794
// Set SSH connection environment variables (these are also set by OpenSSH

agent/agentssh/agentssh_test.go

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
"context"
99
"fmt"
1010
"net"
11+
"os/user"
1112
"runtime"
1213
"strings"
1314
"sync"
@@ -87,7 +88,7 @@ func TestNewServer_ExecuteShebang(t *testing.T) {
8788
t.Run("Basic", func(t *testing.T) {
8889
t.Parallel()
8990
cmd, err := s.CreateCommand(ctx, `#!/bin/bash
90-
echo test`, nil)
91+
echo test`, nil, nil)
9192
require.NoError(t, err)
9293
output, err := cmd.AsExec().CombinedOutput()
9394
require.NoError(t, err)
@@ -96,12 +97,45 @@ func TestNewServer_ExecuteShebang(t *testing.T) {
9697
t.Run("Args", func(t *testing.T) {
9798
t.Parallel()
9899
cmd, err := s.CreateCommand(ctx, `#!/usr/bin/env bash
99-
echo test`, nil)
100+
echo test`, nil, nil)
100101
require.NoError(t, err)
101102
output, err := cmd.AsExec().CombinedOutput()
102103
require.NoError(t, err)
103104
require.Equal(t, "test\n", string(output))
104105
})
106+
t.Run("CustomEnvInfoer", func(t *testing.T) {
107+
t.Parallel()
108+
ei := &fakeEnvInfoer{
109+
CurrentUserFn: func() (u *user.User, err error) {
110+
return nil, assert.AnError
111+
},
112+
}
113+
_, err := s.CreateCommand(ctx, `whatever`, nil, ei)
114+
require.ErrorIs(t, err, assert.AnError)
115+
})
116+
}
117+
118+
type fakeEnvInfoer struct {
119+
CurrentUserFn func() (*user.User, error)
120+
EnvironFn func() []string
121+
UserHomeDirFn func() (string, error)
122+
UserShellFn func(string) (string, error)
123+
}
124+
125+
func (f *fakeEnvInfoer) CurrentUser() (u *user.User, err error) {
126+
return f.CurrentUserFn()
127+
}
128+
129+
func (f *fakeEnvInfoer) Environ() []string {
130+
return f.EnvironFn()
131+
}
132+
133+
func (f *fakeEnvInfoer) UserHomeDir() (string, error) {
134+
return f.UserHomeDirFn()
135+
}
136+
137+
func (f *fakeEnvInfoer) UserShell(u string) (string, error) {
138+
return f.UserShellFn(u)
105139
}
106140

107141
func TestNewServer_CloseActiveConnections(t *testing.T) {

agent/reconnectingpty/server.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ func (s *Server) handleConn(ctx context.Context, logger slog.Logger, conn net.Co
159159
}()
160160

161161
// Empty command will default to the users shell!
162-
cmd, err := s.commandCreator.CreateCommand(ctx, msg.Command, nil)
162+
cmd, err := s.commandCreator.CreateCommand(ctx, msg.Command, nil, nil)
163163
if err != nil {
164164
s.errorsTotal.WithLabelValues("create_command").Add(1)
165165
return xerrors.Errorf("create command: %w", err)

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy