Skip to content

Commit 5e84d25

Browse files
refactor: convert workspacesdk.AgentConn to an interface (#19392)
Fixes coder/internal#907 We convert `workspacesdk.AgentConn` to an interface and generate a mock for it. This allows writing `coderd` tests that rely on the agent's HTTP api to not have to set up an entire tailnet networking stack.
1 parent 23c494f commit 5e84d25

File tree

18 files changed

+667
-143
lines changed

18 files changed

+667
-143
lines changed

Makefile

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -636,7 +636,8 @@ GEN_FILES := \
636636
coderd/database/pubsub/psmock/psmock.go \
637637
agent/agentcontainers/acmock/acmock.go \
638638
agent/agentcontainers/dcspec/dcspec_gen.go \
639-
coderd/httpmw/loggermw/loggermock/loggermock.go
639+
coderd/httpmw/loggermw/loggermock/loggermock.go \
640+
codersdk/workspacesdk/agentconnmock/agentconnmock.go
640641

641642
# all gen targets should be added here and to gen/mark-fresh
642643
gen: gen/db gen/golden-files $(GEN_FILES)
@@ -686,6 +687,7 @@ gen/mark-fresh:
686687
agent/agentcontainers/acmock/acmock.go \
687688
agent/agentcontainers/dcspec/dcspec_gen.go \
688689
coderd/httpmw/loggermw/loggermock/loggermock.go \
690+
codersdk/workspacesdk/agentconnmock/agentconnmock.go \
689691
"
690692

691693
for file in $$files; do
@@ -729,6 +731,10 @@ coderd/httpmw/loggermw/loggermock/loggermock.go: coderd/httpmw/loggermw/logger.g
729731
go generate ./coderd/httpmw/loggermw/loggermock/
730732
touch "$@"
731733

734+
codersdk/workspacesdk/agentconnmock/agentconnmock.go: codersdk/workspacesdk/agentconn.go
735+
go generate ./codersdk/workspacesdk/agentconnmock/
736+
touch "$@"
737+
732738
agent/agentcontainers/dcspec/dcspec_gen.go: \
733739
node_modules/.installed \
734740
agent/agentcontainers/dcspec/devContainer.base.schema.json \

agent/agent_test.go

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2750,9 +2750,9 @@ func TestAgent_Dial(t *testing.T) {
27502750

27512751
switch l.Addr().Network() {
27522752
case "tcp":
2753-
conn, err = agentConn.Conn.DialContextTCP(ctx, ipp)
2753+
conn, err = agentConn.TailnetConn().DialContextTCP(ctx, ipp)
27542754
case "udp":
2755-
conn, err = agentConn.Conn.DialContextUDP(ctx, ipp)
2755+
conn, err = agentConn.TailnetConn().DialContextUDP(ctx, ipp)
27562756
default:
27572757
t.Fatalf("unknown network: %s", l.Addr().Network())
27582758
}
@@ -2811,7 +2811,7 @@ func TestAgent_UpdatedDERP(t *testing.T) {
28112811
})
28122812

28132813
// Setup a client connection.
2814-
newClientConn := func(derpMap *tailcfg.DERPMap, name string) *workspacesdk.AgentConn {
2814+
newClientConn := func(derpMap *tailcfg.DERPMap, name string) workspacesdk.AgentConn {
28152815
conn, err := tailnet.NewConn(&tailnet.Options{
28162816
Addresses: []netip.Prefix{tailnet.TailscaleServicePrefix.RandomPrefix()},
28172817
DERPMap: derpMap,
@@ -2891,13 +2891,13 @@ func TestAgent_UpdatedDERP(t *testing.T) {
28912891

28922892
// Connect from a second client and make sure it uses the new DERP map.
28932893
conn2 := newClientConn(newDerpMap, "client2")
2894-
require.Equal(t, []int{2}, conn2.DERPMap().RegionIDs())
2894+
require.Equal(t, []int{2}, conn2.TailnetConn().DERPMap().RegionIDs())
28952895
t.Log("conn2 got the new DERPMap")
28962896

28972897
// If the first client gets a DERP map update, it should be able to
28982898
// reconnect just fine.
2899-
conn1.SetDERPMap(newDerpMap)
2900-
require.Equal(t, []int{2}, conn1.DERPMap().RegionIDs())
2899+
conn1.TailnetConn().SetDERPMap(newDerpMap)
2900+
require.Equal(t, []int{2}, conn1.TailnetConn().DERPMap().RegionIDs())
29012901
t.Log("set the new DERPMap on conn1")
29022902
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
29032903
defer cancel()
@@ -3264,7 +3264,7 @@ func setupSSHSessionOnPort(
32643264
}
32653265

32663266
func setupAgent(t testing.TB, metadata agentsdk.Manifest, ptyTimeout time.Duration, opts ...func(*agenttest.Client, *agent.Options)) (
3267-
*workspacesdk.AgentConn,
3267+
workspacesdk.AgentConn,
32683268
*agenttest.Client,
32693269
<-chan *proto.Stats,
32703270
afero.Fs,

cli/ping.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ func (r *RootCmd) ping() *serpent.Command {
147147
}
148148
defer conn.Close()
149149

150-
derpMap := conn.DERPMap()
150+
derpMap := conn.TailnetConn().DERPMap()
151151

152152
diagCtx, diagCancel := context.WithTimeout(inv.Context(), 30*time.Second)
153153
defer diagCancel()
@@ -156,7 +156,7 @@ func (r *RootCmd) ping() *serpent.Command {
156156
// Silent ping to determine whether we should show diags
157157
_, didP2p, _, _ := conn.Ping(ctx)
158158

159-
ni := conn.GetNetInfo()
159+
ni := conn.TailnetConn().GetNetInfo()
160160
connDiags := cliui.ConnDiags{
161161
DisableDirect: r.disableDirect,
162162
LocalNetInfo: ni,

cli/portforward.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,7 @@ func (r *RootCmd) portForward() *serpent.Command {
221221
func listenAndPortForward(
222222
ctx context.Context,
223223
inv *serpent.Invocation,
224-
conn *workspacesdk.AgentConn,
224+
conn workspacesdk.AgentConn,
225225
wg *sync.WaitGroup,
226226
spec portForwardSpec,
227227
logger slog.Logger,

cli/speedtest.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ func (r *RootCmd) speedtest() *serpent.Command {
139139
if err != nil {
140140
continue
141141
}
142-
status := conn.Status()
142+
status := conn.TailnetConn().Status()
143143
if len(status.Peers()) != 1 {
144144
continue
145145
}
@@ -189,7 +189,7 @@ func (r *RootCmd) speedtest() *serpent.Command {
189189
outputResult.Intervals[i] = interval
190190
}
191191
}
192-
conn.Conn.SendSpeedtestTelemetry(outputResult.Overall.ThroughputMbits)
192+
conn.TailnetConn().SendSpeedtestTelemetry(outputResult.Overall.ThroughputMbits)
193193
out, err := formatter.Format(inv.Context(), outputResult)
194194
if err != nil {
195195
return err

cli/ssh.go

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -590,7 +590,7 @@ func (r *RootCmd) ssh() *serpent.Command {
590590
}
591591

592592
err = sshSession.Wait()
593-
conn.SendDisconnectedTelemetry()
593+
conn.TailnetConn().SendDisconnectedTelemetry()
594594
if err != nil {
595595
if exitErr := (&gossh.ExitError{}); errors.As(err, &exitErr) {
596596
// Clear the error since it's not useful beyond
@@ -1364,7 +1364,7 @@ func getUsageAppName(usageApp string) codersdk.UsageAppName {
13641364

13651365
func setStatsCallback(
13661366
ctx context.Context,
1367-
agentConn *workspacesdk.AgentConn,
1367+
agentConn workspacesdk.AgentConn,
13681368
logger slog.Logger,
13691369
networkInfoDir string,
13701370
networkInfoInterval time.Duration,
@@ -1437,7 +1437,7 @@ func setStatsCallback(
14371437

14381438
now := time.Now()
14391439
cb(now, now.Add(time.Nanosecond), map[netlogtype.Connection]netlogtype.Counts{}, map[netlogtype.Connection]netlogtype.Counts{})
1440-
agentConn.SetConnStatsCallback(networkInfoInterval, 2048, cb)
1440+
agentConn.TailnetConn().SetConnStatsCallback(networkInfoInterval, 2048, cb)
14411441
return errCh, nil
14421442
}
14431443

@@ -1451,13 +1451,13 @@ type sshNetworkStats struct {
14511451
UsingCoderConnect bool `json:"using_coder_connect"`
14521452
}
14531453

1454-
func collectNetworkStats(ctx context.Context, agentConn *workspacesdk.AgentConn, start, end time.Time, counts map[netlogtype.Connection]netlogtype.Counts) (*sshNetworkStats, error) {
1454+
func collectNetworkStats(ctx context.Context, agentConn workspacesdk.AgentConn, start, end time.Time, counts map[netlogtype.Connection]netlogtype.Counts) (*sshNetworkStats, error) {
14551455
latency, p2p, pingResult, err := agentConn.Ping(ctx)
14561456
if err != nil {
14571457
return nil, err
14581458
}
1459-
node := agentConn.Node()
1460-
derpMap := agentConn.DERPMap()
1459+
node := agentConn.TailnetConn().Node()
1460+
derpMap := agentConn.TailnetConn().DERPMap()
14611461

14621462
totalRx := uint64(0)
14631463
totalTx := uint64(0)

coderd/coderd.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -325,6 +325,9 @@ func New(options *Options) *API {
325325
})
326326
}
327327

328+
if options.PrometheusRegistry == nil {
329+
options.PrometheusRegistry = prometheus.NewRegistry()
330+
}
328331
if options.Authorizer == nil {
329332
options.Authorizer = rbac.NewCachingAuthorizer(options.PrometheusRegistry)
330333
if buildinfo.IsDev() {
@@ -381,9 +384,6 @@ func New(options *Options) *API {
381384
if options.FilesRateLimit == 0 {
382385
options.FilesRateLimit = 12
383386
}
384-
if options.PrometheusRegistry == nil {
385-
options.PrometheusRegistry = prometheus.NewRegistry()
386-
}
387387
if options.Clock == nil {
388388
options.Clock = quartz.NewReal()
389389
}

coderd/tailnet.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -277,9 +277,9 @@ func (s *ServerTailnet) dialContext(ctx context.Context, network, addr string) (
277277
}, nil
278278
}
279279

280-
func (s *ServerTailnet) AgentConn(ctx context.Context, agentID uuid.UUID) (*workspacesdk.AgentConn, func(), error) {
280+
func (s *ServerTailnet) AgentConn(ctx context.Context, agentID uuid.UUID) (workspacesdk.AgentConn, func(), error) {
281281
var (
282-
conn *workspacesdk.AgentConn
282+
conn workspacesdk.AgentConn
283283
ret func()
284284
)
285285

Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
package coderd
2+
3+
import (
4+
"bytes"
5+
"context"
6+
"database/sql"
7+
"fmt"
8+
"io"
9+
"net/http"
10+
"net/http/httptest"
11+
"net/http/httputil"
12+
"net/url"
13+
"strings"
14+
"testing"
15+
16+
"github.com/go-chi/chi/v5"
17+
"github.com/google/uuid"
18+
"github.com/stretchr/testify/require"
19+
"go.uber.org/mock/gomock"
20+
21+
"cdr.dev/slog"
22+
"cdr.dev/slog/sloggers/slogtest"
23+
"github.com/coder/coder/v2/coderd/database"
24+
"github.com/coder/coder/v2/coderd/database/dbmock"
25+
"github.com/coder/coder/v2/coderd/database/dbtime"
26+
"github.com/coder/coder/v2/coderd/httpmw"
27+
"github.com/coder/coder/v2/coderd/workspaceapps/appurl"
28+
"github.com/coder/coder/v2/codersdk"
29+
"github.com/coder/coder/v2/codersdk/workspacesdk"
30+
"github.com/coder/coder/v2/codersdk/workspacesdk/agentconnmock"
31+
"github.com/coder/coder/v2/codersdk/wsjson"
32+
"github.com/coder/coder/v2/tailnet"
33+
"github.com/coder/coder/v2/tailnet/tailnettest"
34+
"github.com/coder/coder/v2/testutil"
35+
"github.com/coder/websocket"
36+
)
37+
38+
type fakeAgentProvider struct {
39+
agentConn func(ctx context.Context, agentID uuid.UUID) (_ workspacesdk.AgentConn, release func(), _ error)
40+
}
41+
42+
func (fakeAgentProvider) ReverseProxy(targetURL, dashboardURL *url.URL, agentID uuid.UUID, app appurl.ApplicationURL, wildcardHost string) *httputil.ReverseProxy {
43+
panic("unimplemented")
44+
}
45+
46+
func (f fakeAgentProvider) AgentConn(ctx context.Context, agentID uuid.UUID) (_ workspacesdk.AgentConn, release func(), _ error) {
47+
if f.agentConn != nil {
48+
return f.agentConn(ctx, agentID)
49+
}
50+
51+
panic("unimplemented")
52+
}
53+
54+
func (fakeAgentProvider) ServeHTTPDebug(w http.ResponseWriter, r *http.Request) {
55+
panic("unimplemented")
56+
}
57+
58+
func (fakeAgentProvider) Close() error {
59+
return nil
60+
}
61+
62+
func TestWatchAgentContainers(t *testing.T) {
63+
t.Parallel()
64+
65+
t.Run("WebSocketClosesProperly", func(t *testing.T) {
66+
t.Parallel()
67+
68+
// This test ensures that the agent containers `/watch` websocket can gracefully
69+
// handle the underlying websocket unexpectedly closing. This test was created in
70+
// response to this issue: https://github.com/coder/coder/issues/19372
71+
72+
var (
73+
ctx = testutil.Context(t, testutil.WaitShort)
74+
logger = slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug).Named("coderd")
75+
76+
mCtrl = gomock.NewController(t)
77+
mDB = dbmock.NewMockStore(mCtrl)
78+
mCoordinator = tailnettest.NewMockCoordinator(mCtrl)
79+
mAgentConn = agentconnmock.NewMockAgentConn(mCtrl)
80+
81+
fAgentProvider = fakeAgentProvider{
82+
agentConn: func(ctx context.Context, agentID uuid.UUID) (_ workspacesdk.AgentConn, release func(), _ error) {
83+
return mAgentConn, func() {}, nil
84+
},
85+
}
86+
87+
workspaceID = uuid.New()
88+
agentID = uuid.New()
89+
resourceID = uuid.New()
90+
jobID = uuid.New()
91+
buildID = uuid.New()
92+
93+
containersCh = make(chan codersdk.WorkspaceAgentListContainersResponse)
94+
95+
r = chi.NewMux()
96+
97+
api = API{
98+
ctx: ctx,
99+
Options: &Options{
100+
AgentInactiveDisconnectTimeout: testutil.WaitShort,
101+
Database: mDB,
102+
Logger: logger,
103+
DeploymentValues: &codersdk.DeploymentValues{},
104+
TailnetCoordinator: tailnettest.NewFakeCoordinator(),
105+
},
106+
}
107+
)
108+
109+
var tailnetCoordinator tailnet.Coordinator = mCoordinator
110+
api.TailnetCoordinator.Store(&tailnetCoordinator)
111+
api.agentProvider = fAgentProvider
112+
113+
// Setup: Allow `ExtractWorkspaceAgentParams` to complete.
114+
mDB.EXPECT().GetWorkspaceAgentByID(gomock.Any(), agentID).Return(database.WorkspaceAgent{
115+
ID: agentID,
116+
ResourceID: resourceID,
117+
LifecycleState: database.WorkspaceAgentLifecycleStateReady,
118+
FirstConnectedAt: sql.NullTime{Valid: true, Time: dbtime.Now()},
119+
LastConnectedAt: sql.NullTime{Valid: true, Time: dbtime.Now()},
120+
}, nil)
121+
mDB.EXPECT().GetWorkspaceResourceByID(gomock.Any(), resourceID).Return(database.WorkspaceResource{
122+
ID: resourceID,
123+
JobID: jobID,
124+
}, nil)
125+
mDB.EXPECT().GetProvisionerJobByID(gomock.Any(), jobID).Return(database.ProvisionerJob{
126+
ID: jobID,
127+
Type: database.ProvisionerJobTypeWorkspaceBuild,
128+
}, nil)
129+
mDB.EXPECT().GetWorkspaceBuildByJobID(gomock.Any(), jobID).Return(database.WorkspaceBuild{
130+
WorkspaceID: workspaceID,
131+
ID: buildID,
132+
}, nil)
133+
134+
// And: Allow `db2dsk.WorkspaceAgent` to complete.
135+
mCoordinator.EXPECT().Node(gomock.Any()).Return(nil)
136+
137+
// And: Allow `WatchContainers` to be called, returing our `containersCh` channel.
138+
mAgentConn.EXPECT().WatchContainers(gomock.Any(), gomock.Any()).
139+
Return(containersCh, io.NopCloser(&bytes.Buffer{}), nil)
140+
141+
// And: We mount the HTTP Handler
142+
r.With(httpmw.ExtractWorkspaceAgentParam(mDB)).
143+
Get("/workspaceagents/{workspaceagent}/containers/watch", api.watchWorkspaceAgentContainers)
144+
145+
// Given: We create the HTTP server
146+
srv := httptest.NewServer(r)
147+
defer srv.Close()
148+
149+
// And: Dial the WebSocket
150+
wsURL := strings.Replace(srv.URL, "http://", "ws://", 1)
151+
conn, resp, err := websocket.Dial(ctx, fmt.Sprintf("%s/workspaceagents/%s/containers/watch", wsURL, agentID), nil)
152+
require.NoError(t, err)
153+
if resp.Body != nil {
154+
defer resp.Body.Close()
155+
}
156+
157+
// And: Create a streaming decoder
158+
decoder := wsjson.NewDecoder[codersdk.WorkspaceAgentListContainersResponse](conn, websocket.MessageText, logger)
159+
defer decoder.Close()
160+
decodeCh := decoder.Chan()
161+
162+
// And: We can successfully send through the channel.
163+
testutil.RequireSend(ctx, t, containersCh, codersdk.WorkspaceAgentListContainersResponse{
164+
Containers: []codersdk.WorkspaceAgentContainer{{
165+
ID: "test-container-id",
166+
}},
167+
})
168+
169+
// And: Receive the data.
170+
containerResp := testutil.RequireReceive(ctx, t, decodeCh)
171+
require.Len(t, containerResp.Containers, 1)
172+
require.Equal(t, "test-container-id", containerResp.Containers[0].ID)
173+
174+
// When: We close the `containersCh`
175+
close(containersCh)
176+
177+
// Then: We expect `decodeCh` to be closed.
178+
select {
179+
case <-ctx.Done():
180+
t.Fail()
181+
182+
case _, ok := <-decodeCh:
183+
require.False(t, ok, "channel is expected to be closed")
184+
}
185+
})
186+
}

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy