From 2403c9459cb4358f63c3e98217be0cabc85d64bd Mon Sep 17 00:00:00 2001 From: Kyle Carberry Date: Mon, 5 Sep 2022 16:57:22 -0500 Subject: [PATCH 1/3] chore: Remove WebRTC networking --- Makefile | 10 +- agent/agent.go | 123 +---- agent/agent_test.go | 210 +++------ agent/conn.go | 131 +----- cli/agent.go | 3 - cli/agent_test.go | 33 +- cli/configssh.go | 12 +- cli/configssh_test.go | 6 +- cli/gitssh_test.go | 13 +- cli/portforward.go | 15 +- cli/portforward_test.go | 92 +--- cli/server.go | 31 +- cli/ssh.go | 12 +- cli/ssh_test.go | 3 - coderd/coderd.go | 21 +- coderd/coderdtest/authtest.go | 28 +- coderd/coderdtest/coderdtest.go | 8 - coderd/templates_test.go | 1 - coderd/turnconn/turnconn.go | 203 -------- coderd/turnconn/turnconn_test.go | 107 ----- coderd/workspaceagents.go | 487 +++++-------------- coderd/workspaceagents_test.go | 29 +- coderd/workspaceapps_test.go | 1 - coderd/wsconncache/wsconncache.go | 4 +- coderd/wsconncache/wsconncache_test.go | 12 +- codersdk/workspaceagents.go | 175 +------ go.mod | 26 +- go.sum | 46 +- peer/channel.go | 317 ------------- peer/conn.go | 616 ------------------------- peer/conn_test.go | 434 ----------------- peer/netconn.go | 59 --- peerbroker/dial.go | 87 ---- peerbroker/dial_test.go | 67 --- peerbroker/listen.go | 188 -------- peerbroker/listen_test.go | 52 --- peerbroker/proto/peerbroker.pb.go | 269 ----------- peerbroker/proto/peerbroker.proto | 28 -- peerbroker/proto/peerbroker_drpc.pb.go | 146 ------ peerbroker/proxy.go | 283 ------------ peerbroker/proxy_test.go | 84 ---- 41 files changed, 292 insertions(+), 4180 deletions(-) delete mode 100644 coderd/turnconn/turnconn.go delete mode 100644 coderd/turnconn/turnconn_test.go delete mode 100644 peer/channel.go delete mode 100644 peer/conn.go delete mode 100644 peer/conn_test.go delete mode 100644 peer/netconn.go delete mode 100644 peerbroker/dial.go delete mode 100644 peerbroker/dial_test.go delete mode 100644 peerbroker/listen.go delete mode 100644 peerbroker/listen_test.go delete mode 100644 peerbroker/proto/peerbroker.pb.go delete mode 100644 peerbroker/proto/peerbroker.proto delete mode 100644 peerbroker/proto/peerbroker_drpc.pb.go delete mode 100644 peerbroker/proxy.go delete mode 100644 peerbroker/proxy_test.go diff --git a/Makefile b/Makefile index 3c1a1d74aa038..aa651b1d262f5 100644 --- a/Makefile +++ b/Makefile @@ -117,7 +117,7 @@ endif fmt: fmt/prettier fmt/terraform fmt/shfmt .PHONY: fmt -gen: coderd/database/querier.go peerbroker/proto/peerbroker.pb.go provisionersdk/proto/provisioner.pb.go provisionerd/proto/provisionerd.pb.go site/src/api/typesGenerated.ts +gen: coderd/database/querier.go provisionersdk/proto/provisioner.pb.go provisionerd/proto/provisionerd.pb.go site/src/api/typesGenerated.ts .PHONY: gen install: site/out/index.html $(shell find . -not -path './vendor/*' -type f -name '*.go') go.mod go.sum $(shell find ./examples/templates) @@ -152,14 +152,6 @@ lint/shellcheck: $(shell shfmt -f .) shellcheck --external-sources $(shell shfmt -f .) .PHONY: lint/shellcheck -peerbroker/proto/peerbroker.pb.go: peerbroker/proto/peerbroker.proto - protoc \ - --go_out=. \ - --go_opt=paths=source_relative \ - --go-drpc_out=. \ - --go-drpc_opt=paths=source_relative \ - ./peerbroker/proto/peerbroker.proto - provisionerd/proto/provisionerd.pb.go: provisionerd/proto/provisionerd.proto protoc \ --go_out=. \ diff --git a/agent/agent.go b/agent/agent.go index 981c633a82e38..d03b534a8db5d 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -33,8 +33,6 @@ import ( "cdr.dev/slog" "github.com/coder/coder/agent/usershell" - "github.com/coder/coder/peer" - "github.com/coder/coder/peerbroker" "github.com/coder/coder/pty" "github.com/coder/coder/tailnet" "github.com/coder/retry" @@ -62,7 +60,6 @@ var ( type Options struct { CoordinatorDialer CoordinatorDialer - WebRTCDialer WebRTCDialer FetchMetadata FetchMetadata StatsReporter StatsReporter @@ -78,8 +75,6 @@ type Metadata struct { Directory string `json:"directory"` } -type WebRTCDialer func(ctx context.Context, logger slog.Logger) (*peerbroker.Listener, error) - // CoordinatorDialer is a function that constructs a new broker. // A dialer must be passed in to allow for reconnects. type CoordinatorDialer func(ctx context.Context) (net.Conn, error) @@ -93,7 +88,6 @@ func New(options Options) io.Closer { } ctx, cancelFunc := context.WithCancel(context.Background()) server := &agent{ - webrtcDialer: options.WebRTCDialer, reconnectingPTYTimeout: options.ReconnectingPTYTimeout, logger: options.Logger, closeCancel: cancelFunc, @@ -109,8 +103,7 @@ func New(options Options) io.Closer { } type agent struct { - webrtcDialer WebRTCDialer - logger slog.Logger + logger slog.Logger reconnectingPTYs sync.Map reconnectingPTYTimeout time.Duration @@ -171,9 +164,6 @@ func (a *agent) run(ctx context.Context) { } }() - if a.webrtcDialer != nil { - go a.runWebRTCNetworking(ctx) - } if metadata.DERPMap != nil { go a.runTailnet(ctx, metadata.DERPMap) } @@ -303,49 +293,6 @@ func (a *agent) runCoordinator(ctx context.Context) { } } -func (a *agent) runWebRTCNetworking(ctx context.Context) { - var peerListener *peerbroker.Listener - var err error - // An exponential back-off occurs when the connection is failing to dial. - // This is to prevent server spam in case of a coderd outage. - for retrier := retry.New(50*time.Millisecond, 10*time.Second); retrier.Wait(ctx); { - peerListener, err = a.webrtcDialer(ctx, a.logger) - if err != nil { - if errors.Is(err, context.Canceled) { - return - } - if a.isClosed() { - return - } - a.logger.Warn(context.Background(), "failed to dial", slog.Error(err)) - continue - } - a.logger.Info(context.Background(), "connected to webrtc broker") - break - } - select { - case <-ctx.Done(): - return - default: - } - - for { - conn, err := peerListener.Accept() - if err != nil { - if a.isClosed() { - return - } - a.logger.Debug(ctx, "peer listener accept exited; restarting connection", slog.Error(err)) - a.runWebRTCNetworking(ctx) - return - } - a.closeMutex.Lock() - a.connCloseWait.Add(1) - a.closeMutex.Unlock() - go a.handlePeerConn(ctx, conn) - } -} - func (a *agent) runStartupScript(ctx context.Context, script string) error { if script == "" { return nil @@ -378,74 +325,6 @@ func (a *agent) runStartupScript(ctx context.Context, script string) error { return nil } -func (a *agent) handlePeerConn(ctx context.Context, peerConn *peer.Conn) { - go func() { - select { - case <-a.closed: - case <-peerConn.Closed(): - } - _ = peerConn.Close() - a.connCloseWait.Done() - }() - for { - channel, err := peerConn.Accept(ctx) - if err != nil { - if errors.Is(err, peer.ErrClosed) || a.isClosed() { - return - } - a.logger.Debug(ctx, "accept channel from peer connection", slog.Error(err)) - return - } - - conn := channel.NetConn() - - switch channel.Protocol() { - case ProtocolSSH: - go a.sshServer.HandleConn(a.stats.wrapConn(conn)) - case ProtocolReconnectingPTY: - rawID := channel.Label() - // The ID format is referenced in conn.go. - // :: - idParts := strings.SplitN(rawID, ":", 4) - if len(idParts) != 4 { - a.logger.Warn(ctx, "client sent invalid id format", slog.F("raw-id", rawID)) - continue - } - id := idParts[0] - // Enforce a consistent format for IDs. - _, err := uuid.Parse(id) - if err != nil { - a.logger.Warn(ctx, "client sent reconnection token that isn't a uuid", slog.F("id", id), slog.Error(err)) - continue - } - // Parse the initial terminal dimensions. - height, err := strconv.Atoi(idParts[1]) - if err != nil { - a.logger.Warn(ctx, "client sent invalid height", slog.F("id", id), slog.F("height", idParts[1])) - continue - } - width, err := strconv.Atoi(idParts[2]) - if err != nil { - a.logger.Warn(ctx, "client sent invalid width", slog.F("id", id), slog.F("width", idParts[2])) - continue - } - go a.handleReconnectingPTY(ctx, reconnectingPTYInit{ - ID: id, - Height: uint16(height), - Width: uint16(width), - Command: idParts[3], - }, a.stats.wrapConn(conn)) - case ProtocolDial: - go a.handleDial(ctx, channel.Label(), a.stats.wrapConn(conn)) - default: - a.logger.Warn(ctx, "unhandled protocol from channel", - slog.F("protocol", channel.Protocol()), - slog.F("label", channel.Label()), - ) - } - } -} - func (a *agent) init(ctx context.Context) { a.logger.Info(ctx, "generating host key") // Clients' should ignore the host key when connecting. diff --git a/agent/agent_test.go b/agent/agent_test.go index 49f57214ab6eb..c7113969cebf7 100644 --- a/agent/agent_test.go +++ b/agent/agent_test.go @@ -19,12 +19,10 @@ import ( "time" "golang.org/x/xerrors" - "tailscale.com/tailcfg" scp "github.com/bramvdbogaerde/go-scp" "github.com/google/uuid" "github.com/pion/udp" - "github.com/pion/webrtc/v3" "github.com/pkg/sftp" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -36,10 +34,6 @@ import ( "cdr.dev/slog" "cdr.dev/slog/sloggers/slogtest" "github.com/coder/coder/agent" - "github.com/coder/coder/peer" - "github.com/coder/coder/peerbroker" - "github.com/coder/coder/peerbroker/proto" - "github.com/coder/coder/provisionersdk" "github.com/coder/coder/pty/ptytest" "github.com/coder/coder/tailnet" "github.com/coder/coder/tailnet/tailnettest" @@ -53,64 +47,49 @@ func TestMain(m *testing.M) { func TestAgent(t *testing.T) { t.Parallel() t.Run("Stats", func(t *testing.T) { - for _, tailscale := range []bool{true, false} { - t.Run(fmt.Sprintf("tailscale=%v", tailscale), func(t *testing.T) { - t.Parallel() + t.Parallel() - setupAgent := func(t *testing.T) (agent.Conn, <-chan *agent.Stats) { - var derpMap *tailcfg.DERPMap - if tailscale { - derpMap = tailnettest.RunDERPAndSTUN(t) - } - conn, stats := setupAgent(t, agent.Metadata{ - DERPMap: derpMap, - }, 0) - assert.Empty(t, <-stats) - return conn, stats - } + t.Run("SSH", func(t *testing.T) { + t.Parallel() + conn, stats := setupAgent(t, agent.Metadata{}, 0) + + sshClient, err := conn.SSHClient() + require.NoError(t, err) + defer sshClient.Close() + session, err := sshClient.NewSession() + require.NoError(t, err) + defer session.Close() - t.Run("SSH", func(t *testing.T) { - t.Parallel() - conn, stats := setupAgent(t) - - sshClient, err := conn.SSHClient() - require.NoError(t, err) - session, err := sshClient.NewSession() - require.NoError(t, err) - defer session.Close() - - assert.EqualValues(t, 1, (<-stats).NumConns) - assert.Greater(t, (<-stats).RxBytes, int64(0)) - assert.Greater(t, (<-stats).TxBytes, int64(0)) - }) - - t.Run("ReconnectingPTY", func(t *testing.T) { - t.Parallel() - - conn, stats := setupAgent(t) - - ptyConn, err := conn.ReconnectingPTY(uuid.NewString(), 128, 128, "/bin/bash") - require.NoError(t, err) - defer ptyConn.Close() - - data, err := json.Marshal(agent.ReconnectingPTYRequest{ - Data: "echo test\r\n", - }) - require.NoError(t, err) - _, err = ptyConn.Write(data) - require.NoError(t, err) - - var s *agent.Stats - require.Eventuallyf(t, func() bool { - var ok bool - s, ok = (<-stats) - return ok && s.NumConns > 0 && s.RxBytes > 0 && s.TxBytes > 0 - }, testutil.WaitLong, testutil.IntervalFast, - "never saw stats: %+v", s, - ) - }) + assert.EqualValues(t, 1, (<-stats).NumConns) + assert.Greater(t, (<-stats).RxBytes, int64(0)) + assert.Greater(t, (<-stats).TxBytes, int64(0)) + }) + + t.Run("ReconnectingPTY", func(t *testing.T) { + t.Parallel() + + conn, stats := setupAgent(t, agent.Metadata{}, 0) + + ptyConn, err := conn.ReconnectingPTY(uuid.NewString(), 128, 128, "/bin/bash") + require.NoError(t, err) + defer ptyConn.Close() + + data, err := json.Marshal(agent.ReconnectingPTYRequest{ + Data: "echo test\r\n", }) - } + require.NoError(t, err) + _, err = ptyConn.Write(data) + require.NoError(t, err) + + var s *agent.Stats + require.Eventuallyf(t, func() bool { + var ok bool + s, ok = (<-stats) + return ok && s.NumConns > 0 && s.RxBytes > 0 && s.TxBytes > 0 + }, testutil.WaitLong, testutil.IntervalFast, + "never saw stats: %+v", s, + ) + }) }) t.Run("SessionExec", func(t *testing.T) { @@ -234,6 +213,7 @@ func TestAgent(t *testing.T) { conn, _ := setupAgent(t, agent.Metadata{}, 0) sshClient, err := conn.SSHClient() require.NoError(t, err) + defer sshClient.Close() client, err := sftp.NewClient(sshClient) require.NoError(t, err) tempFile := filepath.Join(t.TempDir(), "sftp") @@ -251,6 +231,7 @@ func TestAgent(t *testing.T) { conn, _ := setupAgent(t, agent.Metadata{}, 0) sshClient, err := conn.SSHClient() require.NoError(t, err) + defer sshClient.Close() scpClient, err := scp.NewClientBySSH(sshClient) require.NoError(t, err) tempFile := filepath.Join(t.TempDir(), "scp") @@ -383,9 +364,7 @@ func TestAgent(t *testing.T) { t.Skip("ConPTY appears to be inconsistent on Windows.") } - conn, _ := setupAgent(t, agent.Metadata{ - DERPMap: tailnettest.RunDERPAndSTUN(t), - }, 0) + conn, _ := setupAgent(t, agent.Metadata{}, 0) id := uuid.NewString() netConn, err := conn.ReconnectingPTY(id, 100, 100, "/bin/bash") require.NoError(t, err) @@ -461,19 +440,6 @@ func TestAgent(t *testing.T) { return l }, }, - { - name: "Unix", - setup: func(t *testing.T) net.Listener { - if runtime.GOOS == "windows" { - t.Skip("Unix socket forwarding isn't supported on Windows") - } - - tmpDir := t.TempDir() - l, err := net.Listen("unix", filepath.Join(tmpDir, "test.sock")) - require.NoError(t, err, "create UDP listener") - return l - }, - }, } for _, c := range cases { @@ -495,8 +461,11 @@ func TestAgent(t *testing.T) { } }() - // Dial the listener over WebRTC twice and test out of order conn, _ := setupAgent(t, agent.Metadata{}, 0) + require.Eventually(t, func() bool { + _, err := conn.Ping() + return err == nil + }, testutil.WaitMedium, testutil.IntervalFast) conn1, err := conn.DialContext(context.Background(), l.Addr().Network(), l.Addr().String()) require.NoError(t, err) defer conn1.Close() @@ -505,36 +474,11 @@ func TestAgent(t *testing.T) { defer conn2.Close() testDial(t, conn2) testDial(t, conn1) + time.Sleep(150 * time.Millisecond) }) } }) - t.Run("DialError", func(t *testing.T) { - t.Parallel() - - if runtime.GOOS == "windows" { - // This test uses Unix listeners so we can very easily ensure that - // no other tests decide to listen on the same random port we - // picked. - t.Skip("this test is unsupported on Windows") - return - } - - tmpDir, err := os.MkdirTemp("", "coderd_agent_test_") - require.NoError(t, err, "create temp dir") - t.Cleanup(func() { - _ = os.RemoveAll(tmpDir) - }) - - // Try to dial the non-existent Unix socket over WebRTC - conn, _ := setupAgent(t, agent.Metadata{}, 0) - netConn, err := conn.DialContext(context.Background(), "unix", filepath.Join(tmpDir, "test.sock")) - require.Error(t, err) - require.ErrorContains(t, err, "remote dial error") - require.ErrorContains(t, err, "no such file") - require.Nil(t, netConn) - }) - t.Run("Tailnet", func(t *testing.T) { t.Parallel() derpMap := tailnettest.RunDERPAndSTUN(t) @@ -606,11 +550,12 @@ func (c closeFunc) Close() error { } func setupAgent(t *testing.T, metadata agent.Metadata, ptyTimeout time.Duration) ( - agent.Conn, + *agent.Conn, <-chan *agent.Stats, ) { - client, server := provisionersdk.TransportPipe() - tailscale := metadata.DERPMap != nil + if metadata.DERPMap == nil { + metadata.DERPMap = tailnettest.RunDERPAndSTUN(t) + } coordinator := tailnet.NewCoordinator() agentID := uuid.New() statsCh := make(chan *agent.Stats) @@ -618,10 +563,6 @@ func setupAgent(t *testing.T, metadata agent.Metadata, ptyTimeout time.Duration) FetchMetadata: func(ctx context.Context) (agent.Metadata, error) { return metadata, nil }, - WebRTCDialer: func(ctx context.Context, logger slog.Logger) (*peerbroker.Listener, error) { - listener, err := peerbroker.Listen(server, nil) - return listener, err - }, CoordinatorDialer: func(ctx context.Context) (net.Conn, error) { clientConn, serverConn := net.Pipe() t.Cleanup(func() { @@ -667,46 +608,27 @@ func setupAgent(t *testing.T, metadata agent.Metadata, ptyTimeout time.Duration) }, }) t.Cleanup(func() { - _ = client.Close() - _ = server.Close() _ = closer.Close() }) - api := proto.NewDRPCPeerBrokerClient(provisionersdk.Conn(client)) - stream, err := api.NegotiateConnection(context.Background()) - assert.NoError(t, err) - if tailscale { - conn, err := tailnet.NewConn(&tailnet.Options{ - Addresses: []netip.Prefix{netip.PrefixFrom(tailnet.IP(), 128)}, - DERPMap: metadata.DERPMap, - Logger: slogtest.Make(t, nil).Named("tailnet"), - }) - require.NoError(t, err) - clientConn, serverConn := net.Pipe() - t.Cleanup(func() { - _ = clientConn.Close() - _ = serverConn.Close() - _ = conn.Close() - }) - go coordinator.ServeClient(serverConn, uuid.New(), agentID) - sendNode, _ := tailnet.ServeCoordinator(clientConn, func(node []*tailnet.Node) error { - return conn.UpdateNodes(node) - }) - conn.SetNodeCallback(sendNode) - return &agent.TailnetConn{ - Conn: conn, - }, statsCh - } - conn, err := peerbroker.Dial(stream, []webrtc.ICEServer{}, &peer.ConnOptions{ - Logger: slogtest.Make(t, nil), + conn, err := tailnet.NewConn(&tailnet.Options{ + Addresses: []netip.Prefix{netip.PrefixFrom(tailnet.IP(), 128)}, + DERPMap: metadata.DERPMap, + Logger: slogtest.Make(t, nil).Named("tailnet"), }) require.NoError(t, err) + clientConn, serverConn := net.Pipe() t.Cleanup(func() { + _ = clientConn.Close() + _ = serverConn.Close() _ = conn.Close() }) - - return &agent.WebRTCConn{ - Negotiator: api, - Conn: conn, + go coordinator.ServeClient(serverConn, uuid.New(), agentID) + sendNode, _ := tailnet.ServeCoordinator(clientConn, func(node []*tailnet.Node) error { + return conn.UpdateNodes(node) + }) + conn.SetNodeCallback(sendNode) + return &agent.Conn{ + Conn: conn, }, statsCh } diff --git a/agent/conn.go b/agent/conn.go index 0e95e97e21254..15d65011ac82c 100644 --- a/agent/conn.go +++ b/agent/conn.go @@ -4,13 +4,9 @@ import ( "context" "encoding/binary" "encoding/json" - "fmt" - "io" "net" "net/netip" - "net/url" "strconv" - "strings" "time" "golang.org/x/crypto/ssh" @@ -18,8 +14,6 @@ import ( "tailscale.com/ipn/ipnstate" "tailscale.com/tailcfg" - "github.com/coder/coder/peer" - "github.com/coder/coder/peerbroker/proto" "github.com/coder/coder/tailnet" ) @@ -31,118 +25,12 @@ type ReconnectingPTYRequest struct { Width uint16 `json:"width"` } -// Conn is a temporary interface while we switch from WebRTC to Wireguard networking. -type Conn interface { - io.Closer - Closed() <-chan struct{} - Ping() (time.Duration, error) - CloseWithError(err error) error - ReconnectingPTY(id string, height, width uint16, command string) (net.Conn, error) - SSH() (net.Conn, error) - SSHClient() (*ssh.Client, error) - DialContext(ctx context.Context, network string, addr string) (net.Conn, error) -} - -// Conn wraps a peer connection with helper functions to -// communicate with the agent. -type WebRTCConn struct { - // Negotiator is responsible for exchanging messages. - Negotiator proto.DRPCPeerBrokerClient - - *peer.Conn -} - -// ReconnectingPTY returns a connection serving a TTY that can -// be reconnected to via ID. -// -// The command is optional and defaults to start a shell. -func (c *WebRTCConn) ReconnectingPTY(id string, height, width uint16, command string) (net.Conn, error) { - channel, err := c.CreateChannel(context.Background(), fmt.Sprintf("%s:%d:%d:%s", id, height, width, command), &peer.ChannelOptions{ - Protocol: ProtocolReconnectingPTY, - }) - if err != nil { - return nil, xerrors.Errorf("pty: %w", err) - } - return channel.NetConn(), nil -} - -// SSH dials the built-in SSH server. -func (c *WebRTCConn) SSH() (net.Conn, error) { - channel, err := c.CreateChannel(context.Background(), "ssh", &peer.ChannelOptions{ - Protocol: ProtocolSSH, - }) - if err != nil { - return nil, xerrors.Errorf("dial: %w", err) - } - return channel.NetConn(), nil -} - -// SSHClient calls SSH to create a client that uses a weak cipher -// for high throughput. -func (c *WebRTCConn) SSHClient() (*ssh.Client, error) { - netConn, err := c.SSH() - if err != nil { - return nil, xerrors.Errorf("ssh: %w", err) - } - sshConn, channels, requests, err := ssh.NewClientConn(netConn, "localhost:22", &ssh.ClientConfig{ - // SSH host validation isn't helpful, because obtaining a peer - // connection already signifies user-intent to dial a workspace. - // #nosec - HostKeyCallback: ssh.InsecureIgnoreHostKey(), - }) - if err != nil { - return nil, xerrors.Errorf("ssh conn: %w", err) - } - return ssh.NewClient(sshConn, channels, requests), nil -} - -// DialContext dials an arbitrary protocol+address from inside the workspace and -// proxies it through the provided net.Conn. -func (c *WebRTCConn) DialContext(ctx context.Context, network string, addr string) (net.Conn, error) { - u := &url.URL{ - Scheme: network, - } - if strings.HasPrefix(network, "unix") { - u.Path = addr - } else { - u.Host = addr - } - - channel, err := c.CreateChannel(ctx, u.String(), &peer.ChannelOptions{ - Protocol: ProtocolDial, - Unordered: strings.HasPrefix(network, "udp"), - }) - if err != nil { - return nil, xerrors.Errorf("create datachannel: %w", err) - } - - // The first message written from the other side is a JSON payload - // containing the dial error. - dec := json.NewDecoder(channel) - var res dialResponse - err = dec.Decode(&res) - if err != nil { - return nil, xerrors.Errorf("decode agent dial response: %w", err) - } - if res.Error != "" { - _ = channel.Close() - return nil, xerrors.Errorf("remote dial error: %v", res.Error) - } - - return channel.NetConn(), nil -} - -func (c *WebRTCConn) Close() error { - _ = c.Negotiator.DRPCConn().Close() - return c.Conn.Close() -} - -type TailnetConn struct { +type Conn struct { *tailnet.Conn CloseFunc func() } -func (c *TailnetConn) Ping() (time.Duration, error) { +func (c *Conn) Ping() (time.Duration, error) { errCh := make(chan error, 1) durCh := make(chan time.Duration, 1) c.Conn.Ping(tailnetIP, tailcfg.PingICMP, func(pr *ipnstate.PingResult) { @@ -160,11 +48,11 @@ func (c *TailnetConn) Ping() (time.Duration, error) { } } -func (c *TailnetConn) CloseWithError(_ error) error { +func (c *Conn) CloseWithError(_ error) error { return c.Close() } -func (c *TailnetConn) Close() error { +func (c *Conn) Close() error { if c.CloseFunc != nil { c.CloseFunc() } @@ -178,7 +66,7 @@ type reconnectingPTYInit struct { Command string } -func (c *TailnetConn) ReconnectingPTY(id string, height, width uint16, command string) (net.Conn, error) { +func (c *Conn) ReconnectingPTY(id string, height, width uint16, command string) (net.Conn, error) { conn, err := c.DialContextTCP(context.Background(), netip.AddrPortFrom(tailnetIP, uint16(tailnetReconnectingPTYPort))) if err != nil { return nil, err @@ -204,13 +92,13 @@ func (c *TailnetConn) ReconnectingPTY(id string, height, width uint16, command s return conn, nil } -func (c *TailnetConn) SSH() (net.Conn, error) { +func (c *Conn) SSH() (net.Conn, error) { return c.DialContextTCP(context.Background(), netip.AddrPortFrom(tailnetIP, uint16(tailnetSSHPort))) } // SSHClient calls SSH to create a client that uses a weak cipher // for high throughput. -func (c *TailnetConn) SSHClient() (*ssh.Client, error) { +func (c *Conn) SSHClient() (*ssh.Client, error) { netConn, err := c.SSH() if err != nil { return nil, xerrors.Errorf("ssh: %w", err) @@ -227,7 +115,10 @@ func (c *TailnetConn) SSHClient() (*ssh.Client, error) { return ssh.NewClient(sshConn, channels, requests), nil } -func (c *TailnetConn) DialContext(ctx context.Context, network string, addr string) (net.Conn, error) { +func (c *Conn) DialContext(ctx context.Context, network string, addr string) (net.Conn, error) { + if network == "unix" { + return nil, xerrors.New("network must be tcp or udp") + } _, rawPort, _ := net.SplitHostPort(addr) port, _ := strconv.Atoi(rawPort) ipp := netip.AddrPortFrom(tailnetIP, uint16(port)) diff --git a/cli/agent.go b/cli/agent.go index 2c6fdef4a03ce..837d30eb37176 100644 --- a/cli/agent.go +++ b/cli/agent.go @@ -32,7 +32,6 @@ func workspaceAgent() *cobra.Command { pprofEnabled bool pprofAddress string noReap bool - wireguard bool ) cmd := &cobra.Command{ Use: "agent", @@ -184,7 +183,6 @@ func workspaceAgent() *cobra.Command { closer := agent.New(agent.Options{ FetchMetadata: client.WorkspaceAgentMetadata, - WebRTCDialer: client.ListenWorkspaceAgent, Logger: logger, EnvironmentVariables: map[string]string{ // Override the "CODER_AGENT_TOKEN" variable in all @@ -203,6 +201,5 @@ func workspaceAgent() *cobra.Command { cliflag.BoolVarP(cmd.Flags(), &pprofEnabled, "pprof-enable", "", "CODER_AGENT_PPROF_ENABLE", false, "Enable serving pprof metrics on the address defined by --pprof-address.") cliflag.BoolVarP(cmd.Flags(), &noReap, "no-reap", "", "", false, "Do not start a process reaper.") cliflag.StringVarP(cmd.Flags(), &pprofAddress, "pprof-address", "", "CODER_AGENT_PPROF_ADDRESS", "127.0.0.1:6060", "The address to serve pprof.") - cliflag.BoolVarP(cmd.Flags(), &wireguard, "wireguard", "", "CODER_AGENT_WIREGUARD", true, "Whether to start the Wireguard interface.") return cmd } diff --git a/cli/agent_test.go b/cli/agent_test.go index 39662a1cde89f..6dd8849b74d79 100644 --- a/cli/agent_test.go +++ b/cli/agent_test.go @@ -7,10 +7,13 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "cdr.dev/slog" + "github.com/coder/coder/cli/clitest" "github.com/coder/coder/coderd/coderdtest" "github.com/coder/coder/provisioner/echo" "github.com/coder/coder/provisionersdk/proto" + "github.com/coder/coder/testutil" ) func TestWorkspaceAgent(t *testing.T) { @@ -47,7 +50,7 @@ func TestWorkspaceAgent(t *testing.T) { workspace := coderdtest.CreateWorkspace(t, client, user.OrganizationID, template.ID) coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID) - cmd, _ := clitest.New(t, "agent", "--auth", "azure-instance-identity", "--agent-url", client.URL.String(), "--wireguard=false") + cmd, _ := clitest.New(t, "agent", "--auth", "azure-instance-identity", "--agent-url", client.URL.String()) ctx, cancelFunc := context.WithCancel(context.Background()) defer cancelFunc() errC := make(chan error) @@ -63,11 +66,13 @@ func TestWorkspaceAgent(t *testing.T) { if assert.NotEmpty(t, resources) && assert.NotEmpty(t, resources[0].Agents) { assert.NotEmpty(t, resources[0].Agents[0].Version) } - dialer, err := client.DialWorkspaceAgent(ctx, resources[0].Agents[0].ID, nil) + dialer, err := client.DialWorkspaceAgentTailnet(ctx, slog.Logger{}, resources[0].Agents[0].ID) require.NoError(t, err) defer dialer.Close() - _, err = dialer.Ping() - require.NoError(t, err) + require.Eventually(t, func() bool { + _, err := dialer.Ping() + return err == nil + }, testutil.WaitMedium, testutil.IntervalFast) cancelFunc() err = <-errC require.NoError(t, err) @@ -105,7 +110,7 @@ func TestWorkspaceAgent(t *testing.T) { workspace := coderdtest.CreateWorkspace(t, client, user.OrganizationID, template.ID) coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID) - cmd, _ := clitest.New(t, "agent", "--auth", "aws-instance-identity", "--agent-url", client.URL.String(), "--wireguard=false") + cmd, _ := clitest.New(t, "agent", "--auth", "aws-instance-identity", "--agent-url", client.URL.String()) ctx, cancelFunc := context.WithCancel(context.Background()) defer cancelFunc() errC := make(chan error) @@ -121,11 +126,13 @@ func TestWorkspaceAgent(t *testing.T) { if assert.NotEmpty(t, resources) && assert.NotEmpty(t, resources[0].Agents) { assert.NotEmpty(t, resources[0].Agents[0].Version) } - dialer, err := client.DialWorkspaceAgent(ctx, resources[0].Agents[0].ID, nil) + dialer, err := client.DialWorkspaceAgentTailnet(ctx, slog.Logger{}, resources[0].Agents[0].ID) require.NoError(t, err) defer dialer.Close() - _, err = dialer.Ping() - require.NoError(t, err) + require.Eventually(t, func() bool { + _, err := dialer.Ping() + return err == nil + }, testutil.WaitMedium, testutil.IntervalFast) cancelFunc() err = <-errC require.NoError(t, err) @@ -163,7 +170,7 @@ func TestWorkspaceAgent(t *testing.T) { workspace := coderdtest.CreateWorkspace(t, client, user.OrganizationID, template.ID) coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID) - cmd, _ := clitest.New(t, "agent", "--auth", "google-instance-identity", "--agent-url", client.URL.String(), "--wireguard=false") + cmd, _ := clitest.New(t, "agent", "--auth", "google-instance-identity", "--agent-url", client.URL.String()) ctx, cancelFunc := context.WithCancel(context.Background()) defer cancelFunc() errC := make(chan error) @@ -179,11 +186,13 @@ func TestWorkspaceAgent(t *testing.T) { if assert.NotEmpty(t, resources) && assert.NotEmpty(t, resources[0].Agents) { assert.NotEmpty(t, resources[0].Agents[0].Version) } - dialer, err := client.DialWorkspaceAgent(ctx, resources[0].Agents[0].ID, nil) + dialer, err := client.DialWorkspaceAgentTailnet(ctx, slog.Logger{}, resources[0].Agents[0].ID) require.NoError(t, err) defer dialer.Close() - _, err = dialer.Ping() - require.NoError(t, err) + require.Eventually(t, func() bool { + _, err := dialer.Ping() + return err == nil + }, testutil.WaitMedium, testutil.IntervalFast) cancelFunc() err = <-errC require.NoError(t, err) diff --git a/cli/configssh.go b/cli/configssh.go index 79eb280559697..dc2ef6bbd6d38 100644 --- a/cli/configssh.go +++ b/cli/configssh.go @@ -139,7 +139,6 @@ func configSSH() *cobra.Command { usePreviousOpts bool dryRun bool skipProxyCommand bool - wireguard bool ) cmd := &cobra.Command{ Annotations: workspaceCommand, @@ -289,15 +288,11 @@ func configSSH() *cobra.Command { "\tLogLevel ERROR", ) if !skipProxyCommand { - wgArg := "" - if wireguard { - wgArg = "--wireguard " - } configOptions = append( configOptions, fmt.Sprintf( - "\tProxyCommand %s --global-config %s ssh %s--stdio %s", - escapedCoderBinary, escapedGlobalConfig, wgArg, hostname, + "\tProxyCommand %s --global-config %s ssh --stdio %s", + escapedCoderBinary, escapedGlobalConfig, hostname, ), ) } @@ -374,9 +369,6 @@ func configSSH() *cobra.Command { cmd.Flags().BoolVarP(&skipProxyCommand, "skip-proxy-command", "", false, "Specifies whether the ProxyCommand option should be skipped. Useful for testing.") _ = cmd.Flags().MarkHidden("skip-proxy-command") cliflag.BoolVarP(cmd.Flags(), &usePreviousOpts, "use-previous-options", "", "CODER_SSH_USE_PREVIOUS_OPTIONS", false, "Specifies whether or not to keep options from previous run of config-ssh.") - cliflag.BoolVarP(cmd.Flags(), &wireguard, "wireguard", "", "CODER_CONFIG_SSH_WIREGUARD", false, "Whether to use Wireguard for SSH tunneling.") - _ = cmd.Flags().MarkHidden("wireguard") - cliui.AllowSkipPrompt(cmd) return cmd diff --git a/cli/configssh_test.go b/cli/configssh_test.go index e91d46c03aaeb..6aa70678e6b26 100644 --- a/cli/configssh_test.go +++ b/cli/configssh_test.go @@ -18,6 +18,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "cdr.dev/slog" "cdr.dev/slog/sloggers/slogtest" "github.com/coder/coder/agent" @@ -106,15 +107,14 @@ func TestConfigSSH(t *testing.T) { agentClient.SessionToken = authToken agentCloser := agent.New(agent.Options{ FetchMetadata: agentClient.WorkspaceAgentMetadata, - WebRTCDialer: agentClient.ListenWorkspaceAgent, - CoordinatorDialer: client.ListenWorkspaceAgentTailnet, + CoordinatorDialer: agentClient.ListenWorkspaceAgentTailnet, Logger: slogtest.Make(t, nil).Named("agent"), }) defer func() { _ = agentCloser.Close() }() resources := coderdtest.AwaitWorkspaceAgents(t, client, workspace.LatestBuild.ID) - agentConn, err := client.DialWorkspaceAgent(context.Background(), resources[0].Agents[0].ID, nil) + agentConn, err := client.DialWorkspaceAgentTailnet(context.Background(), slog.Logger{}, resources[0].Agents[0].ID) require.NoError(t, err) defer agentConn.Close() diff --git a/cli/gitssh_test.go b/cli/gitssh_test.go index cdbbe6bd1f4ff..0aced5c2889a5 100644 --- a/cli/gitssh_test.go +++ b/cli/gitssh_test.go @@ -12,11 +12,14 @@ import ( "github.com/stretchr/testify/require" gossh "golang.org/x/crypto/ssh" + "cdr.dev/slog" + "github.com/coder/coder/cli/clitest" "github.com/coder/coder/coderd/coderdtest" "github.com/coder/coder/codersdk" "github.com/coder/coder/provisioner/echo" "github.com/coder/coder/provisionersdk/proto" + "github.com/coder/coder/testutil" ) func TestGitSSH(t *testing.T) { @@ -59,7 +62,7 @@ func TestGitSSH(t *testing.T) { coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID) // start workspace agent - cmd, root := clitest.New(t, "agent", "--agent-token", agentToken, "--agent-url", client.URL.String(), "--wireguard=false") + cmd, root := clitest.New(t, "agent", "--agent-token", agentToken, "--agent-url", client.URL.String()) agentClient := client clitest.SetupConfig(t, agentClient, root) ctx, cancelFunc := context.WithCancel(context.Background()) @@ -72,11 +75,13 @@ func TestGitSSH(t *testing.T) { coderdtest.AwaitWorkspaceAgents(t, client, workspace.LatestBuild.ID) resources, err := client.WorkspaceResourcesByBuild(context.Background(), workspace.LatestBuild.ID) require.NoError(t, err) - dialer, err := client.DialWorkspaceAgent(context.Background(), resources[0].Agents[0].ID, nil) + dialer, err := client.DialWorkspaceAgentTailnet(context.Background(), slog.Logger{}, resources[0].Agents[0].ID) require.NoError(t, err) defer dialer.Close() - _, err = dialer.Ping() - require.NoError(t, err) + require.Eventually(t, func() bool { + _, err := dialer.Ping() + return err == nil + }, testutil.WaitMedium, testutil.IntervalFast) // start ssh server l, err := net.Listen("tcp", "localhost:0") diff --git a/cli/portforward.go b/cli/portforward.go index 3b78fa8d11a65..7e8ea110c53fc 100644 --- a/cli/portforward.go +++ b/cli/portforward.go @@ -12,11 +12,12 @@ import ( "sync" "syscall" - "cdr.dev/slog" "github.com/pion/udp" "github.com/spf13/cobra" "golang.org/x/xerrors" + "cdr.dev/slog" + "github.com/coder/coder/agent" "github.com/coder/coder/cli/cliui" "github.com/coder/coder/codersdk" @@ -27,7 +28,6 @@ func portForward() *cobra.Command { tcpForwards []string // : udpForwards []string // : unixForwards []string // : OR : - wireguard bool ) cmd := &cobra.Command{ Use: "port-forward ", @@ -101,12 +101,7 @@ func portForward() *cobra.Command { return xerrors.Errorf("await agent: %w", err) } - var conn agent.Conn - if !wireguard { - conn, err = client.DialWorkspaceAgent(ctx, workspaceAgent.ID, nil) - } else { - conn, err = client.DialWorkspaceAgentTailnet(ctx, slog.Logger{}, workspaceAgent.ID) - } + conn, err := client.DialWorkspaceAgentTailnet(ctx, slog.Logger{}, workspaceAgent.ID) if err != nil { return err } @@ -166,12 +161,10 @@ func portForward() *cobra.Command { cmd.Flags().StringArrayVarP(&tcpForwards, "tcp", "p", []string{}, "Forward a TCP port from the workspace to the local machine") cmd.Flags().StringArrayVar(&udpForwards, "udp", []string{}, "Forward a UDP port from the workspace to the local machine. The UDP connection has TCP-like semantics to support stateful UDP protocols") cmd.Flags().StringArrayVar(&unixForwards, "unix", []string{}, "Forward a Unix socket in the workspace to a local Unix socket or TCP port") - cmd.Flags().BoolVarP(&wireguard, "wireguard", "", false, "Specifies whether to use wireguard networking or not.") - _ = cmd.Flags().MarkHidden("wireguard") return cmd } -func listenAndPortForward(ctx context.Context, cmd *cobra.Command, conn agent.Conn, wg *sync.WaitGroup, spec portForwardSpec) (net.Listener, error) { +func listenAndPortForward(ctx context.Context, cmd *cobra.Command, conn *agent.Conn, wg *sync.WaitGroup, spec portForwardSpec) (net.Listener, error) { _, _ = fmt.Fprintf(cmd.OutOrStderr(), "Forwarding '%v://%v' locally to '%v://%v' in the workspace\n", spec.listenNetwork, spec.listenAddress, spec.dialNetwork, spec.dialAddress) var ( diff --git a/cli/portforward_test.go b/cli/portforward_test.go index d98fbec7920c0..b642faea22193 100644 --- a/cli/portforward_test.go +++ b/cli/portforward_test.go @@ -6,8 +6,6 @@ import ( "fmt" "io" "net" - "path/filepath" - "runtime" "strings" "sync" "testing" @@ -110,26 +108,6 @@ func TestPortForward(t *testing.T) { return l.Addr().String(), port }, }, - { - name: "Unix", - network: "unix", - flag: "--unix=%v:%v", - setupRemote: func(t *testing.T) net.Listener { - if runtime.GOOS == "windows" { - t.Skip("Unix socket forwarding isn't supported on Windows") - } - - tmpDir := t.TempDir() - l, err := net.Listen("unix", filepath.Join(tmpDir, "test.sock")) - require.NoError(t, err, "create UDP listener") - return l - }, - setupLocal: func(t *testing.T) (string, string) { - tmpDir := t.TempDir() - path := filepath.Join(tmpDir, "test.sock") - return path, path - }, - }, } // Setup agent once to be shared between test-cases (avoid expensive @@ -234,74 +212,16 @@ func TestPortForward(t *testing.T) { }) } - // Test doing a TCP -> Unix forward. - //nolint:paralleltest - t.Run("TCP2Unix", func(t *testing.T) { - var ( - // Find the TCP and Unix cases so we can use their setupLocal and - // setupRemote methods respectively. - tcpCase = cases[0] - unixCase = cases[2] - - // Setup remote Unix listener. - p1 = setupTestListener(t, unixCase.setupRemote(t)) - ) - - // Create a flag that forwards from local TCP to Unix listener 1. - // Notably this is a --unix flag. - localAddress, localFlag := tcpCase.setupLocal(t) - flag := fmt.Sprintf(unixCase.flag, localFlag, p1) - - // Launch port-forward in a goroutine so we can start dialing - // the "local" listener. - cmd, root := clitest.New(t, "port-forward", workspace.Name, flag) - clitest.SetupConfig(t, client, root) - buf := newThreadSafeBuffer() - cmd.SetOut(buf) - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - errC := make(chan error) - go func() { - errC <- cmd.ExecuteContext(ctx) - }() - waitForPortForwardReady(t, buf) - - t.Parallel() // Port is reserved, enable parallel execution. - - // Open two connections simultaneously and test them out of - // sync. - d := net.Dialer{Timeout: testutil.WaitShort} - c1, err := d.DialContext(ctx, tcpCase.network, localAddress) - require.NoError(t, err, "open connection 1 to 'local' listener") - defer c1.Close() - c2, err := d.DialContext(ctx, tcpCase.network, localAddress) - require.NoError(t, err, "open connection 2 to 'local' listener") - defer c2.Close() - testDial(t, c2) - testDial(t, c1) - - cancel() - err = <-errC - require.ErrorIs(t, err, context.Canceled) - }) - - // Test doing TCP, UDP and Unix at the same time. + // Test doing TCP and UDP at the same time. //nolint:paralleltest t.Run("All", func(t *testing.T) { var ( - // These aren't fixed size because we exclude Unix on Windows. dials = []addr{} flags = []string{} ) // Start listeners and populate arrays with the cases. for _, c := range cases { - if strings.HasPrefix(c.network, "unix") && runtime.GOOS == "windows" { - // Unix isn't supported on Windows, but we can still - // test other protocols together. - continue - } - p := setupTestListener(t, c.setupRemote(t)) localAddress, localFlag := c.setupLocal(t) @@ -391,7 +311,7 @@ func runAgent(t *testing.T, client *codersdk.Client, userID uuid.UUID) ([]coders coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID) // Start workspace agent in a goroutine - cmd, root := clitest.New(t, "agent", "--agent-token", agentToken, "--agent-url", client.URL.String(), "--wireguard=false") + cmd, root := clitest.New(t, "agent", "--agent-token", agentToken, "--agent-url", client.URL.String()) clitest.SetupConfig(t, client, root) errC := make(chan error) agentCtx, agentCancel := context.WithCancel(ctx) @@ -444,11 +364,9 @@ func setupTestListener(t *testing.T, l net.Listener) string { }() addr := l.Addr().String() - if !strings.HasPrefix(l.Addr().Network(), "unix") { - _, port, err := net.SplitHostPort(addr) - require.NoErrorf(t, err, "split non-Unix listen path %q", addr) - addr = port - } + _, port, err := net.SplitHostPort(addr) + require.NoErrorf(t, err, "split non-Unix listen path %q", addr) + addr = port return addr } diff --git a/cli/server.go b/cli/server.go index f1c654f9ac629..2adc2f557a91a 100644 --- a/cli/server.go +++ b/cli/server.go @@ -28,8 +28,6 @@ import ( embeddedpostgres "github.com/fergusstrange/embedded-postgres" "github.com/google/go-github/v43/github" "github.com/google/uuid" - "github.com/pion/turn/v2" - "github.com/pion/webrtc/v3" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promhttp" "github.com/spf13/afero" @@ -58,7 +56,6 @@ import ( "github.com/coder/coder/coderd/prometheusmetrics" "github.com/coder/coder/coderd/telemetry" "github.com/coder/coder/coderd/tracing" - "github.com/coder/coder/coderd/turnconn" "github.com/coder/coder/codersdk" "github.com/coder/coder/cryptorand" "github.com/coder/coder/provisioner/echo" @@ -111,9 +108,7 @@ func Server(newAPI func(*coderd.Options) *coderd.API) *cobra.Command { tlsEnable bool tlsKeyFile string tlsMinVersion string - turnRelayAddress string tunnel bool - stunServers []string trace bool secureAuthCookie bool sshKeygenAlgorithmRaw string @@ -292,22 +287,6 @@ func Server(newAPI func(*coderd.Options) *coderd.API) *cobra.Command { return xerrors.Errorf("parse ssh keygen algorithm %s: %w", sshKeygenAlgorithmRaw, err) } - turnServer, err := turnconn.New(&turn.RelayAddressGeneratorStatic{ - RelayAddress: net.ParseIP(turnRelayAddress), - Address: turnRelayAddress, - }) - if err != nil { - return xerrors.Errorf("create turn server: %w", err) - } - defer turnServer.Close() - - iceServers := make([]webrtc.ICEServer, 0) - for _, stunServer := range stunServers { - iceServers = append(iceServers, webrtc.ICEServer{ - URLs: []string{stunServer}, - }) - } - // Validate provided auto-import templates. var ( validatedAutoImportTemplates = make([]coderd.AutoImportTemplate, len(autoImportTemplates)) @@ -352,7 +331,6 @@ func Server(newAPI func(*coderd.Options) *coderd.API) *cobra.Command { options := &coderd.Options{ AccessURL: accessURLParsed, - ICEServers: iceServers, Logger: logger.Named("coderd"), Database: databasefake.New(), DERPMap: derpMap, @@ -361,8 +339,6 @@ func Server(newAPI func(*coderd.Options) *coderd.API) *cobra.Command { GoogleTokenValidator: googleTokenValidator, SecureAuthCookie: secureAuthCookie, SSHKeygenAlgorithm: sshKeygenAlgorithm, - TailscaleEnable: tailscaleEnable, - TURNServer: turnServer, TracerProvider: tracerProvider, Telemetry: telemetry.NewNoop(), AutoImportTemplates: validatedAutoImportTemplates, @@ -470,7 +446,7 @@ func Server(newAPI func(*coderd.Options) *coderd.API) *cobra.Command { OIDCAuth: oidcClientID != "", OIDCIssuerURL: oidcIssuerURL, Prometheus: promEnabled, - STUN: len(stunServers) != 0, + STUN: len(derpServerSTUNAddrs) != 0, Tunnel: tunnel, }) if err != nil { @@ -831,12 +807,7 @@ func Server(newAPI func(*coderd.Options) *coderd.API) *cobra.Command { `Specifies the minimum supported version of TLS. Accepted values are "tls10", "tls11", "tls12" or "tls13"`) cliflag.BoolVarP(root.Flags(), &tunnel, "tunnel", "", "CODER_TUNNEL", false, "Workspaces must be able to reach the `access-url`. This overrides your access URL with a public access URL that tunnels your Coder deployment.") - cliflag.StringArrayVarP(root.Flags(), &stunServers, "stun-server", "", "CODER_STUN_SERVERS", []string{ - "stun:stun.l.google.com:19302", - }, "Specify URLs for STUN servers to enable P2P connections.") cliflag.BoolVarP(root.Flags(), &trace, "trace", "", "CODER_TRACE", false, "Specifies if application tracing data is collected") - cliflag.StringVarP(root.Flags(), &turnRelayAddress, "turn-relay-address", "", "CODER_TURN_RELAY_ADDRESS", "127.0.0.1", - "Specifies the address to bind TURN connections.") cliflag.BoolVarP(root.Flags(), &secureAuthCookie, "secure-auth-cookie", "", "CODER_SECURE_AUTH_COOKIE", false, "Specifies if the 'Secure' property is set on browser session cookies") cliflag.StringVarP(root.Flags(), &sshKeygenAlgorithmRaw, "ssh-keygen-algorithm", "", "CODER_SSH_KEYGEN_ALGORITHM", "ed25519", "Specifies the algorithm to use for generating ssh keys. "+ `Accepted values are "ed25519", "ecdsa", or "rsa4096"`) diff --git a/cli/ssh.go b/cli/ssh.go index 7f23cce706c20..97f223fdf1d70 100644 --- a/cli/ssh.go +++ b/cli/ssh.go @@ -22,7 +22,6 @@ import ( "cdr.dev/slog" - "github.com/coder/coder/agent" "github.com/coder/coder/cli/cliflag" "github.com/coder/coder/cli/cliui" "github.com/coder/coder/coderd/autobuild/notify" @@ -43,7 +42,6 @@ func ssh() *cobra.Command { forwardAgent bool identityAgent string wsPollInterval time.Duration - wireguard bool ) cmd := &cobra.Command{ Annotations: workspaceCommand, @@ -88,12 +86,7 @@ func ssh() *cobra.Command { return xerrors.Errorf("await agent: %w", err) } - var conn agent.Conn - if !wireguard { - conn, err = client.DialWorkspaceAgent(ctx, workspaceAgent.ID, nil) - } else { - conn, err = client.DialWorkspaceAgentTailnet(ctx, slog.Logger{}, workspaceAgent.ID) - } + conn, err := client.DialWorkspaceAgentTailnet(ctx, slog.Logger{}, workspaceAgent.ID) if err != nil { return err } @@ -214,9 +207,6 @@ func ssh() *cobra.Command { cliflag.BoolVarP(cmd.Flags(), &forwardAgent, "forward-agent", "A", "CODER_SSH_FORWARD_AGENT", false, "Specifies whether to forward the SSH agent specified in $SSH_AUTH_SOCK") cliflag.StringVarP(cmd.Flags(), &identityAgent, "identity-agent", "", "CODER_SSH_IDENTITY_AGENT", "", "Specifies which identity agent to use (overrides $SSH_AUTH_SOCK), forward agent must also be enabled") cliflag.DurationVarP(cmd.Flags(), &wsPollInterval, "workspace-poll-interval", "", "CODER_WORKSPACE_POLL_INTERVAL", workspacePollInterval, "Specifies how often to poll for workspace automated shutdown.") - cliflag.BoolVarP(cmd.Flags(), &wireguard, "wireguard", "", "CODER_SSH_WIREGUARD", false, "Whether to use Wireguard for SSH tunneling.") - _ = cmd.Flags().MarkHidden("wireguard") - return cmd } diff --git a/cli/ssh_test.go b/cli/ssh_test.go index 410c8429705ac..226909bfdc645 100644 --- a/cli/ssh_test.go +++ b/cli/ssh_test.go @@ -90,7 +90,6 @@ func TestSSH(t *testing.T) { agentClient.SessionToken = agentToken agentCloser := agent.New(agent.Options{ FetchMetadata: agentClient.WorkspaceAgentMetadata, - WebRTCDialer: agentClient.ListenWorkspaceAgent, CoordinatorDialer: agentClient.ListenWorkspaceAgentTailnet, Logger: slogtest.Make(t, nil).Named("agent"), }) @@ -112,7 +111,6 @@ func TestSSH(t *testing.T) { agentClient.SessionToken = agentToken agentCloser := agent.New(agent.Options{ FetchMetadata: agentClient.WorkspaceAgentMetadata, - WebRTCDialer: agentClient.ListenWorkspaceAgent, CoordinatorDialer: agentClient.ListenWorkspaceAgentTailnet, Logger: slogtest.Make(t, nil).Named("agent"), }) @@ -181,7 +179,6 @@ func TestSSH(t *testing.T) { agentClient.SessionToken = agentToken agentCloser := agent.New(agent.Options{ FetchMetadata: agentClient.WorkspaceAgentMetadata, - WebRTCDialer: agentClient.ListenWorkspaceAgent, CoordinatorDialer: agentClient.ListenWorkspaceAgentTailnet, Logger: slogtest.Make(t, nil).Named("agent"), }) diff --git a/coderd/coderd.go b/coderd/coderd.go index b4042197d1c2f..acfa271f802f6 100644 --- a/coderd/coderd.go +++ b/coderd/coderd.go @@ -13,7 +13,6 @@ import ( "github.com/go-chi/chi/v5" "github.com/go-chi/chi/v5/middleware" "github.com/klauspost/compress/zstd" - "github.com/pion/webrtc/v3" "github.com/prometheus/client_golang/prometheus" sdktrace "go.opentelemetry.io/otel/sdk/trace" "golang.org/x/xerrors" @@ -34,7 +33,6 @@ import ( "github.com/coder/coder/coderd/rbac" "github.com/coder/coder/coderd/telemetry" "github.com/coder/coder/coderd/tracing" - "github.com/coder/coder/coderd/turnconn" "github.com/coder/coder/coderd/wsconncache" "github.com/coder/coder/codersdk" "github.com/coder/coder/site" @@ -64,17 +62,14 @@ type Options struct { GithubOAuth2Config *GithubOAuth2Config OIDCConfig *OIDCConfig PrometheusRegistry *prometheus.Registry - ICEServers []webrtc.ICEServer SecureAuthCookie bool SSHKeygenAlgorithm gitsshkey.Algorithm Telemetry telemetry.Reporter - TURNServer *turnconn.Server TracerProvider *sdktrace.TracerProvider AutoImportTemplates []AutoImportTemplate LicenseHandler http.Handler FeaturesService FeaturesService - TailscaleEnable bool TailnetCoordinator *tailnet.Coordinator DERPMap *tailcfg.DERPMap @@ -142,11 +137,7 @@ func New(options *Options) *API { }, metricsCache: metricsCache, } - if options.TailscaleEnable { - api.workspaceAgentCache = wsconncache.New(api.dialWorkspaceAgentTailnet, 0) - } else { - api.workspaceAgentCache = wsconncache.New(api.dialWorkspaceAgent, 0) - } + api.workspaceAgentCache = wsconncache.New(api.dialWorkspaceAgentTailnet, 0) api.derpServer = derp.NewServer(key.NewNode(), tailnet.Logger(options.Logger)) oauthConfigs := &httpmw.OAuth2Configs{ Github: options.GithubOAuth2Config, @@ -377,14 +368,8 @@ func New(options *Options) *API { r.Use(httpmw.ExtractWorkspaceAgent(options.Database)) r.Get("/metadata", api.workspaceAgentMetadata) r.Post("/version", api.postWorkspaceAgentVersion) - r.Get("/listen", api.workspaceAgentListen) - r.Get("/gitsshkey", api.agentGitSSHKey) - r.Get("/turn", api.workspaceAgentTurn) - r.Get("/iceservers", api.workspaceAgentICEServers) - r.Get("/coordinate", api.workspaceAgentCoordinate) - r.Get("/report-stats", api.workspaceAgentReportStats) }) r.Route("/{workspaceagent}", func(r chi.Router) { @@ -394,11 +379,7 @@ func New(options *Options) *API { httpmw.ExtractWorkspaceParam(options.Database), ) r.Get("/", api.workspaceAgent) - r.Get("/dial", api.workspaceAgentDial) - r.Get("/turn", api.userWorkspaceAgentTurn) r.Get("/pty", api.workspaceAgentPTY) - r.Get("/iceservers", api.workspaceAgentICEServers) - r.Get("/connection", api.workspaceAgentConnection) r.Get("/coordinate", api.workspaceAgentClientCoordinate) }) diff --git a/coderd/coderdtest/authtest.go b/coderd/coderdtest/authtest.go index 354ea0a353e06..444f64a87a185 100644 --- a/coderd/coderdtest/authtest.go +++ b/coderd/coderdtest/authtest.go @@ -187,18 +187,14 @@ func AGPLRoutes(a *AuthTester) (map[string]string, map[string]RouteCheck) { "GET:/api/v2/users/oidc/callback": {NoAuthorize: true}, // All workspaceagents endpoints do not use rbac - "POST:/api/v2/workspaceagents/aws-instance-identity": {NoAuthorize: true}, - "POST:/api/v2/workspaceagents/azure-instance-identity": {NoAuthorize: true}, - "POST:/api/v2/workspaceagents/google-instance-identity": {NoAuthorize: true}, - "GET:/api/v2/workspaceagents/me/gitsshkey": {NoAuthorize: true}, - "GET:/api/v2/workspaceagents/me/iceservers": {NoAuthorize: true}, - "GET:/api/v2/workspaceagents/me/listen": {NoAuthorize: true}, - "GET:/api/v2/workspaceagents/me/metadata": {NoAuthorize: true}, - "GET:/api/v2/workspaceagents/me/turn": {NoAuthorize: true}, - "GET:/api/v2/workspaceagents/me/coordinate": {NoAuthorize: true}, - "POST:/api/v2/workspaceagents/me/version": {NoAuthorize: true}, - "GET:/api/v2/workspaceagents/me/report-stats": {NoAuthorize: true}, - "GET:/api/v2/workspaceagents/{workspaceagent}/iceservers": {NoAuthorize: true}, + "POST:/api/v2/workspaceagents/aws-instance-identity": {NoAuthorize: true}, + "POST:/api/v2/workspaceagents/azure-instance-identity": {NoAuthorize: true}, + "POST:/api/v2/workspaceagents/google-instance-identity": {NoAuthorize: true}, + "GET:/api/v2/workspaceagents/me/gitsshkey": {NoAuthorize: true}, + "GET:/api/v2/workspaceagents/me/metadata": {NoAuthorize: true}, + "GET:/api/v2/workspaceagents/me/coordinate": {NoAuthorize: true}, + "POST:/api/v2/workspaceagents/me/version": {NoAuthorize: true}, + "GET:/api/v2/workspaceagents/me/report-stats": {NoAuthorize: true}, // These endpoints have more assertions. This is good, add more endpoints to assert if you can! "GET:/api/v2/organizations/{organization}": {AssertObject: rbac.ResourceOrganization.InOrg(a.Admin.OrganizationID)}, @@ -259,14 +255,6 @@ func AGPLRoutes(a *AuthTester) (map[string]string, map[string]RouteCheck) { AssertAction: rbac.ActionRead, AssertObject: workspaceRBACObj, }, - "GET:/api/v2/workspaceagents/{workspaceagent}/dial": { - AssertAction: rbac.ActionCreate, - AssertObject: workspaceExecObj, - }, - "GET:/api/v2/workspaceagents/{workspaceagent}/turn": { - AssertAction: rbac.ActionCreate, - AssertObject: workspaceExecObj, - }, "GET:/api/v2/workspaceagents/{workspaceagent}/pty": { AssertAction: rbac.ActionCreate, AssertObject: workspaceExecObj, diff --git a/coderd/coderdtest/coderdtest.go b/coderd/coderdtest/coderdtest.go index a129d0667344c..82106a6ba070d 100644 --- a/coderd/coderdtest/coderdtest.go +++ b/coderd/coderdtest/coderdtest.go @@ -50,7 +50,6 @@ import ( "github.com/coder/coder/coderd/gitsshkey" "github.com/coder/coder/coderd/rbac" "github.com/coder/coder/coderd/telemetry" - "github.com/coder/coder/coderd/turnconn" "github.com/coder/coder/coderd/util/ptr" "github.com/coder/coder/codersdk" "github.com/coder/coder/cryptorand" @@ -188,12 +187,6 @@ func newWithAPI(t *testing.T, options *Options) (*codersdk.Client, io.Closer, *c options.SSHKeygenAlgorithm = gitsshkey.AlgorithmEd25519 } - turnServer, err := turnconn.New(nil) - require.NoError(t, err) - t.Cleanup(func() { - _ = turnServer.Close() - }) - // We set the handler after server creation for the access URL. coderAPI := options.APIBuilder(&coderd.Options{ AgentConnectionUpdateFrequency: 150 * time.Millisecond, @@ -212,7 +205,6 @@ func newWithAPI(t *testing.T, options *Options) (*codersdk.Client, io.Closer, *c OIDCConfig: options.OIDCConfig, GoogleTokenValidator: options.GoogleTokenValidator, SSHKeygenAlgorithm: options.SSHKeygenAlgorithm, - TURNServer: turnServer, APIRateLimit: options.APIRateLimit, Authorizer: options.Authorizer, Telemetry: telemetry.NewNoop(), diff --git a/coderd/templates_test.go b/coderd/templates_test.go index 2147d0fc4d6a2..26bd80d542a43 100644 --- a/coderd/templates_test.go +++ b/coderd/templates_test.go @@ -584,7 +584,6 @@ func TestTemplateDAUs(t *testing.T) { agentCloser := agent.New(agent.Options{ Logger: slogtest.Make(t, nil), StatsReporter: agentClient.AgentReportStats, - WebRTCDialer: agentClient.ListenWorkspaceAgent, FetchMetadata: agentClient.WorkspaceAgentMetadata, CoordinatorDialer: agentClient.ListenWorkspaceAgentTailnet, }) diff --git a/coderd/turnconn/turnconn.go b/coderd/turnconn/turnconn.go deleted file mode 100644 index b8231146d3cba..0000000000000 --- a/coderd/turnconn/turnconn.go +++ /dev/null @@ -1,203 +0,0 @@ -package turnconn - -import ( - "io" - "net" - "sync" - - "github.com/pion/logging" - "github.com/pion/turn/v2" - "github.com/pion/webrtc/v3" - "golang.org/x/net/proxy" - "golang.org/x/xerrors" -) - -var ( - // reservedAddress is a magic address that's used exclusively - // for proxying via Coder. We don't proxy all TURN connections, - // because that'd exclude the possibility of a customer using - // their own TURN server. - reservedAddress = "127.0.0.1:12345" - credential = "coder" - localhost = &net.TCPAddr{ - IP: net.IPv4(127, 0, 0, 1), - } - - // Proxy is a an ICE Server that uses a special hostname - // to indicate traffic should be proxied. - Proxy = webrtc.ICEServer{ - URLs: []string{"turns:" + reservedAddress}, - Username: "coder", - Credential: credential, - } -) - -// New constructs a new TURN server binding to the relay address provided. -// The relay address is used to broadcast the location of an accepted connection. -func New(relayAddress *turn.RelayAddressGeneratorStatic) (*Server, error) { - if relayAddress == nil { - relayAddress = &turn.RelayAddressGeneratorStatic{ - RelayAddress: localhost.IP, - Address: "127.0.0.1", - } - } - logger := logging.NewDefaultLoggerFactory() - logger.DefaultLogLevel = logging.LogLevelDisabled - server := &Server{ - conns: make(chan net.Conn, 1), - closed: make(chan struct{}), - } - server.listener = &listener{ - srv: server, - } - var err error - server.turn, err = turn.NewServer(turn.ServerConfig{ - AuthHandler: func(username, realm string, srcAddr net.Addr) (key []byte, ok bool) { - // TURN connections require credentials. It's not important - // for our use-case, because our listener is entirely in-memory. - return turn.GenerateAuthKey(Proxy.Username, "", credential), true - }, - ListenerConfigs: []turn.ListenerConfig{{ - Listener: server.listener, - RelayAddressGenerator: relayAddress, - }}, - LoggerFactory: logger, - }) - if err != nil { - return nil, xerrors.Errorf("create server: %w", err) - } - - return server, nil -} - -// Server accepts and connects TURN allocations. -// -// This is a thin wrapper around pion/turn that pipes -// connections directly to the in-memory handler. -type Server struct { - listener *listener - turn *turn.Server - - closeMutex sync.Mutex - closed chan (struct{}) - conns chan (net.Conn) -} - -// Accept consumes a new connection into the TURN server. -// A unique remote address must exist per-connection. -// pion/turn indexes allocations based on the address. -func (s *Server) Accept(nc net.Conn, remoteAddress, localAddress *net.TCPAddr) *Conn { - if localAddress == nil { - localAddress = localhost - } - conn := &Conn{ - Conn: nc, - remoteAddress: remoteAddress, - localAddress: localAddress, - closed: make(chan struct{}), - } - s.conns <- conn - return conn -} - -// Close ends the TURN server. -func (s *Server) Close() error { - s.closeMutex.Lock() - defer s.closeMutex.Unlock() - if s.isClosed() { - return nil - } - err := s.turn.Close() - close(s.conns) - close(s.closed) - return err -} - -func (s *Server) isClosed() bool { - select { - case <-s.closed: - return true - default: - return false - } -} - -// listener implements net.Listener for the TURN -// server to consume. -type listener struct { - srv *Server -} - -func (l *listener) Accept() (net.Conn, error) { - conn, ok := <-l.srv.conns - if !ok { - return nil, io.EOF - } - return conn, nil -} - -func (*listener) Close() error { - return nil -} - -func (*listener) Addr() net.Addr { - return nil -} - -type Conn struct { - net.Conn - closed chan struct{} - localAddress *net.TCPAddr - remoteAddress *net.TCPAddr -} - -func (c *Conn) LocalAddr() net.Addr { - return c.localAddress -} - -func (c *Conn) RemoteAddr() net.Addr { - return c.remoteAddress -} - -// Closed returns a channel which is closed when -// the connection is. -func (c *Conn) Closed() <-chan struct{} { - return c.closed -} - -func (c *Conn) Close() error { - err := c.Conn.Close() - select { - case <-c.closed: - default: - close(c.closed) - } - return err -} - -type dialer func(network, addr string) (c net.Conn, err error) - -func (d dialer) Dial(network, addr string) (c net.Conn, err error) { - return d(network, addr) -} - -// ProxyDialer accepts a proxy function that's called when the connection -// address matches the reserved host in the "Proxy" ICE server. -// -// This should be passed to WebRTC connections as an ICE dialer. -func ProxyDialer(proxyFunc func() (c net.Conn, err error)) proxy.Dialer { - return dialer(func(network, addr string) (net.Conn, error) { - if addr != reservedAddress { - return proxy.Direct.Dial(network, addr) - } - netConn, err := proxyFunc() - if err != nil { - return nil, err - } - return &Conn{ - localAddress: localhost, - closed: make(chan struct{}), - Conn: netConn, - }, nil - }) -} diff --git a/coderd/turnconn/turnconn_test.go b/coderd/turnconn/turnconn_test.go deleted file mode 100644 index 6a8d0411cb7b1..0000000000000 --- a/coderd/turnconn/turnconn_test.go +++ /dev/null @@ -1,107 +0,0 @@ -package turnconn_test - -import ( - "net" - "sync" - "testing" - - "github.com/pion/webrtc/v3" - "github.com/stretchr/testify/require" - "go.uber.org/goleak" - - "cdr.dev/slog" - "cdr.dev/slog/sloggers/slogtest" - "github.com/coder/coder/coderd/turnconn" - "github.com/coder/coder/peer" -) - -func TestMain(m *testing.M) { - goleak.VerifyTestMain(m) -} - -func TestTURNConn(t *testing.T) { - t.Parallel() - turnServer, err := turnconn.New(nil) - require.NoError(t, err) - defer turnServer.Close() - - logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) - - clientDialer, clientTURN := net.Pipe() - turnServer.Accept(clientTURN, &net.TCPAddr{ - IP: net.IPv4(127, 0, 0, 1), - Port: 16000, - }, nil) - require.NoError(t, err) - clientSettings := webrtc.SettingEngine{} - clientSettings.SetNetworkTypes([]webrtc.NetworkType{webrtc.NetworkTypeTCP4, webrtc.NetworkTypeTCP6}) - clientSettings.SetRelayAcceptanceMinWait(0) - clientSettings.SetICEProxyDialer(turnconn.ProxyDialer(func() (net.Conn, error) { - return clientDialer, nil - })) - client, err := peer.Client([]webrtc.ICEServer{turnconn.Proxy}, &peer.ConnOptions{ - SettingEngine: clientSettings, - Logger: logger.Named("client"), - }) - require.NoError(t, err) - defer func() { - _ = client.Close() - }() - - serverDialer, serverTURN := net.Pipe() - turnServer.Accept(serverTURN, &net.TCPAddr{ - IP: net.IPv4(127, 0, 0, 1), - Port: 16001, - }, nil) - require.NoError(t, err) - serverSettings := webrtc.SettingEngine{} - serverSettings.SetNetworkTypes([]webrtc.NetworkType{webrtc.NetworkTypeTCP4, webrtc.NetworkTypeTCP6}) - serverSettings.SetRelayAcceptanceMinWait(0) - serverSettings.SetICEProxyDialer(turnconn.ProxyDialer(func() (net.Conn, error) { - return serverDialer, nil - })) - server, err := peer.Server([]webrtc.ICEServer{turnconn.Proxy}, &peer.ConnOptions{ - SettingEngine: serverSettings, - Logger: logger.Named("server"), - }) - require.NoError(t, err) - defer func() { - _ = server.Close() - }() - exchange(t, client, server) - - _, err = client.Ping() - require.NoError(t, err) -} - -func exchange(t *testing.T, client, server *peer.Conn) { - var wg sync.WaitGroup - wg.Add(2) - t.Cleanup(wg.Wait) - go func() { - defer wg.Done() - for { - select { - case c := <-server.LocalCandidate(): - client.AddRemoteCandidate(c) - case c := <-server.LocalSessionDescription(): - client.SetRemoteSessionDescription(c) - case <-server.Closed(): - return - } - } - }() - go func() { - defer wg.Done() - for { - select { - case c := <-client.LocalCandidate(): - server.AddRemoteCandidate(c) - case c := <-client.LocalSessionDescription(): - server.SetRemoteSessionDescription(c) - case <-client.Closed(): - return - } - } - }() -} diff --git a/coderd/workspaceagents.go b/coderd/workspaceagents.go index bb48da0bab4d5..7654a1e0f0ff8 100644 --- a/coderd/workspaceagents.go +++ b/coderd/workspaceagents.go @@ -15,7 +15,6 @@ import ( "time" "github.com/google/uuid" - "github.com/hashicorp/yamux" "golang.org/x/mod/semver" "golang.org/x/xerrors" "nhooyr.io/websocket" @@ -29,12 +28,7 @@ import ( "github.com/coder/coder/coderd/httpmw" "github.com/coder/coder/coderd/rbac" "github.com/coder/coder/coderd/tracing" - "github.com/coder/coder/coderd/turnconn" "github.com/coder/coder/codersdk" - "github.com/coder/coder/peer" - "github.com/coder/coder/peerbroker" - "github.com/coder/coder/peerbroker/proto" - "github.com/coder/coder/provisionersdk" "github.com/coder/coder/tailnet" ) @@ -65,67 +59,6 @@ func (api *API) workspaceAgent(rw http.ResponseWriter, r *http.Request) { httpapi.Write(rw, http.StatusOK, apiAgent) } -func (api *API) workspaceAgentDial(rw http.ResponseWriter, r *http.Request) { - api.websocketWaitMutex.Lock() - api.websocketWaitGroup.Add(1) - api.websocketWaitMutex.Unlock() - defer api.websocketWaitGroup.Done() - - workspaceAgent := httpmw.WorkspaceAgentParam(r) - workspace := httpmw.WorkspaceParam(r) - if !api.Authorize(r, rbac.ActionCreate, workspace.ExecutionRBAC()) { - httpapi.ResourceNotFound(rw) - return - } - apiAgent, err := convertWorkspaceAgent(api.DERPMap, api.TailnetCoordinator, workspaceAgent, nil, api.AgentInactiveDisconnectTimeout) - if err != nil { - httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{ - Message: "Internal error reading workspace agent.", - Detail: err.Error(), - }) - return - } - if apiAgent.Status != codersdk.WorkspaceAgentConnected { - httpapi.Write(rw, http.StatusPreconditionFailed, codersdk.Response{ - Message: fmt.Sprintf("Agent isn't connected! Status: %s.", apiAgent.Status), - }) - return - } - - conn, err := websocket.Accept(rw, r, nil) - if err != nil { - httpapi.Write(rw, http.StatusBadRequest, codersdk.Response{ - Message: "Failed to accept websocket.", - Detail: err.Error(), - }) - return - } - - ctx, wsNetConn := websocketNetConn(r.Context(), conn, websocket.MessageBinary) - defer wsNetConn.Close() // Also closes conn. - - config := yamux.DefaultConfig() - config.LogOutput = io.Discard - session, err := yamux.Server(wsNetConn, config) - if err != nil { - _ = conn.Close(websocket.StatusAbnormalClosure, err.Error()) - return - } - - // end span so we don't get long lived trace data - tracing.EndHTTPSpan(r, 200) - - err = peerbroker.ProxyListen(ctx, session, peerbroker.ProxyOptions{ - ChannelID: workspaceAgent.ID.String(), - Logger: api.Logger.Named("peerbroker-proxy-dial"), - Pubsub: api.Pubsub, - }) - if err != nil { - _ = conn.Close(websocket.StatusInternalError, httpapi.WebsocketCloseSprintf("serve: %s", err)) - return - } -} - func (api *API) workspaceAgentMetadata(rw http.ResponseWriter, r *http.Request) { workspaceAgent := httpmw.WorkspaceAgent(r) apiAgent, err := convertWorkspaceAgent(api.DERPMap, api.TailnetCoordinator, workspaceAgent, nil, api.AgentInactiveDisconnectTimeout) @@ -185,230 +118,6 @@ func (api *API) postWorkspaceAgentVersion(rw http.ResponseWriter, r *http.Reques httpapi.Write(rw, http.StatusOK, nil) } -func (api *API) workspaceAgentListen(rw http.ResponseWriter, r *http.Request) { - api.websocketWaitMutex.Lock() - api.websocketWaitGroup.Add(1) - api.websocketWaitMutex.Unlock() - defer api.websocketWaitGroup.Done() - - workspaceAgent := httpmw.WorkspaceAgent(r) - resource, err := api.Database.GetWorkspaceResourceByID(r.Context(), workspaceAgent.ResourceID) - if err != nil { - httpapi.Write(rw, http.StatusBadRequest, codersdk.Response{ - Message: "Failed to accept websocket.", - Detail: err.Error(), - }) - return - } - - build, err := api.Database.GetWorkspaceBuildByJobID(r.Context(), resource.JobID) - if err != nil { - httpapi.Write(rw, http.StatusBadRequest, codersdk.Response{ - Message: "Internal error fetching workspace build job.", - Detail: err.Error(), - }) - return - } - // Ensure the resource is still valid! - // We only accept agents for resources on the latest build. - ensureLatestBuild := func() error { - latestBuild, err := api.Database.GetLatestWorkspaceBuildByWorkspaceID(r.Context(), build.WorkspaceID) - if err != nil { - return err - } - if build.ID != latestBuild.ID { - return xerrors.New("build is outdated") - } - return nil - } - - err = ensureLatestBuild() - if err != nil { - api.Logger.Debug(r.Context(), "agent tried to connect from non-latest built", - slog.F("resource", resource), - slog.F("agent", workspaceAgent), - ) - httpapi.Write(rw, http.StatusForbidden, codersdk.Response{ - Message: "Agent trying to connect from non-latest build.", - Detail: err.Error(), - }) - return - } - - conn, err := websocket.Accept(rw, r, &websocket.AcceptOptions{ - CompressionMode: websocket.CompressionDisabled, - }) - if err != nil { - httpapi.Write(rw, http.StatusBadRequest, codersdk.Response{ - Message: "Failed to accept websocket.", - Detail: err.Error(), - }) - return - } - - ctx, wsNetConn := websocketNetConn(r.Context(), conn, websocket.MessageBinary) - defer wsNetConn.Close() // Also closes conn. - - config := yamux.DefaultConfig() - config.LogOutput = io.Discard - session, err := yamux.Server(wsNetConn, config) - if err != nil { - _ = conn.Close(websocket.StatusAbnormalClosure, err.Error()) - return - } - - closer, err := peerbroker.ProxyDial(proto.NewDRPCPeerBrokerClient(provisionersdk.Conn(session)), peerbroker.ProxyOptions{ - ChannelID: workspaceAgent.ID.String(), - Pubsub: api.Pubsub, - Logger: api.Logger.Named("peerbroker-proxy-listen"), - }) - if err != nil { - _ = conn.Close(websocket.StatusAbnormalClosure, err.Error()) - return - } - defer closer.Close() - - firstConnectedAt := workspaceAgent.FirstConnectedAt - if !firstConnectedAt.Valid { - firstConnectedAt = sql.NullTime{ - Time: database.Now(), - Valid: true, - } - } - lastConnectedAt := sql.NullTime{ - Time: database.Now(), - Valid: true, - } - disconnectedAt := workspaceAgent.DisconnectedAt - updateConnectionTimes := func() error { - err = api.Database.UpdateWorkspaceAgentConnectionByID(ctx, database.UpdateWorkspaceAgentConnectionByIDParams{ - ID: workspaceAgent.ID, - FirstConnectedAt: firstConnectedAt, - LastConnectedAt: lastConnectedAt, - DisconnectedAt: disconnectedAt, - UpdatedAt: database.Now(), - }) - if err != nil { - return err - } - return nil - } - - defer func() { - disconnectedAt = sql.NullTime{ - Time: database.Now(), - Valid: true, - } - _ = updateConnectionTimes() - }() - - err = updateConnectionTimes() - if err != nil { - _ = conn.Close(websocket.StatusAbnormalClosure, err.Error()) - return - } - - // end span so we don't get long lived trace data - tracing.EndHTTPSpan(r, 200) - - api.Logger.Info(ctx, "accepting agent", slog.F("resource", resource), slog.F("agent", workspaceAgent)) - - ticker := time.NewTicker(api.AgentConnectionUpdateFrequency) - defer ticker.Stop() - for { - select { - case <-session.CloseChan(): - return - case <-ticker.C: - lastConnectedAt = sql.NullTime{ - Time: database.Now(), - Valid: true, - } - err = updateConnectionTimes() - if err != nil { - _ = conn.Close(websocket.StatusAbnormalClosure, err.Error()) - return - } - err = ensureLatestBuild() - if err != nil { - // Disconnect agents that are no longer valid. - _ = conn.Close(websocket.StatusGoingAway, "") - return - } - } - } -} - -func (api *API) workspaceAgentICEServers(rw http.ResponseWriter, _ *http.Request) { - httpapi.Write(rw, http.StatusOK, api.ICEServers) -} - -// userWorkspaceAgentTurn is a user connecting to a remote workspace agent -// through turn. -func (api *API) userWorkspaceAgentTurn(rw http.ResponseWriter, r *http.Request) { - workspace := httpmw.WorkspaceParam(r) - if !api.Authorize(r, rbac.ActionCreate, workspace.ExecutionRBAC()) { - httpapi.ResourceNotFound(rw) - return - } - - // Passed authorization - api.workspaceAgentTurn(rw, r) -} - -// workspaceAgentTurn proxies a WebSocket connection to the TURN server. -func (api *API) workspaceAgentTurn(rw http.ResponseWriter, r *http.Request) { - api.websocketWaitMutex.Lock() - api.websocketWaitGroup.Add(1) - api.websocketWaitMutex.Unlock() - defer api.websocketWaitGroup.Done() - - localAddress, _ := r.Context().Value(http.LocalAddrContextKey).(*net.TCPAddr) - remoteAddress := &net.TCPAddr{ - IP: net.ParseIP(r.RemoteAddr), - } - // By default requests have the remote address and port. - host, port, err := net.SplitHostPort(r.RemoteAddr) - if err != nil { - httpapi.Write(rw, http.StatusBadRequest, codersdk.Response{ - Message: "Invalid remote address.", - Detail: err.Error(), - }) - return - } - remoteAddress.IP = net.ParseIP(host) - remoteAddress.Port, err = strconv.Atoi(port) - if err != nil { - httpapi.Write(rw, http.StatusBadRequest, codersdk.Response{ - Message: fmt.Sprintf("Port for remote address %q must be an integer.", r.RemoteAddr), - Detail: err.Error(), - }) - return - } - - wsConn, err := websocket.Accept(rw, r, &websocket.AcceptOptions{ - CompressionMode: websocket.CompressionDisabled, - }) - if err != nil { - httpapi.Write(rw, http.StatusBadRequest, codersdk.Response{ - Message: "Failed to accept websocket.", - Detail: err.Error(), - }) - return - } - - ctx, wsNetConn := websocketNetConn(r.Context(), wsConn, websocket.MessageBinary) - defer wsNetConn.Close() // Also closes conn. - tracing.EndHTTPSpan(r, 200) // end span so we don't get long lived trace data - - api.Logger.Debug(ctx, "accepting turn connection", slog.F("remote-address", r.RemoteAddr), slog.F("local-address", localAddress)) - select { - case <-api.TURNServer.Accept(wsNetConn, remoteAddress, localAddress).Closed(): - case <-ctx.Done(): - } - api.Logger.Debug(ctx, "completed turn connection", slog.F("remote-address", r.RemoteAddr), slog.F("local-address", localAddress)) -} - // workspaceAgentPTY spawns a PTY and pipes it over a WebSocket. // This is used for the web terminal. func (api *API) workspaceAgentPTY(rw http.ResponseWriter, r *http.Request) { @@ -490,75 +199,7 @@ func (api *API) workspaceAgentPTY(rw http.ResponseWriter, r *http.Request) { _, _ = io.Copy(ptNetConn, wsNetConn) } -// dialWorkspaceAgent connects to a workspace agent by ID. Only rely on -// r.Context() for cancellation if it's use is safe or r.Hijack() has -// not been performed. -func (api *API) dialWorkspaceAgent(r *http.Request, agentID uuid.UUID) (agent.Conn, error) { - client, server := provisionersdk.TransportPipe() - ctx, cancelFunc := context.WithCancel(context.Background()) - go func() { - _ = peerbroker.ProxyListen(ctx, server, peerbroker.ProxyOptions{ - ChannelID: agentID.String(), - Logger: api.Logger.Named("peerbroker-proxy-dial"), - Pubsub: api.Pubsub, - }) - _ = client.Close() - _ = server.Close() - }() - - peerClient := proto.NewDRPCPeerBrokerClient(provisionersdk.Conn(client)) - stream, err := peerClient.NegotiateConnection(ctx) - if err != nil { - cancelFunc() - return nil, xerrors.Errorf("negotiate: %w", err) - } - options := &peer.ConnOptions{ - Logger: api.Logger.Named("agent-dialer"), - } - options.SettingEngine.SetSrflxAcceptanceMinWait(0) - options.SettingEngine.SetRelayAcceptanceMinWait(0) - // Use the ProxyDialer for the TURN server. - // This is required for connections where P2P is not enabled. - options.SettingEngine.SetICEProxyDialer(turnconn.ProxyDialer(func() (c net.Conn, err error) { - clientPipe, serverPipe := net.Pipe() - go func() { - <-ctx.Done() - _ = clientPipe.Close() - _ = serverPipe.Close() - }() - localAddress, _ := r.Context().Value(http.LocalAddrContextKey).(*net.TCPAddr) - remoteAddress := &net.TCPAddr{ - IP: net.ParseIP(r.RemoteAddr), - } - // By default requests have the remote address and port. - host, port, err := net.SplitHostPort(r.RemoteAddr) - if err != nil { - return nil, xerrors.Errorf("split remote address: %w", err) - } - remoteAddress.IP = net.ParseIP(host) - remoteAddress.Port, err = strconv.Atoi(port) - if err != nil { - return nil, xerrors.Errorf("convert remote port: %w", err) - } - api.TURNServer.Accept(clientPipe, remoteAddress, localAddress) - return serverPipe, nil - })) - peerConn, err := peerbroker.Dial(stream, append(api.ICEServers, turnconn.Proxy), options) - if err != nil { - cancelFunc() - return nil, xerrors.Errorf("dial: %w", err) - } - go func() { - <-peerConn.Closed() - cancelFunc() - }() - return &agent.WebRTCConn{ - Negotiator: peerClient, - Conn: peerConn, - }, nil -} - -func (api *API) dialWorkspaceAgentTailnet(r *http.Request, agentID uuid.UUID) (agent.Conn, error) { +func (api *API) dialWorkspaceAgentTailnet(r *http.Request, agentID uuid.UUID) (*agent.Conn, error) { clientConn, serverConn := net.Pipe() go func() { <-r.Context().Done() @@ -585,7 +226,7 @@ func (api *API) dialWorkspaceAgentTailnet(r *http.Request, agentID uuid.UUID) (a _ = conn.Close() } }() - return &agent.TailnetConn{ + return &agent.Conn{ Conn: conn, }, nil } @@ -607,6 +248,48 @@ func (api *API) workspaceAgentCoordinate(rw http.ResponseWriter, r *http.Request api.websocketWaitMutex.Unlock() defer api.websocketWaitGroup.Done() workspaceAgent := httpmw.WorkspaceAgent(r) + resource, err := api.Database.GetWorkspaceResourceByID(r.Context(), workspaceAgent.ResourceID) + if err != nil { + httpapi.Write(rw, http.StatusBadRequest, codersdk.Response{ + Message: "Failed to accept websocket.", + Detail: err.Error(), + }) + return + } + + build, err := api.Database.GetWorkspaceBuildByJobID(r.Context(), resource.JobID) + if err != nil { + httpapi.Write(rw, http.StatusBadRequest, codersdk.Response{ + Message: "Internal error fetching workspace build job.", + Detail: err.Error(), + }) + return + } + // Ensure the resource is still valid! + // We only accept agents for resources on the latest build. + ensureLatestBuild := func() error { + latestBuild, err := api.Database.GetLatestWorkspaceBuildByWorkspaceID(r.Context(), build.WorkspaceID) + if err != nil { + return err + } + if build.ID != latestBuild.ID { + return xerrors.New("build is outdated") + } + return nil + } + + err = ensureLatestBuild() + if err != nil { + api.Logger.Debug(r.Context(), "agent tried to connect from non-latest built", + slog.F("resource", resource), + slog.F("agent", workspaceAgent), + ) + httpapi.Write(rw, http.StatusForbidden, codersdk.Response{ + Message: "Agent trying to connect from non-latest build.", + Detail: err.Error(), + }) + return + } conn, err := websocket.Accept(rw, r, nil) if err != nil { @@ -616,12 +299,88 @@ func (api *API) workspaceAgentCoordinate(rw http.ResponseWriter, r *http.Request }) return } - defer conn.Close(websocket.StatusNormalClosure, "") - err = api.TailnetCoordinator.ServeAgent(websocket.NetConn(r.Context(), conn, websocket.MessageBinary), workspaceAgent.ID) + ctx, wsNetConn := websocketNetConn(r.Context(), conn, websocket.MessageBinary) + defer wsNetConn.Close() + + firstConnectedAt := workspaceAgent.FirstConnectedAt + if !firstConnectedAt.Valid { + firstConnectedAt = sql.NullTime{ + Time: database.Now(), + Valid: true, + } + } + lastConnectedAt := sql.NullTime{ + Time: database.Now(), + Valid: true, + } + disconnectedAt := workspaceAgent.DisconnectedAt + updateConnectionTimes := func() error { + err = api.Database.UpdateWorkspaceAgentConnectionByID(ctx, database.UpdateWorkspaceAgentConnectionByIDParams{ + ID: workspaceAgent.ID, + FirstConnectedAt: firstConnectedAt, + LastConnectedAt: lastConnectedAt, + DisconnectedAt: disconnectedAt, + UpdatedAt: database.Now(), + }) + if err != nil { + return err + } + return nil + } + + defer func() { + disconnectedAt = sql.NullTime{ + Time: database.Now(), + Valid: true, + } + _ = updateConnectionTimes() + }() + + err = updateConnectionTimes() if err != nil { - _ = conn.Close(websocket.StatusInternalError, err.Error()) + _ = conn.Close(websocket.StatusAbnormalClosure, err.Error()) return } + + // end span so we don't get long lived trace data + tracing.EndHTTPSpan(r, 200) + api.Logger.Info(ctx, "accepting agent", slog.F("resource", resource), slog.F("agent", workspaceAgent)) + + defer conn.Close(websocket.StatusNormalClosure, "") + + closeChan := make(chan struct{}) + go func() { + defer close(closeChan) + err = api.TailnetCoordinator.ServeAgent(wsNetConn, workspaceAgent.ID) + if err != nil { + _ = conn.Close(websocket.StatusInternalError, err.Error()) + return + } + }() + ticker := time.NewTicker(api.AgentConnectionUpdateFrequency) + defer ticker.Stop() + for { + select { + case <-closeChan: + return + case <-ticker.C: + } + lastConnectedAt = sql.NullTime{ + Time: database.Now(), + Valid: true, + } + err = updateConnectionTimes() + if err != nil { + _ = conn.Close(websocket.StatusAbnormalClosure, err.Error()) + return + } + err = ensureLatestBuild() + if err != nil { + // Disconnect agents that are no longer valid. + _ = conn.Close(websocket.StatusGoingAway, "") + return + } + } } // workspaceAgentClientCoordinate accepts a WebSocket that reads node network updates. diff --git a/coderd/workspaceagents_test.go b/coderd/workspaceagents_test.go index 38d27e26e799a..41161fe397e29 100644 --- a/coderd/workspaceagents_test.go +++ b/coderd/workspaceagents_test.go @@ -10,7 +10,6 @@ import ( "time" "github.com/google/uuid" - "github.com/pion/webrtc/v3" "github.com/stretchr/testify/require" "cdr.dev/slog" @@ -18,7 +17,6 @@ import ( "github.com/coder/coder/agent" "github.com/coder/coder/coderd/coderdtest" "github.com/coder/coder/codersdk" - "github.com/coder/coder/peer" "github.com/coder/coder/provisioner/echo" "github.com/coder/coder/provisionersdk/proto" "github.com/coder/coder/testutil" @@ -112,7 +110,6 @@ func TestWorkspaceAgentListen(t *testing.T) { agentCloser := agent.New(agent.Options{ FetchMetadata: agentClient.WorkspaceAgentMetadata, CoordinatorDialer: agentClient.ListenWorkspaceAgentTailnet, - WebRTCDialer: agentClient.ListenWorkspaceAgent, Logger: slogtest.Make(t, nil).Named("agent").Leveled(slog.LevelDebug), }) defer func() { @@ -123,13 +120,15 @@ func TestWorkspaceAgentListen(t *testing.T) { defer cancel() resources := coderdtest.AwaitWorkspaceAgents(t, client, workspace.LatestBuild.ID) - conn, err := client.DialWorkspaceAgent(ctx, resources[0].Agents[0].ID, nil) + conn, err := client.DialWorkspaceAgentTailnet(ctx, slog.Logger{}, resources[0].Agents[0].ID) require.NoError(t, err) defer func() { _ = conn.Close() }() - _, err = conn.Ping() - require.NoError(t, err) + require.Eventually(t, func() bool { + _, err := conn.Ping() + return err == nil + }, testutil.WaitMedium, testutil.IntervalFast) }) t.Run("FailNonLatestBuild", func(t *testing.T) { @@ -202,7 +201,7 @@ func TestWorkspaceAgentListen(t *testing.T) { agentClient := codersdk.New(client.URL) agentClient.SessionToken = authToken - _, err = agentClient.ListenWorkspaceAgent(ctx, slogtest.Make(t, nil)) + _, err = agentClient.ListenWorkspaceAgentTailnet(ctx) require.Error(t, err) require.ErrorContains(t, err, "build is outdated") }) @@ -246,7 +245,6 @@ func TestWorkspaceAgentTURN(t *testing.T) { agentCloser := agent.New(agent.Options{ FetchMetadata: agentClient.WorkspaceAgentMetadata, CoordinatorDialer: agentClient.ListenWorkspaceAgentTailnet, - WebRTCDialer: agentClient.ListenWorkspaceAgent, Logger: slogtest.Make(t, nil).Named("agent").Leveled(slog.LevelDebug), }) defer func() { @@ -257,18 +255,15 @@ func TestWorkspaceAgentTURN(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) defer cancel() - opts := &peer.ConnOptions{ - Logger: slogtest.Make(t, nil).Named("client"), - } - // Force a TURN connection! - opts.SettingEngine.SetNetworkTypes([]webrtc.NetworkType{webrtc.NetworkTypeTCP4}) - conn, err := client.DialWorkspaceAgent(ctx, resources[0].Agents[0].ID, opts) + conn, err := client.DialWorkspaceAgentTailnet(ctx, slog.Logger{}, resources[0].Agents[0].ID) require.NoError(t, err) defer func() { _ = conn.Close() }() - _, err = conn.Ping() - require.NoError(t, err) + require.Eventually(t, func() bool { + _, err := conn.Ping() + return err == nil + }, testutil.WaitMedium, testutil.IntervalFast) } func TestWorkspaceAgentTailnet(t *testing.T) { @@ -306,7 +301,6 @@ func TestWorkspaceAgentTailnet(t *testing.T) { agentClient.SessionToken = authToken agentCloser := agent.New(agent.Options{ FetchMetadata: agentClient.WorkspaceAgentMetadata, - WebRTCDialer: agentClient.ListenWorkspaceAgent, CoordinatorDialer: agentClient.ListenWorkspaceAgentTailnet, Logger: slogtest.Make(t, nil).Named("agent").Leveled(slog.LevelDebug), }) @@ -373,7 +367,6 @@ func TestWorkspaceAgentPTY(t *testing.T) { agentCloser := agent.New(agent.Options{ FetchMetadata: agentClient.WorkspaceAgentMetadata, CoordinatorDialer: agentClient.ListenWorkspaceAgentTailnet, - WebRTCDialer: agentClient.ListenWorkspaceAgent, Logger: slogtest.Make(t, nil).Named("agent").Leveled(slog.LevelDebug), }) defer func() { diff --git a/coderd/workspaceapps_test.go b/coderd/workspaceapps_test.go index 7e67f057bb354..010927ccdfc4e 100644 --- a/coderd/workspaceapps_test.go +++ b/coderd/workspaceapps_test.go @@ -84,7 +84,6 @@ func TestWorkspaceAppsProxyPath(t *testing.T) { agentCloser := agent.New(agent.Options{ FetchMetadata: agentClient.WorkspaceAgentMetadata, CoordinatorDialer: agentClient.ListenWorkspaceAgentTailnet, - WebRTCDialer: agentClient.ListenWorkspaceAgent, Logger: slogtest.Make(t, nil).Named("agent"), }) t.Cleanup(func() { diff --git a/coderd/wsconncache/wsconncache.go b/coderd/wsconncache/wsconncache.go index 698f467a40790..7d3b741a63b7e 100644 --- a/coderd/wsconncache/wsconncache.go +++ b/coderd/wsconncache/wsconncache.go @@ -32,11 +32,11 @@ func New(dialer Dialer, inactiveTimeout time.Duration) *Cache { } // Dialer creates a new agent connection by ID. -type Dialer func(r *http.Request, id uuid.UUID) (agent.Conn, error) +type Dialer func(r *http.Request, id uuid.UUID) (*agent.Conn, error) // Conn wraps an agent connection with a reusable HTTP transport. type Conn struct { - agent.Conn + *agent.Conn locks atomic.Uint64 timeoutMutex sync.Mutex diff --git a/coderd/wsconncache/wsconncache_test.go b/coderd/wsconncache/wsconncache_test.go index 80f187ba15ab7..a9ea85a2492ac 100644 --- a/coderd/wsconncache/wsconncache_test.go +++ b/coderd/wsconncache/wsconncache_test.go @@ -35,7 +35,7 @@ func TestCache(t *testing.T) { t.Parallel() t.Run("Same", func(t *testing.T) { t.Parallel() - cache := wsconncache.New(func(r *http.Request, id uuid.UUID) (agent.Conn, error) { + cache := wsconncache.New(func(r *http.Request, id uuid.UUID) (*agent.Conn, error) { return setupAgent(t, agent.Metadata{}, 0), nil }, 0) defer func() { @@ -50,7 +50,7 @@ func TestCache(t *testing.T) { t.Run("Expire", func(t *testing.T) { t.Parallel() called := atomic.NewInt32(0) - cache := wsconncache.New(func(r *http.Request, id uuid.UUID) (agent.Conn, error) { + cache := wsconncache.New(func(r *http.Request, id uuid.UUID) (*agent.Conn, error) { called.Add(1) return setupAgent(t, agent.Metadata{}, 0), nil }, time.Microsecond) @@ -69,7 +69,7 @@ func TestCache(t *testing.T) { }) t.Run("NoExpireWhenLocked", func(t *testing.T) { t.Parallel() - cache := wsconncache.New(func(r *http.Request, id uuid.UUID) (agent.Conn, error) { + cache := wsconncache.New(func(r *http.Request, id uuid.UUID) (*agent.Conn, error) { return setupAgent(t, agent.Metadata{}, 0), nil }, time.Microsecond) defer func() { @@ -102,7 +102,7 @@ func TestCache(t *testing.T) { }() go server.Serve(random) - cache := wsconncache.New(func(r *http.Request, id uuid.UUID) (agent.Conn, error) { + cache := wsconncache.New(func(r *http.Request, id uuid.UUID) (*agent.Conn, error) { return setupAgent(t, agent.Metadata{}, 0), nil }, time.Microsecond) defer func() { @@ -139,7 +139,7 @@ func TestCache(t *testing.T) { }) } -func setupAgent(t *testing.T, metadata agent.Metadata, ptyTimeout time.Duration) agent.Conn { +func setupAgent(t *testing.T, metadata agent.Metadata, ptyTimeout time.Duration) *agent.Conn { metadata.DERPMap = tailnettest.RunDERPAndSTUN(t) coordinator := tailnet.NewCoordinator() @@ -180,7 +180,7 @@ func setupAgent(t *testing.T, metadata agent.Metadata, ptyTimeout time.Duration) return conn.UpdateNodes(node) }) conn.SetNodeCallback(sendNode) - return &agent.TailnetConn{ + return &agent.Conn{ Conn: conn, } } diff --git a/codersdk/workspaceagents.go b/codersdk/workspaceagents.go index 021c4bcf77865..527a5a07856c0 100644 --- a/codersdk/workspaceagents.go +++ b/codersdk/workspaceagents.go @@ -14,9 +14,6 @@ import ( "cloud.google.com/go/compute/metadata" "github.com/google/uuid" - "github.com/hashicorp/yamux" - "github.com/pion/webrtc/v3" - "golang.org/x/net/proxy" "golang.org/x/xerrors" "nhooyr.io/websocket" "nhooyr.io/websocket/wsjson" @@ -25,11 +22,6 @@ import ( "cdr.dev/slog" "github.com/coder/coder/agent" - "github.com/coder/coder/coderd/turnconn" - "github.com/coder/coder/peer" - "github.com/coder/coder/peerbroker" - "github.com/coder/coder/peerbroker/proto" - "github.com/coder/coder/provisionersdk" "github.com/coder/coder/tailnet" "github.com/coder/retry" ) @@ -206,69 +198,6 @@ func (c *Client) WorkspaceAgentMetadata(ctx context.Context) (agent.Metadata, er return agentMetadata, json.NewDecoder(res.Body).Decode(&agentMetadata) } -// ListenWorkspaceAgent connects as a workspace agent identifying with the session token. -// On each inbound connection request, connection info is fetched. -func (c *Client) ListenWorkspaceAgent(ctx context.Context, logger slog.Logger) (*peerbroker.Listener, error) { - serverURL, err := c.URL.Parse("/api/v2/workspaceagents/me/listen") - if err != nil { - return nil, xerrors.Errorf("parse url: %w", err) - } - jar, err := cookiejar.New(nil) - if err != nil { - return nil, xerrors.Errorf("create cookie jar: %w", err) - } - jar.SetCookies(serverURL, []*http.Cookie{{ - Name: SessionTokenKey, - Value: c.SessionToken, - }}) - httpClient := &http.Client{ - Jar: jar, - } - conn, res, err := websocket.Dial(ctx, serverURL.String(), &websocket.DialOptions{ - HTTPClient: httpClient, - // Need to disable compression to avoid a data-race. - CompressionMode: websocket.CompressionDisabled, - }) - if err != nil { - if res == nil { - return nil, err - } - return nil, readBodyAsError(res) - } - config := yamux.DefaultConfig() - config.LogOutput = io.Discard - session, err := yamux.Client(websocket.NetConn(ctx, conn, websocket.MessageBinary), config) - if err != nil { - return nil, xerrors.Errorf("multiplex client: %w", err) - } - return peerbroker.Listen(session, func(ctx context.Context) ([]webrtc.ICEServer, *peer.ConnOptions, error) { - // This can be cached if it adds to latency too much. - res, err := c.Request(ctx, http.MethodGet, "/api/v2/workspaceagents/me/iceservers", nil) - if err != nil { - return nil, nil, err - } - defer res.Body.Close() - if res.StatusCode != http.StatusOK { - return nil, nil, readBodyAsError(res) - } - var iceServers []webrtc.ICEServer - err = json.NewDecoder(res.Body).Decode(&iceServers) - if err != nil { - return nil, nil, err - } - - options := webrtc.SettingEngine{} - options.SetSrflxAcceptanceMinWait(0) - options.SetRelayAcceptanceMinWait(0) - options.SetICEProxyDialer(c.turnProxyDialer(ctx, httpClient, "/api/v2/workspaceagents/me/turn")) - iceServers = append(iceServers, turnconn.Proxy) - return iceServers, &peer.ConnOptions{ - SettingEngine: options, - Logger: logger, - }, nil - }) -} - func (c *Client) ListenWorkspaceAgentTailnet(ctx context.Context) (net.Conn, error) { coordinateURL, err := c.URL.Parse("/api/v2/workspaceagents/me/coordinate") if err != nil { @@ -286,17 +215,20 @@ func (c *Client) ListenWorkspaceAgentTailnet(ctx context.Context) (net.Conn, err Jar: jar, } // nolint:bodyclose - conn, _, err := websocket.Dial(ctx, coordinateURL.String(), &websocket.DialOptions{ + conn, res, err := websocket.Dial(ctx, coordinateURL.String(), &websocket.DialOptions{ HTTPClient: httpClient, }) if err != nil { - return nil, err + if res == nil { + return nil, err + } + return nil, readBodyAsError(res) } return websocket.NetConn(ctx, conn, websocket.MessageBinary), nil } -func (c *Client) DialWorkspaceAgentTailnet(ctx context.Context, logger slog.Logger, agentID uuid.UUID) (agent.Conn, error) { +func (c *Client) DialWorkspaceAgentTailnet(ctx context.Context, logger slog.Logger, agentID uuid.UUID) (*agent.Conn, error) { res, err := c.Request(ctx, http.MethodGet, fmt.Sprintf("/api/v2/workspaceagents/%s/connection", agentID), nil) if err != nil { return nil, err @@ -370,7 +302,7 @@ func (c *Client) DialWorkspaceAgentTailnet(ctx context.Context, logger slog.Logg } } }() - return &agent.TailnetConn{ + return &agent.Conn{ Conn: conn, CloseFunc: func() { cancelFunc() @@ -379,78 +311,6 @@ func (c *Client) DialWorkspaceAgentTailnet(ctx context.Context, logger slog.Logg }, nil } -// DialWorkspaceAgent creates a connection to the specified resource. -func (c *Client) DialWorkspaceAgent(ctx context.Context, agentID uuid.UUID, options *peer.ConnOptions) (agent.Conn, error) { - serverURL, err := c.URL.Parse(fmt.Sprintf("/api/v2/workspaceagents/%s/dial", agentID.String())) - if err != nil { - return nil, xerrors.Errorf("parse url: %w", err) - } - jar, err := cookiejar.New(nil) - if err != nil { - return nil, xerrors.Errorf("create cookie jar: %w", err) - } - jar.SetCookies(serverURL, []*http.Cookie{{ - Name: SessionTokenKey, - Value: c.SessionToken, - }}) - httpClient := &http.Client{ - Jar: jar, - } - conn, res, err := websocket.Dial(ctx, serverURL.String(), &websocket.DialOptions{ - HTTPClient: httpClient, - // Need to disable compression to avoid a data-race. - CompressionMode: websocket.CompressionDisabled, - }) - if err != nil { - if res == nil { - return nil, err - } - return nil, readBodyAsError(res) - } - config := yamux.DefaultConfig() - config.LogOutput = io.Discard - session, err := yamux.Client(websocket.NetConn(ctx, conn, websocket.MessageBinary), config) - if err != nil { - return nil, xerrors.Errorf("multiplex client: %w", err) - } - client := proto.NewDRPCPeerBrokerClient(provisionersdk.Conn(session)) - stream, err := client.NegotiateConnection(ctx) - if err != nil { - return nil, xerrors.Errorf("negotiate connection: %w", err) - } - - res, err = c.Request(ctx, http.MethodGet, fmt.Sprintf("/api/v2/workspaceagents/%s/iceservers", agentID.String()), nil) - if err != nil { - return nil, err - } - defer res.Body.Close() - if res.StatusCode != http.StatusOK { - return nil, readBodyAsError(res) - } - var iceServers []webrtc.ICEServer - err = json.NewDecoder(res.Body).Decode(&iceServers) - if err != nil { - return nil, err - } - - if options == nil { - options = &peer.ConnOptions{} - } - options.SettingEngine.SetSrflxAcceptanceMinWait(0) - options.SettingEngine.SetRelayAcceptanceMinWait(0) - options.SettingEngine.SetICEProxyDialer(c.turnProxyDialer(ctx, httpClient, fmt.Sprintf("/api/v2/workspaceagents/%s/turn", agentID.String()))) - iceServers = append(iceServers, turnconn.Proxy) - - peerConn, err := peerbroker.Dial(stream, iceServers, options) - if err != nil { - return nil, xerrors.Errorf("dial peer: %w", err) - } - return &agent.WebRTCConn{ - Negotiator: client, - Conn: peerConn, - }, nil -} - // WorkspaceAgent returns an agent by ID. func (c *Client) WorkspaceAgent(ctx context.Context, id uuid.UUID) (WorkspaceAgent, error) { res, err := c.Request(ctx, http.MethodGet, fmt.Sprintf("/api/v2/workspaceagents/%s", id), nil) @@ -509,27 +369,6 @@ func (c *Client) WorkspaceAgentReconnectingPTY(ctx context.Context, agentID, rec return websocket.NetConn(ctx, conn, websocket.MessageBinary), nil } -func (c *Client) turnProxyDialer(ctx context.Context, httpClient *http.Client, path string) proxy.Dialer { - return turnconn.ProxyDialer(func() (net.Conn, error) { - turnURL, err := c.URL.Parse(path) - if err != nil { - return nil, xerrors.Errorf("parse url: %w", err) - } - conn, res, err := websocket.Dial(ctx, turnURL.String(), &websocket.DialOptions{ - HTTPClient: httpClient, - // Need to disable compression to avoid a data-race. - CompressionMode: websocket.CompressionDisabled, - }) - if err != nil { - if res == nil { - return nil, err - } - return nil, readBodyAsError(res) - } - return websocket.NetConn(ctx, conn, websocket.MessageBinary), nil - }) -} - // AgentReportStats begins a stat streaming connection with the Coder server. // It is resilient to network failures and intermittent coderd issues. func (c *Client) AgentReportStats( diff --git a/go.mod b/go.mod index 57d68cccaab7f..e86961cfd3209 100644 --- a/go.mod +++ b/go.mod @@ -41,7 +41,7 @@ replace golang.zx2c4.com/wireguard/tun/netstack => github.com/coder/wireguard-go // https://github.com/pion/udp/pull/73 replace github.com/pion/udp => github.com/mafredri/udp v0.1.2-0.20220805105907-b2872e92e98d -// https://github.com/hashicorp/hc-install/pull/68 +// https://github.com/hashicorp/hc-dinstall/pull/68 replace github.com/hashicorp/hc-install => github.com/mafredri/hc-install v0.4.1-0.20220727132613-e91868e28445 // https://github.com/tcnksm/go-httpstat/pull/29 @@ -49,7 +49,7 @@ replace github.com/tcnksm/go-httpstat => github.com/kylecarbs/go-httpstat v0.0.0 // There are a few minor changes we make to Tailscale that we're slowly upstreaming. Compare here: // https://github.com/tailscale/tailscale/compare/main...coder:tailscale:main -replace tailscale.com => github.com/coder/tailscale v1.1.1-0.20220902164407-ae46caa65076 +replace tailscale.com => github.com/coder/tailscale v1.1.1-0.20220905215011-3b49f627e66c require ( cdr.dev/slog v1.4.2-0.20220525200111-18dce5c2cd5f @@ -109,12 +109,7 @@ require ( github.com/nhatthm/otelsql v0.4.0 github.com/open-policy-agent/opa v0.41.0 github.com/ory/dockertest/v3 v3.9.1 - github.com/pion/datachannel v1.5.2 - github.com/pion/logging v0.2.2 - github.com/pion/transport v0.13.1 - github.com/pion/turn/v2 v2.0.8 github.com/pion/udp v0.1.1 - github.com/pion/webrtc/v3 v3.1.43 github.com/pkg/browser v0.0.0-20210911075715-681adbf594b8 github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e github.com/pkg/sftp v1.13.5 @@ -138,7 +133,6 @@ require ( golang.org/x/crypto v0.0.0-20220517005047-85d78b3ac167 golang.org/x/exp v0.0.0-20220722155223-a9213eeb770e golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4 - golang.org/x/net v0.0.0-20220630215102-69896b714898 golang.org/x/oauth2 v0.0.0-20220622183110-fd043fe589d2 golang.org/x/sync v0.0.0-20220601150217-0de741cfad7f golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8 @@ -160,6 +154,11 @@ require ( tailscale.com v1.26.2 ) +require ( + github.com/pion/transport v0.13.1 // indirect + golang.org/x/net v0.0.0-20220630215102-69896b714898 // indirect +) + require ( filippo.io/edwards25519 v1.0.0-rc.1 // indirect github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1 // indirect @@ -247,17 +246,6 @@ require ( github.com/opencontainers/image-spec v1.0.3-0.20220114050600-8b9d41f48198 // indirect github.com/opencontainers/runc v1.1.2 // indirect github.com/pelletier/go-toml/v2 v2.0.2 // indirect - github.com/pion/dtls/v2 v2.1.5 // indirect - github.com/pion/ice/v2 v2.2.6 // indirect - github.com/pion/interceptor v0.1.11 // indirect - github.com/pion/mdns v0.0.5 // indirect - github.com/pion/randutil v0.1.0 // indirect - github.com/pion/rtcp v1.2.9 // indirect - github.com/pion/rtp v1.7.13 // indirect - github.com/pion/sctp v1.8.2 // indirect - github.com/pion/sdp/v3 v3.0.5 // indirect - github.com/pion/srtp/v2 v2.0.10 // indirect - github.com/pion/stun v0.3.5 // indirect github.com/pkg/errors v0.9.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/prometheus/client_model v0.2.0 // indirect diff --git a/go.sum b/go.sum index d1915643f56ac..589f57c044003 100644 --- a/go.sum +++ b/go.sum @@ -352,8 +352,8 @@ github.com/coder/glog v1.0.1-0.20220322161911-7365fe7f2cd1 h1:UqBrPWSYvRI2s5RtOu github.com/coder/glog v1.0.1-0.20220322161911-7365fe7f2cd1/go.mod h1:EWib/APOK0SL3dFbYqvxE3UYd8E6s1ouQ7iEp/0LWV4= github.com/coder/retry v1.3.0 h1:5lAAwt/2Cm6lVmnfBY7sOMXcBOwcwJhmV5QGSELIVWY= github.com/coder/retry v1.3.0/go.mod h1:tXuRgZgWjUnU5LZPT4lJh4ew2elUhexhlnXzrJWdyFY= -github.com/coder/tailscale v1.1.1-0.20220902164407-ae46caa65076 h1:PITEtBolloXfTMGSkL1hQSPBMT4+YJFUgjRQl5osB5k= -github.com/coder/tailscale v1.1.1-0.20220902164407-ae46caa65076/go.mod h1:MO+tWkQp2YIF3KBnnej/mQvgYccRS5Xk/IrEpZ4Z3BU= +github.com/coder/tailscale v1.1.1-0.20220905215011-3b49f627e66c h1:OJkiwMlJtLNUbuiRkzlhzKjRTfNW2B3mTXkl7KeOP+k= +github.com/coder/tailscale v1.1.1-0.20220905215011-3b49f627e66c/go.mod h1:MO+tWkQp2YIF3KBnnej/mQvgYccRS5Xk/IrEpZ4Z3BU= github.com/coder/wireguard-go/tun/netstack v0.0.0-20220823170024-a78136eb0cab h1:9yEvRWXXfyKzXu8AqywCi+tFZAoqCy4wVcsXwuvZNMc= github.com/coder/wireguard-go/tun/netstack v0.0.0-20220823170024-a78136eb0cab/go.mod h1:TCJ66NtXh3urJotTdoYQOHHkyE899vOQl5TuF+WLSes= github.com/containerd/aufs v0.0.0-20200908144142-dab0cbea06f4/go.mod h1:nukgQABAEopAHvB6j7cnP5zJ+/3aVcE7hCYqvIwAHyE= @@ -1451,7 +1451,6 @@ github.com/onsi/ginkgo v1.12.1/go.mod h1:zj2OWP4+oCPe1qIXoGWkgMRwljMUYCdkwsT2108 github.com/onsi/ginkgo v1.13.0/go.mod h1:+REjRxOmWfHCjfv9TTWB1jD1Frx4XydAD3zm1lskyM0= github.com/onsi/ginkgo v1.14.0/go.mod h1:iSB4RoI2tjJc9BBv4NKIKWKya62Rps+oPG/Lv9klQyY= github.com/onsi/ginkgo v1.16.4/go.mod h1:dX+/inL/fNMqNlz0e9LfyB9TswhZpCVdJM/Z6Vvnwo0= -github.com/onsi/ginkgo v1.16.5/go.mod h1:+E8gABHa3K6zRBolWtd+ROzc/U5bkGt0FwiG042wbpU= github.com/onsi/ginkgo/v2 v2.1.3/go.mod h1:vw5CSIxN1JObi/U8gcbwft7ZxR2dgaR70JSE3/PpL4c= github.com/onsi/gomega v0.0.0-20151007035656-2152b45fa28a/go.mod h1:C1qb7wdrVGGVU+Z6iS04AVkA3Q65CEZX59MT0QO5uiA= github.com/onsi/gomega v0.0.0-20170829124025-dcabb60a477c/go.mod h1:C1qb7wdrVGGVU+Z6iS04AVkA3Q65CEZX59MT0QO5uiA= @@ -1526,43 +1525,9 @@ github.com/phpdave11/gofpdf v1.4.2/go.mod h1:zpO6xFn9yxo3YLyMvW8HcKWVdbNqgIfOOp2 github.com/phpdave11/gofpdi v1.0.12/go.mod h1:vBmVV0Do6hSBHC8uKUQ71JGW+ZGQq74llk/7bXwjDoI= github.com/pierrec/lz4 v2.0.5+incompatible/go.mod h1:pdkljMzZIN41W+lC3N2tnIh5sFi+IEE17M5jbnwPHcY= github.com/pierrec/lz4/v4 v4.1.8/go.mod h1:gZWDp/Ze/IJXGXf23ltt2EXimqmTUXEy0GFuRQyBid4= -github.com/pion/datachannel v1.5.2 h1:piB93s8LGmbECrpO84DnkIVWasRMk3IimbcXkTQLE6E= -github.com/pion/datachannel v1.5.2/go.mod h1:FTGQWaHrdCwIJ1rw6xBIfZVkslikjShim5yr05XFuCQ= -github.com/pion/dtls/v2 v2.1.3/go.mod h1:o6+WvyLDAlXF7YiPB/RlskRoeK+/JtuaZa5emwQcWus= -github.com/pion/dtls/v2 v2.1.5 h1:jlh2vtIyUBShchoTDqpCCqiYCyRFJ/lvf/gQ8TALs+c= -github.com/pion/dtls/v2 v2.1.5/go.mod h1:BqCE7xPZbPSubGasRoDFJeTsyJtdD1FanJYL0JGheqY= -github.com/pion/ice/v2 v2.2.6 h1:R/vaLlI1J2gCx141L5PEwtuGAGcyS6e7E0hDeJFq5Ig= -github.com/pion/ice/v2 v2.2.6/go.mod h1:SWuHiOGP17lGromHTFadUe1EuPgFh/oCU6FCMZHooVE= -github.com/pion/interceptor v0.1.11 h1:00U6OlqxA3FFB50HSg25J/8cWi7P6FbSzw4eFn24Bvs= -github.com/pion/interceptor v0.1.11/go.mod h1:tbtKjZY14awXd7Bq0mmWvgtHB5MDaRN7HV3OZ/uy7s8= -github.com/pion/logging v0.2.2 h1:M9+AIj/+pxNsDfAT64+MAVgJO0rsyLnoJKCqf//DoeY= github.com/pion/logging v0.2.2/go.mod h1:k0/tDVsRCX2Mb2ZEmTqNa7CWsQPc+YYCB7Q+5pahoms= -github.com/pion/mdns v0.0.5 h1:Q2oj/JB3NqfzY9xGZ1fPzZzK7sDSD8rZPOvcIQ10BCw= -github.com/pion/mdns v0.0.5/go.mod h1:UgssrvdD3mxpi8tMxAXbsppL3vJ4Jipw1mTCW+al01g= -github.com/pion/randutil v0.1.0 h1:CFG1UdESneORglEsnimhUjf33Rwjubwj6xfiOXBa3mA= -github.com/pion/randutil v0.1.0/go.mod h1:XcJrSMMbbMRhASFVOlj/5hQial/Y8oH/HVo7TBZq+j8= -github.com/pion/rtcp v1.2.9 h1:1ujStwg++IOLIEoOiIQ2s+qBuJ1VN81KW+9pMPsif+U= -github.com/pion/rtcp v1.2.9/go.mod h1:qVPhiCzAm4D/rxb6XzKeyZiQK69yJpbUDJSF7TgrqNo= -github.com/pion/rtp v1.7.13 h1:qcHwlmtiI50t1XivvoawdCGTP4Uiypzfrsap+bijcoA= -github.com/pion/rtp v1.7.13/go.mod h1:bDb5n+BFZxXx0Ea7E5qe+klMuqiBrP+w8XSjiWtCUko= -github.com/pion/sctp v1.8.0/go.mod h1:xFe9cLMZ5Vj6eOzpyiKjT9SwGM4KpK/8Jbw5//jc+0s= -github.com/pion/sctp v1.8.2 h1:yBBCIrUMJ4yFICL3RIvR4eh/H2BTTvlligmSTy+3kiA= -github.com/pion/sctp v1.8.2/go.mod h1:xFe9cLMZ5Vj6eOzpyiKjT9SwGM4KpK/8Jbw5//jc+0s= -github.com/pion/sdp/v3 v3.0.5 h1:ouvI7IgGl+V4CrqskVtr3AaTrPvPisEOxwgpdktctkU= -github.com/pion/sdp/v3 v3.0.5/go.mod h1:iiFWFpQO8Fy3S5ldclBkpXqmWy02ns78NOKoLLL0YQw= -github.com/pion/srtp/v2 v2.0.10 h1:b8ZvEuI+mrL8hbr/f1YiJFB34UMrOac3R3N1yq2UN0w= -github.com/pion/srtp/v2 v2.0.10/go.mod h1:XEeSWaK9PfuMs7zxXyiN252AHPbH12NX5q/CFDWtUuA= -github.com/pion/stun v0.3.5 h1:uLUCBCkQby4S1cf6CGuR9QrVOKcvUwFeemaC865QHDg= -github.com/pion/stun v0.3.5/go.mod h1:gDMim+47EeEtfWogA37n6qXZS88L5V6LqFcf+DZA2UA= -github.com/pion/transport v0.12.2/go.mod h1:N3+vZQD9HlDP5GWkZ85LohxNsDcNgofQmyL6ojX5d8Q= -github.com/pion/transport v0.12.3/go.mod h1:OViWW9SP2peE/HbwBvARicmAVnesphkNkCVZIWJ6q9A= -github.com/pion/transport v0.13.0/go.mod h1:yxm9uXpK9bpBBWkITk13cLo1y5/ur5VQpG22ny6EP7g= github.com/pion/transport v0.13.1 h1:/UH5yLeQtwm2VZIPjxwnNFxjS4DFhyLfS4GlfuKUzfA= github.com/pion/transport v0.13.1/go.mod h1:EBxbqzyv+ZrmDb82XswEE0BjfQFtuw1Nu6sjnjWCsGg= -github.com/pion/turn/v2 v2.0.8 h1:KEstL92OUN3k5k8qxsXHpr7WWfrdp7iJZHx99ud8muw= -github.com/pion/turn/v2 v2.0.8/go.mod h1:+y7xl719J8bAEVpSXBXvTxStjJv3hbz9YFflvkpcGPw= -github.com/pion/webrtc/v3 v3.1.43 h1:YT3ZTO94UT4kSBvZnRAH82+0jJPUruiKr9CEstdlQzk= -github.com/pion/webrtc/v3 v3.1.43/go.mod h1:G/J8k0+grVsjC/rjCZ24AKoCCxcFFODgh7zThNZGs0M= github.com/pkg/browser v0.0.0-20210706143420-7d21f8c997e2/go.mod h1:HKlIX3XHQyzLZPlr7++PzdhaXEj94dEiJgZDTsxEqUI= github.com/pkg/browser v0.0.0-20210911075715-681adbf594b8 h1:KoWmjvw+nsYOo29YJK9vDA65RGE3NrOnUtO7a+RF9HU= github.com/pkg/browser v0.0.0-20210911075715-681adbf594b8/go.mod h1:HKlIX3XHQyzLZPlr7++PzdhaXEj94dEiJgZDTsxEqUI= @@ -2034,9 +1999,6 @@ golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5y golang.org/x/crypto v0.0.0-20211108221036-ceb1ce70b4fa/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.0.0-20211117183948-ae814b36b871/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/crypto v0.0.0-20211215153901-e495a2d5b3d3/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= -golang.org/x/crypto v0.0.0-20220131195533-30dcbda58838/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= -golang.org/x/crypto v0.0.0-20220427172511-eb4f295cb31f/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= -golang.org/x/crypto v0.0.0-20220516162934-403b01795ae8/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/crypto v0.0.0-20220517005047-85d78b3ac167 h1:O8uGbHCqlTp2P6QJSLmCojM4mN6UemYv8K+dCnmHmu0= golang.org/x/crypto v0.0.0-20220517005047-85d78b3ac167/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= @@ -2148,7 +2110,6 @@ golang.org/x/net v0.0.0-20201010224723-4f7140c49acb/go.mod h1:sp8m0HH+o8qH0wwXwY golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= golang.org/x/net v0.0.0-20201031054903-ff519b6c9102/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= golang.org/x/net v0.0.0-20201110031124-69a78807bb2b/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= -golang.org/x/net v0.0.0-20201201195509-5d6afe98e0b7/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= golang.org/x/net v0.0.0-20201202161906-c7110b5ffcbb/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= golang.org/x/net v0.0.0-20201209123823-ac852fbbde11/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20201224014010-6772e930b67b/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= @@ -2172,7 +2133,6 @@ golang.org/x/net v0.0.0-20210903162142-ad29c8ab022f/go.mod h1:9nx3DQGgdP8bBQD5qx golang.org/x/net v0.0.0-20210928044308-7d9f5e0b762b/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20211015210444-4f30a5c0130f/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= -golang.org/x/net v0.0.0-20211201190559-0a0e4e1bb54c/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20211209124913-491a49abca63/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20211216030914-fe4d6282115f/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20220107192237-5cfca573fb4d/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= @@ -2180,7 +2140,6 @@ golang.org/x/net v0.0.0-20220111093109-d55c255bac03/go.mod h1:9nx3DQGgdP8bBQD5qx golang.org/x/net v0.0.0-20220127200216-cd36cc0744dd/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= golang.org/x/net v0.0.0-20220225172249-27dd8689420f/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= golang.org/x/net v0.0.0-20220325170049-de3da57026de/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= -golang.org/x/net v0.0.0-20220401154927-543a649e0bdd/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= golang.org/x/net v0.0.0-20220412020605-290c469a71a5/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= golang.org/x/net v0.0.0-20220425223048-2871e0cb64e4/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= golang.org/x/net v0.0.0-20220531201128-c960675eff93/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= @@ -2383,7 +2342,6 @@ golang.org/x/sys v0.0.0-20220503163025-988cb79eb6c6/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220608164250-635b8c9b7f68/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220610221304-9f5ed59c137d/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220622161953-175b2fd9d664/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220624220833-87e55d714810/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8 h1:0A+M6Uqn+Eje4kHMK80dtF3JCXC4ykBgQG4Fe06QRhQ= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= diff --git a/peer/channel.go b/peer/channel.go deleted file mode 100644 index 4945e5d465683..0000000000000 --- a/peer/channel.go +++ /dev/null @@ -1,317 +0,0 @@ -package peer - -import ( - "bufio" - "context" - "io" - "net" - "sync" - - "github.com/pion/datachannel" - "github.com/pion/webrtc/v3" - "golang.org/x/xerrors" - - "cdr.dev/slog" -) - -const ( - bufferedAmountLowThreshold uint64 = 512 * 1024 // 512 KB - maxBufferedAmount uint64 = 1024 * 1024 // 1 MB - // For some reason messages larger just don't work... - // This shouldn't be a huge deal for real-world usage. - // See: https://github.com/pion/datachannel/issues/59 - maxMessageLength = 64 * 1024 // 64 KB -) - -// newChannel creates a new channel and initializes it. -// The initialization overrides listener handles, and detaches -// the channel on open. The datachannel should not be manually -// mutated after being passed to this function. -func newChannel(conn *Conn, dc *webrtc.DataChannel, opts *ChannelOptions) *Channel { - channel := &Channel{ - opts: opts, - conn: conn, - dc: dc, - - opened: make(chan struct{}), - closed: make(chan struct{}), - sendMore: make(chan struct{}, 1), - } - channel.init() - return channel -} - -type ChannelOptions struct { - // ID is a channel ID that should be used when `Negotiated` - // is true. - ID uint16 - - // Negotiated returns whether the data channel will already - // be active on the other end. Defaults to false. - Negotiated bool - - // Arbitrary string that can be parsed on `Accept`. - Protocol string - - // Unordered determines whether the channel acts like - // a UDP connection. Defaults to false. - Unordered bool - - // Whether the channel will be left open on disconnect or not. - // If true, data will be buffered on either end to be sent - // once reconnected. Defaults to false. - OpenOnDisconnect bool -} - -// Channel represents a WebRTC DataChannel. -// -// This struct wraps webrtc.DataChannel to add concurrent-safe usage, -// data bufferring, and standardized errors for connection state. -// -// It modifies the default behavior of a DataChannel by closing on -// WebRTC PeerConnection failure. This is done to emulate TCP connections. -// This option can be changed in the options when creating a Channel. -type Channel struct { - opts *ChannelOptions - - conn *Conn - dc *webrtc.DataChannel - // This field can be nil. It becomes set after the DataChannel - // has been opened and is detached. - rwc datachannel.ReadWriteCloser - reader io.Reader - - closed chan struct{} - closeMutex sync.Mutex - closeError error - - opened chan struct{} - - // sendMore is used to block Write operations on a full buffer. - // It's signaled when the buffer can accept more data. - sendMore chan struct{} - writeMutex sync.Mutex -} - -// init attaches listeners to the DataChannel to detect opening, -// closing, and when the channel is ready to transmit data. -// -// This should only be called once on creation. -func (c *Channel) init() { - // WebRTC connections maintain an internal buffer that can fill when: - // 1. Data is being sent faster than it can flush. - // 2. The connection is disconnected, but data is still being sent. - // - // This applies a maximum in-memory buffer for data, and will cause - // write operations to block once the threshold is set. - c.dc.SetBufferedAmountLowThreshold(bufferedAmountLowThreshold) - c.dc.OnBufferedAmountLow(func() { - // Grab the lock to protect the sendMore channel from being - // closed in between the isClosed check and the send. - c.closeMutex.Lock() - defer c.closeMutex.Unlock() - if c.isClosed() { - return - } - select { - case <-c.closed: - case c.sendMore <- struct{}{}: - default: - } - }) - c.dc.OnClose(func() { - c.conn.logger().Debug(context.Background(), "datachannel closing from OnClose", slog.F("id", c.dc.ID()), slog.F("label", c.dc.Label())) - _ = c.closeWithError(ErrClosed) - }) - c.dc.OnOpen(func() { - c.closeMutex.Lock() - c.conn.logger().Debug(context.Background(), "datachannel opening", slog.F("id", c.dc.ID()), slog.F("label", c.dc.Label())) - var err error - c.rwc, err = c.dc.Detach() - if err != nil { - c.closeMutex.Unlock() - _ = c.closeWithError(xerrors.Errorf("detach: %w", err)) - return - } - c.closeMutex.Unlock() - - // pion/webrtc will return an io.ErrShortBuffer when a read - // is triggered with a buffer size less than the chunks written. - // - // This makes sense when considering UDP connections, because - // buffering of data that has no transmit guarantees is likely - // to cause unexpected behavior. - // - // When ordered, this adds a bufio.Reader. This ensures additional - // data on TCP-like connections can be read in parts, while still - // being buffered. - if c.opts.Unordered { - c.reader = c.rwc - } else { - // This must be the max message length otherwise a short - // buffer error can occur. - c.reader = bufio.NewReaderSize(c.rwc, maxMessageLength) - } - close(c.opened) - }) - - c.conn.dcDisconnectListeners.Add(1) - c.conn.dcFailedListeners.Add(1) - c.conn.dcClosedWaitGroup.Add(1) - go func() { - var err error - // A DataChannel can disconnect multiple times, so this needs to loop. - for { - select { - case <-c.conn.closedRTC: - // If this channel was closed, there's no need to close again. - err = c.conn.closeError - case <-c.conn.Closed(): - // If the RTC connection closed with an error, this channel - // should end with the same one. - err = c.conn.closeError - case <-c.conn.dcDisconnectChannel: - // If the RTC connection is disconnected, we need to check if - // the DataChannel is supposed to end on disconnect. - if c.opts.OpenOnDisconnect { - continue - } - err = xerrors.Errorf("rtc disconnected. closing: %w", ErrClosed) - case <-c.conn.dcFailedChannel: - // If the RTC connection failed, close the Channel. - err = ErrFailed - } - if err != nil { - break - } - } - _ = c.closeWithError(err) - }() -} - -// Read blocks until data is received. -// -// This will block until the underlying DataChannel has been opened. -func (c *Channel) Read(bytes []byte) (int, error) { - err := c.waitOpened() - if err != nil { - return 0, err - } - - bytesRead, err := c.reader.Read(bytes) - if err != nil { - if c.isClosed() { - return 0, c.closeError - } - // An EOF always occurs when the connection is closed. - // Alternative close errors will occur first if an unexpected - // close has occurred. - if xerrors.Is(err, io.EOF) { - err = c.closeWithError(ErrClosed) - } - } - return bytesRead, err -} - -// Write sends data to the underlying DataChannel. -// -// This function will block if too much data is being sent. -// Data will buffer if the connection is temporarily disconnected, -// and will be flushed upon reconnection. -// -// If the Channel is setup to close on disconnect, any buffered -// data will be lost. -func (c *Channel) Write(bytes []byte) (n int, err error) { - if len(bytes) > maxMessageLength { - return 0, xerrors.Errorf("outbound packet larger than maximum message size: %d", maxMessageLength) - } - - c.writeMutex.Lock() - defer c.writeMutex.Unlock() - - err = c.waitOpened() - if err != nil { - return 0, err - } - if c.dc.BufferedAmount()+uint64(len(bytes)) >= maxBufferedAmount { - <-c.sendMore - } - - return c.rwc.Write(bytes) -} - -// Close gracefully closes the DataChannel. -func (c *Channel) Close() error { - return c.closeWithError(nil) -} - -// Label returns the label of the underlying DataChannel. -func (c *Channel) Label() string { - return c.dc.Label() -} - -// Protocol returns the protocol of the underlying DataChannel. -func (c *Channel) Protocol() string { - return c.dc.Protocol() -} - -// NetConn wraps the DataChannel in a struct fulfilling net.Conn. -// Read, Write, and Close operations can still be used on the *Channel struct. -func (c *Channel) NetConn() net.Conn { - return &fakeNetConn{ - c: c, - addr: &peerAddr{}, - } -} - -// closeWithError closes the Channel with the error provided. -// If a graceful close occurs, the error will be nil. -func (c *Channel) closeWithError(err error) error { - c.closeMutex.Lock() - defer c.closeMutex.Unlock() - - if c.isClosed() { - return c.closeError - } - - c.conn.logger().Debug(context.Background(), "datachannel closing with error", slog.F("id", c.dc.ID()), slog.F("label", c.dc.Label()), slog.Error(err)) - if err == nil { - c.closeError = ErrClosed - } else { - c.closeError = err - } - if c.rwc != nil { - _ = c.rwc.Close() - } - _ = c.dc.Close() - - close(c.closed) - close(c.sendMore) - c.conn.dcDisconnectListeners.Sub(1) - c.conn.dcFailedListeners.Sub(1) - c.conn.dcClosedWaitGroup.Done() - - return err -} - -func (c *Channel) isClosed() bool { - select { - case <-c.closed: - return true - default: - return false - } -} - -func (c *Channel) waitOpened() error { - select { - case <-c.opened: - // Re-check the closed channel to prioritize closure. - if c.isClosed() { - return c.closeError - } - return nil - case <-c.closed: - return c.closeError - } -} diff --git a/peer/conn.go b/peer/conn.go deleted file mode 100644 index 2e67b500ee5fd..0000000000000 --- a/peer/conn.go +++ /dev/null @@ -1,616 +0,0 @@ -package peer - -import ( - "bytes" - "context" - "crypto/rand" - "io" - "sync" - "time" - - "github.com/pion/logging" - "github.com/pion/webrtc/v3" - "go.uber.org/atomic" - "golang.org/x/xerrors" - - "cdr.dev/slog" -) - -var ( - // ErrDisconnected occurs when the connection has disconnected. - // The connection will be attempting to reconnect at this point. - ErrDisconnected = xerrors.New("connection is disconnected") - // ErrFailed occurs when the connection has failed. - // The connection will not retry after this point. - ErrFailed = xerrors.New("connection has failed") - // ErrClosed occurs when the connection was closed. It wraps io.EOF - // to fulfill expected read errors from closed pipes. - ErrClosed = xerrors.Errorf("connection was closed: %w", io.EOF) - - // The amount of random bytes sent in a ping. - pingDataLength = 64 -) - -// Client creates a new client connection. -func Client(servers []webrtc.ICEServer, opts *ConnOptions) (*Conn, error) { - return newWithClientOrServer(servers, true, opts) -} - -// Server creates a new server connection. -func Server(servers []webrtc.ICEServer, opts *ConnOptions) (*Conn, error) { - return newWithClientOrServer(servers, false, opts) -} - -// newWithClientOrServer constructs a new connection with the client option. -// nolint:revive -func newWithClientOrServer(servers []webrtc.ICEServer, client bool, opts *ConnOptions) (*Conn, error) { - if opts == nil { - opts = &ConnOptions{} - } - - opts.SettingEngine.DetachDataChannels() - logger := logging.NewDefaultLoggerFactory() - logger.DefaultLogLevel = logging.LogLevelDisabled - opts.SettingEngine.LoggerFactory = logger - api := webrtc.NewAPI(webrtc.WithSettingEngine(opts.SettingEngine)) - rtc, err := api.NewPeerConnection(webrtc.Configuration{ - ICEServers: servers, - }) - if err != nil { - return nil, xerrors.Errorf("create peer connection: %w", err) - } - conn := &Conn{ - pingChannelID: 1, - pingEchoChannelID: 2, - rtc: rtc, - offerer: client, - closed: make(chan struct{}), - closedRTC: make(chan struct{}), - closedICE: make(chan struct{}), - dcOpenChannel: make(chan *webrtc.DataChannel, 8), - dcDisconnectChannel: make(chan struct{}), - dcFailedChannel: make(chan struct{}), - localCandidateChannel: make(chan webrtc.ICECandidateInit), - localSessionDescriptionChannel: make(chan webrtc.SessionDescription, 1), - negotiated: make(chan struct{}), - remoteSessionDescriptionChannel: make(chan webrtc.SessionDescription, 1), - settingEngine: opts.SettingEngine, - } - conn.loggerValue.Store(opts.Logger) - if client { - // If we're the client, we want to flip the echo and - // ping channel IDs so pings don't accidentally hit each other. - conn.pingChannelID, conn.pingEchoChannelID = conn.pingEchoChannelID, conn.pingChannelID - } - err = conn.init() - if err != nil { - return nil, xerrors.Errorf("init: %w", err) - } - return conn, nil -} - -type ConnOptions struct { - Logger slog.Logger - - // Enables customization on the underlying WebRTC connection. - SettingEngine webrtc.SettingEngine -} - -// Conn represents a WebRTC peer connection. -// -// This struct wraps webrtc.PeerConnection to add bidirectional pings, -// concurrent-safe webrtc.DataChannel, and standardized errors for connection state. -type Conn struct { - rtc *webrtc.PeerConnection - // Determines whether this connection will send the offer or the answer. - offerer bool - - closed chan struct{} - closedRTC chan struct{} - closedRTCMutex sync.Mutex - closedICE chan struct{} - closedICEMutex sync.Mutex - closeMutex sync.Mutex - closeError error - - dcCreateMutex sync.Mutex - dcOpenChannel chan *webrtc.DataChannel - dcDisconnectChannel chan struct{} - dcDisconnectListeners atomic.Uint32 - dcFailedChannel chan struct{} - dcFailedListeners atomic.Uint32 - dcClosedWaitGroup sync.WaitGroup - - localCandidateChannel chan webrtc.ICECandidateInit - localSessionDescriptionChannel chan webrtc.SessionDescription - remoteSessionDescriptionChannel chan webrtc.SessionDescription - - negotiated chan struct{} - - loggerValue atomic.Value - settingEngine webrtc.SettingEngine - - pingChannelID uint16 - pingEchoChannelID uint16 - - pingEchoChan *Channel - pingEchoOnce sync.Once - pingEchoError error - pingMutex sync.Mutex - pingOnce sync.Once - pingChan *Channel - pingError error -} - -func (c *Conn) logger() slog.Logger { - log, valid := c.loggerValue.Load().(slog.Logger) - if !valid { - return slog.Logger{} - } - - return log -} - -func (c *Conn) init() error { - c.rtc.OnNegotiationNeeded(c.negotiate) - c.rtc.OnICEConnectionStateChange(func(iceConnectionState webrtc.ICEConnectionState) { - c.closedICEMutex.Lock() - defer c.closedICEMutex.Unlock() - select { - case <-c.closedICE: - // Don't log more state changes if we've already closed. - return - default: - c.logger().Debug(context.Background(), "ice connection state updated", - slog.F("state", iceConnectionState)) - - if iceConnectionState == webrtc.ICEConnectionStateClosed { - // pion/webrtc can update this state multiple times. - // A connection can never become un-closed, so we - // close the channel if it isn't already. - close(c.closedICE) - } - } - }) - c.rtc.OnICEGatheringStateChange(func(iceGatherState webrtc.ICEGathererState) { - c.closedICEMutex.Lock() - defer c.closedICEMutex.Unlock() - select { - case <-c.closedICE: - // Don't log more state changes if we've already closed. - return - default: - c.logger().Debug(context.Background(), "ice gathering state updated", - slog.F("state", iceGatherState)) - - if iceGatherState == webrtc.ICEGathererStateClosed { - // pion/webrtc can update this state multiple times. - // A connection can never become un-closed, so we - // close the channel if it isn't already. - close(c.closedICE) - } - } - }) - c.rtc.OnConnectionStateChange(func(peerConnectionState webrtc.PeerConnectionState) { - go func() { - c.closeMutex.Lock() - defer c.closeMutex.Unlock() - if c.isClosed() { - return - } - c.logger().Debug(context.Background(), "rtc connection updated", - slog.F("state", peerConnectionState)) - }() - - switch peerConnectionState { - case webrtc.PeerConnectionStateDisconnected: - for i := 0; i < int(c.dcDisconnectListeners.Load()); i++ { - select { - case c.dcDisconnectChannel <- struct{}{}: - default: - } - } - case webrtc.PeerConnectionStateFailed: - for i := 0; i < int(c.dcFailedListeners.Load()); i++ { - select { - case c.dcFailedChannel <- struct{}{}: - default: - } - } - case webrtc.PeerConnectionStateClosed: - // pion/webrtc can update this state multiple times. - // A connection can never become un-closed, so we - // close the channel if it isn't already. - c.closedRTCMutex.Lock() - defer c.closedRTCMutex.Unlock() - select { - case <-c.closedRTC: - default: - close(c.closedRTC) - } - } - }) - - // These functions need to check if the conn is closed, because they can be - // called after being closed. - c.rtc.OnSignalingStateChange(func(signalState webrtc.SignalingState) { - c.logger().Debug(context.Background(), "signaling state updated", - slog.F("state", signalState)) - }) - c.rtc.SCTP().Transport().OnStateChange(func(dtlsTransportState webrtc.DTLSTransportState) { - c.logger().Debug(context.Background(), "dtls transport state updated", - slog.F("state", dtlsTransportState)) - }) - c.rtc.SCTP().Transport().ICETransport().OnSelectedCandidatePairChange(func(candidatePair *webrtc.ICECandidatePair) { - c.logger().Debug(context.Background(), "selected candidate pair changed", - slog.F("local", candidatePair.Local), slog.F("remote", candidatePair.Remote)) - }) - c.rtc.OnICECandidate(func(iceCandidate *webrtc.ICECandidate) { - if iceCandidate == nil { - return - } - // Run this in a goroutine so we don't block pion/webrtc - // from continuing. - go func() { - c.logger().Debug(context.Background(), "sending local candidate", slog.F("candidate", iceCandidate.ToJSON().Candidate)) - select { - case <-c.closed: - case c.localCandidateChannel <- iceCandidate.ToJSON(): - } - }() - }) - c.rtc.OnDataChannel(func(dc *webrtc.DataChannel) { - go func() { - select { - case <-c.closed: - case c.dcOpenChannel <- dc: - } - }() - }) - _, err := c.pingChannel() - if err != nil { - return err - } - _, err = c.pingEchoChannel() - if err != nil { - return err - } - - return nil -} - -// negotiate is triggered when a connection is ready to be established. -// See trickle ICE for the expected exchange: https://webrtchacks.com/trickle-ice/ -func (c *Conn) negotiate() { - c.logger().Debug(context.Background(), "negotiating") - // ICE candidates cannot be added until SessionDescriptions have been - // exchanged between peers. - defer func() { - select { - case <-c.negotiated: - default: - close(c.negotiated) - } - }() - - if c.offerer { - offer, err := c.rtc.CreateOffer(&webrtc.OfferOptions{}) - if err != nil { - _ = c.CloseWithError(xerrors.Errorf("create offer: %w", err)) - return - } - // pion/webrtc will panic if Close is called while this - // function is being executed. - c.closeMutex.Lock() - err = c.rtc.SetLocalDescription(offer) - c.closeMutex.Unlock() - if err != nil { - _ = c.CloseWithError(xerrors.Errorf("set local description: %w", err)) - return - } - c.logger().Debug(context.Background(), "sending offer", slog.F("offer", offer)) - select { - case <-c.closed: - return - case c.localSessionDescriptionChannel <- offer: - } - c.logger().Debug(context.Background(), "sent offer") - } - - var sessionDescription webrtc.SessionDescription - c.logger().Debug(context.Background(), "awaiting remote description...") - select { - case <-c.closed: - return - case sessionDescription = <-c.remoteSessionDescriptionChannel: - } - c.logger().Debug(context.Background(), "setting remote description") - - err := c.rtc.SetRemoteDescription(sessionDescription) - if err != nil { - _ = c.CloseWithError(xerrors.Errorf("set remote description (closed %v): %w", c.isClosed(), err)) - return - } - - if !c.offerer { - answer, err := c.rtc.CreateAnswer(&webrtc.AnswerOptions{}) - if err != nil { - _ = c.CloseWithError(xerrors.Errorf("create answer: %w", err)) - return - } - // pion/webrtc will panic if Close is called while this - // function is being executed. - c.closeMutex.Lock() - err = c.rtc.SetLocalDescription(answer) - c.closeMutex.Unlock() - if err != nil { - _ = c.CloseWithError(xerrors.Errorf("set local description: %w", err)) - return - } - c.logger().Debug(context.Background(), "sending answer", slog.F("answer", answer)) - select { - case <-c.closed: - return - case c.localSessionDescriptionChannel <- answer: - } - c.logger().Debug(context.Background(), "sent answer") - } -} - -// AddRemoteCandidate adds a remote candidate to the RTC connection. -func (c *Conn) AddRemoteCandidate(i webrtc.ICECandidateInit) { - if c.isClosed() { - return - } - // This must occur in a goroutine to allow the SessionDescriptions - // to be exchanged first. - go func() { - select { - case <-c.closed: - case <-c.negotiated: - } - if c.isClosed() { - return - } - c.logger().Debug(context.Background(), "accepting candidate", slog.F("candidate", i.Candidate)) - err := c.rtc.AddICECandidate(i) - if err != nil { - if c.rtc.ConnectionState() == webrtc.PeerConnectionStateClosed { - return - } - _ = c.CloseWithError(xerrors.Errorf("accept candidate: %w", err)) - } - }() -} - -// SetRemoteSessionDescription sets the remote description for the WebRTC connection. -func (c *Conn) SetRemoteSessionDescription(sessionDescription webrtc.SessionDescription) { - select { - case <-c.closed: - case c.remoteSessionDescriptionChannel <- sessionDescription: - } -} - -// LocalSessionDescription returns a channel that emits a session description -// when one is required to be exchanged. -func (c *Conn) LocalSessionDescription() <-chan webrtc.SessionDescription { - return c.localSessionDescriptionChannel -} - -// LocalCandidate returns a channel that emits when a local candidate -// needs to be exchanged with a remote connection. -func (c *Conn) LocalCandidate() <-chan webrtc.ICECandidateInit { - return c.localCandidateChannel -} - -func (c *Conn) pingChannel() (*Channel, error) { - c.pingOnce.Do(func() { - c.pingChan, c.pingError = c.dialChannel(context.Background(), "ping", &ChannelOptions{ - ID: c.pingChannelID, - Negotiated: true, - OpenOnDisconnect: true, - }) - if c.pingError != nil { - return - } - }) - return c.pingChan, c.pingError -} - -func (c *Conn) pingEchoChannel() (*Channel, error) { - c.pingEchoOnce.Do(func() { - c.pingEchoChan, c.pingEchoError = c.dialChannel(context.Background(), "echo", &ChannelOptions{ - ID: c.pingEchoChannelID, - Negotiated: true, - OpenOnDisconnect: true, - }) - if c.pingEchoError != nil { - return - } - go func() { - for { - data := make([]byte, pingDataLength) - bytesRead, err := c.pingEchoChan.Read(data) - if err != nil { - _ = c.CloseWithError(xerrors.Errorf("read ping echo channel: %w", err)) - return - } - _, err = c.pingEchoChan.Write(data[:bytesRead]) - if err != nil { - _ = c.CloseWithError(xerrors.Errorf("write ping echo channel: %w", err)) - return - } - } - }() - }) - return c.pingEchoChan, c.pingEchoError -} - -// SetConfiguration applies options to the WebRTC connection. -// Generally used for updating transport options, like ICE servers. -func (c *Conn) SetConfiguration(configuration webrtc.Configuration) error { - return c.rtc.SetConfiguration(configuration) -} - -// Accept blocks waiting for a channel to be opened. -func (c *Conn) Accept(ctx context.Context) (*Channel, error) { - var dataChannel *webrtc.DataChannel - select { - case <-ctx.Done(): - return nil, ctx.Err() - case <-c.closed: - return nil, c.closeError - case dataChannel = <-c.dcOpenChannel: - } - - return newChannel(c, dataChannel, &ChannelOptions{}), nil -} - -// CreateChannel creates a new DataChannel. -func (c *Conn) CreateChannel(ctx context.Context, label string, opts *ChannelOptions) (*Channel, error) { - if opts == nil { - opts = &ChannelOptions{} - } - if opts.ID == c.pingChannelID || opts.ID == c.pingEchoChannelID { - return nil, xerrors.Errorf("datachannel id %d and %d are reserved for ping", c.pingChannelID, c.pingEchoChannelID) - } - return c.dialChannel(ctx, label, opts) -} - -func (c *Conn) dialChannel(ctx context.Context, label string, opts *ChannelOptions) (*Channel, error) { - // pion/webrtc is slower when opening multiple channels - // in parallel than it is sequentially. - c.dcCreateMutex.Lock() - defer c.dcCreateMutex.Unlock() - - c.logger().Debug(ctx, "creating data channel", slog.F("label", label), slog.F("opts", opts)) - var id *uint16 - if opts.ID != 0 { - id = &opts.ID - } - ordered := true - if opts.Unordered { - ordered = false - } - if opts.OpenOnDisconnect && !opts.Negotiated { - return nil, xerrors.New("OpenOnDisconnect is only allowed for Negotiated channels") - } - if c.isClosed() { - return nil, xerrors.Errorf("closed: %w", c.closeError) - } - - dataChannel, err := c.rtc.CreateDataChannel(label, &webrtc.DataChannelInit{ - ID: id, - Negotiated: &opts.Negotiated, - Ordered: &ordered, - Protocol: &opts.Protocol, - }) - if err != nil { - return nil, xerrors.Errorf("create data channel: %w", err) - } - return newChannel(c, dataChannel, opts), nil -} - -// Ping returns the duration it took to round-trip data. -// Multiple pings cannot occur at the same time, so this function will block. -func (c *Conn) Ping() (time.Duration, error) { - // Pings are not async, so we need a mutex. - c.pingMutex.Lock() - defer c.pingMutex.Unlock() - - ping, err := c.pingChannel() - if err != nil { - return 0, xerrors.Errorf("get ping channel: %w", err) - } - pingDataSent := make([]byte, pingDataLength) - _, err = rand.Read(pingDataSent) - if err != nil { - return 0, xerrors.Errorf("read random ping data: %w", err) - } - start := time.Now() - _, err = ping.Write(pingDataSent) - if err != nil { - return 0, xerrors.Errorf("send ping: %w", err) - } - c.logger().Debug(context.Background(), "wrote ping", - slog.F("connection_state", c.rtc.ConnectionState())) - - pingDataReceived := make([]byte, pingDataLength) - _, err = ping.Read(pingDataReceived) - if err != nil { - return 0, xerrors.Errorf("read ping: %w", err) - } - end := time.Now() - if !bytes.Equal(pingDataSent, pingDataReceived) { - return 0, xerrors.Errorf("ping data inconsistency sent != received") - } - return end.Sub(start), nil -} - -func (c *Conn) Closed() <-chan struct{} { - return c.closed -} - -// Close closes the connection and frees all associated resources. -func (c *Conn) Close() error { - return c.CloseWithError(nil) -} - -func (c *Conn) isClosed() bool { - select { - case <-c.closed: - return true - default: - return false - } -} - -// CloseWithError closes the connection; subsequent reads/writes will return the error err. -func (c *Conn) CloseWithError(err error) error { - c.closeMutex.Lock() - defer c.closeMutex.Unlock() - - if c.isClosed() { - return c.closeError - } - - logger := c.logger() - - logger.Debug(context.Background(), "closing conn with error", slog.Error(err)) - if err == nil { - c.closeError = ErrClosed - } else { - c.closeError = err - } - - if ch, _ := c.pingChannel(); ch != nil { - _ = ch.closeWithError(c.closeError) - } - // If the WebRTC connection has already been closed (due to failure or disconnect), - // this call will return an error that isn't typed. We don't check the error because - // closing an already closed connection isn't an issue for us. - _ = c.rtc.Close() - - // Waiting for pion/webrtc to report closed state on both of these - // ensures no goroutine leaks. - if c.rtc.ConnectionState() != webrtc.PeerConnectionStateNew { - logger.Debug(context.Background(), "waiting for rtc connection close...") - <-c.closedRTC - } - if c.rtc.ICEConnectionState() != webrtc.ICEConnectionStateNew { - logger.Debug(context.Background(), "waiting for ice connection close...") - <-c.closedICE - } - - // Waits for all DataChannels to exit before officially labeling as closed. - // All logging, goroutines, and async functionality is cleaned up after this. - c.dcClosedWaitGroup.Wait() - - // Disable logging! - c.loggerValue.Store(slog.Logger{}) - logger.Sync() - - logger.Debug(context.Background(), "closed") - close(c.closed) - return err -} diff --git a/peer/conn_test.go b/peer/conn_test.go deleted file mode 100644 index 992765b940c74..0000000000000 --- a/peer/conn_test.go +++ /dev/null @@ -1,434 +0,0 @@ -package peer_test - -import ( - "context" - "io" - "net" - "net/http" - "os" - "sync" - "testing" - "time" - - "github.com/pion/logging" - "github.com/pion/transport/vnet" - "github.com/pion/webrtc/v3" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "go.uber.org/goleak" - "golang.org/x/xerrors" - - "cdr.dev/slog" - "cdr.dev/slog/sloggers/slogtest" - "github.com/coder/coder/peer" - "github.com/coder/coder/testutil" -) - -var ( - disconnectedTimeout = func() time.Duration { - // Connection state is unfortunately time-based. When resources are - // contended, a connection can take greater than this timeout to - // handshake, which results in a test flake. - // - // During local testing resources are rarely contended. Reducing this - // timeout leads to faster local development. - // - // In CI resources are frequently contended, so increasing this value - // results in less flakes. - if os.Getenv("CI") == "true" { - return time.Second - } - return 100 * time.Millisecond - }() - failedTimeout = disconnectedTimeout * 3 - keepAliveInterval = time.Millisecond * 2 - - // There's a global race in the vnet library allocation code. - // This mutex locks around the creation of the vnet. - vnetMutex = sync.Mutex{} -) - -func TestMain(m *testing.M) { - // pion/ice doesn't properly close immediately. The solution for this isn't yet known. See: - // https://github.com/pion/ice/pull/413 - goleak.VerifyTestMain(m, - goleak.IgnoreTopFunction("github.com/pion/ice/v2.(*Agent).startOnConnectionStateChangeRoutine.func1"), - goleak.IgnoreTopFunction("github.com/pion/ice/v2.(*Agent).startOnConnectionStateChangeRoutine.func2"), - goleak.IgnoreTopFunction("github.com/pion/ice/v2.(*Agent).taskLoop"), - goleak.IgnoreTopFunction("internal/poll.runtime_pollWait"), - ) -} - -func TestConn(t *testing.T) { - t.Parallel() - t.Run("Ping", func(t *testing.T) { - t.Parallel() - client, server, _ := createPair(t) - exchange(t, client, server) - _, err := client.Ping() - require.NoError(t, err) - _, err = server.Ping() - require.NoError(t, err) - }) - - t.Run("PingNetworkOffline", func(t *testing.T) { - t.Parallel() - client, server, wan := createPair(t) - exchange(t, client, server) - _, err := server.Ping() - require.NoError(t, err) - err = wan.Stop() - require.NoError(t, err) - _, err = server.Ping() - require.ErrorIs(t, err, peer.ErrFailed) - }) - - t.Run("PingReconnect", func(t *testing.T) { - t.Parallel() - client, server, wan := createPair(t) - exchange(t, client, server) - _, err := server.Ping() - require.NoError(t, err) - // Create a channel that closes on disconnect. - channel, err := server.CreateChannel(context.Background(), "wow", nil) - assert.NoError(t, err) - defer channel.Close() - - err = wan.Stop() - require.NoError(t, err) - // Once the connection is marked as disconnected, this - // channel will be closed. - _, err = channel.Read(make([]byte, 4)) - assert.ErrorIs(t, err, peer.ErrClosed) - err = wan.Start() - require.NoError(t, err) - _, err = server.Ping() - require.NoError(t, err) - }) - - t.Run("Accept", func(t *testing.T) { - t.Parallel() - client, server, _ := createPair(t) - exchange(t, client, server) - ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) - defer cancel() - cch, err := client.CreateChannel(ctx, "hello", &peer.ChannelOptions{}) - require.NoError(t, err) - defer cch.Close() - - sch, err := server.Accept(ctx) - require.NoError(t, err) - defer sch.Close() - - _ = cch.Close() - _, err = sch.Read(make([]byte, 4)) - require.ErrorIs(t, err, peer.ErrClosed) - }) - - t.Run("AcceptNetworkOffline", func(t *testing.T) { - t.Parallel() - client, server, wan := createPair(t) - exchange(t, client, server) - ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) - defer cancel() - cch, err := client.CreateChannel(ctx, "hello", &peer.ChannelOptions{}) - require.NoError(t, err) - defer cch.Close() - sch, err := server.Accept(ctx) - require.NoError(t, err) - defer sch.Close() - - err = wan.Stop() - require.NoError(t, err) - _ = cch.Close() - _, err = sch.Read(make([]byte, 4)) - require.ErrorIs(t, err, peer.ErrClosed) - }) - - t.Run("Buffering", func(t *testing.T) { - t.Parallel() - client, server, _ := createPair(t) - exchange(t, client, server) - ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) - defer cancel() - cch, err := client.CreateChannel(ctx, "hello", &peer.ChannelOptions{}) - require.NoError(t, err) - defer cch.Close() - - readErr := make(chan error, 1) - go func() { - sch, err := server.Accept(ctx) - if err != nil { - readErr <- err - _ = cch.Close() - return - } - defer sch.Close() - - bytes := make([]byte, 4096) - for { - _, err = sch.Read(bytes) - if err != nil { - readErr <- err - return - } - } - }() - - bytes := make([]byte, 4096) - for i := 0; i < 1024; i++ { - _, err = cch.Write(bytes) - require.NoError(t, err, "write i=%d", i) - } - _ = cch.Close() - - select { - case err = <-readErr: - require.ErrorIs(t, err, peer.ErrClosed, "read error") - case <-ctx.Done(): - require.Fail(t, "timeout waiting for read error") - } - }) - - t.Run("NetConn", func(t *testing.T) { - t.Parallel() - client, server, _ := createPair(t) - exchange(t, client, server) - srv, err := net.Listen("tcp", "127.0.0.1:0") - require.NoError(t, err) - defer srv.Close() - ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) - defer cancel() - go func() { - sch, err := server.Accept(ctx) - if err != nil { - assert.NoError(t, err) - return - } - defer sch.Close() - - nc2 := sch.NetConn() - defer nc2.Close() - - nc1, err := net.Dial("tcp", srv.Addr().String()) - if err != nil { - assert.NoError(t, err) - return - } - defer nc1.Close() - - go func() { - defer nc1.Close() - defer nc2.Close() - _, _ = io.Copy(nc1, nc2) - }() - _, _ = io.Copy(nc2, nc1) - }() - go func() { - server := http.Server{ - ReadHeaderTimeout: time.Minute, - Handler: http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { - rw.WriteHeader(200) - }), - } - defer server.Close() - _ = server.Serve(srv) - }() - - //nolint:forcetypeassert - defaultTransport := http.DefaultTransport.(*http.Transport).Clone() - var cch *peer.Channel - defaultTransport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) { - cch, err = client.CreateChannel(ctx, "hello", &peer.ChannelOptions{}) - if err != nil { - return nil, err - } - return cch.NetConn(), nil - } - c := http.Client{ - Transport: defaultTransport, - } - req, err := http.NewRequestWithContext(ctx, "GET", "http://localhost/", nil) - require.NoError(t, err) - resp, err := c.Do(req) - require.NoError(t, err) - defer resp.Body.Close() - require.Equal(t, resp.StatusCode, 200) - // Triggers any connections to close. - // This test below ensures the DataChannel actually closes. - defaultTransport.CloseIdleConnections() - err = cch.Close() - require.ErrorIs(t, err, peer.ErrClosed) - }) - - t.Run("CloseBeforeNegotiate", func(t *testing.T) { - t.Parallel() - client, server, _ := createPair(t) - exchange(t, client, server) - err := client.Close() - require.NoError(t, err) - err = server.Close() - require.NoError(t, err) - }) - - t.Run("CloseWithError", func(t *testing.T) { - t.Parallel() - conn, err := peer.Client([]webrtc.ICEServer{}, nil) - require.NoError(t, err) - expectedErr := xerrors.New("wow") - _ = conn.CloseWithError(expectedErr) - _, err = conn.CreateChannel(context.Background(), "", nil) - require.ErrorIs(t, err, expectedErr) - }) - - t.Run("PingConcurrent", func(t *testing.T) { - t.Parallel() - client, server, _ := createPair(t) - exchange(t, client, server) - var wg sync.WaitGroup - wg.Add(2) - go func() { - defer wg.Done() - _, err := client.Ping() - assert.NoError(t, err) - }() - go func() { - defer wg.Done() - _, err := server.Ping() - assert.NoError(t, err) - }() - wg.Wait() - }) - - t.Run("CandidateBeforeSessionDescription", func(t *testing.T) { - t.Parallel() - client, server, _ := createPair(t) - server.SetRemoteSessionDescription(<-client.LocalSessionDescription()) - sdp := <-server.LocalSessionDescription() - client.AddRemoteCandidate(<-server.LocalCandidate()) - client.SetRemoteSessionDescription(sdp) - server.AddRemoteCandidate(<-client.LocalCandidate()) - _, err := client.Ping() - require.NoError(t, err) - }) - - t.Run("ShortBuffer", func(t *testing.T) { - t.Parallel() - client, server, _ := createPair(t) - exchange(t, client, server) - ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) - defer cancel() - go func() { - channel, err := client.CreateChannel(ctx, "test", nil) - if err != nil { - assert.NoError(t, err) - return - } - defer channel.Close() - _, err = channel.Write([]byte{1, 2}) - assert.NoError(t, err) - }() - channel, err := server.Accept(ctx) - require.NoError(t, err) - defer channel.Close() - data := make([]byte, 1) - _, err = channel.Read(data) - require.NoError(t, err) - require.Equal(t, uint8(0x1), data[0]) - _, err = channel.Read(data) - require.NoError(t, err) - require.Equal(t, uint8(0x2), data[0]) - }) -} - -func createPair(t *testing.T) (client *peer.Conn, server *peer.Conn, wan *vnet.Router) { - loggingFactory := logging.NewDefaultLoggerFactory() - loggingFactory.DefaultLogLevel = logging.LogLevelDisabled - vnetMutex.Lock() - defer vnetMutex.Unlock() - wan, err := vnet.NewRouter(&vnet.RouterConfig{ - CIDR: "1.2.3.0/24", - LoggerFactory: loggingFactory, - }) - require.NoError(t, err) - c1Net := vnet.NewNet(&vnet.NetConfig{ - StaticIPs: []string{"1.2.3.4"}, - }) - err = wan.AddNet(c1Net) - require.NoError(t, err) - c2Net := vnet.NewNet(&vnet.NetConfig{ - StaticIPs: []string{"1.2.3.5"}, - }) - err = wan.AddNet(c2Net) - require.NoError(t, err) - - c1SettingEngine := webrtc.SettingEngine{} - c1SettingEngine.SetVNet(c1Net) - c1SettingEngine.SetPrflxAcceptanceMinWait(0) - c1SettingEngine.SetICETimeouts(disconnectedTimeout, failedTimeout, keepAliveInterval) - channel1, err := peer.Client([]webrtc.ICEServer{{}}, &peer.ConnOptions{ - SettingEngine: c1SettingEngine, - Logger: slogtest.Make(t, nil).Named("client").Leveled(slog.LevelDebug), - }) - require.NoError(t, err) - t.Cleanup(func() { - channel1.Close() - }) - c2SettingEngine := webrtc.SettingEngine{} - c2SettingEngine.SetVNet(c2Net) - c2SettingEngine.SetPrflxAcceptanceMinWait(0) - c2SettingEngine.SetICETimeouts(disconnectedTimeout, failedTimeout, keepAliveInterval) - channel2, err := peer.Server([]webrtc.ICEServer{{}}, &peer.ConnOptions{ - SettingEngine: c2SettingEngine, - Logger: slogtest.Make(t, nil).Named("server").Leveled(slog.LevelDebug), - }) - require.NoError(t, err) - t.Cleanup(func() { - channel2.Close() - }) - - err = wan.Start() - require.NoError(t, err) - t.Cleanup(func() { - _ = wan.Stop() - }) - - return channel1, channel2, wan -} - -func exchange(t *testing.T, client, server *peer.Conn) { - var wg sync.WaitGroup - wg.Add(2) - t.Cleanup(func() { - _ = client.Close() - _ = server.Close() - - wg.Wait() - }) - go func() { - defer wg.Done() - for { - select { - case c := <-server.LocalCandidate(): - client.AddRemoteCandidate(c) - case c := <-server.LocalSessionDescription(): - client.SetRemoteSessionDescription(c) - case <-server.Closed(): - return - } - } - }() - go func() { - defer wg.Done() - for { - select { - case c := <-client.LocalCandidate(): - server.AddRemoteCandidate(c) - case c := <-client.LocalSessionDescription(): - server.SetRemoteSessionDescription(c) - case <-client.Closed(): - return - } - } - }() -} diff --git a/peer/netconn.go b/peer/netconn.go deleted file mode 100644 index e564c0ecc209c..0000000000000 --- a/peer/netconn.go +++ /dev/null @@ -1,59 +0,0 @@ -package peer - -import ( - "net" - "time" -) - -type peerAddr struct{} - -// Statically checks if we properly implement net.Addr. -var _ net.Addr = &peerAddr{} - -func (*peerAddr) Network() string { - return "peer" -} - -func (*peerAddr) String() string { - return "peer/unknown-addr" -} - -type fakeNetConn struct { - c *Channel - addr *peerAddr -} - -// Statically checks if we properly implement net.Conn. -var _ net.Conn = &fakeNetConn{} - -func (c *fakeNetConn) Read(b []byte) (n int, err error) { - return c.c.Read(b) -} - -func (c *fakeNetConn) Write(b []byte) (n int, err error) { - return c.c.Write(b) -} - -func (c *fakeNetConn) Close() error { - return c.c.Close() -} - -func (c *fakeNetConn) LocalAddr() net.Addr { - return c.addr -} - -func (c *fakeNetConn) RemoteAddr() net.Addr { - return c.addr -} - -func (*fakeNetConn) SetDeadline(_ time.Time) error { - return nil -} - -func (*fakeNetConn) SetReadDeadline(_ time.Time) error { - return nil -} - -func (*fakeNetConn) SetWriteDeadline(_ time.Time) error { - return nil -} diff --git a/peerbroker/dial.go b/peerbroker/dial.go deleted file mode 100644 index 61ef7b409a597..0000000000000 --- a/peerbroker/dial.go +++ /dev/null @@ -1,87 +0,0 @@ -package peerbroker - -import ( - "context" - "errors" - "io" - "reflect" - - "github.com/pion/webrtc/v3" - "golang.org/x/xerrors" - - "github.com/coder/coder/peer" - "github.com/coder/coder/peerbroker/proto" -) - -// Dial consumes the PeerBroker gRPC connection negotiation stream to produce a WebRTC peered connection. -func Dial(stream proto.DRPCPeerBroker_NegotiateConnectionClient, iceServers []webrtc.ICEServer, opts *peer.ConnOptions) (*peer.Conn, error) { - peerConn, err := peer.Client(iceServers, opts) - if err != nil { - return nil, xerrors.Errorf("create peer connection: %w", err) - } - go func() { - defer stream.Close() - // Exchanging messages from the peer connection to negotiate a connection. - for { - select { - case <-peerConn.Closed(): - return - case sessionDescription := <-peerConn.LocalSessionDescription(): - err = stream.Send(&proto.Exchange{ - Message: &proto.Exchange_Sdp{ - Sdp: &proto.WebRTCSessionDescription{ - SdpType: int32(sessionDescription.Type), - Sdp: sessionDescription.SDP, - }, - }, - }) - if err != nil { - _ = peerConn.CloseWithError(xerrors.Errorf("send local session description: %w", err)) - return - } - case iceCandidate := <-peerConn.LocalCandidate(): - err = stream.Send(&proto.Exchange{ - Message: &proto.Exchange_IceCandidate{ - IceCandidate: iceCandidate.Candidate, - }, - }) - if err != nil { - _ = peerConn.CloseWithError(xerrors.Errorf("send local candidate: %w", err)) - return - } - } - } - }() - go func() { - // Exchanging messages from the server to negotiate a connection. - for { - serverToClientMessage, err := stream.Recv() - if err != nil { - // p2p connections should never die if this stream does due - // to proper closure or context cancellation! - if errors.Is(err, io.EOF) || errors.Is(err, context.Canceled) { - return - } - _ = peerConn.CloseWithError(xerrors.Errorf("recv: %w", err)) - return - } - - switch { - case serverToClientMessage.GetSdp() != nil: - peerConn.SetRemoteSessionDescription(webrtc.SessionDescription{ - Type: webrtc.SDPType(serverToClientMessage.GetSdp().SdpType), - SDP: serverToClientMessage.GetSdp().Sdp, - }) - case serverToClientMessage.GetIceCandidate() != "": - peerConn.AddRemoteCandidate(webrtc.ICECandidateInit{ - Candidate: serverToClientMessage.GetIceCandidate(), - }) - default: - _ = peerConn.CloseWithError(xerrors.Errorf("unhandled message: %s", reflect.TypeOf(serverToClientMessage).String())) - return - } - } - }() - - return peerConn, nil -} diff --git a/peerbroker/dial_test.go b/peerbroker/dial_test.go deleted file mode 100644 index efd4e6917ac41..0000000000000 --- a/peerbroker/dial_test.go +++ /dev/null @@ -1,67 +0,0 @@ -package peerbroker_test - -import ( - "context" - "testing" - - "github.com/pion/webrtc/v3" - "github.com/stretchr/testify/require" - "go.uber.org/goleak" - - "cdr.dev/slog" - "cdr.dev/slog/sloggers/slogtest" - - "github.com/coder/coder/peer" - "github.com/coder/coder/peerbroker" - "github.com/coder/coder/peerbroker/proto" - "github.com/coder/coder/provisionersdk" -) - -func TestMain(m *testing.M) { - goleak.VerifyTestMain(m) -} - -func TestDial(t *testing.T) { - t.Parallel() - - t.Run("Connect", func(t *testing.T) { - t.Parallel() - ctx := context.Background() - client, server := provisionersdk.TransportPipe() - defer client.Close() - defer server.Close() - - settingEngine := webrtc.SettingEngine{} - listener, err := peerbroker.Listen(server, func(ctx context.Context) ([]webrtc.ICEServer, *peer.ConnOptions, error) { - return []webrtc.ICEServer{{ - URLs: []string{"stun:stun.l.google.com:19302"}, - }}, &peer.ConnOptions{ - Logger: slogtest.Make(t, nil).Named("server").Leveled(slog.LevelDebug), - SettingEngine: settingEngine, - }, nil - }) - require.NoError(t, err) - - api := proto.NewDRPCPeerBrokerClient(provisionersdk.Conn(client)) - stream, err := api.NegotiateConnection(ctx) - require.NoError(t, err) - - clientConn, err := peerbroker.Dial(stream, []webrtc.ICEServer{{ - URLs: []string{"stun:stun.l.google.com:19302"}, - }}, &peer.ConnOptions{ - Logger: slogtest.Make(t, nil).Named("client").Leveled(slog.LevelDebug), - SettingEngine: settingEngine, - }) - require.NoError(t, err) - defer clientConn.Close() - - serverConn, err := listener.Accept() - require.NoError(t, err) - defer serverConn.Close() - _, err = serverConn.Ping() - require.NoError(t, err) - - _, err = clientConn.Ping() - require.NoError(t, err) - }) -} diff --git a/peerbroker/listen.go b/peerbroker/listen.go deleted file mode 100644 index 34c91ea6e51a4..0000000000000 --- a/peerbroker/listen.go +++ /dev/null @@ -1,188 +0,0 @@ -package peerbroker - -import ( - "context" - "errors" - "io" - "net" - "reflect" - "sync" - - "github.com/pion/webrtc/v3" - "golang.org/x/xerrors" - "storj.io/drpc/drpcmux" - "storj.io/drpc/drpcserver" - - "github.com/coder/coder/peer" - "github.com/coder/coder/peerbroker/proto" -) - -// ConnSettingsFunc returns initialization options for a connection -type ConnSettingsFunc func(ctx context.Context) ([]webrtc.ICEServer, *peer.ConnOptions, error) - -// Listen consumes the transport as the server-side of the PeerBroker dRPC service. -// The Accept function must be serviced, or new connections will hang. -func Listen(connListener net.Listener, connSettingsFunc ConnSettingsFunc) (*Listener, error) { - if connSettingsFunc == nil { - connSettingsFunc = func(ctx context.Context) ([]webrtc.ICEServer, *peer.ConnOptions, error) { - return []webrtc.ICEServer{}, nil, nil - } - } - ctx, cancelFunc := context.WithCancel(context.Background()) - listener := &Listener{ - connectionChannel: make(chan *peer.Conn), - connectionListener: connListener, - - closeFunc: cancelFunc, - closed: make(chan struct{}), - } - - mux := drpcmux.New() - err := proto.DRPCRegisterPeerBroker(mux, &peerBrokerService{ - connSettingsFunc: connSettingsFunc, - - listener: listener, - }) - if err != nil { - return nil, xerrors.Errorf("register peer broker: %w", err) - } - srv := drpcserver.New(mux) - go func() { - err := srv.Serve(ctx, connListener) - _ = listener.closeWithError(err) - }() - - return listener, nil -} - -type Listener struct { - connectionChannel chan *peer.Conn - connectionListener net.Listener - - closeFunc context.CancelFunc - closed chan struct{} - closeMutex sync.Mutex - closeError error -} - -// Accept blocks until a connection arrives or the listener is closed. -func (l *Listener) Accept() (*peer.Conn, error) { - select { - case <-l.closed: - return nil, l.closeError - case conn := <-l.connectionChannel: - return conn, nil - } -} - -// Close ends the listener. This will block all new WebRTC connections -// from establishing, but will not close active connections. -func (l *Listener) Close() error { - return l.closeWithError(io.EOF) -} - -func (l *Listener) closeWithError(err error) error { - l.closeMutex.Lock() - defer l.closeMutex.Unlock() - - if l.isClosed() { - return l.closeError - } - - _ = l.connectionListener.Close() - l.closeError = err - l.closeFunc() - close(l.closed) - - return nil -} - -func (l *Listener) isClosed() bool { - select { - case <-l.closed: - return true - default: - return false - } -} - -// Implements the PeerBroker service protobuf definition. -type peerBrokerService struct { - listener *Listener - - connSettingsFunc ConnSettingsFunc -} - -// NegotiateConnection negotiates a WebRTC connection. -func (b *peerBrokerService) NegotiateConnection(stream proto.DRPCPeerBroker_NegotiateConnectionStream) error { - iceServers, connOptions, err := b.connSettingsFunc(stream.Context()) - if err != nil { - return xerrors.Errorf("get connection settings: %w", err) - } - peerConn, err := peer.Server(iceServers, connOptions) - if err != nil { - return xerrors.Errorf("create peer connection: %w", err) - } - select { - case <-b.listener.closed: - return peerConn.CloseWithError(b.listener.closeError) - case b.listener.connectionChannel <- peerConn: - } - go func() { - defer stream.Close() - for { - select { - case <-peerConn.Closed(): - return - case sessionDescription := <-peerConn.LocalSessionDescription(): - err = stream.Send(&proto.Exchange{ - Message: &proto.Exchange_Sdp{ - Sdp: &proto.WebRTCSessionDescription{ - SdpType: int32(sessionDescription.Type), - Sdp: sessionDescription.SDP, - }, - }, - }) - if err != nil { - _ = peerConn.CloseWithError(xerrors.Errorf("send local session description: %w", err)) - return - } - case iceCandidate := <-peerConn.LocalCandidate(): - err = stream.Send(&proto.Exchange{ - Message: &proto.Exchange_IceCandidate{ - IceCandidate: iceCandidate.Candidate, - }, - }) - if err != nil { - _ = peerConn.CloseWithError(xerrors.Errorf("send local candidate: %w", err)) - return - } - } - } - }() - for { - clientToServerMessage, err := stream.Recv() - if err != nil { - // p2p connections should never die if this stream does due - // to proper closure or context cancellation! - if errors.Is(err, io.EOF) || errors.Is(err, context.Canceled) { - return nil - } - return peerConn.CloseWithError(xerrors.Errorf("recv: %w", err)) - } - - switch { - case clientToServerMessage.GetSdp() != nil: - peerConn.SetRemoteSessionDescription(webrtc.SessionDescription{ - Type: webrtc.SDPType(clientToServerMessage.GetSdp().SdpType), - SDP: clientToServerMessage.GetSdp().Sdp, - }) - case clientToServerMessage.GetIceCandidate() != "": - peerConn.AddRemoteCandidate(webrtc.ICECandidateInit{ - Candidate: clientToServerMessage.GetIceCandidate(), - }) - default: - return peerConn.CloseWithError(xerrors.Errorf("unhandled message: %s", reflect.TypeOf(clientToServerMessage).String())) - } - } -} diff --git a/peerbroker/listen_test.go b/peerbroker/listen_test.go deleted file mode 100644 index 81582a91d4b84..0000000000000 --- a/peerbroker/listen_test.go +++ /dev/null @@ -1,52 +0,0 @@ -package peerbroker_test - -import ( - "context" - "io" - "testing" - - "github.com/stretchr/testify/require" - - "github.com/coder/coder/peerbroker" - "github.com/coder/coder/peerbroker/proto" - "github.com/coder/coder/provisionersdk" -) - -func TestListen(t *testing.T) { - t.Parallel() - // Ensures connections blocked on Accept() are - // closed if the listener is. - t.Run("NoAcceptClosed", func(t *testing.T) { - t.Parallel() - ctx := context.Background() - client, server := provisionersdk.TransportPipe() - defer client.Close() - defer server.Close() - - listener, err := peerbroker.Listen(server, nil) - require.NoError(t, err) - - api := proto.NewDRPCPeerBrokerClient(provisionersdk.Conn(client)) - stream, err := api.NegotiateConnection(ctx) - require.NoError(t, err) - clientConn, err := peerbroker.Dial(stream, nil, nil) - require.NoError(t, err) - defer clientConn.Close() - - _ = listener.Close() - }) - - // Ensures Accept() properly exits when Close() is called. - t.Run("AcceptClosed", func(t *testing.T) { - t.Parallel() - client, server := provisionersdk.TransportPipe() - defer client.Close() - defer server.Close() - - listener, err := peerbroker.Listen(server, nil) - require.NoError(t, err) - go listener.Close() - _, err = listener.Accept() - require.ErrorIs(t, err, io.EOF) - }) -} diff --git a/peerbroker/proto/peerbroker.pb.go b/peerbroker/proto/peerbroker.pb.go deleted file mode 100644 index d4e09f44be118..0000000000000 --- a/peerbroker/proto/peerbroker.pb.go +++ /dev/null @@ -1,269 +0,0 @@ -// Code generated by protoc-gen-go. DO NOT EDIT. -// versions: -// protoc-gen-go v1.26.0 -// protoc v3.21.5 -// source: peerbroker/proto/peerbroker.proto - -package proto - -import ( - protoreflect "google.golang.org/protobuf/reflect/protoreflect" - protoimpl "google.golang.org/protobuf/runtime/protoimpl" - reflect "reflect" - sync "sync" -) - -const ( - // Verify that this generated code is sufficiently up-to-date. - _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) - // Verify that runtime/protoimpl is sufficiently up-to-date. - _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) -) - -type WebRTCSessionDescription struct { - state protoimpl.MessageState - sizeCache protoimpl.SizeCache - unknownFields protoimpl.UnknownFields - - SdpType int32 `protobuf:"varint,1,opt,name=sdp_type,json=sdpType,proto3" json:"sdp_type,omitempty"` - Sdp string `protobuf:"bytes,2,opt,name=sdp,proto3" json:"sdp,omitempty"` -} - -func (x *WebRTCSessionDescription) Reset() { - *x = WebRTCSessionDescription{} - if protoimpl.UnsafeEnabled { - mi := &file_peerbroker_proto_peerbroker_proto_msgTypes[0] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) - } -} - -func (x *WebRTCSessionDescription) String() string { - return protoimpl.X.MessageStringOf(x) -} - -func (*WebRTCSessionDescription) ProtoMessage() {} - -func (x *WebRTCSessionDescription) ProtoReflect() protoreflect.Message { - mi := &file_peerbroker_proto_peerbroker_proto_msgTypes[0] - if protoimpl.UnsafeEnabled && x != nil { - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - if ms.LoadMessageInfo() == nil { - ms.StoreMessageInfo(mi) - } - return ms - } - return mi.MessageOf(x) -} - -// Deprecated: Use WebRTCSessionDescription.ProtoReflect.Descriptor instead. -func (*WebRTCSessionDescription) Descriptor() ([]byte, []int) { - return file_peerbroker_proto_peerbroker_proto_rawDescGZIP(), []int{0} -} - -func (x *WebRTCSessionDescription) GetSdpType() int32 { - if x != nil { - return x.SdpType - } - return 0 -} - -func (x *WebRTCSessionDescription) GetSdp() string { - if x != nil { - return x.Sdp - } - return "" -} - -type Exchange struct { - state protoimpl.MessageState - sizeCache protoimpl.SizeCache - unknownFields protoimpl.UnknownFields - - // Types that are assignable to Message: - // - // *Exchange_Sdp - // *Exchange_IceCandidate - Message isExchange_Message `protobuf_oneof:"message"` -} - -func (x *Exchange) Reset() { - *x = Exchange{} - if protoimpl.UnsafeEnabled { - mi := &file_peerbroker_proto_peerbroker_proto_msgTypes[1] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) - } -} - -func (x *Exchange) String() string { - return protoimpl.X.MessageStringOf(x) -} - -func (*Exchange) ProtoMessage() {} - -func (x *Exchange) ProtoReflect() protoreflect.Message { - mi := &file_peerbroker_proto_peerbroker_proto_msgTypes[1] - if protoimpl.UnsafeEnabled && x != nil { - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - if ms.LoadMessageInfo() == nil { - ms.StoreMessageInfo(mi) - } - return ms - } - return mi.MessageOf(x) -} - -// Deprecated: Use Exchange.ProtoReflect.Descriptor instead. -func (*Exchange) Descriptor() ([]byte, []int) { - return file_peerbroker_proto_peerbroker_proto_rawDescGZIP(), []int{1} -} - -func (m *Exchange) GetMessage() isExchange_Message { - if m != nil { - return m.Message - } - return nil -} - -func (x *Exchange) GetSdp() *WebRTCSessionDescription { - if x, ok := x.GetMessage().(*Exchange_Sdp); ok { - return x.Sdp - } - return nil -} - -func (x *Exchange) GetIceCandidate() string { - if x, ok := x.GetMessage().(*Exchange_IceCandidate); ok { - return x.IceCandidate - } - return "" -} - -type isExchange_Message interface { - isExchange_Message() -} - -type Exchange_Sdp struct { - Sdp *WebRTCSessionDescription `protobuf:"bytes,1,opt,name=sdp,proto3,oneof"` -} - -type Exchange_IceCandidate struct { - IceCandidate string `protobuf:"bytes,2,opt,name=ice_candidate,json=iceCandidate,proto3,oneof"` -} - -func (*Exchange_Sdp) isExchange_Message() {} - -func (*Exchange_IceCandidate) isExchange_Message() {} - -var File_peerbroker_proto_peerbroker_proto protoreflect.FileDescriptor - -var file_peerbroker_proto_peerbroker_proto_rawDesc = []byte{ - 0x0a, 0x21, 0x70, 0x65, 0x65, 0x72, 0x62, 0x72, 0x6f, 0x6b, 0x65, 0x72, 0x2f, 0x70, 0x72, 0x6f, - 0x74, 0x6f, 0x2f, 0x70, 0x65, 0x65, 0x72, 0x62, 0x72, 0x6f, 0x6b, 0x65, 0x72, 0x2e, 0x70, 0x72, - 0x6f, 0x74, 0x6f, 0x12, 0x0a, 0x70, 0x65, 0x65, 0x72, 0x62, 0x72, 0x6f, 0x6b, 0x65, 0x72, 0x22, - 0x47, 0x0a, 0x18, 0x57, 0x65, 0x62, 0x52, 0x54, 0x43, 0x53, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, - 0x44, 0x65, 0x73, 0x63, 0x72, 0x69, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x19, 0x0a, 0x08, 0x73, - 0x64, 0x70, 0x5f, 0x74, 0x79, 0x70, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x05, 0x52, 0x07, 0x73, - 0x64, 0x70, 0x54, 0x79, 0x70, 0x65, 0x12, 0x10, 0x0a, 0x03, 0x73, 0x64, 0x70, 0x18, 0x02, 0x20, - 0x01, 0x28, 0x09, 0x52, 0x03, 0x73, 0x64, 0x70, 0x22, 0x76, 0x0a, 0x08, 0x45, 0x78, 0x63, 0x68, - 0x61, 0x6e, 0x67, 0x65, 0x12, 0x38, 0x0a, 0x03, 0x73, 0x64, 0x70, 0x18, 0x01, 0x20, 0x01, 0x28, - 0x0b, 0x32, 0x24, 0x2e, 0x70, 0x65, 0x65, 0x72, 0x62, 0x72, 0x6f, 0x6b, 0x65, 0x72, 0x2e, 0x57, - 0x65, 0x62, 0x52, 0x54, 0x43, 0x53, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x44, 0x65, 0x73, 0x63, - 0x72, 0x69, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x48, 0x00, 0x52, 0x03, 0x73, 0x64, 0x70, 0x12, 0x25, - 0x0a, 0x0d, 0x69, 0x63, 0x65, 0x5f, 0x63, 0x61, 0x6e, 0x64, 0x69, 0x64, 0x61, 0x74, 0x65, 0x18, - 0x02, 0x20, 0x01, 0x28, 0x09, 0x48, 0x00, 0x52, 0x0c, 0x69, 0x63, 0x65, 0x43, 0x61, 0x6e, 0x64, - 0x69, 0x64, 0x61, 0x74, 0x65, 0x42, 0x09, 0x0a, 0x07, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, - 0x32, 0x53, 0x0a, 0x0a, 0x50, 0x65, 0x65, 0x72, 0x42, 0x72, 0x6f, 0x6b, 0x65, 0x72, 0x12, 0x45, - 0x0a, 0x13, 0x4e, 0x65, 0x67, 0x6f, 0x74, 0x69, 0x61, 0x74, 0x65, 0x43, 0x6f, 0x6e, 0x6e, 0x65, - 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x14, 0x2e, 0x70, 0x65, 0x65, 0x72, 0x62, 0x72, 0x6f, 0x6b, - 0x65, 0x72, 0x2e, 0x45, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x1a, 0x14, 0x2e, 0x70, 0x65, - 0x65, 0x72, 0x62, 0x72, 0x6f, 0x6b, 0x65, 0x72, 0x2e, 0x45, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, - 0x65, 0x28, 0x01, 0x30, 0x01, 0x42, 0x29, 0x5a, 0x27, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, - 0x63, 0x6f, 0x6d, 0x2f, 0x63, 0x6f, 0x64, 0x65, 0x72, 0x2f, 0x63, 0x6f, 0x64, 0x65, 0x72, 0x2f, - 0x70, 0x65, 0x65, 0x72, 0x62, 0x72, 0x6f, 0x6b, 0x65, 0x72, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, - 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, -} - -var ( - file_peerbroker_proto_peerbroker_proto_rawDescOnce sync.Once - file_peerbroker_proto_peerbroker_proto_rawDescData = file_peerbroker_proto_peerbroker_proto_rawDesc -) - -func file_peerbroker_proto_peerbroker_proto_rawDescGZIP() []byte { - file_peerbroker_proto_peerbroker_proto_rawDescOnce.Do(func() { - file_peerbroker_proto_peerbroker_proto_rawDescData = protoimpl.X.CompressGZIP(file_peerbroker_proto_peerbroker_proto_rawDescData) - }) - return file_peerbroker_proto_peerbroker_proto_rawDescData -} - -var file_peerbroker_proto_peerbroker_proto_msgTypes = make([]protoimpl.MessageInfo, 2) -var file_peerbroker_proto_peerbroker_proto_goTypes = []interface{}{ - (*WebRTCSessionDescription)(nil), // 0: peerbroker.WebRTCSessionDescription - (*Exchange)(nil), // 1: peerbroker.Exchange -} -var file_peerbroker_proto_peerbroker_proto_depIdxs = []int32{ - 0, // 0: peerbroker.Exchange.sdp:type_name -> peerbroker.WebRTCSessionDescription - 1, // 1: peerbroker.PeerBroker.NegotiateConnection:input_type -> peerbroker.Exchange - 1, // 2: peerbroker.PeerBroker.NegotiateConnection:output_type -> peerbroker.Exchange - 2, // [2:3] is the sub-list for method output_type - 1, // [1:2] is the sub-list for method input_type - 1, // [1:1] is the sub-list for extension type_name - 1, // [1:1] is the sub-list for extension extendee - 0, // [0:1] is the sub-list for field type_name -} - -func init() { file_peerbroker_proto_peerbroker_proto_init() } -func file_peerbroker_proto_peerbroker_proto_init() { - if File_peerbroker_proto_peerbroker_proto != nil { - return - } - if !protoimpl.UnsafeEnabled { - file_peerbroker_proto_peerbroker_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*WebRTCSessionDescription); i { - case 0: - return &v.state - case 1: - return &v.sizeCache - case 2: - return &v.unknownFields - default: - return nil - } - } - file_peerbroker_proto_peerbroker_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*Exchange); i { - case 0: - return &v.state - case 1: - return &v.sizeCache - case 2: - return &v.unknownFields - default: - return nil - } - } - } - file_peerbroker_proto_peerbroker_proto_msgTypes[1].OneofWrappers = []interface{}{ - (*Exchange_Sdp)(nil), - (*Exchange_IceCandidate)(nil), - } - type x struct{} - out := protoimpl.TypeBuilder{ - File: protoimpl.DescBuilder{ - GoPackagePath: reflect.TypeOf(x{}).PkgPath(), - RawDescriptor: file_peerbroker_proto_peerbroker_proto_rawDesc, - NumEnums: 0, - NumMessages: 2, - NumExtensions: 0, - NumServices: 1, - }, - GoTypes: file_peerbroker_proto_peerbroker_proto_goTypes, - DependencyIndexes: file_peerbroker_proto_peerbroker_proto_depIdxs, - MessageInfos: file_peerbroker_proto_peerbroker_proto_msgTypes, - }.Build() - File_peerbroker_proto_peerbroker_proto = out.File - file_peerbroker_proto_peerbroker_proto_rawDesc = nil - file_peerbroker_proto_peerbroker_proto_goTypes = nil - file_peerbroker_proto_peerbroker_proto_depIdxs = nil -} diff --git a/peerbroker/proto/peerbroker.proto b/peerbroker/proto/peerbroker.proto deleted file mode 100644 index f67b338ed3372..0000000000000 --- a/peerbroker/proto/peerbroker.proto +++ /dev/null @@ -1,28 +0,0 @@ - -syntax = "proto3"; -option go_package = "github.com/coder/coder/peerbroker/proto"; - -package peerbroker; - -message WebRTCSessionDescription { - int32 sdp_type = 1; - string sdp = 2; -} - -message Exchange { - oneof message { - WebRTCSessionDescription sdp = 1; - string ice_candidate = 2; - } -} - -// PeerBroker mediates WebRTC connection signaling. -service PeerBroker { - // NegotiateConnection establishes a bidirectional stream to negotiate a new WebRTC connection. - // 1. Client sends WebRTCSessionDescription to the server. - // 2. Server sends WebRTCSessionDescription to the client, exchanging encryption keys. - // 3. Client<->Server exchange ICE Candidates to establish a peered connection. - // - // See: https://davekilian.com/webrtc-the-hard-way.html - rpc NegotiateConnection(stream Exchange) returns (stream Exchange); -} \ No newline at end of file diff --git a/peerbroker/proto/peerbroker_drpc.pb.go b/peerbroker/proto/peerbroker_drpc.pb.go deleted file mode 100644 index ae06f79a01371..0000000000000 --- a/peerbroker/proto/peerbroker_drpc.pb.go +++ /dev/null @@ -1,146 +0,0 @@ -// Code generated by protoc-gen-go-drpc. DO NOT EDIT. -// protoc-gen-go-drpc version: v0.0.26 -// source: peerbroker/proto/peerbroker.proto - -package proto - -import ( - context "context" - errors "errors" - protojson "google.golang.org/protobuf/encoding/protojson" - proto "google.golang.org/protobuf/proto" - drpc "storj.io/drpc" - drpcerr "storj.io/drpc/drpcerr" -) - -type drpcEncoding_File_peerbroker_proto_peerbroker_proto struct{} - -func (drpcEncoding_File_peerbroker_proto_peerbroker_proto) Marshal(msg drpc.Message) ([]byte, error) { - return proto.Marshal(msg.(proto.Message)) -} - -func (drpcEncoding_File_peerbroker_proto_peerbroker_proto) MarshalAppend(buf []byte, msg drpc.Message) ([]byte, error) { - return proto.MarshalOptions{}.MarshalAppend(buf, msg.(proto.Message)) -} - -func (drpcEncoding_File_peerbroker_proto_peerbroker_proto) Unmarshal(buf []byte, msg drpc.Message) error { - return proto.Unmarshal(buf, msg.(proto.Message)) -} - -func (drpcEncoding_File_peerbroker_proto_peerbroker_proto) JSONMarshal(msg drpc.Message) ([]byte, error) { - return protojson.Marshal(msg.(proto.Message)) -} - -func (drpcEncoding_File_peerbroker_proto_peerbroker_proto) JSONUnmarshal(buf []byte, msg drpc.Message) error { - return protojson.Unmarshal(buf, msg.(proto.Message)) -} - -type DRPCPeerBrokerClient interface { - DRPCConn() drpc.Conn - - NegotiateConnection(ctx context.Context) (DRPCPeerBroker_NegotiateConnectionClient, error) -} - -type drpcPeerBrokerClient struct { - cc drpc.Conn -} - -func NewDRPCPeerBrokerClient(cc drpc.Conn) DRPCPeerBrokerClient { - return &drpcPeerBrokerClient{cc} -} - -func (c *drpcPeerBrokerClient) DRPCConn() drpc.Conn { return c.cc } - -func (c *drpcPeerBrokerClient) NegotiateConnection(ctx context.Context) (DRPCPeerBroker_NegotiateConnectionClient, error) { - stream, err := c.cc.NewStream(ctx, "/peerbroker.PeerBroker/NegotiateConnection", drpcEncoding_File_peerbroker_proto_peerbroker_proto{}) - if err != nil { - return nil, err - } - x := &drpcPeerBroker_NegotiateConnectionClient{stream} - return x, nil -} - -type DRPCPeerBroker_NegotiateConnectionClient interface { - drpc.Stream - Send(*Exchange) error - Recv() (*Exchange, error) -} - -type drpcPeerBroker_NegotiateConnectionClient struct { - drpc.Stream -} - -func (x *drpcPeerBroker_NegotiateConnectionClient) Send(m *Exchange) error { - return x.MsgSend(m, drpcEncoding_File_peerbroker_proto_peerbroker_proto{}) -} - -func (x *drpcPeerBroker_NegotiateConnectionClient) Recv() (*Exchange, error) { - m := new(Exchange) - if err := x.MsgRecv(m, drpcEncoding_File_peerbroker_proto_peerbroker_proto{}); err != nil { - return nil, err - } - return m, nil -} - -func (x *drpcPeerBroker_NegotiateConnectionClient) RecvMsg(m *Exchange) error { - return x.MsgRecv(m, drpcEncoding_File_peerbroker_proto_peerbroker_proto{}) -} - -type DRPCPeerBrokerServer interface { - NegotiateConnection(DRPCPeerBroker_NegotiateConnectionStream) error -} - -type DRPCPeerBrokerUnimplementedServer struct{} - -func (s *DRPCPeerBrokerUnimplementedServer) NegotiateConnection(DRPCPeerBroker_NegotiateConnectionStream) error { - return drpcerr.WithCode(errors.New("Unimplemented"), drpcerr.Unimplemented) -} - -type DRPCPeerBrokerDescription struct{} - -func (DRPCPeerBrokerDescription) NumMethods() int { return 1 } - -func (DRPCPeerBrokerDescription) Method(n int) (string, drpc.Encoding, drpc.Receiver, interface{}, bool) { - switch n { - case 0: - return "/peerbroker.PeerBroker/NegotiateConnection", drpcEncoding_File_peerbroker_proto_peerbroker_proto{}, - func(srv interface{}, ctx context.Context, in1, in2 interface{}) (drpc.Message, error) { - return nil, srv.(DRPCPeerBrokerServer). - NegotiateConnection( - &drpcPeerBroker_NegotiateConnectionStream{in1.(drpc.Stream)}, - ) - }, DRPCPeerBrokerServer.NegotiateConnection, true - default: - return "", nil, nil, nil, false - } -} - -func DRPCRegisterPeerBroker(mux drpc.Mux, impl DRPCPeerBrokerServer) error { - return mux.Register(impl, DRPCPeerBrokerDescription{}) -} - -type DRPCPeerBroker_NegotiateConnectionStream interface { - drpc.Stream - Send(*Exchange) error - Recv() (*Exchange, error) -} - -type drpcPeerBroker_NegotiateConnectionStream struct { - drpc.Stream -} - -func (x *drpcPeerBroker_NegotiateConnectionStream) Send(m *Exchange) error { - return x.MsgSend(m, drpcEncoding_File_peerbroker_proto_peerbroker_proto{}) -} - -func (x *drpcPeerBroker_NegotiateConnectionStream) Recv() (*Exchange, error) { - m := new(Exchange) - if err := x.MsgRecv(m, drpcEncoding_File_peerbroker_proto_peerbroker_proto{}); err != nil { - return nil, err - } - return m, nil -} - -func (x *drpcPeerBroker_NegotiateConnectionStream) RecvMsg(m *Exchange) error { - return x.MsgRecv(m, drpcEncoding_File_peerbroker_proto_peerbroker_proto{}) -} diff --git a/peerbroker/proxy.go b/peerbroker/proxy.go deleted file mode 100644 index 3e3ccb441776b..0000000000000 --- a/peerbroker/proxy.go +++ /dev/null @@ -1,283 +0,0 @@ -package peerbroker - -import ( - "context" - "encoding/base64" - "errors" - "fmt" - "io" - "net" - "sync" - - "github.com/google/uuid" - "github.com/hashicorp/yamux" - "golang.org/x/xerrors" - protobuf "google.golang.org/protobuf/proto" - "storj.io/drpc/drpcmux" - "storj.io/drpc/drpcserver" - - "cdr.dev/slog" - "github.com/coder/coder/coderd/database" - "github.com/coder/coder/peerbroker/proto" -) - -var ( - // Each NegotiateConnection() function call spawns a new stream. - streamIDLength = len(uuid.NewString()) - // We shouldn't PubSub anything larger than this! - maxPayloadSizeBytes = 8192 -) - -// ProxyOptions provides values to configure a proxy. -type ProxyOptions struct { - ChannelID string - Logger slog.Logger - Pubsub database.Pubsub -} - -// ProxyDial writes client negotiation streams over PubSub. -// -// PubSub is used to geodistribute WebRTC handshakes. All negotiation -// messages are small in size (<=8KB), and we don't require delivery -// guarantees because connections can always be renegotiated. -// -// ┌────────────────────┐ ┌─────────────────────────────┐ -// │ coderd │ │ coderd │ -// -// ┌─────────────────────┐ │//connect │ │ //listen │ -// │ client │ │ │ │ │ ┌─────┐ -// │ ├──►│Creates a stream ID │◄─►│Subscribe() to the │◄──┤agent│ -// │NegotiateConnection()│ │and Publish() to the│ │channel. Parse the stream ID │ └─────┘ -// └─────────────────────┘ │ channel: │ │from payloads to create new │ -// -// │ │ │NegotiateConnection() streams│ -// ││ │or write to existing ones. │ -// └────────────────────┘ └─────────────────────────────┘ -func ProxyDial(client proto.DRPCPeerBrokerClient, options ProxyOptions) (io.Closer, error) { - proxyDial := &proxyDial{ - channelID: options.ChannelID, - logger: options.Logger, - pubsub: options.Pubsub, - connection: client, - streams: make(map[string]proto.DRPCPeerBroker_NegotiateConnectionClient), - } - return proxyDial, proxyDial.listen() -} - -// ProxyListen accepts client negotiation streams over PubSub and writes them to the listener -// as new NegotiateConnection() streams. -func ProxyListen(ctx context.Context, connListener net.Listener, options ProxyOptions) error { - mux := drpcmux.New() - err := proto.DRPCRegisterPeerBroker(mux, &proxyListen{ - channelID: options.ChannelID, - pubsub: options.Pubsub, - logger: options.Logger, - }) - if err != nil { - return xerrors.Errorf("register peer broker: %w", err) - } - server := drpcserver.New(mux) - err = server.Serve(ctx, connListener) - if err != nil { - if errors.Is(err, yamux.ErrSessionShutdown) { - return nil - } - return xerrors.Errorf("serve: %w", err) - } - return nil -} - -type proxyListen struct { - channelID string - pubsub database.Pubsub - logger slog.Logger -} - -func (p *proxyListen) NegotiateConnection(stream proto.DRPCPeerBroker_NegotiateConnectionStream) error { - streamID := uuid.NewString() - var err error - closeSubscribe, err := p.pubsub.Subscribe(proxyInID(p.channelID), func(ctx context.Context, message []byte) { - err := p.onServerToClientMessage(streamID, stream, message) - if err != nil { - p.logger.Debug(ctx, "failed to accept server message", slog.Error(err)) - } - }) - if err != nil { - return xerrors.Errorf("subscribe: %w", err) - } - defer closeSubscribe() - for { - clientToServerMessage, err := stream.Recv() - if err != nil { - if errors.Is(err, io.EOF) { - break - } - return xerrors.Errorf("recv: %w", err) - } - data, err := protobuf.Marshal(clientToServerMessage) - if err != nil { - return xerrors.Errorf("marshal: %w", err) - } - if len(data) > maxPayloadSizeBytes { - return xerrors.Errorf("maximum payload size %d exceeded", maxPayloadSizeBytes) - } - data = append([]byte(streamID), data...) - err = p.pubsub.Publish(proxyOutID(p.channelID), marshal(data)) - if err != nil { - return xerrors.Errorf("publish: %w", err) - } - } - return nil -} - -func (*proxyListen) onServerToClientMessage(streamID string, stream proto.DRPCPeerBroker_NegotiateConnectionStream, message []byte) error { - var err error - message, err = unmarshal(message) - if err != nil { - return xerrors.Errorf("decode: %w", err) - } - if len(message) < streamIDLength { - return xerrors.Errorf("got message length %d < %d", len(message), streamIDLength) - } - serverStreamID := string(message[0:streamIDLength]) - if serverStreamID != streamID { - // It's not trying to communicate with this stream! - return nil - } - var msg proto.Exchange - err = protobuf.Unmarshal(message[streamIDLength:], &msg) - if err != nil { - return xerrors.Errorf("unmarshal message: %w", err) - } - err = stream.Send(&msg) - if err != nil { - return xerrors.Errorf("send message: %w", err) - } - return nil -} - -type proxyDial struct { - channelID string - pubsub database.Pubsub - logger slog.Logger - - connection proto.DRPCPeerBrokerClient - closeSubscribe func() - streamMutex sync.Mutex - streams map[string]proto.DRPCPeerBroker_NegotiateConnectionClient -} - -func (p *proxyDial) listen() error { - var err error - p.closeSubscribe, err = p.pubsub.Subscribe(proxyOutID(p.channelID), func(ctx context.Context, message []byte) { - err := p.onClientToServerMessage(ctx, message) - if err != nil { - p.logger.Debug(ctx, "failed to accept client message", slog.Error(err)) - } - }) - if err != nil { - return err - } - return nil -} - -func (p *proxyDial) onClientToServerMessage(ctx context.Context, message []byte) error { - var err error - message, err = unmarshal(message) - if err != nil { - return xerrors.Errorf("decode: %w", err) - } - if len(message) < streamIDLength { - return xerrors.Errorf("got message length %d < %d", len(message), streamIDLength) - } - streamID := string(message[0:streamIDLength]) - p.streamMutex.Lock() - stream, ok := p.streams[streamID] - if !ok { - stream, err = p.connection.NegotiateConnection(ctx) - if err != nil { - p.streamMutex.Unlock() - return xerrors.Errorf("negotiate connection: %w", err) - } - p.streams[streamID] = stream - go func() { - defer stream.Close() - - err := p.onServerToClientMessage(streamID, stream) - if err != nil { - p.logger.Debug(ctx, "failed to accept server message", slog.Error(err)) - } - }() - go func() { - <-stream.Context().Done() - p.streamMutex.Lock() - delete(p.streams, streamID) - p.streamMutex.Unlock() - }() - } - p.streamMutex.Unlock() - - var msg proto.Exchange - err = protobuf.Unmarshal(message[streamIDLength:], &msg) - if err != nil { - return xerrors.Errorf("unmarshal message: %w", err) - } - err = stream.Send(&msg) - if err != nil { - return xerrors.Errorf("write message: %w", err) - } - return nil -} - -func (p *proxyDial) onServerToClientMessage(streamID string, stream proto.DRPCPeerBroker_NegotiateConnectionClient) error { - for { - serverToClientMessage, err := stream.Recv() - if err != nil { - if errors.Is(err, io.EOF) { - break - } - if errors.Is(err, context.Canceled) { - break - } - return xerrors.Errorf("recv: %w", err) - } - data, err := protobuf.Marshal(serverToClientMessage) - if err != nil { - return xerrors.Errorf("marshal: %w", err) - } - if len(data) > maxPayloadSizeBytes { - return xerrors.Errorf("maximum payload size %d exceeded", maxPayloadSizeBytes) - } - data = append([]byte(streamID), data...) - err = p.pubsub.Publish(proxyInID(p.channelID), marshal(data)) - if err != nil { - return xerrors.Errorf("publish: %w", err) - } - } - return nil -} - -func (p *proxyDial) Close() error { - p.streamMutex.Lock() - defer p.streamMutex.Unlock() - p.closeSubscribe() - return nil -} - -// base64 needs to be used here to keep the pubsub messages in UTF-8 range. -// PostgreSQL cannot handle non UTF-8 messages over pubsub. -func marshal(data []byte) []byte { - return []byte(base64.StdEncoding.EncodeToString(data)) -} - -func unmarshal(data []byte) ([]byte, error) { - return base64.StdEncoding.DecodeString(string(data)) -} - -func proxyOutID(channelID string) string { - return fmt.Sprintf("%s-out", channelID) -} - -func proxyInID(channelID string) string { - return fmt.Sprintf("%s-in", channelID) -} diff --git a/peerbroker/proxy_test.go b/peerbroker/proxy_test.go deleted file mode 100644 index 80fe405c24fcf..0000000000000 --- a/peerbroker/proxy_test.go +++ /dev/null @@ -1,84 +0,0 @@ -package peerbroker_test - -import ( - "context" - "sync" - "testing" - - "github.com/pion/webrtc/v3" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - - "cdr.dev/slog" - "cdr.dev/slog/sloggers/slogtest" - "github.com/coder/coder/coderd/database" - "github.com/coder/coder/peer" - "github.com/coder/coder/peerbroker" - "github.com/coder/coder/peerbroker/proto" - "github.com/coder/coder/provisionersdk" -) - -func TestProxy(t *testing.T) { - t.Parallel() - ctx := context.Background() - channelID := "hello" - pubsub := database.NewPubsubInMemory() - dialerClient, dialerServer := provisionersdk.TransportPipe() - defer dialerClient.Close() - defer dialerServer.Close() - listenerClient, listenerServer := provisionersdk.TransportPipe() - defer listenerClient.Close() - defer listenerServer.Close() - - listener, err := peerbroker.Listen(listenerServer, func(ctx context.Context) ([]webrtc.ICEServer, *peer.ConnOptions, error) { - return nil, &peer.ConnOptions{ - Logger: slogtest.Make(t, nil).Named("server").Leveled(slog.LevelDebug), - }, nil - }) - require.NoError(t, err) - - proxyCloser, err := peerbroker.ProxyDial(proto.NewDRPCPeerBrokerClient(provisionersdk.Conn(listenerClient)), peerbroker.ProxyOptions{ - ChannelID: channelID, - Logger: slogtest.Make(t, nil).Named("proxy-listen").Leveled(slog.LevelDebug), - Pubsub: pubsub, - }) - require.NoError(t, err) - defer func() { - _ = proxyCloser.Close() - }() - - var wg sync.WaitGroup - wg.Add(1) - go func() { - defer wg.Done() - err = peerbroker.ProxyListen(ctx, dialerServer, peerbroker.ProxyOptions{ - ChannelID: channelID, - Logger: slogtest.Make(t, nil).Named("proxy-dial").Leveled(slog.LevelDebug), - Pubsub: pubsub, - }) - assert.NoError(t, err) - }() - - api := proto.NewDRPCPeerBrokerClient(provisionersdk.Conn(dialerClient)) - stream, err := api.NegotiateConnection(ctx) - require.NoError(t, err) - clientConn, err := peerbroker.Dial(stream, []webrtc.ICEServer{{ - URLs: []string{"stun:stun.l.google.com:19302"}, - }}, &peer.ConnOptions{ - Logger: slogtest.Make(t, nil).Named("client").Leveled(slog.LevelDebug), - }) - require.NoError(t, err) - defer clientConn.Close() - - serverConn, err := listener.Accept() - require.NoError(t, err) - defer serverConn.Close() - _, err = serverConn.Ping() - require.NoError(t, err) - - _, err = clientConn.Ping() - require.NoError(t, err) - - _ = dialerServer.Close() - wg.Wait() -} From 299d30c9ec486bd17cb8c716bf96801036d8e06f Mon Sep 17 00:00:00 2001 From: Kyle Carberry Date: Mon, 19 Sep 2022 18:19:29 +0000 Subject: [PATCH 2/3] Fix race condition --- coderd/workspaceagents.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/coderd/workspaceagents.go b/coderd/workspaceagents.go index ff148e32e85d5..7780ae154fe05 100644 --- a/coderd/workspaceagents.go +++ b/coderd/workspaceagents.go @@ -352,7 +352,7 @@ func (api *API) workspaceAgentCoordinate(rw http.ResponseWriter, r *http.Request closeChan := make(chan struct{}) go func() { defer close(closeChan) - err = api.TailnetCoordinator.ServeAgent(wsNetConn, workspaceAgent.ID) + err := api.TailnetCoordinator.ServeAgent(wsNetConn, workspaceAgent.ID) if err != nil { _ = conn.Close(websocket.StatusInternalError, err.Error()) return @@ -375,7 +375,7 @@ func (api *API) workspaceAgentCoordinate(rw http.ResponseWriter, r *http.Request _ = conn.Close(websocket.StatusAbnormalClosure, err.Error()) return } - err = ensureLatestBuild() + err := ensureLatestBuild() if err != nil { // Disconnect agents that are no longer valid. _ = conn.Close(websocket.StatusGoingAway, "") From 0b6b47072ce4df108e7922b76fc1333634f45fc5 Mon Sep 17 00:00:00 2001 From: Kyle Carberry Date: Mon, 19 Sep 2022 21:59:01 +0000 Subject: [PATCH 3/3] Fix WebSocket not closing --- agent/agent_test.go | 9 +++++++-- codersdk/provisionerdaemons.go | 1 + codersdk/workspaceagents.go | 7 +++++++ 3 files changed, 15 insertions(+), 2 deletions(-) diff --git a/agent/agent_test.go b/agent/agent_test.go index 5b613d72a87cc..08c7918765319 100644 --- a/agent/agent_test.go +++ b/agent/agent_test.go @@ -522,7 +522,7 @@ func setupSSHCommand(t *testing.T, beforeArgs []string, afterArgs []string) *exe return } ssh, err := agentConn.SSH() - if !assert.NoError(t, err) { + if err != nil { _ = conn.Close() return } @@ -581,11 +581,16 @@ func setupAgent(t *testing.T, metadata agent.Metadata, ptyTimeout time.Duration) }, CoordinatorDialer: func(ctx context.Context) (net.Conn, error) { clientConn, serverConn := net.Pipe() + closed := make(chan struct{}) t.Cleanup(func() { _ = serverConn.Close() _ = clientConn.Close() + <-closed }) - go coordinator.ServeAgent(serverConn, agentID) + go func() { + _ = coordinator.ServeAgent(serverConn, agentID) + close(closed) + }() return clientConn, nil }, Logger: slogtest.Make(t, nil).Leveled(slog.LevelDebug), diff --git a/codersdk/provisionerdaemons.go b/codersdk/provisionerdaemons.go index e71fbec7fbee5..296df9b5ac70d 100644 --- a/codersdk/provisionerdaemons.go +++ b/codersdk/provisionerdaemons.go @@ -135,6 +135,7 @@ func (c *Client) provisionerJobLogsAfter(ctx context.Context, path string, after decoder := json.NewDecoder(websocket.NetConn(ctx, conn, websocket.MessageText)) go func() { defer close(logs) + defer conn.Close(websocket.StatusGoingAway, "") var log ProvisionerJobLog for { err = decoder.Decode(&log) diff --git a/codersdk/workspaceagents.go b/codersdk/workspaceagents.go index 527a5a07856c0..2117de03c6ce3 100644 --- a/codersdk/workspaceagents.go +++ b/codersdk/workspaceagents.go @@ -281,10 +281,12 @@ func (c *Client) DialWorkspaceAgentTailnet(ctx context.Context, logger slog.Logg CompressionMode: websocket.CompressionDisabled, }) if errors.Is(err, context.Canceled) { + _ = ws.Close(websocket.StatusAbnormalClosure, "") return } if err != nil { logger.Debug(ctx, "failed to dial", slog.Error(err)) + _ = ws.Close(websocket.StatusAbnormalClosure, "") continue } sendNode, errChan := tailnet.ServeCoordinator(websocket.NetConn(ctx, ws, websocket.MessageBinary), func(node []*tailnet.Node) error { @@ -294,12 +296,15 @@ func (c *Client) DialWorkspaceAgentTailnet(ctx context.Context, logger slog.Logg logger.Debug(ctx, "serving coordinator") err = <-errChan if errors.Is(err, context.Canceled) { + _ = ws.Close(websocket.StatusAbnormalClosure, "") return } if err != nil { logger.Debug(ctx, "error serving coordinator", slog.Error(err)) + _ = ws.Close(websocket.StatusAbnormalClosure, "") continue } + _ = ws.Close(websocket.StatusAbnormalClosure, "") } }() return &agent.Conn{ @@ -423,6 +428,7 @@ func (c *Client) AgentReportStats( var req AgentStatsReportRequest err := wsjson.Read(ctx, conn, &req) if err != nil { + _ = conn.Close(websocket.StatusAbnormalClosure, "") return err } @@ -436,6 +442,7 @@ func (c *Client) AgentReportStats( err = wsjson.Write(ctx, conn, resp) if err != nil { + _ = conn.Close(websocket.StatusAbnormalClosure, "") return err } } pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy