Skip to content

Commit 9f5ad23

Browse files
authored
refactor(agent/agentssh): move parsing of magic session and create type (#16630)
This change refactors the parsing of MagicSessionEnvs in the agentssh package and moves the logic to an earlier stage. Also intoduces enums for MagicSessionType. Refs #15139
1 parent 570e42b commit 9f5ad23

File tree

3 files changed

+92
-56
lines changed

3 files changed

+92
-56
lines changed

agent/agent_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ func TestAgent_Stats_Magic(t *testing.T) {
138138
defer sshClient.Close()
139139
session, err := sshClient.NewSession()
140140
require.NoError(t, err)
141-
session.Setenv(agentssh.MagicSessionTypeEnvironmentVariable, agentssh.MagicSessionTypeVSCode)
141+
session.Setenv(agentssh.MagicSessionTypeEnvironmentVariable, string(agentssh.MagicSessionTypeVSCode))
142142
defer session.Close()
143143

144144
command := "sh -c 'echo $" + agentssh.MagicSessionTypeEnvironmentVariable + "'"
@@ -165,7 +165,7 @@ func TestAgent_Stats_Magic(t *testing.T) {
165165
defer sshClient.Close()
166166
session, err := sshClient.NewSession()
167167
require.NoError(t, err)
168-
session.Setenv(agentssh.MagicSessionTypeEnvironmentVariable, agentssh.MagicSessionTypeVSCode)
168+
session.Setenv(agentssh.MagicSessionTypeEnvironmentVariable, string(agentssh.MagicSessionTypeVSCode))
169169
defer session.Close()
170170
stdin, err := session.StdinPipe()
171171
require.NoError(t, err)

agent/agentssh/agentssh.go

Lines changed: 85 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ import (
2626
"github.com/spf13/afero"
2727
"go.uber.org/atomic"
2828
gossh "golang.org/x/crypto/ssh"
29+
"golang.org/x/exp/slices"
2930
"golang.org/x/xerrors"
3031

3132
"cdr.dev/slog"
@@ -42,14 +43,6 @@ const (
4243
// unlikely to shadow other exit codes, which are typically 1, 2, 3, etc.
4344
MagicSessionErrorCode = 229
4445

45-
// MagicSessionTypeEnvironmentVariable is used to track the purpose behind an SSH connection.
46-
// This is stripped from any commands being executed, and is counted towards connection stats.
47-
MagicSessionTypeEnvironmentVariable = "CODER_SSH_SESSION_TYPE"
48-
// MagicSessionTypeVSCode is set in the SSH config by the VS Code extension to identify itself.
49-
MagicSessionTypeVSCode = "vscode"
50-
// MagicSessionTypeJetBrains is set in the SSH config by the JetBrains
51-
// extension to identify itself.
52-
MagicSessionTypeJetBrains = "jetbrains"
5346
// MagicProcessCmdlineJetBrains is a string in a process's command line that
5447
// uniquely identifies it as JetBrains software.
5548
MagicProcessCmdlineJetBrains = "idea.vendor.name=JetBrains"
@@ -60,6 +53,29 @@ const (
6053
BlockedFileTransferErrorMessage = "File transfer has been disabled."
6154
)
6255

56+
// MagicSessionType is a type that represents the type of session that is being
57+
// established.
58+
type MagicSessionType string
59+
60+
const (
61+
// MagicSessionTypeEnvironmentVariable is used to track the purpose behind an SSH connection.
62+
// This is stripped from any commands being executed, and is counted towards connection stats.
63+
MagicSessionTypeEnvironmentVariable = "CODER_SSH_SESSION_TYPE"
64+
)
65+
66+
// MagicSessionType enums.
67+
const (
68+
// MagicSessionTypeUnknown means the session type could not be determined.
69+
MagicSessionTypeUnknown MagicSessionType = "unknown"
70+
// MagicSessionTypeSSH is the default session type.
71+
MagicSessionTypeSSH MagicSessionType = "ssh"
72+
// MagicSessionTypeVSCode is set in the SSH config by the VS Code extension to identify itself.
73+
MagicSessionTypeVSCode MagicSessionType = "vscode"
74+
// MagicSessionTypeJetBrains is set in the SSH config by the JetBrains
75+
// extension to identify itself.
76+
MagicSessionTypeJetBrains MagicSessionType = "jetbrains"
77+
)
78+
6379
// BlockedFileTransferCommands contains a list of restricted file transfer commands.
6480
var BlockedFileTransferCommands = []string{"nc", "rsync", "scp", "sftp"}
6581

@@ -255,14 +271,42 @@ func (s *Server) ConnStats() ConnStats {
255271
}
256272
}
257273

274+
func extractMagicSessionType(env []string) (magicType MagicSessionType, rawType string, filteredEnv []string) {
275+
for _, kv := range env {
276+
if !strings.HasPrefix(kv, MagicSessionTypeEnvironmentVariable) {
277+
continue
278+
}
279+
280+
rawType = strings.TrimPrefix(kv, MagicSessionTypeEnvironmentVariable+"=")
281+
// Keep going, we'll use the last instance of the env.
282+
}
283+
284+
// Always force lowercase checking to be case-insensitive.
285+
switch MagicSessionType(strings.ToLower(rawType)) {
286+
case MagicSessionTypeVSCode:
287+
magicType = MagicSessionTypeVSCode
288+
case MagicSessionTypeJetBrains:
289+
magicType = MagicSessionTypeJetBrains
290+
case "", MagicSessionTypeSSH:
291+
magicType = MagicSessionTypeSSH
292+
default:
293+
magicType = MagicSessionTypeUnknown
294+
}
295+
296+
return magicType, rawType, slices.DeleteFunc(env, func(kv string) bool {
297+
return strings.HasPrefix(kv, MagicSessionTypeEnvironmentVariable+"=")
298+
})
299+
}
300+
258301
func (s *Server) sessionHandler(session ssh.Session) {
259302
ctx := session.Context()
303+
id := uuid.New()
260304
logger := s.logger.With(
261305
slog.F("remote_addr", session.RemoteAddr()),
262306
slog.F("local_addr", session.LocalAddr()),
263307
// Assigning a random uuid for each session is useful for tracking
264308
// logs for the same ssh session.
265-
slog.F("id", uuid.NewString()),
309+
slog.F("id", id.String()),
266310
)
267311
logger.Info(ctx, "handling ssh session")
268312

@@ -274,16 +318,21 @@ func (s *Server) sessionHandler(session ssh.Session) {
274318
}
275319
defer s.trackSession(session, false)
276320

277-
extraEnv := make([]string, 0)
278-
x11, hasX11 := session.X11()
279-
if hasX11 {
280-
display, handled := s.x11Handler(session.Context(), x11)
281-
if !handled {
282-
_ = session.Exit(1)
283-
logger.Error(ctx, "x11 handler failed")
284-
return
285-
}
286-
extraEnv = append(extraEnv, fmt.Sprintf("DISPLAY=localhost:%d.%d", display, x11.ScreenNumber))
321+
env := session.Environ()
322+
magicType, magicTypeRaw, env := extractMagicSessionType(env)
323+
324+
switch magicType {
325+
case MagicSessionTypeVSCode:
326+
s.connCountVSCode.Add(1)
327+
defer s.connCountVSCode.Add(-1)
328+
case MagicSessionTypeJetBrains:
329+
// Do nothing here because JetBrains launches hundreds of ssh sessions.
330+
// We instead track JetBrains in the single persistent tcp forwarding channel.
331+
case MagicSessionTypeSSH:
332+
s.connCountSSHSession.Add(1)
333+
defer s.connCountSSHSession.Add(-1)
334+
case MagicSessionTypeUnknown:
335+
logger.Warn(ctx, "invalid magic ssh session type specified", slog.F("raw_type", magicTypeRaw))
287336
}
288337

289338
if s.fileTransferBlocked(session) {
@@ -309,7 +358,18 @@ func (s *Server) sessionHandler(session ssh.Session) {
309358
return
310359
}
311360

312-
err := s.sessionStart(logger, session, extraEnv)
361+
x11, hasX11 := session.X11()
362+
if hasX11 {
363+
display, handled := s.x11Handler(session.Context(), x11)
364+
if !handled {
365+
_ = session.Exit(1)
366+
logger.Error(ctx, "x11 handler failed")
367+
return
368+
}
369+
env = append(env, fmt.Sprintf("DISPLAY=localhost:%d.%d", display, x11.ScreenNumber))
370+
}
371+
372+
err := s.sessionStart(logger, session, env, magicType)
313373
var exitError *exec.ExitError
314374
if xerrors.As(err, &exitError) {
315375
code := exitError.ExitCode()
@@ -379,32 +439,8 @@ func (s *Server) fileTransferBlocked(session ssh.Session) bool {
379439
return false
380440
}
381441

382-
func (s *Server) sessionStart(logger slog.Logger, session ssh.Session, extraEnv []string) (retErr error) {
442+
func (s *Server) sessionStart(logger slog.Logger, session ssh.Session, env []string, magicType MagicSessionType) (retErr error) {
383443
ctx := session.Context()
384-
env := append(session.Environ(), extraEnv...)
385-
var magicType string
386-
for index, kv := range env {
387-
if !strings.HasPrefix(kv, MagicSessionTypeEnvironmentVariable) {
388-
continue
389-
}
390-
magicType = strings.ToLower(strings.TrimPrefix(kv, MagicSessionTypeEnvironmentVariable+"="))
391-
env = append(env[:index], env[index+1:]...)
392-
}
393-
394-
// Always force lowercase checking to be case-insensitive.
395-
switch magicType {
396-
case MagicSessionTypeVSCode:
397-
s.connCountVSCode.Add(1)
398-
defer s.connCountVSCode.Add(-1)
399-
case MagicSessionTypeJetBrains:
400-
// Do nothing here because JetBrains launches hundreds of ssh sessions.
401-
// We instead track JetBrains in the single persistent tcp forwarding channel.
402-
case "":
403-
s.connCountSSHSession.Add(1)
404-
defer s.connCountSSHSession.Add(-1)
405-
default:
406-
logger.Warn(ctx, "invalid magic ssh session type specified", slog.F("type", magicType))
407-
}
408444

409445
magicTypeLabel := magicTypeMetricLabel(magicType)
410446
sshPty, windowSize, isPty := session.Pty()
@@ -473,7 +509,7 @@ func (s *Server) startNonPTYSession(logger slog.Logger, session ssh.Session, mag
473509
}()
474510
go func() {
475511
for sig := range sigs {
476-
s.handleSignal(logger, sig, cmd.Process, magicTypeLabel)
512+
handleSignal(logger, sig, cmd.Process, s.metrics, magicTypeLabel)
477513
}
478514
}()
479515
return cmd.Wait()
@@ -558,7 +594,7 @@ func (s *Server) startPTYSession(logger slog.Logger, session ptySession, magicTy
558594
sigs = nil
559595
continue
560596
}
561-
s.handleSignal(logger, sig, process, magicTypeLabel)
597+
handleSignal(logger, sig, process, s.metrics, magicTypeLabel)
562598
case win, ok := <-windowSize:
563599
if !ok {
564600
windowSize = nil
@@ -612,15 +648,15 @@ func (s *Server) startPTYSession(logger slog.Logger, session ptySession, magicTy
612648
return nil
613649
}
614650

615-
func (s *Server) handleSignal(logger slog.Logger, ssig ssh.Signal, signaler interface{ Signal(os.Signal) error }, magicTypeLabel string) {
651+
func handleSignal(logger slog.Logger, ssig ssh.Signal, signaler interface{ Signal(os.Signal) error }, metrics *sshServerMetrics, magicTypeLabel string) {
616652
ctx := context.Background()
617653
sig := osSignalFrom(ssig)
618654
logger = logger.With(slog.F("ssh_signal", ssig), slog.F("signal", sig.String()))
619655
logger.Info(ctx, "received signal from client")
620656
err := signaler.Signal(sig)
621657
if err != nil {
622658
logger.Warn(ctx, "signaling the process failed", slog.Error(err))
623-
s.metrics.sessionErrors.WithLabelValues(magicTypeLabel, "yes", "signal").Add(1)
659+
metrics.sessionErrors.WithLabelValues(magicTypeLabel, "yes", "signal").Add(1)
624660
}
625661
}
626662

agent/agentssh/metrics.go

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -71,15 +71,15 @@ func newSSHServerMetrics(registerer prometheus.Registerer) *sshServerMetrics {
7171
}
7272
}
7373

74-
func magicTypeMetricLabel(magicType string) string {
74+
func magicTypeMetricLabel(magicType MagicSessionType) string {
7575
switch magicType {
7676
case MagicSessionTypeVSCode:
7777
case MagicSessionTypeJetBrains:
78-
case "":
79-
magicType = "ssh"
78+
case MagicSessionTypeSSH:
79+
case MagicSessionTypeUnknown:
8080
default:
81-
magicType = "unknown"
81+
magicType = MagicSessionTypeUnknown
8282
}
8383
// Always be case insensitive
84-
return strings.ToLower(magicType)
84+
return strings.ToLower(string(magicType))
8585
}

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy