Skip to content

Commit b4f5920

Browse files
authored
fix: Avoid use of r.Context() after r.Hijack() (#1978)
1 parent 61aacff commit b4f5920

File tree

1 file changed

+74
-30
lines changed

1 file changed

+74
-30
lines changed

coderd/workspaceagents.go

Lines changed: 74 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package coderd
22

33
import (
4+
"context"
45
"database/sql"
56
"encoding/json"
67
"fmt"
@@ -16,6 +17,7 @@ import (
1617
"nhooyr.io/websocket"
1718

1819
"cdr.dev/slog"
20+
1921
"github.com/coder/coder/agent"
2022
"github.com/coder/coder/coderd/database"
2123
"github.com/coder/coder/coderd/httpapi"
@@ -69,17 +71,18 @@ func (api *API) workspaceAgentDial(rw http.ResponseWriter, r *http.Request) {
6971
})
7072
return
7173
}
72-
defer func() {
73-
_ = conn.Close(websocket.StatusNormalClosure, "")
74-
}()
74+
75+
ctx, wsNetConn := websocketNetConn(r.Context(), conn, websocket.MessageBinary)
76+
defer wsNetConn.Close() // Also closes conn.
77+
7578
config := yamux.DefaultConfig()
7679
config.LogOutput = io.Discard
77-
session, err := yamux.Server(websocket.NetConn(r.Context(), conn, websocket.MessageBinary), config)
80+
session, err := yamux.Server(wsNetConn, config)
7881
if err != nil {
7982
_ = conn.Close(websocket.StatusAbnormalClosure, err.Error())
8083
return
8184
}
82-
err = peerbroker.ProxyListen(r.Context(), session, peerbroker.ProxyOptions{
85+
err = peerbroker.ProxyListen(ctx, session, peerbroker.ProxyOptions{
8386
ChannelID: workspaceAgent.ID.String(),
8487
Logger: api.Logger.Named("peerbroker-proxy-dial"),
8588
Pubsub: api.Pubsub,
@@ -193,13 +196,12 @@ func (api *API) workspaceAgentListen(rw http.ResponseWriter, r *http.Request) {
193196
return
194197
}
195198

196-
defer func() {
197-
_ = conn.Close(websocket.StatusNormalClosure, "")
198-
}()
199+
ctx, wsNetConn := websocketNetConn(r.Context(), conn, websocket.MessageBinary)
200+
defer wsNetConn.Close() // Also closes conn.
199201

200202
config := yamux.DefaultConfig()
201203
config.LogOutput = io.Discard
202-
session, err := yamux.Server(websocket.NetConn(r.Context(), conn, websocket.MessageBinary), config)
204+
session, err := yamux.Server(wsNetConn, config)
203205
if err != nil {
204206
_ = conn.Close(websocket.StatusAbnormalClosure, err.Error())
205207
return
@@ -229,7 +231,7 @@ func (api *API) workspaceAgentListen(rw http.ResponseWriter, r *http.Request) {
229231
}
230232
disconnectedAt := workspaceAgent.DisconnectedAt
231233
updateConnectionTimes := func() error {
232-
err = api.Database.UpdateWorkspaceAgentConnectionByID(r.Context(), database.UpdateWorkspaceAgentConnectionByIDParams{
234+
err = api.Database.UpdateWorkspaceAgentConnectionByID(ctx, database.UpdateWorkspaceAgentConnectionByIDParams{
233235
ID: workspaceAgent.ID,
234236
FirstConnectedAt: firstConnectedAt,
235237
LastConnectedAt: lastConnectedAt,
@@ -255,7 +257,7 @@ func (api *API) workspaceAgentListen(rw http.ResponseWriter, r *http.Request) {
255257
return
256258
}
257259

258-
api.Logger.Info(r.Context(), "accepting agent", slog.F("resource", resource), slog.F("agent", workspaceAgent))
260+
api.Logger.Info(ctx, "accepting agent", slog.F("resource", resource), slog.F("agent", workspaceAgent))
259261

260262
ticker := time.NewTicker(api.AgentConnectionUpdateFrequency)
261263
defer ticker.Stop()
@@ -324,16 +326,16 @@ func (api *API) workspaceAgentTurn(rw http.ResponseWriter, r *http.Request) {
324326
})
325327
return
326328
}
327-
defer func() {
328-
_ = wsConn.Close(websocket.StatusNormalClosure, "")
329-
}()
330-
netConn := websocket.NetConn(r.Context(), wsConn, websocket.MessageBinary)
331-
api.Logger.Debug(r.Context(), "accepting turn connection", slog.F("remote-address", r.RemoteAddr), slog.F("local-address", localAddress))
329+
330+
ctx, wsNetConn := websocketNetConn(r.Context(), wsConn, websocket.MessageBinary)
331+
defer wsNetConn.Close() // Also closes conn.
332+
333+
api.Logger.Debug(ctx, "accepting turn connection", slog.F("remote-address", r.RemoteAddr), slog.F("local-address", localAddress))
332334
select {
333-
case <-api.TURNServer.Accept(netConn, remoteAddress, localAddress).Closed():
334-
case <-r.Context().Done():
335+
case <-api.TURNServer.Accept(wsNetConn, remoteAddress, localAddress).Closed():
336+
case <-ctx.Done():
335337
}
336-
api.Logger.Debug(r.Context(), "completed turn connection", slog.F("remote-address", r.RemoteAddr), slog.F("local-address", localAddress))
338+
api.Logger.Debug(ctx, "completed turn connection", slog.F("remote-address", r.RemoteAddr), slog.F("local-address", localAddress))
337339
}
338340

339341
// workspaceAgentPTY spawns a PTY and pipes it over a WebSocket.
@@ -384,12 +386,11 @@ func (api *API) workspaceAgentPTY(rw http.ResponseWriter, r *http.Request) {
384386
})
385387
return
386388
}
387-
defer func() {
388-
_ = conn.Close(websocket.StatusNormalClosure, "ended")
389-
}()
390-
// Accept text connections, because it's more developer friendly.
391-
wsNetConn := websocket.NetConn(r.Context(), conn, websocket.MessageBinary)
392-
agentConn, err := api.dialWorkspaceAgent(r, workspaceAgent.ID)
389+
390+
ctx, wsNetConn := websocketNetConn(r.Context(), conn, websocket.MessageBinary)
391+
defer wsNetConn.Close() // Also closes conn.
392+
393+
agentConn, err := api.dialWorkspaceAgent(ctx, r, workspaceAgent.ID)
393394
if err != nil {
394395
_ = conn.Close(websocket.StatusInternalError, httpapi.WebsocketCloseSprintf("dial workspace agent: %s", err))
395396
return
@@ -408,11 +409,13 @@ func (api *API) workspaceAgentPTY(rw http.ResponseWriter, r *http.Request) {
408409
_, _ = io.Copy(ptNetConn, wsNetConn)
409410
}
410411

411-
// dialWorkspaceAgent connects to a workspace agent by ID.
412-
func (api *API) dialWorkspaceAgent(r *http.Request, agentID uuid.UUID) (*agent.Conn, error) {
412+
// dialWorkspaceAgent connects to a workspace agent by ID. Only rely on
413+
// r.Context() for cancellation if it's use is safe or r.Hijack() has
414+
// not been performed.
415+
func (api *API) dialWorkspaceAgent(ctx context.Context, r *http.Request, agentID uuid.UUID) (*agent.Conn, error) {
413416
client, server := provisionersdk.TransportPipe()
414417
go func() {
415-
_ = peerbroker.ProxyListen(r.Context(), server, peerbroker.ProxyOptions{
418+
_ = peerbroker.ProxyListen(ctx, server, peerbroker.ProxyOptions{
416419
ChannelID: agentID.String(),
417420
Logger: api.Logger.Named("peerbroker-proxy-dial"),
418421
Pubsub: api.Pubsub,
@@ -422,7 +425,7 @@ func (api *API) dialWorkspaceAgent(r *http.Request, agentID uuid.UUID) (*agent.C
422425
}()
423426

424427
peerClient := proto.NewDRPCPeerBrokerClient(provisionersdk.Conn(client))
425-
stream, err := peerClient.NegotiateConnection(r.Context())
428+
stream, err := peerClient.NegotiateConnection(ctx)
426429
if err != nil {
427430
return nil, xerrors.Errorf("negotiate: %w", err)
428431
}
@@ -434,7 +437,7 @@ func (api *API) dialWorkspaceAgent(r *http.Request, agentID uuid.UUID) (*agent.C
434437
options.SettingEngine.SetICEProxyDialer(turnconn.ProxyDialer(func() (c net.Conn, err error) {
435438
clientPipe, serverPipe := net.Pipe()
436439
go func() {
437-
<-r.Context().Done()
440+
<-ctx.Done()
438441
_ = clientPipe.Close()
439442
_ = serverPipe.Close()
440443
}()
@@ -515,3 +518,44 @@ func convertWorkspaceAgent(dbAgent database.WorkspaceAgent, agentUpdateFrequency
515518

516519
return workspaceAgent, nil
517520
}
521+
522+
// wsNetConn wraps net.Conn created by websocket.NetConn(). Cancel func
523+
// is called if a read or write error is encountered.
524+
type wsNetConn struct {
525+
cancel context.CancelFunc
526+
net.Conn
527+
}
528+
529+
func (c *wsNetConn) Read(b []byte) (n int, err error) {
530+
n, err = c.Conn.Read(b)
531+
if err != nil {
532+
c.cancel()
533+
}
534+
return n, err
535+
}
536+
537+
func (c *wsNetConn) Write(b []byte) (n int, err error) {
538+
n, err = c.Conn.Write(b)
539+
if err != nil {
540+
c.cancel()
541+
}
542+
return n, err
543+
}
544+
545+
func (c *wsNetConn) Close() error {
546+
defer c.cancel()
547+
return c.Conn.Close()
548+
}
549+
550+
// websocketNetConn wraps websocket.NetConn and returns a context that
551+
// is tied to the parent context and the lifetime of the conn. Any error
552+
// during read or write will cancel the context, but not close the
553+
// conn. Close should be called to release context resources.
554+
func websocketNetConn(ctx context.Context, conn *websocket.Conn, msgType websocket.MessageType) (context.Context, net.Conn) {
555+
ctx, cancel := context.WithCancel(ctx)
556+
nc := websocket.NetConn(ctx, conn, msgType)
557+
return ctx, &wsNetConn{
558+
cancel: cancel,
559+
Conn: nc,
560+
}
561+
}

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy