diff --git a/Makefile b/Makefile index 9040a891700e1..a5341ee79f753 100644 --- a/Makefile +++ b/Makefile @@ -636,7 +636,8 @@ GEN_FILES := \ coderd/database/pubsub/psmock/psmock.go \ agent/agentcontainers/acmock/acmock.go \ agent/agentcontainers/dcspec/dcspec_gen.go \ - coderd/httpmw/loggermw/loggermock/loggermock.go + coderd/httpmw/loggermw/loggermock/loggermock.go \ + codersdk/workspacesdk/agentconnmock/agentconnmock.go # all gen targets should be added here and to gen/mark-fresh gen: gen/db gen/golden-files $(GEN_FILES) @@ -686,6 +687,7 @@ gen/mark-fresh: agent/agentcontainers/acmock/acmock.go \ agent/agentcontainers/dcspec/dcspec_gen.go \ coderd/httpmw/loggermw/loggermock/loggermock.go \ + codersdk/workspacesdk/agentconnmock/agentconnmock.go \ " for file in $$files; do @@ -729,6 +731,10 @@ coderd/httpmw/loggermw/loggermock/loggermock.go: coderd/httpmw/loggermw/logger.g go generate ./coderd/httpmw/loggermw/loggermock/ touch "$@" +codersdk/workspacesdk/agentconnmock/agentconnmock.go: codersdk/workspacesdk/agentconn.go + go generate ./codersdk/workspacesdk/agentconnmock/ + touch "$@" + agent/agentcontainers/dcspec/dcspec_gen.go: \ node_modules/.installed \ agent/agentcontainers/dcspec/devContainer.base.schema.json \ diff --git a/agent/agent_test.go b/agent/agent_test.go index 52d8cfc09d573..2425fd81a0ead 100644 --- a/agent/agent_test.go +++ b/agent/agent_test.go @@ -2750,9 +2750,9 @@ func TestAgent_Dial(t *testing.T) { switch l.Addr().Network() { case "tcp": - conn, err = agentConn.Conn.DialContextTCP(ctx, ipp) + conn, err = agentConn.TailnetConn().DialContextTCP(ctx, ipp) case "udp": - conn, err = agentConn.Conn.DialContextUDP(ctx, ipp) + conn, err = agentConn.TailnetConn().DialContextUDP(ctx, ipp) default: t.Fatalf("unknown network: %s", l.Addr().Network()) } @@ -2811,7 +2811,7 @@ func TestAgent_UpdatedDERP(t *testing.T) { }) // Setup a client connection. - newClientConn := func(derpMap *tailcfg.DERPMap, name string) *workspacesdk.AgentConn { + newClientConn := func(derpMap *tailcfg.DERPMap, name string) workspacesdk.AgentConn { conn, err := tailnet.NewConn(&tailnet.Options{ Addresses: []netip.Prefix{tailnet.TailscaleServicePrefix.RandomPrefix()}, DERPMap: derpMap, @@ -2891,13 +2891,13 @@ func TestAgent_UpdatedDERP(t *testing.T) { // Connect from a second client and make sure it uses the new DERP map. conn2 := newClientConn(newDerpMap, "client2") - require.Equal(t, []int{2}, conn2.DERPMap().RegionIDs()) + require.Equal(t, []int{2}, conn2.TailnetConn().DERPMap().RegionIDs()) t.Log("conn2 got the new DERPMap") // If the first client gets a DERP map update, it should be able to // reconnect just fine. - conn1.SetDERPMap(newDerpMap) - require.Equal(t, []int{2}, conn1.DERPMap().RegionIDs()) + conn1.TailnetConn().SetDERPMap(newDerpMap) + require.Equal(t, []int{2}, conn1.TailnetConn().DERPMap().RegionIDs()) t.Log("set the new DERPMap on conn1") ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) defer cancel() @@ -3264,7 +3264,7 @@ func setupSSHSessionOnPort( } func setupAgent(t testing.TB, metadata agentsdk.Manifest, ptyTimeout time.Duration, opts ...func(*agenttest.Client, *agent.Options)) ( - *workspacesdk.AgentConn, + workspacesdk.AgentConn, *agenttest.Client, <-chan *proto.Stats, afero.Fs, diff --git a/cli/ping.go b/cli/ping.go index 0b9fde5c62eb8..29af06affeaee 100644 --- a/cli/ping.go +++ b/cli/ping.go @@ -147,7 +147,7 @@ func (r *RootCmd) ping() *serpent.Command { } defer conn.Close() - derpMap := conn.DERPMap() + derpMap := conn.TailnetConn().DERPMap() diagCtx, diagCancel := context.WithTimeout(inv.Context(), 30*time.Second) defer diagCancel() @@ -156,7 +156,7 @@ func (r *RootCmd) ping() *serpent.Command { // Silent ping to determine whether we should show diags _, didP2p, _, _ := conn.Ping(ctx) - ni := conn.GetNetInfo() + ni := conn.TailnetConn().GetNetInfo() connDiags := cliui.ConnDiags{ DisableDirect: r.disableDirect, LocalNetInfo: ni, diff --git a/cli/portforward.go b/cli/portforward.go index 59c1f5827b06f..1b055d9e4362e 100644 --- a/cli/portforward.go +++ b/cli/portforward.go @@ -221,7 +221,7 @@ func (r *RootCmd) portForward() *serpent.Command { func listenAndPortForward( ctx context.Context, inv *serpent.Invocation, - conn *workspacesdk.AgentConn, + conn workspacesdk.AgentConn, wg *sync.WaitGroup, spec portForwardSpec, logger slog.Logger, diff --git a/cli/speedtest.go b/cli/speedtest.go index 3827b45125842..86d0e6a9ee63c 100644 --- a/cli/speedtest.go +++ b/cli/speedtest.go @@ -139,7 +139,7 @@ func (r *RootCmd) speedtest() *serpent.Command { if err != nil { continue } - status := conn.Status() + status := conn.TailnetConn().Status() if len(status.Peers()) != 1 { continue } @@ -189,7 +189,7 @@ func (r *RootCmd) speedtest() *serpent.Command { outputResult.Intervals[i] = interval } } - conn.Conn.SendSpeedtestTelemetry(outputResult.Overall.ThroughputMbits) + conn.TailnetConn().SendSpeedtestTelemetry(outputResult.Overall.ThroughputMbits) out, err := formatter.Format(inv.Context(), outputResult) if err != nil { return err diff --git a/cli/ssh.go b/cli/ssh.go index bc2bb24235ad2..a2f0db7327bef 100644 --- a/cli/ssh.go +++ b/cli/ssh.go @@ -590,7 +590,7 @@ func (r *RootCmd) ssh() *serpent.Command { } err = sshSession.Wait() - conn.SendDisconnectedTelemetry() + conn.TailnetConn().SendDisconnectedTelemetry() if err != nil { if exitErr := (&gossh.ExitError{}); errors.As(err, &exitErr) { // Clear the error since it's not useful beyond @@ -1364,7 +1364,7 @@ func getUsageAppName(usageApp string) codersdk.UsageAppName { func setStatsCallback( ctx context.Context, - agentConn *workspacesdk.AgentConn, + agentConn workspacesdk.AgentConn, logger slog.Logger, networkInfoDir string, networkInfoInterval time.Duration, @@ -1437,7 +1437,7 @@ func setStatsCallback( now := time.Now() cb(now, now.Add(time.Nanosecond), map[netlogtype.Connection]netlogtype.Counts{}, map[netlogtype.Connection]netlogtype.Counts{}) - agentConn.SetConnStatsCallback(networkInfoInterval, 2048, cb) + agentConn.TailnetConn().SetConnStatsCallback(networkInfoInterval, 2048, cb) return errCh, nil } @@ -1451,13 +1451,13 @@ type sshNetworkStats struct { UsingCoderConnect bool `json:"using_coder_connect"` } -func collectNetworkStats(ctx context.Context, agentConn *workspacesdk.AgentConn, start, end time.Time, counts map[netlogtype.Connection]netlogtype.Counts) (*sshNetworkStats, error) { +func collectNetworkStats(ctx context.Context, agentConn workspacesdk.AgentConn, start, end time.Time, counts map[netlogtype.Connection]netlogtype.Counts) (*sshNetworkStats, error) { latency, p2p, pingResult, err := agentConn.Ping(ctx) if err != nil { return nil, err } - node := agentConn.Node() - derpMap := agentConn.DERPMap() + node := agentConn.TailnetConn().Node() + derpMap := agentConn.TailnetConn().DERPMap() totalRx := uint64(0) totalTx := uint64(0) diff --git a/coderd/coderd.go b/coderd/coderd.go index a934536c0aef0..8ab204f8a31ef 100644 --- a/coderd/coderd.go +++ b/coderd/coderd.go @@ -325,6 +325,9 @@ func New(options *Options) *API { }) } + if options.PrometheusRegistry == nil { + options.PrometheusRegistry = prometheus.NewRegistry() + } if options.Authorizer == nil { options.Authorizer = rbac.NewCachingAuthorizer(options.PrometheusRegistry) if buildinfo.IsDev() { @@ -381,9 +384,6 @@ func New(options *Options) *API { if options.FilesRateLimit == 0 { options.FilesRateLimit = 12 } - if options.PrometheusRegistry == nil { - options.PrometheusRegistry = prometheus.NewRegistry() - } if options.Clock == nil { options.Clock = quartz.NewReal() } diff --git a/coderd/tailnet.go b/coderd/tailnet.go index 172edea95a586..cdcf657fe732d 100644 --- a/coderd/tailnet.go +++ b/coderd/tailnet.go @@ -277,9 +277,9 @@ func (s *ServerTailnet) dialContext(ctx context.Context, network, addr string) ( }, nil } -func (s *ServerTailnet) AgentConn(ctx context.Context, agentID uuid.UUID) (*workspacesdk.AgentConn, func(), error) { +func (s *ServerTailnet) AgentConn(ctx context.Context, agentID uuid.UUID) (workspacesdk.AgentConn, func(), error) { var ( - conn *workspacesdk.AgentConn + conn workspacesdk.AgentConn ret func() ) diff --git a/coderd/workspaceagents_internal_test.go b/coderd/workspaceagents_internal_test.go new file mode 100644 index 0000000000000..c7520f05ab503 --- /dev/null +++ b/coderd/workspaceagents_internal_test.go @@ -0,0 +1,186 @@ +package coderd + +import ( + "bytes" + "context" + "database/sql" + "fmt" + "io" + "net/http" + "net/http/httptest" + "net/http/httputil" + "net/url" + "strings" + "testing" + + "github.com/go-chi/chi/v5" + "github.com/google/uuid" + "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" + + "cdr.dev/slog" + "cdr.dev/slog/sloggers/slogtest" + "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/database/dbmock" + "github.com/coder/coder/v2/coderd/database/dbtime" + "github.com/coder/coder/v2/coderd/httpmw" + "github.com/coder/coder/v2/coderd/workspaceapps/appurl" + "github.com/coder/coder/v2/codersdk" + "github.com/coder/coder/v2/codersdk/workspacesdk" + "github.com/coder/coder/v2/codersdk/workspacesdk/agentconnmock" + "github.com/coder/coder/v2/codersdk/wsjson" + "github.com/coder/coder/v2/tailnet" + "github.com/coder/coder/v2/tailnet/tailnettest" + "github.com/coder/coder/v2/testutil" + "github.com/coder/websocket" +) + +type fakeAgentProvider struct { + agentConn func(ctx context.Context, agentID uuid.UUID) (_ workspacesdk.AgentConn, release func(), _ error) +} + +func (fakeAgentProvider) ReverseProxy(targetURL, dashboardURL *url.URL, agentID uuid.UUID, app appurl.ApplicationURL, wildcardHost string) *httputil.ReverseProxy { + panic("unimplemented") +} + +func (f fakeAgentProvider) AgentConn(ctx context.Context, agentID uuid.UUID) (_ workspacesdk.AgentConn, release func(), _ error) { + if f.agentConn != nil { + return f.agentConn(ctx, agentID) + } + + panic("unimplemented") +} + +func (fakeAgentProvider) ServeHTTPDebug(w http.ResponseWriter, r *http.Request) { + panic("unimplemented") +} + +func (fakeAgentProvider) Close() error { + return nil +} + +func TestWatchAgentContainers(t *testing.T) { + t.Parallel() + + t.Run("WebSocketClosesProperly", func(t *testing.T) { + t.Parallel() + + // This test ensures that the agent containers `/watch` websocket can gracefully + // handle the underlying websocket unexpectedly closing. This test was created in + // response to this issue: https://github.com/coder/coder/issues/19372 + + var ( + ctx = testutil.Context(t, testutil.WaitShort) + logger = slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug).Named("coderd") + + mCtrl = gomock.NewController(t) + mDB = dbmock.NewMockStore(mCtrl) + mCoordinator = tailnettest.NewMockCoordinator(mCtrl) + mAgentConn = agentconnmock.NewMockAgentConn(mCtrl) + + fAgentProvider = fakeAgentProvider{ + agentConn: func(ctx context.Context, agentID uuid.UUID) (_ workspacesdk.AgentConn, release func(), _ error) { + return mAgentConn, func() {}, nil + }, + } + + workspaceID = uuid.New() + agentID = uuid.New() + resourceID = uuid.New() + jobID = uuid.New() + buildID = uuid.New() + + containersCh = make(chan codersdk.WorkspaceAgentListContainersResponse) + + r = chi.NewMux() + + api = API{ + ctx: ctx, + Options: &Options{ + AgentInactiveDisconnectTimeout: testutil.WaitShort, + Database: mDB, + Logger: logger, + DeploymentValues: &codersdk.DeploymentValues{}, + TailnetCoordinator: tailnettest.NewFakeCoordinator(), + }, + } + ) + + var tailnetCoordinator tailnet.Coordinator = mCoordinator + api.TailnetCoordinator.Store(&tailnetCoordinator) + api.agentProvider = fAgentProvider + + // Setup: Allow `ExtractWorkspaceAgentParams` to complete. + mDB.EXPECT().GetWorkspaceAgentByID(gomock.Any(), agentID).Return(database.WorkspaceAgent{ + ID: agentID, + ResourceID: resourceID, + LifecycleState: database.WorkspaceAgentLifecycleStateReady, + FirstConnectedAt: sql.NullTime{Valid: true, Time: dbtime.Now()}, + LastConnectedAt: sql.NullTime{Valid: true, Time: dbtime.Now()}, + }, nil) + mDB.EXPECT().GetWorkspaceResourceByID(gomock.Any(), resourceID).Return(database.WorkspaceResource{ + ID: resourceID, + JobID: jobID, + }, nil) + mDB.EXPECT().GetProvisionerJobByID(gomock.Any(), jobID).Return(database.ProvisionerJob{ + ID: jobID, + Type: database.ProvisionerJobTypeWorkspaceBuild, + }, nil) + mDB.EXPECT().GetWorkspaceBuildByJobID(gomock.Any(), jobID).Return(database.WorkspaceBuild{ + WorkspaceID: workspaceID, + ID: buildID, + }, nil) + + // And: Allow `db2dsk.WorkspaceAgent` to complete. + mCoordinator.EXPECT().Node(gomock.Any()).Return(nil) + + // And: Allow `WatchContainers` to be called, returing our `containersCh` channel. + mAgentConn.EXPECT().WatchContainers(gomock.Any(), gomock.Any()). + Return(containersCh, io.NopCloser(&bytes.Buffer{}), nil) + + // And: We mount the HTTP Handler + r.With(httpmw.ExtractWorkspaceAgentParam(mDB)). + Get("/workspaceagents/{workspaceagent}/containers/watch", api.watchWorkspaceAgentContainers) + + // Given: We create the HTTP server + srv := httptest.NewServer(r) + defer srv.Close() + + // And: Dial the WebSocket + wsURL := strings.Replace(srv.URL, "http://", "ws://", 1) + conn, resp, err := websocket.Dial(ctx, fmt.Sprintf("%s/workspaceagents/%s/containers/watch", wsURL, agentID), nil) + require.NoError(t, err) + if resp.Body != nil { + defer resp.Body.Close() + } + + // And: Create a streaming decoder + decoder := wsjson.NewDecoder[codersdk.WorkspaceAgentListContainersResponse](conn, websocket.MessageText, logger) + defer decoder.Close() + decodeCh := decoder.Chan() + + // And: We can successfully send through the channel. + testutil.RequireSend(ctx, t, containersCh, codersdk.WorkspaceAgentListContainersResponse{ + Containers: []codersdk.WorkspaceAgentContainer{{ + ID: "test-container-id", + }}, + }) + + // And: Receive the data. + containerResp := testutil.RequireReceive(ctx, t, decodeCh) + require.Len(t, containerResp.Containers, 1) + require.Equal(t, "test-container-id", containerResp.Containers[0].ID) + + // When: We close the `containersCh` + close(containersCh) + + // Then: We expect `decodeCh` to be closed. + select { + case <-ctx.Done(): + t.Fail() + + case _, ok := <-decodeCh: + require.False(t, ok, "channel is expected to be closed") + } + }) +} diff --git a/coderd/workspaceagents_test.go b/coderd/workspaceagents_test.go index 1855ed8a7e8fc..ac58df1b772ad 100644 --- a/coderd/workspaceagents_test.go +++ b/coderd/workspaceagents_test.go @@ -593,7 +593,7 @@ func TestWorkspaceAgentTailnet(t *testing.T) { _ = agenttest.New(t, client.URL, r.AgentToken) resources := coderdtest.AwaitWorkspaceAgents(t, client, r.Workspace.ID) - conn, err := func() (*workspacesdk.AgentConn, error) { + conn, err := func() (workspacesdk.AgentConn, error) { ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) defer cancel() // Connection should remain open even if the dial context is canceled. @@ -1574,82 +1574,6 @@ func TestWatchWorkspaceAgentDevcontainers(t *testing.T) { } } }) - - t.Run("PayloadTooLarge", func(t *testing.T) { - t.Parallel() - - var ( - ctx = testutil.Context(t, testutil.WaitSuperLong) - logger = slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug) - mClock = quartz.NewMock(t) - updaterTickerTrap = mClock.Trap().TickerFunc("updaterLoop") - mCtrl = gomock.NewController(t) - mCCLI = acmock.NewMockContainerCLI(mCtrl) - - client, db = coderdtest.NewWithDatabase(t, &coderdtest.Options{Logger: &logger}) - user = coderdtest.CreateFirstUser(t, client) - r = dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{ - OrganizationID: user.OrganizationID, - OwnerID: user.UserID, - }).WithAgent(func(agents []*proto.Agent) []*proto.Agent { - return agents - }).Do() - ) - - // WebSocket limit is 4MiB, so we want to ensure we create _more_ than 4MiB worth of payload. - // Creating 20,000 fake containers creates a payload of roughly 7MiB. - var fakeContainers []codersdk.WorkspaceAgentContainer - for range 20_000 { - fakeContainers = append(fakeContainers, codersdk.WorkspaceAgentContainer{ - CreatedAt: time.Now(), - ID: uuid.NewString(), - FriendlyName: uuid.NewString(), - Image: "busybox:latest", - Labels: map[string]string{ - agentcontainers.DevcontainerLocalFolderLabel: "/home/coder/project", - agentcontainers.DevcontainerConfigFileLabel: "/home/coder/project/.devcontainer/devcontainer.json", - }, - Running: false, - Ports: []codersdk.WorkspaceAgentContainerPort{}, - Status: string(codersdk.WorkspaceAgentDevcontainerStatusRunning), - Volumes: map[string]string{}, - }) - } - - mCCLI.EXPECT().List(gomock.Any()).Return(codersdk.WorkspaceAgentListContainersResponse{Containers: fakeContainers}, nil) - mCCLI.EXPECT().DetectArchitecture(gomock.Any(), gomock.Any()).Return("", nil).AnyTimes() - - _ = agenttest.New(t, client.URL, r.AgentToken, func(o *agent.Options) { - o.Logger = logger.Named("agent") - o.Devcontainers = true - o.DevcontainerAPIOptions = []agentcontainers.Option{ - agentcontainers.WithClock(mClock), - agentcontainers.WithContainerCLI(mCCLI), - agentcontainers.WithWatcher(watcher.NewNoop()), - } - }) - - resources := coderdtest.NewWorkspaceAgentWaiter(t, client, r.Workspace.ID).Wait() - require.Len(t, resources, 1, "expected one resource") - require.Len(t, resources[0].Agents, 1, "expected one agent") - agentID := resources[0].Agents[0].ID - - updaterTickerTrap.MustWait(ctx).MustRelease(ctx) - defer updaterTickerTrap.Close() - - containers, closer, err := client.WatchWorkspaceAgentContainers(ctx, agentID) - require.NoError(t, err) - defer func() { - closer.Close() - }() - - select { - case <-ctx.Done(): - t.Fail() - case _, ok := <-containers: - require.False(t, ok) - } - }) } func TestWorkspaceAgentRecreateDevcontainer(t *testing.T) { @@ -2497,7 +2421,7 @@ func TestWorkspaceAgent_UpdatedDERP(t *testing.T) { agentID := resources[0].Agents[0].ID // Connect from a client. - conn1, err := func() (*workspacesdk.AgentConn, error) { + conn1, err := func() (workspacesdk.AgentConn, error) { ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) defer cancel() // Connection should remain open even if the dial context is canceled. @@ -2538,7 +2462,7 @@ func TestWorkspaceAgent_UpdatedDERP(t *testing.T) { // Wait for the DERP map to be updated on the existing client. require.Eventually(t, func() bool { - regionIDs := conn1.Conn.DERPMap().RegionIDs() + regionIDs := conn1.TailnetConn().DERPMap().RegionIDs() return len(regionIDs) == 1 && regionIDs[0] == 2 }, testutil.WaitLong, testutil.IntervalFast) @@ -2555,7 +2479,7 @@ func TestWorkspaceAgent_UpdatedDERP(t *testing.T) { defer conn2.Close() ok = conn2.AwaitReachable(ctx) require.True(t, ok) - require.Equal(t, []int{2}, conn2.DERPMap().RegionIDs()) + require.Equal(t, []int{2}, conn2.TailnetConn().DERPMap().RegionIDs()) } func TestWorkspaceAgentExternalAuthListen(t *testing.T) { diff --git a/coderd/workspaceapps/proxy.go b/coderd/workspaceapps/proxy.go index 2f1294558f67a..002bb1ea05aae 100644 --- a/coderd/workspaceapps/proxy.go +++ b/coderd/workspaceapps/proxy.go @@ -74,7 +74,7 @@ type AgentProvider interface { ReverseProxy(targetURL, dashboardURL *url.URL, agentID uuid.UUID, app appurl.ApplicationURL, wildcardHost string) *httputil.ReverseProxy // AgentConn returns a new connection to the specified agent. - AgentConn(ctx context.Context, agentID uuid.UUID) (_ *workspacesdk.AgentConn, release func(), _ error) + AgentConn(ctx context.Context, agentID uuid.UUID) (_ workspacesdk.AgentConn, release func(), _ error) ServeHTTPDebug(w http.ResponseWriter, r *http.Request) diff --git a/codersdk/workspacesdk/agentconn.go b/codersdk/workspacesdk/agentconn.go index 9c65b7ee9a1e1..bb929c9ba2a04 100644 --- a/codersdk/workspacesdk/agentconn.go +++ b/codersdk/workspacesdk/agentconn.go @@ -34,8 +34,8 @@ import ( // to the WorkspaceAgentConn, or it may be shared in the case of coderd. If the // conn is shared and closing it is undesirable, you may return ErrNoClose from // opts.CloseFunc. This will ensure the underlying conn is not closed. -func NewAgentConn(conn *tailnet.Conn, opts AgentConnOptions) *AgentConn { - return &AgentConn{ +func NewAgentConn(conn *tailnet.Conn, opts AgentConnOptions) AgentConn { + return &agentConn{ Conn: conn, opts: opts, } @@ -43,23 +43,54 @@ func NewAgentConn(conn *tailnet.Conn, opts AgentConnOptions) *AgentConn { // AgentConn represents a connection to a workspace agent. // @typescript-ignore AgentConn -type AgentConn struct { +type AgentConn interface { + TailnetConn() *tailnet.Conn + + AwaitReachable(ctx context.Context) bool + Close() error + DebugLogs(ctx context.Context) ([]byte, error) + DebugMagicsock(ctx context.Context) ([]byte, error) + DebugManifest(ctx context.Context) ([]byte, error) + DialContext(ctx context.Context, network string, addr string) (net.Conn, error) + GetPeerDiagnostics() tailnet.PeerDiagnostics + ListContainers(ctx context.Context) (codersdk.WorkspaceAgentListContainersResponse, error) + ListeningPorts(ctx context.Context) (codersdk.WorkspaceAgentListeningPortsResponse, error) + Netcheck(ctx context.Context) (healthsdk.AgentNetcheckReport, error) + Ping(ctx context.Context) (time.Duration, bool, *ipnstate.PingResult, error) + PrometheusMetrics(ctx context.Context) ([]byte, error) + ReconnectingPTY(ctx context.Context, id uuid.UUID, height uint16, width uint16, command string, initOpts ...AgentReconnectingPTYInitOption) (net.Conn, error) + RecreateDevcontainer(ctx context.Context, devcontainerID string) (codersdk.Response, error) + SSH(ctx context.Context) (*gonet.TCPConn, error) + SSHClient(ctx context.Context) (*ssh.Client, error) + SSHClientOnPort(ctx context.Context, port uint16) (*ssh.Client, error) + SSHOnPort(ctx context.Context, port uint16) (*gonet.TCPConn, error) + Speedtest(ctx context.Context, direction speedtest.Direction, duration time.Duration) ([]speedtest.Result, error) + WatchContainers(ctx context.Context, logger slog.Logger) (<-chan codersdk.WorkspaceAgentListContainersResponse, io.Closer, error) +} + +// AgentConn represents a connection to a workspace agent. +// @typescript-ignore AgentConn +type agentConn struct { *tailnet.Conn opts AgentConnOptions } +func (c *agentConn) TailnetConn() *tailnet.Conn { + return c.Conn +} + // @typescript-ignore AgentConnOptions type AgentConnOptions struct { AgentID uuid.UUID CloseFunc func() error } -func (c *AgentConn) agentAddress() netip.Addr { +func (c *agentConn) agentAddress() netip.Addr { return tailnet.TailscaleServicePrefix.AddrFromUUID(c.opts.AgentID) } // AwaitReachable waits for the agent to be reachable. -func (c *AgentConn) AwaitReachable(ctx context.Context) bool { +func (c *agentConn) AwaitReachable(ctx context.Context) bool { ctx, span := tracing.StartSpan(ctx) defer span.End() @@ -68,7 +99,7 @@ func (c *AgentConn) AwaitReachable(ctx context.Context) bool { // Ping pings the agent and returns the round-trip time. // The bool returns true if the ping was made P2P. -func (c *AgentConn) Ping(ctx context.Context) (time.Duration, bool, *ipnstate.PingResult, error) { +func (c *agentConn) Ping(ctx context.Context) (time.Duration, bool, *ipnstate.PingResult, error) { ctx, span := tracing.StartSpan(ctx) defer span.End() @@ -76,7 +107,7 @@ func (c *AgentConn) Ping(ctx context.Context) (time.Duration, bool, *ipnstate.Pi } // Close ends the connection to the workspace agent. -func (c *AgentConn) Close() error { +func (c *agentConn) Close() error { var cerr error if c.opts.CloseFunc != nil { cerr = c.opts.CloseFunc() @@ -131,7 +162,7 @@ type ReconnectingPTYRequest struct { // ReconnectingPTY spawns a new reconnecting terminal session. // `ReconnectingPTYRequest` should be JSON marshaled and written to the returned net.Conn. // Raw terminal output will be read from the returned net.Conn. -func (c *AgentConn) ReconnectingPTY(ctx context.Context, id uuid.UUID, height, width uint16, command string, initOpts ...AgentReconnectingPTYInitOption) (net.Conn, error) { +func (c *agentConn) ReconnectingPTY(ctx context.Context, id uuid.UUID, height, width uint16, command string, initOpts ...AgentReconnectingPTYInitOption) (net.Conn, error) { ctx, span := tracing.StartSpan(ctx) defer span.End() @@ -171,13 +202,13 @@ func (c *AgentConn) ReconnectingPTY(ctx context.Context, id uuid.UUID, height, w // SSH pipes the SSH protocol over the returned net.Conn. // This connects to the built-in SSH server in the workspace agent. -func (c *AgentConn) SSH(ctx context.Context) (*gonet.TCPConn, error) { +func (c *agentConn) SSH(ctx context.Context) (*gonet.TCPConn, error) { return c.SSHOnPort(ctx, AgentSSHPort) } // SSHOnPort pipes the SSH protocol over the returned net.Conn. // This connects to the built-in SSH server in the workspace agent on the specified port. -func (c *AgentConn) SSHOnPort(ctx context.Context, port uint16) (*gonet.TCPConn, error) { +func (c *agentConn) SSHOnPort(ctx context.Context, port uint16) (*gonet.TCPConn, error) { ctx, span := tracing.StartSpan(ctx) defer span.End() @@ -190,12 +221,12 @@ func (c *AgentConn) SSHOnPort(ctx context.Context, port uint16) (*gonet.TCPConn, } // SSHClient calls SSH to create a client -func (c *AgentConn) SSHClient(ctx context.Context) (*ssh.Client, error) { +func (c *agentConn) SSHClient(ctx context.Context) (*ssh.Client, error) { return c.SSHClientOnPort(ctx, AgentSSHPort) } // SSHClientOnPort calls SSH to create a client on a specific port -func (c *AgentConn) SSHClientOnPort(ctx context.Context, port uint16) (*ssh.Client, error) { +func (c *agentConn) SSHClientOnPort(ctx context.Context, port uint16) (*ssh.Client, error) { ctx, span := tracing.StartSpan(ctx) defer span.End() @@ -218,7 +249,7 @@ func (c *AgentConn) SSHClientOnPort(ctx context.Context, port uint16) (*ssh.Clie } // Speedtest runs a speedtest against the workspace agent. -func (c *AgentConn) Speedtest(ctx context.Context, direction speedtest.Direction, duration time.Duration) ([]speedtest.Result, error) { +func (c *agentConn) Speedtest(ctx context.Context, direction speedtest.Direction, duration time.Duration) ([]speedtest.Result, error) { ctx, span := tracing.StartSpan(ctx) defer span.End() @@ -242,7 +273,7 @@ func (c *AgentConn) Speedtest(ctx context.Context, direction speedtest.Direction // DialContext dials the address provided in the workspace agent. // The network must be "tcp" or "udp". -func (c *AgentConn) DialContext(ctx context.Context, network string, addr string) (net.Conn, error) { +func (c *agentConn) DialContext(ctx context.Context, network string, addr string) (net.Conn, error) { ctx, span := tracing.StartSpan(ctx) defer span.End() @@ -265,7 +296,7 @@ func (c *AgentConn) DialContext(ctx context.Context, network string, addr string } // ListeningPorts lists the ports that are currently in use by the workspace. -func (c *AgentConn) ListeningPorts(ctx context.Context) (codersdk.WorkspaceAgentListeningPortsResponse, error) { +func (c *agentConn) ListeningPorts(ctx context.Context) (codersdk.WorkspaceAgentListeningPortsResponse, error) { ctx, span := tracing.StartSpan(ctx) defer span.End() res, err := c.apiRequest(ctx, http.MethodGet, "/api/v0/listening-ports", nil) @@ -282,7 +313,7 @@ func (c *AgentConn) ListeningPorts(ctx context.Context) (codersdk.WorkspaceAgent } // Netcheck returns a network check report from the workspace agent. -func (c *AgentConn) Netcheck(ctx context.Context) (healthsdk.AgentNetcheckReport, error) { +func (c *agentConn) Netcheck(ctx context.Context) (healthsdk.AgentNetcheckReport, error) { ctx, span := tracing.StartSpan(ctx) defer span.End() res, err := c.apiRequest(ctx, http.MethodGet, "/api/v0/netcheck", nil) @@ -299,7 +330,7 @@ func (c *AgentConn) Netcheck(ctx context.Context) (healthsdk.AgentNetcheckReport } // DebugMagicsock makes a request to the workspace agent's magicsock debug endpoint. -func (c *AgentConn) DebugMagicsock(ctx context.Context) ([]byte, error) { +func (c *agentConn) DebugMagicsock(ctx context.Context) ([]byte, error) { ctx, span := tracing.StartSpan(ctx) defer span.End() res, err := c.apiRequest(ctx, http.MethodGet, "/debug/magicsock", nil) @@ -319,7 +350,7 @@ func (c *AgentConn) DebugMagicsock(ctx context.Context) ([]byte, error) { // DebugManifest returns the agent's in-memory manifest. Unfortunately this must // be returns as a []byte to avoid an import cycle. -func (c *AgentConn) DebugManifest(ctx context.Context) ([]byte, error) { +func (c *agentConn) DebugManifest(ctx context.Context) ([]byte, error) { ctx, span := tracing.StartSpan(ctx) defer span.End() res, err := c.apiRequest(ctx, http.MethodGet, "/debug/manifest", nil) @@ -338,7 +369,7 @@ func (c *AgentConn) DebugManifest(ctx context.Context) ([]byte, error) { } // DebugLogs returns up to the last 10MB of `/tmp/coder-agent.log` -func (c *AgentConn) DebugLogs(ctx context.Context) ([]byte, error) { +func (c *agentConn) DebugLogs(ctx context.Context) ([]byte, error) { ctx, span := tracing.StartSpan(ctx) defer span.End() res, err := c.apiRequest(ctx, http.MethodGet, "/debug/logs", nil) @@ -357,7 +388,7 @@ func (c *AgentConn) DebugLogs(ctx context.Context) ([]byte, error) { } // PrometheusMetrics returns a response from the agent's prometheus metrics endpoint -func (c *AgentConn) PrometheusMetrics(ctx context.Context) ([]byte, error) { +func (c *agentConn) PrometheusMetrics(ctx context.Context) ([]byte, error) { ctx, span := tracing.StartSpan(ctx) defer span.End() res, err := c.apiRequest(ctx, http.MethodGet, "/debug/prometheus", nil) @@ -376,7 +407,7 @@ func (c *AgentConn) PrometheusMetrics(ctx context.Context) ([]byte, error) { } // ListContainers returns a response from the agent's containers endpoint -func (c *AgentConn) ListContainers(ctx context.Context) (codersdk.WorkspaceAgentListContainersResponse, error) { +func (c *agentConn) ListContainers(ctx context.Context) (codersdk.WorkspaceAgentListContainersResponse, error) { ctx, span := tracing.StartSpan(ctx) defer span.End() res, err := c.apiRequest(ctx, http.MethodGet, "/api/v0/containers", nil) @@ -391,7 +422,7 @@ func (c *AgentConn) ListContainers(ctx context.Context) (codersdk.WorkspaceAgent return resp, json.NewDecoder(res.Body).Decode(&resp) } -func (c *AgentConn) WatchContainers(ctx context.Context, logger slog.Logger) (<-chan codersdk.WorkspaceAgentListContainersResponse, io.Closer, error) { +func (c *agentConn) WatchContainers(ctx context.Context, logger slog.Logger) (<-chan codersdk.WorkspaceAgentListContainersResponse, io.Closer, error) { ctx, span := tracing.StartSpan(ctx) defer span.End() @@ -427,7 +458,7 @@ func (c *AgentConn) WatchContainers(ctx context.Context, logger slog.Logger) (<- // RecreateDevcontainer recreates a devcontainer with the given container. // This is a blocking call and will wait for the container to be recreated. -func (c *AgentConn) RecreateDevcontainer(ctx context.Context, devcontainerID string) (codersdk.Response, error) { +func (c *agentConn) RecreateDevcontainer(ctx context.Context, devcontainerID string) (codersdk.Response, error) { ctx, span := tracing.StartSpan(ctx) defer span.End() res, err := c.apiRequest(ctx, http.MethodPost, "/api/v0/containers/devcontainers/"+devcontainerID+"/recreate", nil) @@ -446,7 +477,7 @@ func (c *AgentConn) RecreateDevcontainer(ctx context.Context, devcontainerID str } // apiRequest makes a request to the workspace agent's HTTP API server. -func (c *AgentConn) apiRequest(ctx context.Context, method, path string, body io.Reader) (*http.Response, error) { +func (c *agentConn) apiRequest(ctx context.Context, method, path string, body io.Reader) (*http.Response, error) { ctx, span := tracing.StartSpan(ctx) defer span.End() @@ -463,7 +494,7 @@ func (c *AgentConn) apiRequest(ctx context.Context, method, path string, body io // apiClient returns an HTTP client that can be used to make // requests to the workspace agent's HTTP API server. -func (c *AgentConn) apiClient() *http.Client { +func (c *agentConn) apiClient() *http.Client { return &http.Client{ Transport: &http.Transport{ // Disable keep alives as we're usually only making a single @@ -504,6 +535,6 @@ func (c *AgentConn) apiClient() *http.Client { } } -func (c *AgentConn) GetPeerDiagnostics() tailnet.PeerDiagnostics { +func (c *agentConn) GetPeerDiagnostics() tailnet.PeerDiagnostics { return c.Conn.GetPeerDiagnostics(c.opts.AgentID) } diff --git a/codersdk/workspacesdk/agentconnmock/agentconnmock.go b/codersdk/workspacesdk/agentconnmock/agentconnmock.go new file mode 100644 index 0000000000000..eb55bb27938c0 --- /dev/null +++ b/codersdk/workspacesdk/agentconnmock/agentconnmock.go @@ -0,0 +1,373 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: .. (interfaces: AgentConn) +// +// Generated by this command: +// +// mockgen -destination ./agentconnmock.go -package agentconnmock .. AgentConn +// + +// Package agentconnmock is a generated GoMock package. +package agentconnmock + +import ( + context "context" + io "io" + net "net" + reflect "reflect" + time "time" + + slog "cdr.dev/slog" + codersdk "github.com/coder/coder/v2/codersdk" + healthsdk "github.com/coder/coder/v2/codersdk/healthsdk" + workspacesdk "github.com/coder/coder/v2/codersdk/workspacesdk" + tailnet "github.com/coder/coder/v2/tailnet" + uuid "github.com/google/uuid" + gomock "go.uber.org/mock/gomock" + ssh "golang.org/x/crypto/ssh" + gonet "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" + ipnstate "tailscale.com/ipn/ipnstate" + speedtest "tailscale.com/net/speedtest" +) + +// MockAgentConn is a mock of AgentConn interface. +type MockAgentConn struct { + ctrl *gomock.Controller + recorder *MockAgentConnMockRecorder + isgomock struct{} +} + +// MockAgentConnMockRecorder is the mock recorder for MockAgentConn. +type MockAgentConnMockRecorder struct { + mock *MockAgentConn +} + +// NewMockAgentConn creates a new mock instance. +func NewMockAgentConn(ctrl *gomock.Controller) *MockAgentConn { + mock := &MockAgentConn{ctrl: ctrl} + mock.recorder = &MockAgentConnMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockAgentConn) EXPECT() *MockAgentConnMockRecorder { + return m.recorder +} + +// AwaitReachable mocks base method. +func (m *MockAgentConn) AwaitReachable(ctx context.Context) bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "AwaitReachable", ctx) + ret0, _ := ret[0].(bool) + return ret0 +} + +// AwaitReachable indicates an expected call of AwaitReachable. +func (mr *MockAgentConnMockRecorder) AwaitReachable(ctx any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AwaitReachable", reflect.TypeOf((*MockAgentConn)(nil).AwaitReachable), ctx) +} + +// Close mocks base method. +func (m *MockAgentConn) Close() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Close") + ret0, _ := ret[0].(error) + return ret0 +} + +// Close indicates an expected call of Close. +func (mr *MockAgentConnMockRecorder) Close() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockAgentConn)(nil).Close)) +} + +// DebugLogs mocks base method. +func (m *MockAgentConn) DebugLogs(ctx context.Context) ([]byte, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DebugLogs", ctx) + ret0, _ := ret[0].([]byte) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// DebugLogs indicates an expected call of DebugLogs. +func (mr *MockAgentConnMockRecorder) DebugLogs(ctx any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DebugLogs", reflect.TypeOf((*MockAgentConn)(nil).DebugLogs), ctx) +} + +// DebugMagicsock mocks base method. +func (m *MockAgentConn) DebugMagicsock(ctx context.Context) ([]byte, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DebugMagicsock", ctx) + ret0, _ := ret[0].([]byte) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// DebugMagicsock indicates an expected call of DebugMagicsock. +func (mr *MockAgentConnMockRecorder) DebugMagicsock(ctx any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DebugMagicsock", reflect.TypeOf((*MockAgentConn)(nil).DebugMagicsock), ctx) +} + +// DebugManifest mocks base method. +func (m *MockAgentConn) DebugManifest(ctx context.Context) ([]byte, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DebugManifest", ctx) + ret0, _ := ret[0].([]byte) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// DebugManifest indicates an expected call of DebugManifest. +func (mr *MockAgentConnMockRecorder) DebugManifest(ctx any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DebugManifest", reflect.TypeOf((*MockAgentConn)(nil).DebugManifest), ctx) +} + +// DialContext mocks base method. +func (m *MockAgentConn) DialContext(ctx context.Context, network, addr string) (net.Conn, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DialContext", ctx, network, addr) + ret0, _ := ret[0].(net.Conn) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// DialContext indicates an expected call of DialContext. +func (mr *MockAgentConnMockRecorder) DialContext(ctx, network, addr any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DialContext", reflect.TypeOf((*MockAgentConn)(nil).DialContext), ctx, network, addr) +} + +// GetPeerDiagnostics mocks base method. +func (m *MockAgentConn) GetPeerDiagnostics() tailnet.PeerDiagnostics { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetPeerDiagnostics") + ret0, _ := ret[0].(tailnet.PeerDiagnostics) + return ret0 +} + +// GetPeerDiagnostics indicates an expected call of GetPeerDiagnostics. +func (mr *MockAgentConnMockRecorder) GetPeerDiagnostics() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPeerDiagnostics", reflect.TypeOf((*MockAgentConn)(nil).GetPeerDiagnostics)) +} + +// ListContainers mocks base method. +func (m *MockAgentConn) ListContainers(ctx context.Context) (codersdk.WorkspaceAgentListContainersResponse, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ListContainers", ctx) + ret0, _ := ret[0].(codersdk.WorkspaceAgentListContainersResponse) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ListContainers indicates an expected call of ListContainers. +func (mr *MockAgentConnMockRecorder) ListContainers(ctx any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListContainers", reflect.TypeOf((*MockAgentConn)(nil).ListContainers), ctx) +} + +// ListeningPorts mocks base method. +func (m *MockAgentConn) ListeningPorts(ctx context.Context) (codersdk.WorkspaceAgentListeningPortsResponse, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ListeningPorts", ctx) + ret0, _ := ret[0].(codersdk.WorkspaceAgentListeningPortsResponse) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ListeningPorts indicates an expected call of ListeningPorts. +func (mr *MockAgentConnMockRecorder) ListeningPorts(ctx any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListeningPorts", reflect.TypeOf((*MockAgentConn)(nil).ListeningPorts), ctx) +} + +// Netcheck mocks base method. +func (m *MockAgentConn) Netcheck(ctx context.Context) (healthsdk.AgentNetcheckReport, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Netcheck", ctx) + ret0, _ := ret[0].(healthsdk.AgentNetcheckReport) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Netcheck indicates an expected call of Netcheck. +func (mr *MockAgentConnMockRecorder) Netcheck(ctx any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Netcheck", reflect.TypeOf((*MockAgentConn)(nil).Netcheck), ctx) +} + +// Ping mocks base method. +func (m *MockAgentConn) Ping(ctx context.Context) (time.Duration, bool, *ipnstate.PingResult, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Ping", ctx) + ret0, _ := ret[0].(time.Duration) + ret1, _ := ret[1].(bool) + ret2, _ := ret[2].(*ipnstate.PingResult) + ret3, _ := ret[3].(error) + return ret0, ret1, ret2, ret3 +} + +// Ping indicates an expected call of Ping. +func (mr *MockAgentConnMockRecorder) Ping(ctx any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Ping", reflect.TypeOf((*MockAgentConn)(nil).Ping), ctx) +} + +// PrometheusMetrics mocks base method. +func (m *MockAgentConn) PrometheusMetrics(ctx context.Context) ([]byte, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "PrometheusMetrics", ctx) + ret0, _ := ret[0].([]byte) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// PrometheusMetrics indicates an expected call of PrometheusMetrics. +func (mr *MockAgentConnMockRecorder) PrometheusMetrics(ctx any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PrometheusMetrics", reflect.TypeOf((*MockAgentConn)(nil).PrometheusMetrics), ctx) +} + +// ReconnectingPTY mocks base method. +func (m *MockAgentConn) ReconnectingPTY(ctx context.Context, id uuid.UUID, height, width uint16, command string, initOpts ...workspacesdk.AgentReconnectingPTYInitOption) (net.Conn, error) { + m.ctrl.T.Helper() + varargs := []any{ctx, id, height, width, command} + for _, a := range initOpts { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "ReconnectingPTY", varargs...) + ret0, _ := ret[0].(net.Conn) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ReconnectingPTY indicates an expected call of ReconnectingPTY. +func (mr *MockAgentConnMockRecorder) ReconnectingPTY(ctx, id, height, width, command any, initOpts ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{ctx, id, height, width, command}, initOpts...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReconnectingPTY", reflect.TypeOf((*MockAgentConn)(nil).ReconnectingPTY), varargs...) +} + +// RecreateDevcontainer mocks base method. +func (m *MockAgentConn) RecreateDevcontainer(ctx context.Context, devcontainerID string) (codersdk.Response, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "RecreateDevcontainer", ctx, devcontainerID) + ret0, _ := ret[0].(codersdk.Response) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// RecreateDevcontainer indicates an expected call of RecreateDevcontainer. +func (mr *MockAgentConnMockRecorder) RecreateDevcontainer(ctx, devcontainerID any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RecreateDevcontainer", reflect.TypeOf((*MockAgentConn)(nil).RecreateDevcontainer), ctx, devcontainerID) +} + +// SSH mocks base method. +func (m *MockAgentConn) SSH(ctx context.Context) (*gonet.TCPConn, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SSH", ctx) + ret0, _ := ret[0].(*gonet.TCPConn) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// SSH indicates an expected call of SSH. +func (mr *MockAgentConnMockRecorder) SSH(ctx any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SSH", reflect.TypeOf((*MockAgentConn)(nil).SSH), ctx) +} + +// SSHClient mocks base method. +func (m *MockAgentConn) SSHClient(ctx context.Context) (*ssh.Client, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SSHClient", ctx) + ret0, _ := ret[0].(*ssh.Client) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// SSHClient indicates an expected call of SSHClient. +func (mr *MockAgentConnMockRecorder) SSHClient(ctx any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SSHClient", reflect.TypeOf((*MockAgentConn)(nil).SSHClient), ctx) +} + +// SSHClientOnPort mocks base method. +func (m *MockAgentConn) SSHClientOnPort(ctx context.Context, port uint16) (*ssh.Client, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SSHClientOnPort", ctx, port) + ret0, _ := ret[0].(*ssh.Client) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// SSHClientOnPort indicates an expected call of SSHClientOnPort. +func (mr *MockAgentConnMockRecorder) SSHClientOnPort(ctx, port any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SSHClientOnPort", reflect.TypeOf((*MockAgentConn)(nil).SSHClientOnPort), ctx, port) +} + +// SSHOnPort mocks base method. +func (m *MockAgentConn) SSHOnPort(ctx context.Context, port uint16) (*gonet.TCPConn, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SSHOnPort", ctx, port) + ret0, _ := ret[0].(*gonet.TCPConn) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// SSHOnPort indicates an expected call of SSHOnPort. +func (mr *MockAgentConnMockRecorder) SSHOnPort(ctx, port any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SSHOnPort", reflect.TypeOf((*MockAgentConn)(nil).SSHOnPort), ctx, port) +} + +// Speedtest mocks base method. +func (m *MockAgentConn) Speedtest(ctx context.Context, direction speedtest.Direction, duration time.Duration) ([]speedtest.Result, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Speedtest", ctx, direction, duration) + ret0, _ := ret[0].([]speedtest.Result) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Speedtest indicates an expected call of Speedtest. +func (mr *MockAgentConnMockRecorder) Speedtest(ctx, direction, duration any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Speedtest", reflect.TypeOf((*MockAgentConn)(nil).Speedtest), ctx, direction, duration) +} + +// TailnetConn mocks base method. +func (m *MockAgentConn) TailnetConn() *tailnet.Conn { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "TailnetConn") + ret0, _ := ret[0].(*tailnet.Conn) + return ret0 +} + +// TailnetConn indicates an expected call of TailnetConn. +func (mr *MockAgentConnMockRecorder) TailnetConn() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TailnetConn", reflect.TypeOf((*MockAgentConn)(nil).TailnetConn)) +} + +// WatchContainers mocks base method. +func (m *MockAgentConn) WatchContainers(ctx context.Context, logger slog.Logger) (<-chan codersdk.WorkspaceAgentListContainersResponse, io.Closer, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "WatchContainers", ctx, logger) + ret0, _ := ret[0].(<-chan codersdk.WorkspaceAgentListContainersResponse) + ret1, _ := ret[1].(io.Closer) + ret2, _ := ret[2].(error) + return ret0, ret1, ret2 +} + +// WatchContainers indicates an expected call of WatchContainers. +func (mr *MockAgentConnMockRecorder) WatchContainers(ctx, logger any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WatchContainers", reflect.TypeOf((*MockAgentConn)(nil).WatchContainers), ctx, logger) +} diff --git a/codersdk/workspacesdk/agentconnmock/doc.go b/codersdk/workspacesdk/agentconnmock/doc.go new file mode 100644 index 0000000000000..a795b21a4a89d --- /dev/null +++ b/codersdk/workspacesdk/agentconnmock/doc.go @@ -0,0 +1,4 @@ +// Package agentconnmock contains a mock implementation of workspacesdk.AgentConn for use in tests. +package agentconnmock + +//go:generate mockgen -destination ./agentconnmock.go -package agentconnmock .. AgentConn diff --git a/codersdk/workspacesdk/workspacesdk.go b/codersdk/workspacesdk/workspacesdk.go index 9f587cf5267a8..ddaec06388238 100644 --- a/codersdk/workspacesdk/workspacesdk.go +++ b/codersdk/workspacesdk/workspacesdk.go @@ -202,7 +202,7 @@ func (c *Client) RewriteDERPMap(derpMap *tailcfg.DERPMap) { tailnet.RewriteDERPMapDefaultRelay(context.Background(), c.client.Logger(), derpMap, c.client.URL) } -func (c *Client) DialAgent(dialCtx context.Context, agentID uuid.UUID, options *DialAgentOptions) (agentConn *AgentConn, err error) { +func (c *Client) DialAgent(dialCtx context.Context, agentID uuid.UUID, options *DialAgentOptions) (agentConn AgentConn, err error) { if options == nil { options = &DialAgentOptions{} } diff --git a/enterprise/wsproxy/wsproxysdk/wsproxysdk.go b/enterprise/wsproxy/wsproxysdk/wsproxysdk.go index b0051551a0f3d..72f5a4291c40e 100644 --- a/enterprise/wsproxy/wsproxysdk/wsproxysdk.go +++ b/enterprise/wsproxy/wsproxysdk/wsproxysdk.go @@ -75,7 +75,7 @@ func (c *Client) RequestIgnoreRedirects(ctx context.Context, method, path string // DialWorkspaceAgent calls the underlying codersdk.Client's DialWorkspaceAgent // method. -func (c *Client) DialWorkspaceAgent(ctx context.Context, agentID uuid.UUID, options *workspacesdk.DialAgentOptions) (agentConn *workspacesdk.AgentConn, err error) { +func (c *Client) DialWorkspaceAgent(ctx context.Context, agentID uuid.UUID, options *workspacesdk.DialAgentOptions) (agentConn workspacesdk.AgentConn, err error) { return workspacesdk.New(c.SDKClient).DialAgent(ctx, agentID, options) } diff --git a/scaletest/agentconn/run.go b/scaletest/agentconn/run.go index dba21cc24e3a0..b0990d9cb11a6 100644 --- a/scaletest/agentconn/run.go +++ b/scaletest/agentconn/run.go @@ -89,7 +89,7 @@ func (r *Runner) Run(ctx context.Context, _ string, w io.Writer) error { // Ensure DERP for completeness. if r.cfg.ConnectionMode == ConnectionModeDerp { - status := conn.Status() + status := conn.TailnetConn().Status() if len(status.Peers()) != 1 { return xerrors.Errorf("check connection mode: expected 1 peer, got %d", len(status.Peers())) } @@ -133,7 +133,7 @@ func (r *Runner) Run(ctx context.Context, _ string, w io.Writer) error { return nil } -func waitForDisco(ctx context.Context, logs io.Writer, conn *workspacesdk.AgentConn) error { +func waitForDisco(ctx context.Context, logs io.Writer, conn workspacesdk.AgentConn) error { const pingAttempts = 10 const pingDelay = 1 * time.Second @@ -165,7 +165,7 @@ func waitForDisco(ctx context.Context, logs io.Writer, conn *workspacesdk.AgentC return nil } -func waitForDirectConnection(ctx context.Context, logs io.Writer, conn *workspacesdk.AgentConn) error { +func waitForDirectConnection(ctx context.Context, logs io.Writer, conn workspacesdk.AgentConn) error { const directConnectionAttempts = 30 const directConnectionDelay = 1 * time.Second @@ -174,7 +174,7 @@ func waitForDirectConnection(ctx context.Context, logs io.Writer, conn *workspac for i := 0; i < directConnectionAttempts; i++ { _, _ = fmt.Fprintf(logs, "\tDirect connection check %d/%d...\n", i+1, directConnectionAttempts) - status := conn.Status() + status := conn.TailnetConn().Status() var err error if len(status.Peers()) != 1 { @@ -207,7 +207,7 @@ func waitForDirectConnection(ctx context.Context, logs io.Writer, conn *workspac return nil } -func verifyConnection(ctx context.Context, logs io.Writer, conn *workspacesdk.AgentConn) error { +func verifyConnection(ctx context.Context, logs io.Writer, conn workspacesdk.AgentConn) error { const verifyConnectionAttempts = 30 const verifyConnectionDelay = 1 * time.Second @@ -249,7 +249,7 @@ func verifyConnection(ctx context.Context, logs io.Writer, conn *workspacesdk.Ag return nil } -func performInitialConnections(ctx context.Context, logs io.Writer, conn *workspacesdk.AgentConn, specs []Connection) error { +func performInitialConnections(ctx context.Context, logs io.Writer, conn workspacesdk.AgentConn, specs []Connection) error { if len(specs) == 0 { return nil } @@ -287,7 +287,7 @@ func performInitialConnections(ctx context.Context, logs io.Writer, conn *worksp return nil } -func holdConnection(ctx context.Context, logs io.Writer, conn *workspacesdk.AgentConn, holdDur time.Duration, specs []Connection) error { +func holdConnection(ctx context.Context, logs io.Writer, conn workspacesdk.AgentConn, holdDur time.Duration, specs []Connection) error { ctx, span := tracing.StartSpan(ctx) defer span.End() @@ -364,7 +364,7 @@ func holdConnection(ctx context.Context, logs io.Writer, conn *workspacesdk.Agen return nil } -func agentHTTPClient(conn *workspacesdk.AgentConn) *http.Client { +func agentHTTPClient(conn workspacesdk.AgentConn) *http.Client { return &http.Client{ Transport: &http.Transport{ DisableKeepAlives: true, diff --git a/support/support.go b/support/support.go index 2fa41ce7eca8c..31080faaf023b 100644 --- a/support/support.go +++ b/support/support.go @@ -390,7 +390,7 @@ func connectedAgentInfo(ctx context.Context, client *codersdk.Client, log slog.L if err := conn.Close(); err != nil { log.Error(ctx, "failed to close agent connection", slog.Error(err)) } - <-conn.Closed() + <-conn.TailnetConn().Closed() } eg.Go(func() error { @@ -399,7 +399,7 @@ func connectedAgentInfo(ctx context.Context, client *codersdk.Client, log slog.L return xerrors.Errorf("create request: %w", err) } rr := httptest.NewRecorder() - conn.MagicsockServeHTTPDebug(rr, req) + conn.TailnetConn().MagicsockServeHTTPDebug(rr, req) a.ClientMagicsockHTML = rr.Body.Bytes() return nil }) pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy