Skip to content

Commit b60589d

Browse files
committed
chore: refactor TestServer_X11 to use inproc networking
1 parent 1e438a6 commit b60589d

File tree

3 files changed

+72
-28
lines changed

3 files changed

+72
-28
lines changed

agent/agentssh/agentssh.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,10 @@ type Config struct {
116116
// Experimental: allow connecting to running containers if
117117
// CODER_AGENT_DEVCONTAINERS_ENABLE=true.
118118
ExperimentalDevContainersEnabled bool
119+
// X11Net allows overriding the networking implementation used for X11
120+
// forwarding listeners. When nil, a default implementation backed by the
121+
// standard library networking package is used.
122+
X11Net X11Network
119123
}
120124

121125
type Server struct {
@@ -195,6 +199,12 @@ func NewServer(ctx context.Context, logger slog.Logger, prometheusRegistry *prom
195199
displayOffset: *config.X11DisplayOffset,
196200
sessions: make(map[*x11Session]struct{}),
197201
connections: make(map[net.Conn]struct{}),
202+
network: func() X11Network {
203+
if config.X11Net != nil {
204+
return config.X11Net
205+
}
206+
return osNet{}
207+
}(),
198208
},
199209
}
200210

agent/agentssh/x11.go

Lines changed: 43 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"errors"
88
"fmt"
99
"io"
10+
"math"
1011
"net"
1112
"os"
1213
"path/filepath"
@@ -37,12 +38,30 @@ const (
3738
X11MaxPort = X11StartPort + X11MaxDisplays
3839
)
3940

41+
// X11Network abstracts the creation of network listeners for X11 forwarding.
42+
// It is intended mainly for testing; production code uses the default
43+
// implementation backed by the operating system networking stack.
44+
type X11Network interface {
45+
Listen(network, address string) (net.Listener, error)
46+
}
47+
48+
// osNet is the default X11Network implementation that uses the standard
49+
// library network stack.
50+
type osNet struct{}
51+
52+
func (osNet) Listen(network, address string) (net.Listener, error) {
53+
return net.Listen(network, address)
54+
}
55+
4056
type x11Forwarder struct {
4157
logger slog.Logger
4258
x11HandlerErrors *prometheus.CounterVec
4359
fs afero.Fs
4460
displayOffset int
4561

62+
// network creates X11 listener sockets. Defaults to osNet{}.
63+
network X11Network
64+
4665
mu sync.Mutex
4766
sessions map[*x11Session]struct{}
4867
connections map[net.Conn]struct{}
@@ -145,26 +164,35 @@ func (x *x11Forwarder) listenForConnections(ctx context.Context, session *x11Ses
145164
x.cleanSession(session)
146165
}
147166

148-
tcpConn, ok := conn.(*net.TCPConn)
149-
if !ok {
150-
x.logger.Warn(ctx, fmt.Sprintf("failed to cast connection to TCPConn. got: %T", conn))
151-
_ = conn.Close()
152-
continue
167+
var originAddr string
168+
var originPort uint32
169+
170+
if tcpConn, ok := conn.(*net.TCPConn); ok {
171+
if tcpAddr, ok := tcpConn.LocalAddr().(*net.TCPAddr); ok {
172+
originAddr = tcpAddr.IP.String()
173+
// #nosec G115 - Safe conversion as TCP port numbers are within uint32 range (0-65535)
174+
originPort = uint32(tcpAddr.Port)
175+
}
153176
}
154-
tcpAddr, ok := tcpConn.LocalAddr().(*net.TCPAddr)
155-
if !ok {
156-
x.logger.Warn(ctx, fmt.Sprintf("failed to cast local address to TCPAddr. got: %T", tcpConn.LocalAddr()))
157-
_ = conn.Close()
158-
continue
177+
// Fallback values for in-memory or non-TCP connections.
178+
if originAddr == "" {
179+
originAddr = "127.0.0.1"
180+
}
181+
if originPort == 0 {
182+
p := X11StartPort + session.display
183+
if p > math.MaxUint32 {
184+
panic("overflow")
185+
}
186+
// #nosec G115 - Safe conversion as port number is within uint32 range
187+
originPort = uint32(p)
159188
}
160189

161190
channel, reqs, err := serverConn.OpenChannel("x11", gossh.Marshal(struct {
162191
OriginatorAddress string
163192
OriginatorPort uint32
164193
}{
165-
OriginatorAddress: tcpAddr.IP.String(),
166-
// #nosec G115 - Safe conversion as TCP port numbers are within uint32 range (0-65535)
167-
OriginatorPort: uint32(tcpAddr.Port),
194+
OriginatorAddress: originAddr,
195+
OriginatorPort: originPort,
168196
}))
169197
if err != nil {
170198
x.logger.Warn(ctx, "failed to open X11 channel", slog.Error(err))
@@ -281,13 +309,13 @@ func (x *x11Forwarder) evictLeastRecentlyUsedSession() {
281309
// createX11Listener creates a listener for X11 forwarding, it will use
282310
// the next available port starting from X11StartPort and displayOffset.
283311
func (x *x11Forwarder) createX11Listener(ctx context.Context) (ln net.Listener, display int, err error) {
284-
var lc net.ListenConfig
285312
// Look for an open port to listen on.
286313
for port := X11StartPort + x.displayOffset; port <= X11MaxPort; port++ {
287314
if ctx.Err() != nil {
288315
return nil, -1, ctx.Err()
289316
}
290-
ln, err = lc.Listen(ctx, "tcp", fmt.Sprintf("localhost:%d", port))
317+
318+
ln, err = x.network.Listen("tcp", fmt.Sprintf("localhost:%d", port))
291319
if err == nil {
292320
display = port - X11StartPort
293321
return ln, display, nil

agent/agentssh/x11_test.go

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ package agentssh_test
33
import (
44
"bufio"
55
"bytes"
6-
"context"
76
"encoding/hex"
87
"fmt"
98
"net"
@@ -32,10 +31,19 @@ func TestServer_X11(t *testing.T) {
3231
t.Skip("X11 forwarding is only supported on Linux")
3332
}
3433

35-
ctx := context.Background()
34+
ctx := testutil.Context(t, testutil.WaitShort)
3635
logger := testutil.Logger(t)
3736
fs := afero.NewMemMapFs()
38-
s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), fs, agentexec.DefaultExecer, &agentssh.Config{})
37+
38+
// Use in-process networking for X11 forwarding.
39+
inproc := testutil.NewInProcNet()
40+
41+
// Create server config with custom X11 listener.
42+
cfg := &agentssh.Config{
43+
X11Net: inproc,
44+
}
45+
46+
s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), fs, agentexec.DefaultExecer, cfg)
3947
require.NoError(t, err)
4048
defer s.Close()
4149
err = s.UpdateHostSigner(42)
@@ -93,17 +101,15 @@ func TestServer_X11(t *testing.T) {
93101

94102
x11Chans := c.HandleChannelOpen("x11")
95103
payload := "hello world"
96-
require.Eventually(t, func() bool {
97-
conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", agentssh.X11StartPort+displayNumber))
98-
if err == nil {
99-
_, err = conn.Write([]byte(payload))
100-
assert.NoError(t, err)
101-
_ = conn.Close()
102-
}
103-
return err == nil
104-
}, testutil.WaitShort, testutil.IntervalFast)
104+
go func() {
105+
conn, err := inproc.Dial(ctx, testutil.NewAddr("tcp", fmt.Sprintf("localhost:%d", agentssh.X11StartPort+displayNumber)))
106+
assert.NoError(t, err)
107+
_, err = conn.Write([]byte(payload))
108+
assert.NoError(t, err)
109+
_ = conn.Close()
110+
}()
105111

106-
x11 := <-x11Chans
112+
x11 := testutil.RequireReceive(ctx, t, x11Chans)
107113
ch, reqs, err := x11.Accept()
108114
require.NoError(t, err)
109115
go gossh.DiscardRequests(reqs)

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy