From 76911bbab68f191794448c7662a90254c8d84929 Mon Sep 17 00:00:00 2001 From: Spike Curtis Date: Mon, 2 Dec 2024 10:23:40 +0400 Subject: [PATCH] fix: fix goroutine leak in log streaming over websocket --- coderd/provisionerjobs.go | 9 ++-- coderd/workspaceagents.go | 24 ++++------- codersdk/provisionerdaemons.go | 33 ++------------- codersdk/workspaceagents.go | 29 ++----------- codersdk/wsjson/decoder.go | 75 ++++++++++++++++++++++++++++++++++ codersdk/wsjson/encoder.go | 42 +++++++++++++++++++ 6 files changed, 134 insertions(+), 78 deletions(-) create mode 100644 codersdk/wsjson/decoder.go create mode 100644 codersdk/wsjson/encoder.go diff --git a/coderd/provisionerjobs.go b/coderd/provisionerjobs.go index df832b810e696..3db5d7c20a4bf 100644 --- a/coderd/provisionerjobs.go +++ b/coderd/provisionerjobs.go @@ -15,6 +15,7 @@ import ( "nhooyr.io/websocket" "cdr.dev/slog" + "github.com/coder/coder/v2/codersdk/wsjson" "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/database/db2sdk" @@ -312,6 +313,7 @@ type logFollower struct { r *http.Request rw http.ResponseWriter conn *websocket.Conn + enc *wsjson.Encoder[codersdk.ProvisionerJobLog] jobID uuid.UUID after int64 @@ -391,6 +393,7 @@ func (f *logFollower) follow() { } defer f.conn.Close(websocket.StatusNormalClosure, "done") go httpapi.Heartbeat(f.ctx, f.conn) + f.enc = wsjson.NewEncoder[codersdk.ProvisionerJobLog](f.conn, websocket.MessageText) // query for logs once right away, so we can get historical data from before // subscription @@ -488,11 +491,7 @@ func (f *logFollower) query() error { return xerrors.Errorf("error fetching logs: %w", err) } for _, log := range logs { - logB, err := json.Marshal(convertProvisionerJobLog(log)) - if err != nil { - return xerrors.Errorf("error marshaling log: %w", err) - } - err = f.conn.Write(f.ctx, websocket.MessageText, logB) + err := f.enc.Encode(convertProvisionerJobLog(log)) if err != nil { return xerrors.Errorf("error writing to websocket: %w", err) } diff --git a/coderd/workspaceagents.go b/coderd/workspaceagents.go index 922d80f0e8085..6bc09e0e770f6 100644 --- a/coderd/workspaceagents.go +++ b/coderd/workspaceagents.go @@ -39,6 +39,7 @@ import ( "github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/codersdk/agentsdk" "github.com/coder/coder/v2/codersdk/workspacesdk" + "github.com/coder/coder/v2/codersdk/wsjson" "github.com/coder/coder/v2/tailnet" "github.com/coder/coder/v2/tailnet/proto" ) @@ -396,11 +397,9 @@ func (api *API) workspaceAgentLogs(rw http.ResponseWriter, r *http.Request) { } go httpapi.Heartbeat(ctx, conn) - ctx, wsNetConn := codersdk.WebsocketNetConn(ctx, conn, websocket.MessageText) - defer wsNetConn.Close() // Also closes conn. + encoder := wsjson.NewEncoder[[]codersdk.WorkspaceAgentLog](conn, websocket.MessageText) + defer encoder.Close(websocket.StatusNormalClosure) - // The Go stdlib JSON encoder appends a newline character after message write. - encoder := json.NewEncoder(wsNetConn) err = encoder.Encode(convertWorkspaceAgentLogs(logs)) if err != nil { return @@ -740,16 +739,8 @@ func (api *API) derpMapUpdates(rw http.ResponseWriter, r *http.Request) { }) return } - ctx, nconn := codersdk.WebsocketNetConn(ctx, ws, websocket.MessageBinary) - defer nconn.Close() - - // Slurp all packets from the connection into io.Discard so pongs get sent - // by the websocket package. We don't do any reads ourselves so this is - // necessary. - go func() { - _, _ = io.Copy(io.Discard, nconn) - _ = nconn.Close() - }() + encoder := wsjson.NewEncoder[*tailcfg.DERPMap](ws, websocket.MessageBinary) + defer encoder.Close(websocket.StatusGoingAway) go func(ctx context.Context) { // TODO(mafredri): Is this too frequent? Use separate ping disconnect timeout? @@ -767,7 +758,7 @@ func (api *API) derpMapUpdates(rw http.ResponseWriter, r *http.Request) { err := ws.Ping(ctx) cancel() if err != nil { - _ = nconn.Close() + _ = ws.Close(websocket.StatusGoingAway, "ping failed") return } } @@ -780,9 +771,8 @@ func (api *API) derpMapUpdates(rw http.ResponseWriter, r *http.Request) { for { derpMap := api.DERPMap() if lastDERPMap == nil || !tailnet.CompareDERPMaps(lastDERPMap, derpMap) { - err := json.NewEncoder(nconn).Encode(derpMap) + err := encoder.Encode(derpMap) if err != nil { - _ = nconn.Close() return } lastDERPMap = derpMap diff --git a/codersdk/provisionerdaemons.go b/codersdk/provisionerdaemons.go index 27d2766a7cd13..fb588ef8ba468 100644 --- a/codersdk/provisionerdaemons.go +++ b/codersdk/provisionerdaemons.go @@ -19,6 +19,7 @@ import ( "github.com/coder/coder/v2/buildinfo" "github.com/coder/coder/v2/codersdk/drpc" + "github.com/coder/coder/v2/codersdk/wsjson" "github.com/coder/coder/v2/provisionerd/proto" "github.com/coder/coder/v2/provisionerd/runner" ) @@ -162,36 +163,8 @@ func (c *Client) provisionerJobLogsAfter(ctx context.Context, path string, after } return nil, nil, ReadBodyAsError(res) } - logs := make(chan ProvisionerJobLog) - closed := make(chan struct{}) - go func() { - defer close(closed) - defer close(logs) - defer conn.Close(websocket.StatusGoingAway, "") - var log ProvisionerJobLog - for { - msgType, msg, err := conn.Read(ctx) - if err != nil { - return - } - if msgType != websocket.MessageText { - return - } - err = json.Unmarshal(msg, &log) - if err != nil { - return - } - select { - case <-ctx.Done(): - return - case logs <- log: - } - } - }() - return logs, closeFunc(func() error { - <-closed - return nil - }), nil + d := wsjson.NewDecoder[ProvisionerJobLog](conn, websocket.MessageText, c.logger) + return d.Chan(), d, nil } // ServeProvisionerDaemonRequest are the parameters to call ServeProvisionerDaemon with diff --git a/codersdk/workspaceagents.go b/codersdk/workspaceagents.go index eeb335b130cdd..b4aec16a83190 100644 --- a/codersdk/workspaceagents.go +++ b/codersdk/workspaceagents.go @@ -15,6 +15,7 @@ import ( "nhooyr.io/websocket" "github.com/coder/coder/v2/coderd/tracing" + "github.com/coder/coder/v2/codersdk/wsjson" ) type WorkspaceAgentStatus string @@ -454,30 +455,6 @@ func (c *Client) WorkspaceAgentLogsAfter(ctx context.Context, agentID uuid.UUID, } return nil, nil, ReadBodyAsError(res) } - logChunks := make(chan []WorkspaceAgentLog, 1) - closed := make(chan struct{}) - ctx, wsNetConn := WebsocketNetConn(ctx, conn, websocket.MessageText) - decoder := json.NewDecoder(wsNetConn) - go func() { - defer close(closed) - defer close(logChunks) - defer conn.Close(websocket.StatusGoingAway, "") - for { - var logs []WorkspaceAgentLog - err = decoder.Decode(&logs) - if err != nil { - return - } - select { - case <-ctx.Done(): - return - case logChunks <- logs: - } - } - }() - return logChunks, closeFunc(func() error { - _ = wsNetConn.Close() - <-closed - return nil - }), nil + d := wsjson.NewDecoder[[]WorkspaceAgentLog](conn, websocket.MessageText, c.logger) + return d.Chan(), d, nil } diff --git a/codersdk/wsjson/decoder.go b/codersdk/wsjson/decoder.go new file mode 100644 index 0000000000000..4cc7ff380a73a --- /dev/null +++ b/codersdk/wsjson/decoder.go @@ -0,0 +1,75 @@ +package wsjson + +import ( + "context" + "encoding/json" + "sync/atomic" + + "nhooyr.io/websocket" + + "cdr.dev/slog" +) + +type Decoder[T any] struct { + conn *websocket.Conn + typ websocket.MessageType + ctx context.Context + cancel context.CancelFunc + chanCalled atomic.Bool + logger slog.Logger +} + +// Chan starts the decoder reading from the websocket and returns a channel for reading the +// resulting values. The chan T is closed if the underlying websocket is closed, or we encounter an +// error. We also close the underlying websocket if we encounter an error reading or decoding. +func (d *Decoder[T]) Chan() <-chan T { + if !d.chanCalled.CompareAndSwap(false, true) { + panic("chan called more than once") + } + values := make(chan T, 1) + go func() { + defer close(values) + defer d.conn.Close(websocket.StatusGoingAway, "") + for { + // we don't use d.ctx here because it only gets canceled after closing the connection + // and a "connection closed" type error is more clear than context canceled. + typ, b, err := d.conn.Read(context.Background()) + if err != nil { + // might be benign like EOF, so just log at debug + d.logger.Debug(d.ctx, "error reading from websocket", slog.Error(err)) + return + } + if typ != d.typ { + d.logger.Error(d.ctx, "websocket type mismatch while decoding") + return + } + var value T + err = json.Unmarshal(b, &value) + if err != nil { + d.logger.Error(d.ctx, "error unmarshalling", slog.Error(err)) + return + } + select { + case values <- value: + // OK + case <-d.ctx.Done(): + return + } + } + }() + return values +} + +// nolint: revive // complains that Encoder has the same function name +func (d *Decoder[T]) Close() error { + err := d.conn.Close(websocket.StatusNormalClosure, "") + d.cancel() + return err +} + +// NewDecoder creates a JSON-over-websocket decoder for type T, which must be deserializable from +// JSON. +func NewDecoder[T any](conn *websocket.Conn, typ websocket.MessageType, logger slog.Logger) *Decoder[T] { + ctx, cancel := context.WithCancel(context.Background()) + return &Decoder[T]{conn: conn, ctx: ctx, cancel: cancel, typ: typ, logger: logger} +} diff --git a/codersdk/wsjson/encoder.go b/codersdk/wsjson/encoder.go new file mode 100644 index 0000000000000..4cde05984e690 --- /dev/null +++ b/codersdk/wsjson/encoder.go @@ -0,0 +1,42 @@ +package wsjson + +import ( + "context" + "encoding/json" + + "golang.org/x/xerrors" + "nhooyr.io/websocket" +) + +type Encoder[T any] struct { + conn *websocket.Conn + typ websocket.MessageType +} + +func (e *Encoder[T]) Encode(v T) error { + w, err := e.conn.Writer(context.Background(), e.typ) + if err != nil { + return xerrors.Errorf("get websocket writer: %w", err) + } + defer w.Close() + j := json.NewEncoder(w) + err = j.Encode(v) + if err != nil { + return xerrors.Errorf("encode json: %w", err) + } + return nil +} + +func (e *Encoder[T]) Close(c websocket.StatusCode) error { + return e.conn.Close(c, "") +} + +// NewEncoder creates a JSON-over websocket encoder for the type T, which must be JSON-serializable. +// You may then call Encode() to send objects over the websocket. Creating an Encoder closes the +// websocket for reading, turning it into a unidirectional write stream of JSON-encoded objects. +func NewEncoder[T any](conn *websocket.Conn, typ websocket.MessageType) *Encoder[T] { + // Here we close the websocket for reading, so that the websocket library will handle pings and + // close frames. + _ = conn.CloseRead(context.Background()) + return &Encoder[T]{conn: conn, typ: typ} +} pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy