Skip to content

Commit cd2d12e

Browse files
committed
Merge branch 'main' into apps
2 parents 5b9194f + b4f5920 commit cd2d12e

File tree

1 file changed

+70
-27
lines changed

1 file changed

+70
-27
lines changed

coderd/workspaceagents.go

Lines changed: 70 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ import (
1717
"nhooyr.io/websocket"
1818

1919
"cdr.dev/slog"
20+
2021
"github.com/coder/coder/agent"
2122
"github.com/coder/coder/coderd/database"
2223
"github.com/coder/coder/coderd/httpapi"
@@ -77,17 +78,18 @@ func (api *API) workspaceAgentDial(rw http.ResponseWriter, r *http.Request) {
7778
})
7879
return
7980
}
80-
defer func() {
81-
_ = conn.Close(websocket.StatusNormalClosure, "")
82-
}()
81+
82+
ctx, wsNetConn := websocketNetConn(r.Context(), conn, websocket.MessageBinary)
83+
defer wsNetConn.Close() // Also closes conn.
84+
8385
config := yamux.DefaultConfig()
8486
config.LogOutput = io.Discard
85-
session, err := yamux.Server(websocket.NetConn(r.Context(), conn, websocket.MessageBinary), config)
87+
session, err := yamux.Server(wsNetConn, config)
8688
if err != nil {
8789
_ = conn.Close(websocket.StatusAbnormalClosure, err.Error())
8890
return
8991
}
90-
err = peerbroker.ProxyListen(r.Context(), session, peerbroker.ProxyOptions{
92+
err = peerbroker.ProxyListen(ctx, session, peerbroker.ProxyOptions{
9193
ChannelID: workspaceAgent.ID.String(),
9294
Logger: api.Logger.Named("peerbroker-proxy-dial"),
9395
Pubsub: api.Pubsub,
@@ -201,13 +203,12 @@ func (api *API) workspaceAgentListen(rw http.ResponseWriter, r *http.Request) {
201203
return
202204
}
203205

204-
defer func() {
205-
_ = conn.Close(websocket.StatusNormalClosure, "")
206-
}()
206+
ctx, wsNetConn := websocketNetConn(r.Context(), conn, websocket.MessageBinary)
207+
defer wsNetConn.Close() // Also closes conn.
207208

208209
config := yamux.DefaultConfig()
209210
config.LogOutput = io.Discard
210-
session, err := yamux.Server(websocket.NetConn(r.Context(), conn, websocket.MessageBinary), config)
211+
session, err := yamux.Server(wsNetConn, config)
211212
if err != nil {
212213
_ = conn.Close(websocket.StatusAbnormalClosure, err.Error())
213214
return
@@ -237,7 +238,7 @@ func (api *API) workspaceAgentListen(rw http.ResponseWriter, r *http.Request) {
237238
}
238239
disconnectedAt := workspaceAgent.DisconnectedAt
239240
updateConnectionTimes := func() error {
240-
err = api.Database.UpdateWorkspaceAgentConnectionByID(r.Context(), database.UpdateWorkspaceAgentConnectionByIDParams{
241+
err = api.Database.UpdateWorkspaceAgentConnectionByID(ctx, database.UpdateWorkspaceAgentConnectionByIDParams{
241242
ID: workspaceAgent.ID,
242243
FirstConnectedAt: firstConnectedAt,
243244
LastConnectedAt: lastConnectedAt,
@@ -263,7 +264,7 @@ func (api *API) workspaceAgentListen(rw http.ResponseWriter, r *http.Request) {
263264
return
264265
}
265266

266-
api.Logger.Info(r.Context(), "accepting agent", slog.F("resource", resource), slog.F("agent", workspaceAgent))
267+
api.Logger.Info(ctx, "accepting agent", slog.F("resource", resource), slog.F("agent", workspaceAgent))
267268

268269
ticker := time.NewTicker(api.AgentConnectionUpdateFrequency)
269270
defer ticker.Stop()
@@ -332,16 +333,16 @@ func (api *API) workspaceAgentTurn(rw http.ResponseWriter, r *http.Request) {
332333
})
333334
return
334335
}
335-
defer func() {
336-
_ = wsConn.Close(websocket.StatusNormalClosure, "")
337-
}()
338-
netConn := websocket.NetConn(r.Context(), wsConn, websocket.MessageBinary)
339-
api.Logger.Debug(r.Context(), "accepting turn connection", slog.F("remote-address", r.RemoteAddr), slog.F("local-address", localAddress))
336+
337+
ctx, wsNetConn := websocketNetConn(r.Context(), wsConn, websocket.MessageBinary)
338+
defer wsNetConn.Close() // Also closes conn.
339+
340+
api.Logger.Debug(ctx, "accepting turn connection", slog.F("remote-address", r.RemoteAddr), slog.F("local-address", localAddress))
340341
select {
341-
case <-api.TURNServer.Accept(netConn, remoteAddress, localAddress).Closed():
342-
case <-r.Context().Done():
342+
case <-api.TURNServer.Accept(wsNetConn, remoteAddress, localAddress).Closed():
343+
case <-ctx.Done():
343344
}
344-
api.Logger.Debug(r.Context(), "completed turn connection", slog.F("remote-address", r.RemoteAddr), slog.F("local-address", localAddress))
345+
api.Logger.Debug(ctx, "completed turn connection", slog.F("remote-address", r.RemoteAddr), slog.F("local-address", localAddress))
345346
}
346347

347348
// workspaceAgentPTY spawns a PTY and pipes it over a WebSocket.
@@ -392,11 +393,10 @@ func (api *API) workspaceAgentPTY(rw http.ResponseWriter, r *http.Request) {
392393
})
393394
return
394395
}
395-
defer func() {
396-
_ = conn.Close(websocket.StatusNormalClosure, "ended")
397-
}()
398-
// Accept text connections, because it's more developer friendly.
399-
wsNetConn := websocket.NetConn(r.Context(), conn, websocket.MessageBinary)
396+
397+
_, wsNetConn := websocketNetConn(r.Context(), conn, websocket.MessageBinary)
398+
defer wsNetConn.Close() // Also closes conn.
399+
400400
agentConn, release, err := api.workspaceAgentCache.Acquire(r, workspaceAgent.ID)
401401
if err != nil {
402402
_ = conn.Close(websocket.StatusInternalError, httpapi.WebsocketCloseSprintf("dial workspace agent: %s", err))
@@ -416,8 +416,10 @@ func (api *API) workspaceAgentPTY(rw http.ResponseWriter, r *http.Request) {
416416
_, _ = io.Copy(ptNetConn, wsNetConn)
417417
}
418418

419-
// dialWorkspaceAgent connects to a workspace agent by ID.
420-
func (api *API) dialWorkspaceAgent(r *http.Request, agentID uuid.UUID) (*agent.Conn, error) {
419+
// dialWorkspaceAgent connects to a workspace agent by ID. Only rely on
420+
// r.Context() for cancellation if it's use is safe or r.Hijack() has
421+
// not been performed.
422+
func (api *API) dialWorkspaceAgent(ctx context.Context, r *http.Request, agentID uuid.UUID) (*agent.Conn, error) {
421423
client, server := provisionersdk.TransportPipe()
422424
ctx, cancelFunc := context.WithCancel(context.Background())
423425
go func() {
@@ -446,7 +448,7 @@ func (api *API) dialWorkspaceAgent(r *http.Request, agentID uuid.UUID) (*agent.C
446448
options.SettingEngine.SetICEProxyDialer(turnconn.ProxyDialer(func() (c net.Conn, err error) {
447449
clientPipe, serverPipe := net.Pipe()
448450
go func() {
449-
<-r.Context().Done()
451+
<-ctx.Done()
450452
_ = clientPipe.Close()
451453
_ = serverPipe.Close()
452454
}()
@@ -546,3 +548,44 @@ func convertWorkspaceAgent(dbAgent database.WorkspaceAgent, apps []codersdk.Work
546548

547549
return workspaceAgent, nil
548550
}
551+
552+
// wsNetConn wraps net.Conn created by websocket.NetConn(). Cancel func
553+
// is called if a read or write error is encountered.
554+
type wsNetConn struct {
555+
cancel context.CancelFunc
556+
net.Conn
557+
}
558+
559+
func (c *wsNetConn) Read(b []byte) (n int, err error) {
560+
n, err = c.Conn.Read(b)
561+
if err != nil {
562+
c.cancel()
563+
}
564+
return n, err
565+
}
566+
567+
func (c *wsNetConn) Write(b []byte) (n int, err error) {
568+
n, err = c.Conn.Write(b)
569+
if err != nil {
570+
c.cancel()
571+
}
572+
return n, err
573+
}
574+
575+
func (c *wsNetConn) Close() error {
576+
defer c.cancel()
577+
return c.Conn.Close()
578+
}
579+
580+
// websocketNetConn wraps websocket.NetConn and returns a context that
581+
// is tied to the parent context and the lifetime of the conn. Any error
582+
// during read or write will cancel the context, but not close the
583+
// conn. Close should be called to release context resources.
584+
func websocketNetConn(ctx context.Context, conn *websocket.Conn, msgType websocket.MessageType) (context.Context, net.Conn) {
585+
ctx, cancel := context.WithCancel(ctx)
586+
nc := websocket.NetConn(ctx, conn, msgType)
587+
return ctx, &wsNetConn{
588+
cancel: cancel,
589+
Conn: nc,
590+
}
591+
}

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy