Skip to content

Commit 8f07d33

Browse files
authored
feat(agent/agentssh): use tcp for X11 forwarding (#14560)
Fixes #14198
1 parent e6d8f67 commit 8f07d33

File tree

3 files changed

+129
-67
lines changed

3 files changed

+129
-67
lines changed

agent/agentssh/agentssh.go

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -79,9 +79,9 @@ type Config struct {
7979
// where users will land when they connect via SSH. Default is the home
8080
// directory of the user.
8181
WorkingDirectory func() string
82-
// X11SocketDir is the directory where X11 sockets are created. Default is
83-
// /tmp/.X11-unix.
84-
X11SocketDir string
82+
// X11DisplayOffset is the offset to add to the X11 display number.
83+
// Default is 10.
84+
X11DisplayOffset *int
8585
// BlockFileTransfer restricts use of file transfer applications.
8686
BlockFileTransfer bool
8787
}
@@ -124,8 +124,9 @@ func NewServer(ctx context.Context, logger slog.Logger, prometheusRegistry *prom
124124
if config == nil {
125125
config = &Config{}
126126
}
127-
if config.X11SocketDir == "" {
128-
config.X11SocketDir = filepath.Join(os.TempDir(), ".X11-unix")
127+
if config.X11DisplayOffset == nil {
128+
offset := X11DefaultDisplayOffset
129+
config.X11DisplayOffset = &offset
129130
}
130131
if config.UpdateEnv == nil {
131132
config.UpdateEnv = func(current []string) ([]string, error) { return current, nil }
@@ -273,13 +274,13 @@ func (s *Server) sessionHandler(session ssh.Session) {
273274
extraEnv := make([]string, 0)
274275
x11, hasX11 := session.X11()
275276
if hasX11 {
276-
handled := s.x11Handler(session.Context(), x11)
277+
display, handled := s.x11Handler(session.Context(), x11)
277278
if !handled {
278279
_ = session.Exit(1)
279280
logger.Error(ctx, "x11 handler failed")
280281
return
281282
}
282-
extraEnv = append(extraEnv, fmt.Sprintf("DISPLAY=:%d.0", x11.ScreenNumber))
283+
extraEnv = append(extraEnv, fmt.Sprintf("DISPLAY=localhost:%d.%d", display, x11.ScreenNumber))
283284
}
284285

285286
if s.fileTransferBlocked(session) {

agent/agentssh/x11.go

Lines changed: 88 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"errors"
88
"fmt"
99
"io"
10+
"math"
1011
"net"
1112
"os"
1213
"path/filepath"
@@ -22,102 +23,136 @@ import (
2223
"cdr.dev/slog"
2324
)
2425

26+
const (
27+
// X11StartPort is the starting port for X11 forwarding, this is the
28+
// port used for "DISPLAY=localhost:0".
29+
X11StartPort = 6000
30+
// X11DefaultDisplayOffset is the default offset for X11 forwarding.
31+
X11DefaultDisplayOffset = 10
32+
)
33+
2534
// x11Callback is called when the client requests X11 forwarding.
26-
// It adds an Xauthority entry to the Xauthority file.
27-
func (s *Server) x11Callback(ctx ssh.Context, x11 ssh.X11) bool {
35+
func (*Server) x11Callback(_ ssh.Context, _ ssh.X11) bool {
36+
// Always allow.
37+
return true
38+
}
39+
40+
// x11Handler is called when a session has requested X11 forwarding.
41+
// It listens for X11 connections and forwards them to the client.
42+
func (s *Server) x11Handler(ctx ssh.Context, x11 ssh.X11) (displayNumber int, handled bool) {
43+
serverConn, valid := ctx.Value(ssh.ContextKeyConn).(*gossh.ServerConn)
44+
if !valid {
45+
s.logger.Warn(ctx, "failed to get server connection")
46+
return -1, false
47+
}
48+
2849
hostname, err := os.Hostname()
2950
if err != nil {
3051
s.logger.Warn(ctx, "failed to get hostname", slog.Error(err))
3152
s.metrics.x11HandlerErrors.WithLabelValues("hostname").Add(1)
32-
return false
53+
return -1, false
3354
}
3455

35-
err = s.fs.MkdirAll(s.config.X11SocketDir, 0o700)
56+
ln, display, err := createX11Listener(ctx, *s.config.X11DisplayOffset)
3657
if err != nil {
37-
s.logger.Warn(ctx, "failed to make the x11 socket dir", slog.F("dir", s.config.X11SocketDir), slog.Error(err))
38-
s.metrics.x11HandlerErrors.WithLabelValues("socker_dir").Add(1)
39-
return false
40-
}
58+
s.logger.Warn(ctx, "failed to create X11 listener", slog.Error(err))
59+
s.metrics.x11HandlerErrors.WithLabelValues("listen").Add(1)
60+
return -1, false
61+
}
62+
s.trackListener(ln, true)
63+
defer func() {
64+
if !handled {
65+
s.trackListener(ln, false)
66+
_ = ln.Close()
67+
}
68+
}()
4169

42-
err = addXauthEntry(ctx, s.fs, hostname, strconv.Itoa(int(x11.ScreenNumber)), x11.AuthProtocol, x11.AuthCookie)
70+
err = addXauthEntry(ctx, s.fs, hostname, strconv.Itoa(display), x11.AuthProtocol, x11.AuthCookie)
4371
if err != nil {
4472
s.logger.Warn(ctx, "failed to add Xauthority entry", slog.Error(err))
4573
s.metrics.x11HandlerErrors.WithLabelValues("xauthority").Add(1)
46-
return false
74+
return -1, false
4775
}
48-
return true
49-
}
5076

51-
// x11Handler is called when a session has requested X11 forwarding.
52-
// It listens for X11 connections and forwards them to the client.
53-
func (s *Server) x11Handler(ctx ssh.Context, x11 ssh.X11) bool {
54-
serverConn, valid := ctx.Value(ssh.ContextKeyConn).(*gossh.ServerConn)
55-
if !valid {
56-
s.logger.Warn(ctx, "failed to get server connection")
57-
return false
58-
}
59-
// We want to overwrite the socket so that subsequent connections will succeed.
60-
socketPath := filepath.Join(s.config.X11SocketDir, fmt.Sprintf("X%d", x11.ScreenNumber))
61-
err := os.Remove(socketPath)
62-
if err != nil && !errors.Is(err, os.ErrNotExist) {
63-
s.logger.Warn(ctx, "failed to remove existing X11 socket", slog.Error(err))
64-
return false
65-
}
66-
listener, err := net.Listen("unix", socketPath)
67-
if err != nil {
68-
s.logger.Warn(ctx, "failed to listen for X11", slog.Error(err))
69-
return false
70-
}
71-
s.trackListener(listener, true)
77+
go func() {
78+
// Don't leave the listener open after the session is gone.
79+
<-ctx.Done()
80+
_ = ln.Close()
81+
}()
7282

7383
go func() {
74-
defer listener.Close()
75-
defer s.trackListener(listener, false)
76-
handledFirstConnection := false
84+
defer ln.Close()
85+
defer s.trackListener(ln, false)
7786

7887
for {
79-
conn, err := listener.Accept()
88+
conn, err := ln.Accept()
8089
if err != nil {
8190
if errors.Is(err, net.ErrClosed) {
8291
return
8392
}
8493
s.logger.Warn(ctx, "failed to accept X11 connection", slog.Error(err))
8594
return
8695
}
87-
if x11.SingleConnection && handledFirstConnection {
88-
s.logger.Warn(ctx, "X11 connection rejected because single connection is enabled")
89-
_ = conn.Close()
90-
continue
96+
if x11.SingleConnection {
97+
s.logger.Debug(ctx, "single connection requested, closing X11 listener")
98+
_ = ln.Close()
9199
}
92-
handledFirstConnection = true
93100

94-
unixConn, ok := conn.(*net.UnixConn)
101+
tcpConn, ok := conn.(*net.TCPConn)
95102
if !ok {
96-
s.logger.Warn(ctx, fmt.Sprintf("failed to cast connection to UnixConn. got: %T", conn))
97-
return
103+
s.logger.Warn(ctx, fmt.Sprintf("failed to cast connection to TCPConn. got: %T", conn))
104+
_ = conn.Close()
105+
continue
98106
}
99-
unixAddr, ok := unixConn.LocalAddr().(*net.UnixAddr)
107+
tcpAddr, ok := tcpConn.LocalAddr().(*net.TCPAddr)
100108
if !ok {
101-
s.logger.Warn(ctx, fmt.Sprintf("failed to cast local address to UnixAddr. got: %T", unixConn.LocalAddr()))
102-
return
109+
s.logger.Warn(ctx, fmt.Sprintf("failed to cast local address to TCPAddr. got: %T", tcpConn.LocalAddr()))
110+
_ = conn.Close()
111+
continue
103112
}
104113

105114
channel, reqs, err := serverConn.OpenChannel("x11", gossh.Marshal(struct {
106115
OriginatorAddress string
107116
OriginatorPort uint32
108117
}{
109-
OriginatorAddress: unixAddr.Name,
110-
OriginatorPort: 0,
118+
OriginatorAddress: tcpAddr.IP.String(),
119+
OriginatorPort: uint32(tcpAddr.Port),
111120
}))
112121
if err != nil {
113122
s.logger.Warn(ctx, "failed to open X11 channel", slog.Error(err))
114-
return
123+
_ = conn.Close()
124+
continue
115125
}
116126
go gossh.DiscardRequests(reqs)
117-
go Bicopy(ctx, conn, channel)
127+
128+
if !s.trackConn(ln, conn, true) {
129+
s.logger.Warn(ctx, "failed to track X11 connection")
130+
_ = conn.Close()
131+
continue
132+
}
133+
go func() {
134+
defer s.trackConn(ln, conn, false)
135+
Bicopy(ctx, conn, channel)
136+
}()
118137
}
119138
}()
120-
return true
139+
140+
return display, true
141+
}
142+
143+
// createX11Listener creates a listener for X11 forwarding, it will use
144+
// the next available port starting from X11StartPort and displayOffset.
145+
func createX11Listener(ctx context.Context, displayOffset int) (ln net.Listener, display int, err error) {
146+
var lc net.ListenConfig
147+
// Look for an open port to listen on.
148+
for port := X11StartPort + displayOffset; port < math.MaxUint16; port++ {
149+
ln, err = lc.Listen(ctx, "tcp", fmt.Sprintf("localhost:%d", port))
150+
if err == nil {
151+
display = port - X11StartPort
152+
return ln, display, nil
153+
}
154+
}
155+
return nil, -1, xerrors.Errorf("failed to find open port for X11 listener: %w", err)
121156
}
122157

123158
// addXauthEntry adds an Xauthority entry to the Xauthority file.

agent/agentssh/x11_test.go

Lines changed: 33 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,17 @@
11
package agentssh_test
22

33
import (
4+
"bufio"
5+
"bytes"
46
"context"
57
"encoding/hex"
8+
"fmt"
69
"net"
710
"os"
811
"path/filepath"
912
"runtime"
13+
"strconv"
14+
"strings"
1015
"testing"
1116

1217
"github.com/gliderlabs/ssh"
@@ -31,10 +36,7 @@ func TestServer_X11(t *testing.T) {
3136
ctx := context.Background()
3237
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
3338
fs := afero.NewOsFs()
34-
dir := t.TempDir()
35-
s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), fs, &agentssh.Config{
36-
X11SocketDir: dir,
37-
})
39+
s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), fs, &agentssh.Config{})
3840
require.NoError(t, err)
3941
defer s.Close()
4042

@@ -53,21 +55,45 @@ func TestServer_X11(t *testing.T) {
5355
sess, err := c.NewSession()
5456
require.NoError(t, err)
5557

58+
wantScreenNumber := 1
5659
reply, err := sess.SendRequest("x11-req", true, gossh.Marshal(ssh.X11{
5760
AuthProtocol: "MIT-MAGIC-COOKIE-1",
5861
AuthCookie: hex.EncodeToString([]byte("cookie")),
59-
ScreenNumber: 0,
62+
ScreenNumber: uint32(wantScreenNumber),
6063
}))
6164
require.NoError(t, err)
6265
assert.True(t, reply)
6366

64-
err = sess.Shell()
67+
// Want: ~DISPLAY=localhost:10.1
68+
out, err := sess.Output("echo DISPLAY=$DISPLAY")
6569
require.NoError(t, err)
6670

71+
sc := bufio.NewScanner(bytes.NewReader(out))
72+
displayNumber := -1
73+
for sc.Scan() {
74+
line := strings.TrimSpace(sc.Text())
75+
t.Log(line)
76+
if strings.HasPrefix(line, "DISPLAY=") {
77+
parts := strings.SplitN(line, "=", 2)
78+
display := parts[1]
79+
parts = strings.SplitN(display, ":", 2)
80+
parts = strings.SplitN(parts[1], ".", 2)
81+
displayNumber, err = strconv.Atoi(parts[0])
82+
require.NoError(t, err)
83+
assert.GreaterOrEqual(t, displayNumber, 10, "display number should be >= 10")
84+
gotScreenNumber, err := strconv.Atoi(parts[1])
85+
require.NoError(t, err)
86+
assert.Equal(t, wantScreenNumber, gotScreenNumber, "screen number should match")
87+
break
88+
}
89+
}
90+
require.NoError(t, sc.Err())
91+
require.NotEqual(t, -1, displayNumber)
92+
6793
x11Chans := c.HandleChannelOpen("x11")
6894
payload := "hello world"
6995
require.Eventually(t, func() bool {
70-
conn, err := net.Dial("unix", filepath.Join(dir, "X0"))
96+
conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", agentssh.X11StartPort+displayNumber))
7197
if err == nil {
7298
_, err = conn.Write([]byte(payload))
7399
assert.NoError(t, err)

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy