diff --git a/coderd/workspaceagents.go b/coderd/workspaceagents.go index 90396235993cb..f26ebe92d8283 100644 --- a/coderd/workspaceagents.go +++ b/coderd/workspaceagents.go @@ -1,6 +1,7 @@ package coderd import ( + "context" "database/sql" "encoding/json" "fmt" @@ -16,6 +17,7 @@ import ( "nhooyr.io/websocket" "cdr.dev/slog" + "github.com/coder/coder/agent" "github.com/coder/coder/coderd/database" "github.com/coder/coder/coderd/httpapi" @@ -69,17 +71,18 @@ func (api *API) workspaceAgentDial(rw http.ResponseWriter, r *http.Request) { }) return } - defer func() { - _ = conn.Close(websocket.StatusNormalClosure, "") - }() + + ctx, wsNetConn := websocketNetConn(r.Context(), conn, websocket.MessageBinary) + defer wsNetConn.Close() // Also closes conn. + config := yamux.DefaultConfig() config.LogOutput = io.Discard - session, err := yamux.Server(websocket.NetConn(r.Context(), conn, websocket.MessageBinary), config) + session, err := yamux.Server(wsNetConn, config) if err != nil { _ = conn.Close(websocket.StatusAbnormalClosure, err.Error()) return } - err = peerbroker.ProxyListen(r.Context(), session, peerbroker.ProxyOptions{ + err = peerbroker.ProxyListen(ctx, session, peerbroker.ProxyOptions{ ChannelID: workspaceAgent.ID.String(), Logger: api.Logger.Named("peerbroker-proxy-dial"), Pubsub: api.Pubsub, @@ -193,13 +196,12 @@ func (api *API) workspaceAgentListen(rw http.ResponseWriter, r *http.Request) { return } - defer func() { - _ = conn.Close(websocket.StatusNormalClosure, "") - }() + ctx, wsNetConn := websocketNetConn(r.Context(), conn, websocket.MessageBinary) + defer wsNetConn.Close() // Also closes conn. config := yamux.DefaultConfig() config.LogOutput = io.Discard - session, err := yamux.Server(websocket.NetConn(r.Context(), conn, websocket.MessageBinary), config) + session, err := yamux.Server(wsNetConn, config) if err != nil { _ = conn.Close(websocket.StatusAbnormalClosure, err.Error()) return @@ -229,7 +231,7 @@ func (api *API) workspaceAgentListen(rw http.ResponseWriter, r *http.Request) { } disconnectedAt := workspaceAgent.DisconnectedAt updateConnectionTimes := func() error { - err = api.Database.UpdateWorkspaceAgentConnectionByID(r.Context(), database.UpdateWorkspaceAgentConnectionByIDParams{ + err = api.Database.UpdateWorkspaceAgentConnectionByID(ctx, database.UpdateWorkspaceAgentConnectionByIDParams{ ID: workspaceAgent.ID, FirstConnectedAt: firstConnectedAt, LastConnectedAt: lastConnectedAt, @@ -255,7 +257,7 @@ func (api *API) workspaceAgentListen(rw http.ResponseWriter, r *http.Request) { return } - api.Logger.Info(r.Context(), "accepting agent", slog.F("resource", resource), slog.F("agent", workspaceAgent)) + api.Logger.Info(ctx, "accepting agent", slog.F("resource", resource), slog.F("agent", workspaceAgent)) ticker := time.NewTicker(api.AgentConnectionUpdateFrequency) defer ticker.Stop() @@ -324,16 +326,16 @@ func (api *API) workspaceAgentTurn(rw http.ResponseWriter, r *http.Request) { }) return } - defer func() { - _ = wsConn.Close(websocket.StatusNormalClosure, "") - }() - netConn := websocket.NetConn(r.Context(), wsConn, websocket.MessageBinary) - api.Logger.Debug(r.Context(), "accepting turn connection", slog.F("remote-address", r.RemoteAddr), slog.F("local-address", localAddress)) + + ctx, wsNetConn := websocketNetConn(r.Context(), wsConn, websocket.MessageBinary) + defer wsNetConn.Close() // Also closes conn. + + api.Logger.Debug(ctx, "accepting turn connection", slog.F("remote-address", r.RemoteAddr), slog.F("local-address", localAddress)) select { - case <-api.TURNServer.Accept(netConn, remoteAddress, localAddress).Closed(): - case <-r.Context().Done(): + case <-api.TURNServer.Accept(wsNetConn, remoteAddress, localAddress).Closed(): + case <-ctx.Done(): } - api.Logger.Debug(r.Context(), "completed turn connection", slog.F("remote-address", r.RemoteAddr), slog.F("local-address", localAddress)) + api.Logger.Debug(ctx, "completed turn connection", slog.F("remote-address", r.RemoteAddr), slog.F("local-address", localAddress)) } // workspaceAgentPTY spawns a PTY and pipes it over a WebSocket. @@ -384,12 +386,11 @@ func (api *API) workspaceAgentPTY(rw http.ResponseWriter, r *http.Request) { }) return } - defer func() { - _ = conn.Close(websocket.StatusNormalClosure, "ended") - }() - // Accept text connections, because it's more developer friendly. - wsNetConn := websocket.NetConn(r.Context(), conn, websocket.MessageBinary) - agentConn, err := api.dialWorkspaceAgent(r, workspaceAgent.ID) + + ctx, wsNetConn := websocketNetConn(r.Context(), conn, websocket.MessageBinary) + defer wsNetConn.Close() // Also closes conn. + + agentConn, err := api.dialWorkspaceAgent(ctx, r, workspaceAgent.ID) if err != nil { _ = conn.Close(websocket.StatusInternalError, httpapi.WebsocketCloseSprintf("dial workspace agent: %s", err)) return @@ -408,11 +409,13 @@ func (api *API) workspaceAgentPTY(rw http.ResponseWriter, r *http.Request) { _, _ = io.Copy(ptNetConn, wsNetConn) } -// dialWorkspaceAgent connects to a workspace agent by ID. -func (api *API) dialWorkspaceAgent(r *http.Request, agentID uuid.UUID) (*agent.Conn, error) { +// dialWorkspaceAgent connects to a workspace agent by ID. Only rely on +// r.Context() for cancellation if it's use is safe or r.Hijack() has +// not been performed. +func (api *API) dialWorkspaceAgent(ctx context.Context, r *http.Request, agentID uuid.UUID) (*agent.Conn, error) { client, server := provisionersdk.TransportPipe() go func() { - _ = peerbroker.ProxyListen(r.Context(), server, peerbroker.ProxyOptions{ + _ = peerbroker.ProxyListen(ctx, server, peerbroker.ProxyOptions{ ChannelID: agentID.String(), Logger: api.Logger.Named("peerbroker-proxy-dial"), Pubsub: api.Pubsub, @@ -422,7 +425,7 @@ func (api *API) dialWorkspaceAgent(r *http.Request, agentID uuid.UUID) (*agent.C }() peerClient := proto.NewDRPCPeerBrokerClient(provisionersdk.Conn(client)) - stream, err := peerClient.NegotiateConnection(r.Context()) + stream, err := peerClient.NegotiateConnection(ctx) if err != nil { return nil, xerrors.Errorf("negotiate: %w", err) } @@ -434,7 +437,7 @@ func (api *API) dialWorkspaceAgent(r *http.Request, agentID uuid.UUID) (*agent.C options.SettingEngine.SetICEProxyDialer(turnconn.ProxyDialer(func() (c net.Conn, err error) { clientPipe, serverPipe := net.Pipe() go func() { - <-r.Context().Done() + <-ctx.Done() _ = clientPipe.Close() _ = serverPipe.Close() }() @@ -515,3 +518,44 @@ func convertWorkspaceAgent(dbAgent database.WorkspaceAgent, agentUpdateFrequency return workspaceAgent, nil } + +// wsNetConn wraps net.Conn created by websocket.NetConn(). Cancel func +// is called if a read or write error is encountered. +type wsNetConn struct { + cancel context.CancelFunc + net.Conn +} + +func (c *wsNetConn) Read(b []byte) (n int, err error) { + n, err = c.Conn.Read(b) + if err != nil { + c.cancel() + } + return n, err +} + +func (c *wsNetConn) Write(b []byte) (n int, err error) { + n, err = c.Conn.Write(b) + if err != nil { + c.cancel() + } + return n, err +} + +func (c *wsNetConn) Close() error { + defer c.cancel() + return c.Conn.Close() +} + +// websocketNetConn wraps websocket.NetConn and returns a context that +// is tied to the parent context and the lifetime of the conn. Any error +// during read or write will cancel the context, but not close the +// conn. Close should be called to release context resources. +func websocketNetConn(ctx context.Context, conn *websocket.Conn, msgType websocket.MessageType) (context.Context, net.Conn) { + ctx, cancel := context.WithCancel(ctx) + nc := websocket.NetConn(ctx, conn, msgType) + return ctx, &wsNetConn{ + cancel: cancel, + Conn: nc, + } +} pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy