diff --git a/coderd/apidoc/docs.go b/coderd/apidoc/docs.go index a543e5b716e8f..e553f66d7a9a5 100644 --- a/coderd/apidoc/docs.go +++ b/coderd/apidoc/docs.go @@ -8618,6 +8618,7 @@ const docTemplate = `{ ], "summary": "Watch for workspace agent metadata updates", "operationId": "watch-for-workspace-agent-metadata-updates", + "deprecated": true, "parameters": [ { "type": "string", @@ -8638,6 +8639,44 @@ const docTemplate = `{ } } }, + "/workspaceagents/{workspaceagent}/watch-metadata-ws": { + "get": { + "security": [ + { + "CoderSessionToken": [] + } + ], + "produces": [ + "application/json" + ], + "tags": [ + "Agents" + ], + "summary": "Watch for workspace agent metadata updates via WebSockets", + "operationId": "watch-for-workspace-agent-metadata-updates-via-websockets", + "parameters": [ + { + "type": "string", + "format": "uuid", + "description": "Workspace agent ID", + "name": "workspaceagent", + "in": "path", + "required": true + } + ], + "responses": { + "200": { + "description": "OK", + "schema": { + "$ref": "#/definitions/codersdk.ServerSentEvent" + } + } + }, + "x-apidocgen": { + "skip": true + } + } + }, "/workspacebuilds/{workspacebuild}": { "get": { "security": [ @@ -10049,6 +10088,7 @@ const docTemplate = `{ ], "summary": "Watch workspace by ID", "operationId": "watch-workspace-by-id", + "deprecated": true, "parameters": [ { "type": "string", @@ -10068,6 +10108,41 @@ const docTemplate = `{ } } } + }, + "/workspaces/{workspace}/watch-ws": { + "get": { + "security": [ + { + "CoderSessionToken": [] + } + ], + "produces": [ + "application/json" + ], + "tags": [ + "Workspaces" + ], + "summary": "Watch workspace by ID via WebSockets", + "operationId": "watch-workspace-by-id-via-websockets", + "parameters": [ + { + "type": "string", + "format": "uuid", + "description": "Workspace ID", + "name": "workspace", + "in": "path", + "required": true + } + ], + "responses": { + "200": { + "description": "OK", + "schema": { + "$ref": "#/definitions/codersdk.ServerSentEvent" + } + } + } + } } }, "definitions": { @@ -14621,6 +14696,28 @@ const docTemplate = `{ } } }, + "codersdk.ServerSentEvent": { + "type": "object", + "properties": { + "data": {}, + "type": { + "$ref": "#/definitions/codersdk.ServerSentEventType" + } + } + }, + "codersdk.ServerSentEventType": { + "type": "string", + "enum": [ + "ping", + "data", + "error" + ], + "x-enum-varnames": [ + "ServerSentEventTypePing", + "ServerSentEventTypeData", + "ServerSentEventTypeError" + ] + }, "codersdk.SessionCountDeploymentStats": { "type": "object", "properties": { diff --git a/coderd/apidoc/swagger.json b/coderd/apidoc/swagger.json index 586f63e5c6d6f..9765d79218c5e 100644 --- a/coderd/apidoc/swagger.json +++ b/coderd/apidoc/swagger.json @@ -7627,6 +7627,7 @@ "tags": ["Agents"], "summary": "Watch for workspace agent metadata updates", "operationId": "watch-for-workspace-agent-metadata-updates", + "deprecated": true, "parameters": [ { "type": "string", @@ -7647,6 +7648,40 @@ } } }, + "/workspaceagents/{workspaceagent}/watch-metadata-ws": { + "get": { + "security": [ + { + "CoderSessionToken": [] + } + ], + "produces": ["application/json"], + "tags": ["Agents"], + "summary": "Watch for workspace agent metadata updates via WebSockets", + "operationId": "watch-for-workspace-agent-metadata-updates-via-websockets", + "parameters": [ + { + "type": "string", + "format": "uuid", + "description": "Workspace agent ID", + "name": "workspaceagent", + "in": "path", + "required": true + } + ], + "responses": { + "200": { + "description": "OK", + "schema": { + "$ref": "#/definitions/codersdk.ServerSentEvent" + } + } + }, + "x-apidocgen": { + "skip": true + } + } + }, "/workspacebuilds/{workspacebuild}": { "get": { "security": [ @@ -8900,6 +8935,7 @@ "tags": ["Workspaces"], "summary": "Watch workspace by ID", "operationId": "watch-workspace-by-id", + "deprecated": true, "parameters": [ { "type": "string", @@ -8919,6 +8955,37 @@ } } } + }, + "/workspaces/{workspace}/watch-ws": { + "get": { + "security": [ + { + "CoderSessionToken": [] + } + ], + "produces": ["application/json"], + "tags": ["Workspaces"], + "summary": "Watch workspace by ID via WebSockets", + "operationId": "watch-workspace-by-id-via-websockets", + "parameters": [ + { + "type": "string", + "format": "uuid", + "description": "Workspace ID", + "name": "workspace", + "in": "path", + "required": true + } + ], + "responses": { + "200": { + "description": "OK", + "schema": { + "$ref": "#/definitions/codersdk.ServerSentEvent" + } + } + } + } } }, "definitions": { @@ -13265,6 +13332,24 @@ } } }, + "codersdk.ServerSentEvent": { + "type": "object", + "properties": { + "data": {}, + "type": { + "$ref": "#/definitions/codersdk.ServerSentEventType" + } + } + }, + "codersdk.ServerSentEventType": { + "type": "string", + "enum": ["ping", "data", "error"], + "x-enum-varnames": [ + "ServerSentEventTypePing", + "ServerSentEventTypeData", + "ServerSentEventTypeError" + ] + }, "codersdk.SessionCountDeploymentStats": { "type": "object", "properties": { diff --git a/coderd/coderd.go b/coderd/coderd.go index f68ddeadb6e6b..20982de70a741 100644 --- a/coderd/coderd.go +++ b/coderd/coderd.go @@ -1248,7 +1248,8 @@ func New(options *Options) *API { httpmw.ExtractWorkspaceParam(options.Database), ) r.Get("/", api.workspaceAgent) - r.Get("/watch-metadata", api.watchWorkspaceAgentMetadata) + r.Get("/watch-metadata", api.watchWorkspaceAgentMetadataSSE) + r.Get("/watch-metadata-ws", api.watchWorkspaceAgentMetadataWS) r.Get("/startup-logs", api.workspaceAgentLogsDeprecated) r.Get("/logs", api.workspaceAgentLogs) r.Get("/listening-ports", api.workspaceAgentListeningPorts) @@ -1280,7 +1281,8 @@ func New(options *Options) *API { r.Route("/ttl", func(r chi.Router) { r.Put("/", api.putWorkspaceTTL) }) - r.Get("/watch", api.watchWorkspace) + r.Get("/watch", api.watchWorkspaceSSE) + r.Get("/watch-ws", api.watchWorkspaceWS) r.Put("/extend", api.putExtendWorkspace) r.Post("/usage", api.postWorkspaceUsage) r.Put("/dormant", api.putWorkspaceDormant) diff --git a/coderd/httpapi/httpapi.go b/coderd/httpapi/httpapi.go index d5895dcbf86f0..c70290ffe56b0 100644 --- a/coderd/httpapi/httpapi.go +++ b/coderd/httpapi/httpapi.go @@ -16,6 +16,9 @@ import ( "github.com/go-playground/validator/v10" "golang.org/x/xerrors" + "github.com/coder/websocket" + "github.com/coder/websocket/wsjson" + "github.com/coder/coder/v2/coderd/httpapi/httpapiconstraints" "github.com/coder/coder/v2/coderd/tracing" "github.com/coder/coder/v2/codersdk" @@ -282,7 +285,25 @@ func WebsocketCloseSprintf(format string, vars ...any) string { return msg } -func ServerSentEventSender(rw http.ResponseWriter, r *http.Request) (sendEvent func(ctx context.Context, sse codersdk.ServerSentEvent) error, closed chan struct{}, err error) { +type EventSender func(rw http.ResponseWriter, r *http.Request) ( + sendEvent func(sse codersdk.ServerSentEvent) error, + done <-chan struct{}, + err error, +) + +// ServerSentEventSender establishes a Server-Sent Event connection and allows +// the consumer to send messages to the client. +// +// The function returned allows you to send a single message to the client, +// while the channel lets you listen for when the connection closes. +// +// As much as possible, this function should be avoided in favor of using the +// OneWayWebSocket function. See OneWayWebSocket for more context. +func ServerSentEventSender(rw http.ResponseWriter, r *http.Request) ( + func(sse codersdk.ServerSentEvent) error, + <-chan struct{}, + error, +) { h := rw.Header() h.Set("Content-Type", "text/event-stream") h.Set("Cache-Control", "no-cache") @@ -294,7 +315,8 @@ func ServerSentEventSender(rw http.ResponseWriter, r *http.Request) (sendEvent f panic("http.ResponseWriter is not http.Flusher") } - closed = make(chan struct{}) + ctx := r.Context() + closed := make(chan struct{}) type sseEvent struct { payload []byte errC chan error @@ -304,16 +326,13 @@ func ServerSentEventSender(rw http.ResponseWriter, r *http.Request) (sendEvent f // Synchronized handling of events (no guarantee of order). go func() { defer close(closed) - - // Send a heartbeat every 15 seconds to avoid the connection being killed. - ticker := time.NewTicker(time.Second * 15) + ticker := time.NewTicker(HeartbeatInterval) defer ticker.Stop() for { var event sseEvent - select { - case <-r.Context().Done(): + case <-ctx.Done(): return case event = <-eventC: case <-ticker.C: @@ -333,21 +352,21 @@ func ServerSentEventSender(rw http.ResponseWriter, r *http.Request) (sendEvent f } }() - sendEvent = func(ctx context.Context, sse codersdk.ServerSentEvent) error { + sendEvent := func(newEvent codersdk.ServerSentEvent) error { buf := &bytes.Buffer{} - enc := json.NewEncoder(buf) - - _, err := buf.WriteString(fmt.Sprintf("event: %s\n", sse.Type)) + _, err := buf.WriteString(fmt.Sprintf("event: %s\n", newEvent.Type)) if err != nil { return err } - if sse.Data != nil { + if newEvent.Data != nil { _, err = buf.WriteString("data: ") if err != nil { return err } - err = enc.Encode(sse.Data) + + enc := json.NewEncoder(buf) + err = enc.Encode(newEvent.Data) if err != nil { return err } @@ -364,8 +383,6 @@ func ServerSentEventSender(rw http.ResponseWriter, r *http.Request) (sendEvent f } select { - case <-r.Context().Done(): - return r.Context().Err() case <-ctx.Done(): return ctx.Err() case <-closed: @@ -375,8 +392,6 @@ func ServerSentEventSender(rw http.ResponseWriter, r *http.Request) (sendEvent f // for early exit. We don't check closed here because it // can't happen while processing the event. select { - case <-r.Context().Done(): - return r.Context().Err() case <-ctx.Done(): return ctx.Err() case err := <-event.errC: @@ -387,3 +402,90 @@ func ServerSentEventSender(rw http.ResponseWriter, r *http.Request) (sendEvent f return sendEvent, closed, nil } + +// OneWayWebSocketEventSender establishes a new WebSocket connection that +// enforces one-way communication from the server to the client. +// +// The function returned allows you to send a single message to the client, +// while the channel lets you listen for when the connection closes. +// +// We must use an approach like this instead of Server-Sent Events for the +// browser, because on HTTP/1.1 connections, browsers are locked to no more than +// six HTTP connections for a domain total, across all tabs. If a user were to +// open a workspace in multiple tabs, the entire UI can start to lock up. +// WebSockets have no such limitation, no matter what HTTP protocol was used to +// establish the connection. +func OneWayWebSocketEventSender(rw http.ResponseWriter, r *http.Request) ( + func(event codersdk.ServerSentEvent) error, + <-chan struct{}, + error, +) { + ctx, cancel := context.WithCancel(r.Context()) + r = r.WithContext(ctx) + socket, err := websocket.Accept(rw, r, nil) + if err != nil { + cancel() + return nil, nil, xerrors.Errorf("cannot establish connection: %w", err) + } + go Heartbeat(ctx, socket) + + eventC := make(chan codersdk.ServerSentEvent) + socketErrC := make(chan websocket.CloseError, 1) + closed := make(chan struct{}) + go func() { + defer cancel() + defer close(closed) + + for { + select { + case event := <-eventC: + writeCtx, cancel := context.WithTimeout(ctx, 10*time.Second) + err := wsjson.Write(writeCtx, socket, event) + cancel() + if err == nil { + continue + } + _ = socket.Close(websocket.StatusInternalError, "Unable to send newest message") + case err := <-socketErrC: + _ = socket.Close(err.Code, err.Reason) + case <-ctx.Done(): + _ = socket.Close(websocket.StatusNormalClosure, "Connection closed") + } + return + } + }() + + // We have some tools in the UI code to help enforce one-way WebSocket + // connections, but there's still the possibility that the client could send + // a message when it's not supposed to. If that happens, the client likely + // forgot to use those tools, and communication probably can't be trusted. + // Better to just close the socket and force the UI to fix its mess + go func() { + _, _, err := socket.Read(ctx) + if errors.Is(err, context.Canceled) { + return + } + if err != nil { + socketErrC <- websocket.CloseError{ + Code: websocket.StatusInternalError, + Reason: "Unable to process invalid message from client", + } + return + } + socketErrC <- websocket.CloseError{ + Code: websocket.StatusProtocolError, + Reason: "Clients cannot send messages for one-way WebSockets", + } + }() + + sendEvent := func(event codersdk.ServerSentEvent) error { + select { + case eventC <- event: + case <-ctx.Done(): + return ctx.Err() + } + return nil + } + + return sendEvent, closed, nil +} diff --git a/coderd/httpapi/httpapi_test.go b/coderd/httpapi/httpapi_test.go index eb3f23e6ca346..44675e78a255d 100644 --- a/coderd/httpapi/httpapi_test.go +++ b/coderd/httpapi/httpapi_test.go @@ -1,14 +1,18 @@ package httpapi_test import ( + "bufio" "bytes" "context" "encoding/json" "fmt" + "io" + "net" "net/http" "net/http/httptest" "strings" "testing" + "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -16,6 +20,7 @@ import ( "github.com/coder/coder/v2/coderd/httpapi" "github.com/coder/coder/v2/codersdk" + "github.com/coder/coder/v2/testutil" ) func TestInternalServerError(t *testing.T) { @@ -155,3 +160,436 @@ func TestWebsocketCloseMsg(t *testing.T) { assert.Equal(t, len(trunc), 123) }) } + +// Our WebSocket library accepts any arbitrary ResponseWriter at the type level, +// but the writer must also implement http.Hijacker for long-lived connections. +type mockOneWaySocketWriter struct { + serverRecorder *httptest.ResponseRecorder + serverConn net.Conn + clientConn net.Conn + serverReadWriter *bufio.ReadWriter + testContext *testing.T +} + +func (m mockOneWaySocketWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { + return m.serverConn, m.serverReadWriter, nil +} + +func (m mockOneWaySocketWriter) Flush() { + err := m.serverReadWriter.Flush() + require.NoError(m.testContext, err) +} + +func (m mockOneWaySocketWriter) Header() http.Header { + return m.serverRecorder.Header() +} + +func (m mockOneWaySocketWriter) Write(b []byte) (int, error) { + return m.serverReadWriter.Write(b) +} + +func (m mockOneWaySocketWriter) WriteHeader(code int) { + m.serverRecorder.WriteHeader(code) +} + +type mockEventSenderWrite func(b []byte) (int, error) + +func (w mockEventSenderWrite) Write(b []byte) (int, error) { + return w(b) +} + +func TestOneWayWebSocketEventSender(t *testing.T) { + t.Parallel() + + newBaseRequest := func(ctx context.Context) *http.Request { + url := "ws://www.fake-website.com/logs" + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + require.NoError(t, err) + + h := req.Header + h.Add("Connection", "Upgrade") + h.Add("Upgrade", "websocket") + h.Add("Sec-WebSocket-Version", "13") + h.Add("Sec-WebSocket-Key", "dGhlIHNhbXBsZSBub25jZQ==") // Just need any string + + return req + } + + newOneWayWriter := func(t *testing.T) mockOneWaySocketWriter { + mockServer, mockClient := net.Pipe() + recorder := httptest.NewRecorder() + + var write mockEventSenderWrite = func(b []byte) (int, error) { + serverCount, err := mockServer.Write(b) + if err != nil { + return 0, err + } + recorderCount, err := recorder.Write(b) + if err != nil { + return 0, err + } + return min(serverCount, recorderCount), nil + } + + return mockOneWaySocketWriter{ + testContext: t, + serverConn: mockServer, + clientConn: mockClient, + serverRecorder: recorder, + serverReadWriter: bufio.NewReadWriter( + bufio.NewReader(mockServer), + bufio.NewWriter(write), + ), + } + } + + t.Run("Produces error if the socket connection could not be established", func(t *testing.T) { + t.Parallel() + + incorrectProtocols := []struct { + major int + minor int + proto string + }{ + {0, 9, "HTTP/0.9"}, + {1, 0, "HTTP/1.0"}, + } + for _, p := range incorrectProtocols { + ctx := testutil.Context(t, testutil.WaitShort) + req := newBaseRequest(ctx) + req.ProtoMajor = p.major + req.ProtoMinor = p.minor + req.Proto = p.proto + + writer := newOneWayWriter(t) + _, _, err := httpapi.OneWayWebSocketEventSender(writer, req) + require.ErrorContains(t, err, p.proto) + } + }) + + t.Run("Returned callback can publish new event to WebSocket connection", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitShort) + req := newBaseRequest(ctx) + writer := newOneWayWriter(t) + send, _, err := httpapi.OneWayWebSocketEventSender(writer, req) + require.NoError(t, err) + + serverPayload := codersdk.ServerSentEvent{ + Type: codersdk.ServerSentEventTypeData, + Data: "Blah", + } + err = send(serverPayload) + require.NoError(t, err) + + // The client connection will receive a little bit of additional data on + // top of the main payload. Have to make sure check has tolerance for + // extra data being present + serverBytes, err := json.Marshal(serverPayload) + require.NoError(t, err) + clientBytes, err := io.ReadAll(writer.clientConn) + require.NoError(t, err) + require.True(t, bytes.Contains(clientBytes, serverBytes)) + }) + + t.Run("Signals to outside consumer when socket has been closed", func(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithCancel(testutil.Context(t, testutil.WaitShort)) + req := newBaseRequest(ctx) + writer := newOneWayWriter(t) + _, done, err := httpapi.OneWayWebSocketEventSender(writer, req) + require.NoError(t, err) + + successC := make(chan bool) + ticker := time.NewTicker(testutil.WaitShort) + go func() { + select { + case <-done: + successC <- true + case <-ticker.C: + successC <- false + } + }() + + cancel() + require.True(t, <-successC) + }) + + t.Run("Socket will immediately close if client sends any message", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitShort) + req := newBaseRequest(ctx) + writer := newOneWayWriter(t) + _, done, err := httpapi.OneWayWebSocketEventSender(writer, req) + require.NoError(t, err) + + successC := make(chan bool) + ticker := time.NewTicker(testutil.WaitShort) + go func() { + select { + case <-done: + successC <- true + case <-ticker.C: + successC <- false + } + }() + + type JunkClientEvent struct { + Value string + } + b, err := json.Marshal(JunkClientEvent{"Hi :)"}) + require.NoError(t, err) + _, err = writer.clientConn.Write(b) + require.NoError(t, err) + require.True(t, <-successC) + }) + + t.Run("Renders the socket inert if the request context cancels", func(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithCancel(testutil.Context(t, testutil.WaitShort)) + req := newBaseRequest(ctx) + writer := newOneWayWriter(t) + send, done, err := httpapi.OneWayWebSocketEventSender(writer, req) + require.NoError(t, err) + + successC := make(chan bool) + ticker := time.NewTicker(testutil.WaitShort) + go func() { + select { + case <-done: + successC <- true + case <-ticker.C: + successC <- false + } + }() + + cancel() + require.True(t, <-successC) + err = send(codersdk.ServerSentEvent{ + Type: codersdk.ServerSentEventTypeData, + Data: "Didn't realize you were closed - sorry! I'll try coming back tomorrow.", + }) + require.Equal(t, err, ctx.Err()) + _, open := <-done + require.False(t, open) + _, err = writer.serverConn.Write([]byte{}) + require.Equal(t, err, io.ErrClosedPipe) + _, err = writer.clientConn.Read([]byte{}) + require.Equal(t, err, io.EOF) + }) + + t.Run("Sends a heartbeat to the socket on a fixed internal of time to keep connections alive", func(t *testing.T) { + t.Parallel() + + // Need add at least three heartbeats for something to be reliably + // counted as an interval, but also need some wiggle room + heartbeatCount := 3 + hbDuration := time.Duration(heartbeatCount) * httpapi.HeartbeatInterval + timeout := hbDuration + (5 * time.Second) + + ctx := testutil.Context(t, timeout) + req := newBaseRequest(ctx) + writer := newOneWayWriter(t) + _, _, err := httpapi.OneWayWebSocketEventSender(writer, req) + require.NoError(t, err) + + type Result struct { + Err error + Success bool + } + resultC := make(chan Result) + go func() { + err := writer. + clientConn. + SetReadDeadline(time.Now().Add(timeout)) + if err != nil { + resultC <- Result{err, false} + return + } + for range heartbeatCount { + pingBuffer := make([]byte, 1) + pingSize, err := writer.clientConn.Read(pingBuffer) + if err != nil || pingSize != 1 { + resultC <- Result{err, false} + return + } + } + resultC <- Result{nil, true} + }() + + result := <-resultC + require.NoError(t, result.Err) + require.True(t, result.Success) + }) +} + +// ServerSentEventSender accepts any arbitrary ResponseWriter at the type level, +// but the writer must also implement http.Flusher for long-lived connections +type mockServerSentWriter struct { + serverRecorder *httptest.ResponseRecorder + serverConn net.Conn + clientConn net.Conn + buffer *bytes.Buffer + testContext *testing.T +} + +func (m mockServerSentWriter) Flush() { + b := m.buffer.Bytes() + _, err := m.serverConn.Write(b) + require.NoError(m.testContext, err) + m.buffer.Reset() + + // Must close server connection to indicate EOF for any reads from the + // client connection; otherwise reads block forever. This is a testing + // limitation compared to the one-way websockets, since we have no way to + // frame the data and auto-indicate EOF for each message + err = m.serverConn.Close() + require.NoError(m.testContext, err) +} + +func (m mockServerSentWriter) Header() http.Header { + return m.serverRecorder.Header() +} + +func (m mockServerSentWriter) Write(b []byte) (int, error) { + return m.buffer.Write(b) +} + +func (m mockServerSentWriter) WriteHeader(code int) { + m.serverRecorder.WriteHeader(code) +} + +func TestServerSentEventSender(t *testing.T) { + t.Parallel() + + newBaseRequest := func(ctx context.Context) *http.Request { + url := "ws://www.fake-website.com/logs" + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + require.NoError(t, err) + return req + } + + newServerSentWriter := func(t *testing.T) mockServerSentWriter { + mockServer, mockClient := net.Pipe() + return mockServerSentWriter{ + testContext: t, + serverRecorder: httptest.NewRecorder(), + clientConn: mockClient, + serverConn: mockServer, + buffer: &bytes.Buffer{}, + } + } + + t.Run("Mutates response headers to support SSE connections", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitShort) + req := newBaseRequest(ctx) + writer := newServerSentWriter(t) + _, _, err := httpapi.ServerSentEventSender(writer, req) + require.NoError(t, err) + + h := writer.Header() + require.Equal(t, h.Get("Content-Type"), "text/event-stream") + require.Equal(t, h.Get("Cache-Control"), "no-cache") + require.Equal(t, h.Get("Connection"), "keep-alive") + require.Equal(t, h.Get("X-Accel-Buffering"), "no") + }) + + t.Run("Returned callback can publish new event to SSE connection", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitShort) + req := newBaseRequest(ctx) + writer := newServerSentWriter(t) + send, _, err := httpapi.ServerSentEventSender(writer, req) + require.NoError(t, err) + + serverPayload := codersdk.ServerSentEvent{ + Type: codersdk.ServerSentEventTypeData, + Data: "Blah", + } + err = send(serverPayload) + require.NoError(t, err) + + clientBytes, err := io.ReadAll(writer.clientConn) + require.NoError(t, err) + require.Equal( + t, + string(clientBytes), + "event: data\ndata: \"Blah\"\n\n", + ) + }) + + t.Run("Signals to outside consumer when connection has been closed", func(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithCancel(testutil.Context(t, testutil.WaitShort)) + req := newBaseRequest(ctx) + writer := newServerSentWriter(t) + _, done, err := httpapi.ServerSentEventSender(writer, req) + require.NoError(t, err) + + successC := make(chan bool) + ticker := time.NewTicker(testutil.WaitShort) + go func() { + select { + case <-done: + successC <- true + case <-ticker.C: + successC <- false + } + }() + + cancel() + require.True(t, <-successC) + }) + + t.Run("Sends a heartbeat to the client on a fixed internal of time to keep connections alive", func(t *testing.T) { + t.Parallel() + + // Need add at least three heartbeats for something to be reliably + // counted as an interval, but also need some wiggle room + heartbeatCount := 3 + hbDuration := time.Duration(heartbeatCount) * httpapi.HeartbeatInterval + timeout := hbDuration + (5 * time.Second) + + ctx := testutil.Context(t, timeout) + req := newBaseRequest(ctx) + writer := newServerSentWriter(t) + _, _, err := httpapi.ServerSentEventSender(writer, req) + require.NoError(t, err) + + type Result struct { + Err error + Success bool + } + resultC := make(chan Result) + go func() { + err := writer. + clientConn. + SetReadDeadline(time.Now().Add(timeout)) + if err != nil { + resultC <- Result{err, false} + return + } + for range heartbeatCount { + pingBuffer := make([]byte, 1) + pingSize, err := writer.clientConn.Read(pingBuffer) + if err != nil || pingSize != 1 { + resultC <- Result{err, false} + return + } + } + resultC <- Result{nil, true} + }() + + result := <-resultC + require.NoError(t, result.Err) + require.True(t, result.Success) + }) +} diff --git a/coderd/httpapi/websocket.go b/coderd/httpapi/websocket.go index 20c780f6bffa0..3a71c9c9ae8b0 100644 --- a/coderd/httpapi/websocket.go +++ b/coderd/httpapi/websocket.go @@ -11,11 +11,13 @@ import ( "github.com/coder/websocket" ) +const HeartbeatInterval time.Duration = 15 * time.Second + // Heartbeat loops to ping a WebSocket to keep it alive. // Default idle connection timeouts are typically 60 seconds. // See: https://docs.aws.amazon.com/elasticloadbalancing/latest/application/application-load-balancers.html#connection-idle-timeout func Heartbeat(ctx context.Context, conn *websocket.Conn) { - ticker := time.NewTicker(15 * time.Second) + ticker := time.NewTicker(HeartbeatInterval) defer ticker.Stop() for { select { @@ -33,8 +35,7 @@ func Heartbeat(ctx context.Context, conn *websocket.Conn) { // Heartbeat loops to ping a WebSocket to keep it alive. It calls `exit` on ping // failure. func HeartbeatClose(ctx context.Context, logger slog.Logger, exit func(), conn *websocket.Conn) { - interval := 15 * time.Second - ticker := time.NewTicker(interval) + ticker := time.NewTicker(HeartbeatInterval) defer ticker.Stop() for { @@ -43,7 +44,7 @@ func HeartbeatClose(ctx context.Context, logger slog.Logger, exit func(), conn * return case <-ticker.C: } - err := pingWithTimeout(ctx, conn, interval) + err := pingWithTimeout(ctx, conn, HeartbeatInterval) if err != nil { // context.DeadlineExceeded is expected when the client disconnects without sending a close frame if !errors.Is(err, context.DeadlineExceeded) { diff --git a/coderd/workspaceagents.go b/coderd/workspaceagents.go index 975803cb5e1d1..c76d029f43d7c 100644 --- a/coderd/workspaceagents.go +++ b/coderd/workspaceagents.go @@ -1098,7 +1098,29 @@ func convertScripts(dbScripts []database.WorkspaceAgentScript) []codersdk.Worksp // @Param workspaceagent path string true "Workspace agent ID" format(uuid) // @Router /workspaceagents/{workspaceagent}/watch-metadata [get] // @x-apidocgen {"skip": true} -func (api *API) watchWorkspaceAgentMetadata(rw http.ResponseWriter, r *http.Request) { +// @Deprecated Use /workspaceagents/{workspaceagent}/watch-metadata-ws instead +func (api *API) watchWorkspaceAgentMetadataSSE(rw http.ResponseWriter, r *http.Request) { + api.watchWorkspaceAgentMetadata(rw, r, httpapi.ServerSentEventSender) +} + +// @Summary Watch for workspace agent metadata updates via WebSockets +// @ID watch-for-workspace-agent-metadata-updates-via-websockets +// @Security CoderSessionToken +// @Produce json +// @Tags Agents +// @Success 200 {object} codersdk.ServerSentEvent +// @Param workspaceagent path string true "Workspace agent ID" format(uuid) +// @Router /workspaceagents/{workspaceagent}/watch-metadata-ws [get] +// @x-apidocgen {"skip": true} +func (api *API) watchWorkspaceAgentMetadataWS(rw http.ResponseWriter, r *http.Request) { + api.watchWorkspaceAgentMetadata(rw, r, httpapi.OneWayWebSocketEventSender) +} + +func (api *API) watchWorkspaceAgentMetadata( + rw http.ResponseWriter, + r *http.Request, + connect httpapi.EventSender, +) { // Allow us to interrupt watch via cancel. ctx, cancel := context.WithCancel(r.Context()) defer cancel() @@ -1163,7 +1185,7 @@ func (api *API) watchWorkspaceAgentMetadata(rw http.ResponseWriter, r *http.Requ //nolint:ineffassign // Release memory. initialMD = nil - sseSendEvent, sseSenderClosed, err := httpapi.ServerSentEventSender(rw, r) + sendEvent, senderClosed, err := connect(rw, r) if err != nil { httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ Message: "Internal error setting up server-sent events.", @@ -1174,14 +1196,14 @@ func (api *API) watchWorkspaceAgentMetadata(rw http.ResponseWriter, r *http.Requ // Prevent handler from returning until the sender is closed. defer func() { cancel() - <-sseSenderClosed + <-senderClosed }() // Synchronize cancellation from SSE -> context, this lets us simplify the // cancellation logic. go func() { select { case <-ctx.Done(): - case <-sseSenderClosed: + case <-senderClosed: cancel() } }() @@ -1193,7 +1215,7 @@ func (api *API) watchWorkspaceAgentMetadata(rw http.ResponseWriter, r *http.Requ log.Debug(ctx, "sending metadata", "num", len(values)) - _ = sseSendEvent(ctx, codersdk.ServerSentEvent{ + _ = sendEvent(codersdk.ServerSentEvent{ Type: codersdk.ServerSentEventTypeData, Data: convertWorkspaceAgentMetadata(values), }) @@ -1225,7 +1247,7 @@ func (api *API) watchWorkspaceAgentMetadata(rw http.ResponseWriter, r *http.Requ if err != nil { if !database.IsQueryCanceledError(err) { log.Error(ctx, "failed to get metadata", slog.Error(err)) - _ = sseSendEvent(ctx, codersdk.ServerSentEvent{ + _ = sendEvent(codersdk.ServerSentEvent{ Type: codersdk.ServerSentEventTypeError, Data: codersdk.Response{ Message: "Failed to get metadata.", diff --git a/coderd/workspaces.go b/coderd/workspaces.go index 7022938062c64..d57481aa12f90 100644 --- a/coderd/workspaces.go +++ b/coderd/workspaces.go @@ -1719,12 +1719,33 @@ func (api *API) resolveAutostart(rw http.ResponseWriter, r *http.Request) { // @Param workspace path string true "Workspace ID" format(uuid) // @Success 200 {object} codersdk.Response // @Router /workspaces/{workspace}/watch [get] -func (api *API) watchWorkspace(rw http.ResponseWriter, r *http.Request) { +// @Deprecated Use /workspaces/{workspace}/watch-ws instead +func (api *API) watchWorkspaceSSE(rw http.ResponseWriter, r *http.Request) { + api.watchWorkspace(rw, r, httpapi.ServerSentEventSender) +} + +// @Summary Watch workspace by ID via WebSockets +// @ID watch-workspace-by-id-via-websockets +// @Security CoderSessionToken +// @Produce json +// @Tags Workspaces +// @Param workspace path string true "Workspace ID" format(uuid) +// @Success 200 {object} codersdk.ServerSentEvent +// @Router /workspaces/{workspace}/watch-ws [get] +func (api *API) watchWorkspaceWS(rw http.ResponseWriter, r *http.Request) { + api.watchWorkspace(rw, r, httpapi.OneWayWebSocketEventSender) +} + +func (api *API) watchWorkspace( + rw http.ResponseWriter, + r *http.Request, + connect httpapi.EventSender, +) { ctx := r.Context() workspace := httpmw.WorkspaceParam(r) apiKey := httpmw.APIKey(r) - sendEvent, senderClosed, err := httpapi.ServerSentEventSender(rw, r) + sendEvent, senderClosed, err := connect(rw, r) if err != nil { httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ Message: "Internal error setting up server-sent events.", @@ -1740,7 +1761,7 @@ func (api *API) watchWorkspace(rw http.ResponseWriter, r *http.Request) { sendUpdate := func(_ context.Context, _ []byte) { workspace, err := api.Database.GetWorkspaceByID(ctx, workspace.ID) if err != nil { - _ = sendEvent(ctx, codersdk.ServerSentEvent{ + _ = sendEvent(codersdk.ServerSentEvent{ Type: codersdk.ServerSentEventTypeError, Data: codersdk.Response{ Message: "Internal error fetching workspace.", @@ -1752,7 +1773,7 @@ func (api *API) watchWorkspace(rw http.ResponseWriter, r *http.Request) { data, err := api.workspaceData(ctx, []database.Workspace{workspace}) if err != nil { - _ = sendEvent(ctx, codersdk.ServerSentEvent{ + _ = sendEvent(codersdk.ServerSentEvent{ Type: codersdk.ServerSentEventTypeError, Data: codersdk.Response{ Message: "Internal error fetching workspace data.", @@ -1762,7 +1783,7 @@ func (api *API) watchWorkspace(rw http.ResponseWriter, r *http.Request) { return } if len(data.templates) == 0 { - _ = sendEvent(ctx, codersdk.ServerSentEvent{ + _ = sendEvent(codersdk.ServerSentEvent{ Type: codersdk.ServerSentEventTypeError, Data: codersdk.Response{ Message: "Forbidden reading template of selected workspace.", @@ -1779,7 +1800,7 @@ func (api *API) watchWorkspace(rw http.ResponseWriter, r *http.Request) { api.Options.AllowWorkspaceRenames, ) if err != nil { - _ = sendEvent(ctx, codersdk.ServerSentEvent{ + _ = sendEvent(codersdk.ServerSentEvent{ Type: codersdk.ServerSentEventTypeError, Data: codersdk.Response{ Message: "Internal error converting workspace.", @@ -1787,7 +1808,7 @@ func (api *API) watchWorkspace(rw http.ResponseWriter, r *http.Request) { }, }) } - _ = sendEvent(ctx, codersdk.ServerSentEvent{ + _ = sendEvent(codersdk.ServerSentEvent{ Type: codersdk.ServerSentEventTypeData, Data: w, }) @@ -1805,7 +1826,7 @@ func (api *API) watchWorkspace(rw http.ResponseWriter, r *http.Request) { sendUpdate(ctx, nil) })) if err != nil { - _ = sendEvent(ctx, codersdk.ServerSentEvent{ + _ = sendEvent(codersdk.ServerSentEvent{ Type: codersdk.ServerSentEventTypeError, Data: codersdk.Response{ Message: "Internal error subscribing to workspace events.", @@ -1819,7 +1840,7 @@ func (api *API) watchWorkspace(rw http.ResponseWriter, r *http.Request) { // This is required to show whether the workspace is up-to-date. cancelTemplateSubscribe, err := api.Pubsub.Subscribe(watchTemplateChannel(workspace.TemplateID), sendUpdate) if err != nil { - _ = sendEvent(ctx, codersdk.ServerSentEvent{ + _ = sendEvent(codersdk.ServerSentEvent{ Type: codersdk.ServerSentEventTypeError, Data: codersdk.Response{ Message: "Internal error subscribing to template events.", @@ -1832,7 +1853,7 @@ func (api *API) watchWorkspace(rw http.ResponseWriter, r *http.Request) { // An initial ping signals to the request that the server is now ready // and the client can begin servicing a channel with data. - _ = sendEvent(ctx, codersdk.ServerSentEvent{ + _ = sendEvent(codersdk.ServerSentEvent{ Type: codersdk.ServerSentEventTypePing, }) // Send updated workspace info after connection is established. This avoids diff --git a/docs/reference/api/schemas.md b/docs/reference/api/schemas.md index 4fee5c57d5100..652c1274751e9 100644 --- a/docs/reference/api/schemas.md +++ b/docs/reference/api/schemas.md @@ -5735,6 +5735,38 @@ Git clone makes use of this by parsing the URL from: 'Username for "https://gith | `ssh_config_options` | object | false | | | | » `[any property]` | string | false | | | +## codersdk.ServerSentEvent + +```json +{ + "data": null, + "type": "ping" +} +``` + +### Properties + +| Name | Type | Required | Restrictions | Description | +|--------|--------------------------------------------------------------|----------|--------------|-------------| +| `data` | any | false | | | +| `type` | [codersdk.ServerSentEventType](#codersdkserversenteventtype) | false | | | + +## codersdk.ServerSentEventType + +```json +"ping" +``` + +### Properties + +#### Enumerated Values + +| Value | +|---------| +| `ping` | +| `data` | +| `error` | + ## codersdk.SessionCountDeploymentStats ```json diff --git a/docs/reference/api/workspaces.md b/docs/reference/api/workspaces.md index 7264b6dbb3939..18500158567ae 100644 --- a/docs/reference/api/workspaces.md +++ b/docs/reference/api/workspaces.md @@ -1979,3 +1979,41 @@ curl -X GET http://coder-server:8080/api/v2/workspaces/{workspace}/watch \ | 200 | [OK](https://tools.ietf.org/html/rfc7231#section-6.3.1) | OK | [codersdk.Response](schemas.md#codersdkresponse) | To perform this operation, you must be authenticated. [Learn more](authentication.md). + +## Watch workspace by ID via WebSockets + +### Code samples + +```shell +# Example request using curl +curl -X GET http://coder-server:8080/api/v2/workspaces/{workspace}/watch-ws \ + -H 'Accept: application/json' \ + -H 'Coder-Session-Token: API_KEY' +``` + +`GET /workspaces/{workspace}/watch-ws` + +### Parameters + +| Name | In | Type | Required | Description | +|-------------|------|--------------|----------|--------------| +| `workspace` | path | string(uuid) | true | Workspace ID | + +### Example responses + +> 200 Response + +```json +{ + "data": null, + "type": "ping" +} +``` + +### Responses + +| Status | Meaning | Description | Schema | +|--------|---------------------------------------------------------|-------------|----------------------------------------------------------------| +| 200 | [OK](https://tools.ietf.org/html/rfc7231#section-6.3.1) | OK | [codersdk.ServerSentEvent](schemas.md#codersdkserversentevent) | + +To perform this operation, you must be authenticated. [Learn more](authentication.md). diff --git a/site/package.json b/site/package.json index 7f45637237cf7..51ec024ae2fa1 100644 --- a/site/package.json +++ b/site/package.json @@ -166,7 +166,6 @@ "@vitejs/plugin-react": "4.3.4", "autoprefixer": "10.4.20", "chromatic": "11.25.2", - "eventsourcemock": "2.0.0", "express": "4.21.2", "jest": "29.7.0", "jest-canvas-mock": "2.5.2", diff --git a/site/pnpm-lock.yaml b/site/pnpm-lock.yaml index d08ab3c523083..fc5dbb43876f6 100644 --- a/site/pnpm-lock.yaml +++ b/site/pnpm-lock.yaml @@ -403,9 +403,6 @@ importers: chromatic: specifier: 11.25.2 version: 11.25.2 - eventsourcemock: - specifier: 2.0.0 - version: 2.0.0 express: specifier: 4.21.2 version: 4.21.2 @@ -3796,9 +3793,6 @@ packages: eventemitter3@4.0.7: resolution: {integrity: sha512-8guHBZCwKnFhYdHr2ysuRWErTwhoN2X8XELRlrRwpmfeY2jjuUN4taQMsULKUVo1K4DvZl+0pgfyoysHxvmvEw==, tarball: https://registry.npmjs.org/eventemitter3/-/eventemitter3-4.0.7.tgz} - eventsourcemock@2.0.0: - resolution: {integrity: sha512-tSmJnuE+h6A8/hLRg0usf1yL+Q8w01RQtmg0Uzgoxk/HIPZrIUeAr/A4es/8h1wNsoG8RdiESNQLTKiNwbSC3Q==, tarball: https://registry.npmjs.org/eventsourcemock/-/eventsourcemock-2.0.0.tgz} - execa@5.1.1: resolution: {integrity: sha512-8uSpZZocAZRBAPIEINJj3Lo9HyGitllczc27Eh5YYojjMFMn8yHMDMaUHE2Jqfq05D/wucwI4JGURyXt1vchyg==, tarball: https://registry.npmjs.org/execa/-/execa-5.1.1.tgz} engines: {node: '>=10'} @@ -10017,8 +10011,6 @@ snapshots: eventemitter3@4.0.7: {} - eventsourcemock@2.0.0: {} - execa@5.1.1: dependencies: cross-spawn: 7.0.6 diff --git a/site/src/@types/eventsourcemock.d.ts b/site/src/@types/eventsourcemock.d.ts deleted file mode 100644 index 296c4f19c33ce..0000000000000 --- a/site/src/@types/eventsourcemock.d.ts +++ /dev/null @@ -1 +0,0 @@ -declare module "eventsourcemock"; diff --git a/site/src/api/api.ts b/site/src/api/api.ts index 85953bbce736f..3a43772a02657 100644 --- a/site/src/api/api.ts +++ b/site/src/api/api.ts @@ -22,9 +22,10 @@ import globalAxios, { type AxiosInstance, isAxiosError } from "axios"; import type dayjs from "dayjs"; import userAgentParser from "ua-parser-js"; +import { OneWayWebSocket } from "utils/OneWayWebSocket"; import { delay } from "../utils/delay"; -import * as TypesGen from "./typesGenerated"; import type { PostWorkspaceUsageRequest } from "./typesGenerated"; +import * as TypesGen from "./typesGenerated"; const getMissingParameters = ( oldBuildParameters: TypesGen.WorkspaceBuildParameter[], @@ -101,61 +102,40 @@ const getMissingParameters = ( }; /** - * * @param agentId - * @returns An EventSource that emits agent metadata event objects - * (ServerSentEvent) + * @returns {OneWayWebSocket} A OneWayWebSocket that emits Server-Sent Events. */ -export const watchAgentMetadata = (agentId: string): EventSource => { - return new EventSource( - `${location.protocol}//${location.host}/api/v2/workspaceagents/${agentId}/watch-metadata`, - { withCredentials: true }, - ); +export const watchAgentMetadata = ( + agentId: string, +): OneWayWebSocket => { + return new OneWayWebSocket({ + apiRoute: `/api/v2/workspaceagents/${agentId}/watch-metadata-ws`, + }); }; /** - * @returns {EventSource} An EventSource that emits workspace event objects - * (ServerSentEvent) + * @returns {OneWayWebSocket} A OneWayWebSocket that emits Server-Sent Events. */ -export const watchWorkspace = (workspaceId: string): EventSource => { - return new EventSource( - `${location.protocol}//${location.host}/api/v2/workspaces/${workspaceId}/watch`, - { withCredentials: true }, - ); +export const watchWorkspace = ( + workspaceId: string, +): OneWayWebSocket => { + return new OneWayWebSocket({ + apiRoute: `/api/v2/workspaces/${workspaceId}/watch-ws`, + }); }; -type WatchInboxNotificationsParams = { +type WatchInboxNotificationsParams = Readonly<{ read_status?: "read" | "unread" | "all"; -}; +}>; -export const watchInboxNotifications = ( - onNewNotification: (res: TypesGen.GetInboxNotificationResponse) => void, +export function watchInboxNotifications( params?: WatchInboxNotificationsParams, -) => { - const searchParams = new URLSearchParams(params); - const socket = createWebSocket( - "/api/v2/notifications/inbox/watch", - searchParams, - ); - - socket.addEventListener("message", (event) => { - try { - const res = JSON.parse( - event.data, - ) as TypesGen.GetInboxNotificationResponse; - onNewNotification(res); - } catch (error) { - console.warn("Error parsing inbox notification: ", error); - } - }); - - socket.addEventListener("error", (event) => { - console.warn("Watch inbox notifications error: ", event); - socket.close(); +): OneWayWebSocket { + return new OneWayWebSocket({ + apiRoute: "/api/v2/notifications/inbox/watch", + searchParams: params, }); - - return socket; -}; +} export const getURLWithSearchParams = ( basePath: string, @@ -1125,7 +1105,7 @@ class ApiMethods { }; getWorkspaceByOwnerAndName = async ( - username = "me", + username: string, workspaceName: string, params?: TypesGen.WorkspaceOptions, ): Promise => { @@ -1138,7 +1118,7 @@ class ApiMethods { }; getWorkspaceBuildByNumber = async ( - username = "me", + username: string, workspaceName: string, buildNumber: number, ): Promise => { @@ -1324,7 +1304,7 @@ class ApiMethods { }; createWorkspace = async ( - userId = "me", + userId: string, workspace: TypesGen.CreateWorkspaceRequest, ): Promise => { const response = await this.axios.post( @@ -2542,7 +2522,7 @@ function createWebSocket( ) { const protocol = location.protocol === "https:" ? "wss:" : "ws:"; const socket = new WebSocket( - `${protocol}//${location.host}${path}?${params.toString()}`, + `${protocol}//${location.host}${path}?${params}`, ); socket.binaryType = "blob"; return socket; diff --git a/site/src/modules/notifications/NotificationsInbox/NotificationsInbox.tsx b/site/src/modules/notifications/NotificationsInbox/NotificationsInbox.tsx index 656d87fbe31d3..cdbf0941b7fdb 100644 --- a/site/src/modules/notifications/NotificationsInbox/NotificationsInbox.tsx +++ b/site/src/modules/notifications/NotificationsInbox/NotificationsInbox.tsx @@ -61,21 +61,31 @@ export const NotificationsInbox: FC = ({ ); useEffect(() => { - const socket = watchInboxNotifications( - (res) => { - updateNotificationsCache((prev) => { - return { - unread_count: res.unread_count, - notifications: [res.notification, ...prev.notifications], - }; - }); - }, - { read_status: "unread" }, - ); + const socket = watchInboxNotifications({ read_status: "unread" }); - return () => { + socket.addEventListener("message", (e) => { + if (e.parseError) { + console.warn("Error parsing inbox notification: ", e.parseError); + return; + } + + const msg = e.parsedMessage; + updateNotificationsCache((current) => { + return { + unread_count: msg.unread_count, + notifications: [msg.notification, ...current.notifications], + }; + }); + }); + + socket.addEventListener("error", () => { + displayError( + "Unable to retrieve latest inbox notifications. Please try refreshing the browser.", + ); socket.close(); - }; + }); + + return () => socket.close(); }, [updateNotificationsCache]); const { diff --git a/site/src/modules/resources/AgentMetadata.tsx b/site/src/modules/resources/AgentMetadata.tsx index 81b5a14994e81..5e5501809ee49 100644 --- a/site/src/modules/resources/AgentMetadata.tsx +++ b/site/src/modules/resources/AgentMetadata.tsx @@ -3,9 +3,11 @@ import Skeleton from "@mui/material/Skeleton"; import Tooltip from "@mui/material/Tooltip"; import { watchAgentMetadata } from "api/api"; import type { + ServerSentEvent, WorkspaceAgent, WorkspaceAgentMetadata, } from "api/typesGenerated"; +import { displayError } from "components/GlobalSnackbar/utils"; import { Stack } from "components/Stack/Stack"; import dayjs from "dayjs"; import { @@ -17,6 +19,7 @@ import { useState, } from "react"; import { MONOSPACE_FONT_FAMILY } from "theme/constants"; +import type { OneWayWebSocket } from "utils/OneWayWebSocket"; type ItemStatus = "stale" | "valid" | "loading"; @@ -42,50 +45,82 @@ interface AgentMetadataProps { storybookMetadata?: WorkspaceAgentMetadata[]; } +const maxSocketErrorRetryCount = 3; + export const AgentMetadata: FC = ({ agent, storybookMetadata, }) => { - const [metadata, setMetadata] = useState< - WorkspaceAgentMetadata[] | undefined - >(undefined); - + const [activeMetadata, setActiveMetadata] = useState(storybookMetadata); useEffect(() => { + // This is an unfortunate pitfall with this component's testing setup, + // but even though we use the value of storybookMetadata as the initial + // value of the activeMetadata, we cannot put activeMetadata itself into + // the dependency array. If we did, we would destroy and rebuild each + // connection every single time a new message comes in from the socket, + // because the socket has to be wired up to the state setter if (storybookMetadata !== undefined) { - setMetadata(storybookMetadata); return; } - let timeout: ReturnType | undefined = undefined; - - const connect = (): (() => void) => { - const source = watchAgentMetadata(agent.id); + let timeoutId: number | undefined = undefined; + let activeSocket: OneWayWebSocket | null = null; + let retries = 0; + + const createNewConnection = () => { + const socket = watchAgentMetadata(agent.id); + activeSocket = socket; + + socket.addEventListener("error", () => { + setActiveMetadata(undefined); + window.clearTimeout(timeoutId); + + // The error event is supposed to fire when an error happens + // with the connection itself, which implies that the connection + // would auto-close. Couldn't find a definitive answer on MDN, + // though, so closing it manually just to be safe + socket.close(); + activeSocket = null; + + retries++; + if (retries >= maxSocketErrorRetryCount) { + displayError( + "Unexpected disconnect while watching Metadata changes. Please try refreshing the page.", + ); + return; + } - source.onerror = (e) => { - console.error("received error in watch stream", e); - setMetadata(undefined); - source.close(); + displayError( + "Unexpected disconnect while watching Metadata changes. Creating new connection...", + ); + timeoutId = window.setTimeout(() => { + createNewConnection(); + }, 3_000); + }); - timeout = setTimeout(() => { - connect(); - }, 3000); - }; + socket.addEventListener("message", (e) => { + if (e.parseError) { + displayError( + "Unable to process newest response from server. Please try refreshing the page.", + ); + return; + } - source.addEventListener("data", (e) => { - const data = JSON.parse(e.data); - setMetadata(data); - }); - return () => { - if (timeout !== undefined) { - clearTimeout(timeout); + const msg = e.parsedMessage; + if (msg.type === "data") { + setActiveMetadata(msg.data as WorkspaceAgentMetadata[]); } - source.close(); - }; + }); + }; + + createNewConnection(); + return () => { + window.clearTimeout(timeoutId); + activeSocket?.close(); }; - return connect(); }, [agent.id, storybookMetadata]); - if (metadata === undefined) { + if (activeMetadata === undefined) { return (
@@ -93,7 +128,7 @@ export const AgentMetadata: FC = ({ ); } - return ; + return ; }; export const AgentMetadataSkeleton: FC = () => { diff --git a/site/src/modules/templates/useWatchVersionLogs.ts b/site/src/modules/templates/useWatchVersionLogs.ts index 5574e083a9849..1e77b0eb1b073 100644 --- a/site/src/modules/templates/useWatchVersionLogs.ts +++ b/site/src/modules/templates/useWatchVersionLogs.ts @@ -1,46 +1,38 @@ import { watchBuildLogsByTemplateVersionId } from "api/api"; import type { ProvisionerJobLog, TemplateVersion } from "api/typesGenerated"; +import { useEffectEvent } from "hooks/hookPolyfills"; import { useEffect, useState } from "react"; export const useWatchVersionLogs = ( templateVersion: TemplateVersion | undefined, options?: { onDone: () => Promise }, ) => { - const [logs, setLogs] = useState(); + const [logs, setLogs] = useState(); const templateVersionId = templateVersion?.id; - const templateVersionStatus = templateVersion?.job.status; + const [cachedVersionId, setCachedVersionId] = useState(templateVersionId); + if (cachedVersionId !== templateVersionId) { + setCachedVersionId(templateVersionId); + setLogs([]); + } - // biome-ignore lint/correctness/useExhaustiveDependencies: consider refactoring + const stableOnDone = useEffectEvent(() => options?.onDone()); + const status = templateVersion?.job.status; + const canWatch = status === "running" || status === "pending"; useEffect(() => { - setLogs(undefined); - }, [templateVersionId]); - - useEffect(() => { - if (!templateVersionId || !templateVersionStatus) { - return; - } - - if ( - templateVersionStatus !== "running" && - templateVersionStatus !== "pending" - ) { + if (!templateVersionId || !canWatch) { return; } const socket = watchBuildLogsByTemplateVersionId(templateVersionId, { - onMessage: (log) => { - setLogs((logs) => (logs ? [...logs, log] : [log])); - }, - onDone: options?.onDone, - onError: (error) => { - console.error(error); + onError: (error) => console.error(error), + onDone: stableOnDone, + onMessage: (newLog) => { + setLogs((current) => [...(current ?? []), newLog]); }, }); - return () => { - socket.close(); - }; - }, [options?.onDone, templateVersionId, templateVersionStatus]); + return () => socket.close(); + }, [stableOnDone, canWatch, templateVersionId]); return logs; }; diff --git a/site/src/pages/WorkspacePage/WorkspacePage.test.tsx b/site/src/pages/WorkspacePage/WorkspacePage.test.tsx index 50f47a4721320..d120ad5546c17 100644 --- a/site/src/pages/WorkspacePage/WorkspacePage.test.tsx +++ b/site/src/pages/WorkspacePage/WorkspacePage.test.tsx @@ -2,7 +2,7 @@ import { screen, waitFor, within } from "@testing-library/react"; import userEvent from "@testing-library/user-event"; import * as apiModule from "api/api"; import type { TemplateVersionParameter, Workspace } from "api/typesGenerated"; -import EventSourceMock from "eventsourcemock"; +import MockServerSocket from "jest-websocket-mock"; import { DashboardContext, type DashboardProvider, @@ -84,23 +84,11 @@ const testButton = async ( const user = userEvent.setup(); await user.click(button); - expect(actionMock).toBeCalled(); + expect(actionMock).toHaveBeenCalled(); }; -let originalEventSource: typeof window.EventSource; - -beforeAll(() => { - originalEventSource = window.EventSource; - // mocking out EventSource for SSE - window.EventSource = EventSourceMock; -}); - -beforeEach(() => { - jest.resetAllMocks(); -}); - -afterAll(() => { - window.EventSource = originalEventSource; +afterEach(() => { + MockServerSocket.clean(); }); describe("WorkspacePage", () => { diff --git a/site/src/pages/WorkspacePage/WorkspacePage.tsx b/site/src/pages/WorkspacePage/WorkspacePage.tsx index cd2b5f48cb6d3..a55971abfb576 100644 --- a/site/src/pages/WorkspacePage/WorkspacePage.tsx +++ b/site/src/pages/WorkspacePage/WorkspacePage.tsx @@ -5,6 +5,7 @@ import { workspaceBuildsKey } from "api/queries/workspaceBuilds"; import { workspaceByOwnerAndName } from "api/queries/workspaces"; import type { Workspace } from "api/typesGenerated"; import { ErrorAlert } from "components/Alert/ErrorAlert"; +import { displayError } from "components/GlobalSnackbar/utils"; import { Loader } from "components/Loader/Loader"; import { Margins } from "components/Margins/Margins"; import { useEffectEvent } from "hooks/hookPolyfills"; @@ -82,20 +83,26 @@ export const WorkspacePage: FC = () => { return; } - const eventSource = watchWorkspace(workspaceId); + const socket = watchWorkspace(workspaceId); + socket.addEventListener("message", (event) => { + if (event.parseError) { + displayError( + "Unable to process latest data from the server. Please try refreshing the page.", + ); + return; + } - eventSource.addEventListener("data", async (event) => { - const newWorkspaceData = JSON.parse(event.data) as Workspace; - await updateWorkspaceData(newWorkspaceData); + if (event.parsedMessage.type === "data") { + updateWorkspaceData(event.parsedMessage.data as Workspace); + } }); - - eventSource.addEventListener("error", (event) => { - console.error("Error on getting workspace changes.", event); + socket.addEventListener("error", () => { + displayError( + "Unable to get workspace changes. Connection has been closed.", + ); }); - return () => { - eventSource.close(); - }; + return () => socket.close(); }, [updateWorkspaceData, workspaceId]); // Page statuses diff --git a/site/src/utils/OneWayWebSocket.test.ts b/site/src/utils/OneWayWebSocket.test.ts new file mode 100644 index 0000000000000..c6b00b593111f --- /dev/null +++ b/site/src/utils/OneWayWebSocket.test.ts @@ -0,0 +1,492 @@ +/** + * @file Sets up unit tests for OneWayWebSocket. + * + * 2025-03-18 - Really wanted to define these as integration tests with MSW, but + * getting it set up correctly for Jest and JSDOM got a little screwy. That can + * be revisited in the future, but in the meantime, we're assuming that the base + * WebSocket class doesn't have any bugs, and can safely be mocked out. + */ + +import { + type OneWayMessageEvent, + OneWayWebSocket, + type WebSocketEventType, +} from "./OneWayWebSocket"; + +type MockPublisher = Readonly<{ + publishMessage: (event: MessageEvent) => void; + publishError: (event: ErrorEvent) => void; + publishClose: (event: CloseEvent) => void; + publishOpen: (event: Event) => void; +}>; + +function createMockWebSocket( + url: string, + protocols?: string | string[], +): readonly [WebSocket, MockPublisher] { + type EventMap = { + message: MessageEvent; + error: ErrorEvent; + close: CloseEvent; + open: Event; + }; + type CallbackStore = { + [K in keyof EventMap]: ((event: EventMap[K]) => void)[]; + }; + + let activeProtocol: string; + if (Array.isArray(protocols)) { + activeProtocol = protocols[0] ?? ""; + } else if (typeof protocols === "string") { + activeProtocol = protocols; + } else { + activeProtocol = ""; + } + + let closed = false; + const store: CallbackStore = { + message: [], + error: [], + close: [], + open: [], + }; + + const mockSocket: WebSocket = { + CONNECTING: 0, + OPEN: 1, + CLOSING: 2, + CLOSED: 3, + + url, + protocol: activeProtocol, + readyState: 1, + binaryType: "blob", + bufferedAmount: 0, + extensions: "", + onclose: null, + onerror: null, + onmessage: null, + onopen: null, + send: jest.fn(), + dispatchEvent: jest.fn(), + + addEventListener: ( + eventType: E, + callback: WebSocketEventMap[E], + ) => { + if (closed) { + return; + } + + const subscribers = store[eventType]; + const cb = callback as unknown as CallbackStore[E][0]; + if (!subscribers.includes(cb)) { + subscribers.push(cb); + } + }, + + removeEventListener: ( + eventType: E, + callback: WebSocketEventMap[E], + ) => { + if (closed) { + return; + } + + const subscribers = store[eventType]; + const cb = callback as unknown as CallbackStore[E][0]; + if (subscribers.includes(cb)) { + const updated = store[eventType].filter((c) => c !== cb); + store[eventType] = updated as unknown as CallbackStore[E]; + } + }, + + close: () => { + closed = true; + }, + }; + + const publisher: MockPublisher = { + publishOpen: (event) => { + if (closed) { + return; + } + for (const sub of store.open) { + sub(event); + } + }, + + publishError: (event) => { + if (closed) { + return; + } + for (const sub of store.error) { + sub(event); + } + }, + + publishMessage: (event) => { + if (closed) { + return; + } + for (const sub of store.message) { + sub(event); + } + }, + + publishClose: (event) => { + if (closed) { + return; + } + for (const sub of store.close) { + sub(event); + } + }, + }; + + return [mockSocket, publisher] as const; +} + +describe(OneWayWebSocket.name, () => { + const dummyRoute = "/api/v2/blah"; + + it("Errors out if API route does not start with '/api/v2/'", () => { + const testRoutes: string[] = ["blah", "", "/", "/api", "/api/v225"]; + + for (const r of testRoutes) { + expect(() => { + new OneWayWebSocket({ + apiRoute: r, + websocketInit: (url, protocols) => { + const [socket] = createMockWebSocket(url, protocols); + return socket; + }, + }); + }).toThrow(Error); + } + }); + + it("Lets a consumer add an event listener of each type", () => { + let publisher!: MockPublisher; + const oneWay = new OneWayWebSocket({ + apiRoute: dummyRoute, + websocketInit: (url, protocols) => { + const [socket, pub] = createMockWebSocket(url, protocols); + publisher = pub; + return socket; + }, + }); + + const onOpen = jest.fn(); + const onClose = jest.fn(); + const onError = jest.fn(); + const onMessage = jest.fn(); + + oneWay.addEventListener("open", onOpen); + oneWay.addEventListener("close", onClose); + oneWay.addEventListener("error", onError); + oneWay.addEventListener("message", onMessage); + + publisher.publishOpen(new Event("open")); + publisher.publishClose(new CloseEvent("close")); + publisher.publishError( + new ErrorEvent("error", { + error: new Error("Whoops - connection broke"), + }), + ); + publisher.publishMessage( + new MessageEvent("message", { + data: "null", + }), + ); + + expect(onOpen).toHaveBeenCalledTimes(1); + expect(onClose).toHaveBeenCalledTimes(1); + expect(onError).toHaveBeenCalledTimes(1); + expect(onMessage).toHaveBeenCalledTimes(1); + }); + + it("Lets a consumer remove an event listener of each type", () => { + let publisher!: MockPublisher; + const oneWay = new OneWayWebSocket({ + apiRoute: dummyRoute, + websocketInit: (url, protocols) => { + const [socket, pub] = createMockWebSocket(url, protocols); + publisher = pub; + return socket; + }, + }); + + const onOpen = jest.fn(); + const onClose = jest.fn(); + const onError = jest.fn(); + const onMessage = jest.fn(); + + oneWay.addEventListener("open", onOpen); + oneWay.addEventListener("close", onClose); + oneWay.addEventListener("error", onError); + oneWay.addEventListener("message", onMessage); + + oneWay.removeEventListener("open", onOpen); + oneWay.removeEventListener("close", onClose); + oneWay.removeEventListener("error", onError); + oneWay.removeEventListener("message", onMessage); + + publisher.publishOpen(new Event("open")); + publisher.publishClose(new CloseEvent("close")); + publisher.publishError( + new ErrorEvent("error", { + error: new Error("Whoops - connection broke"), + }), + ); + publisher.publishMessage( + new MessageEvent("message", { + data: "null", + }), + ); + + expect(onOpen).toHaveBeenCalledTimes(0); + expect(onClose).toHaveBeenCalledTimes(0); + expect(onError).toHaveBeenCalledTimes(0); + expect(onMessage).toHaveBeenCalledTimes(0); + }); + + it("Only calls each callback once if callback is added multiple times", () => { + let publisher!: MockPublisher; + const oneWay = new OneWayWebSocket({ + apiRoute: dummyRoute, + websocketInit: (url, protocols) => { + const [socket, pub] = createMockWebSocket(url, protocols); + publisher = pub; + return socket; + }, + }); + + const onOpen = jest.fn(); + const onClose = jest.fn(); + const onError = jest.fn(); + const onMessage = jest.fn(); + + for (let i = 0; i < 10; i++) { + oneWay.addEventListener("open", onOpen); + oneWay.addEventListener("close", onClose); + oneWay.addEventListener("error", onError); + oneWay.addEventListener("message", onMessage); + } + + publisher.publishOpen(new Event("open")); + publisher.publishClose(new CloseEvent("close")); + publisher.publishError( + new ErrorEvent("error", { + error: new Error("Whoops - connection broke"), + }), + ); + publisher.publishMessage( + new MessageEvent("message", { + data: "null", + }), + ); + + expect(onOpen).toHaveBeenCalledTimes(1); + expect(onClose).toHaveBeenCalledTimes(1); + expect(onError).toHaveBeenCalledTimes(1); + expect(onMessage).toHaveBeenCalledTimes(1); + }); + + it("Lets consumers register multiple callbacks for each event type", () => { + let publisher!: MockPublisher; + const oneWay = new OneWayWebSocket({ + apiRoute: dummyRoute, + websocketInit: (url, protocols) => { + const [socket, pub] = createMockWebSocket(url, protocols); + publisher = pub; + return socket; + }, + }); + + const onOpen1 = jest.fn(); + const onClose1 = jest.fn(); + const onError1 = jest.fn(); + const onMessage1 = jest.fn(); + oneWay.addEventListener("open", onOpen1); + oneWay.addEventListener("close", onClose1); + oneWay.addEventListener("error", onError1); + oneWay.addEventListener("message", onMessage1); + + const onOpen2 = jest.fn(); + const onClose2 = jest.fn(); + const onError2 = jest.fn(); + const onMessage2 = jest.fn(); + oneWay.addEventListener("open", onOpen2); + oneWay.addEventListener("close", onClose2); + oneWay.addEventListener("error", onError2); + oneWay.addEventListener("message", onMessage2); + + publisher.publishOpen(new Event("open")); + publisher.publishClose(new CloseEvent("close")); + publisher.publishError( + new ErrorEvent("error", { + error: new Error("Whoops - connection broke"), + }), + ); + publisher.publishMessage( + new MessageEvent("message", { + data: "null", + }), + ); + + expect(onOpen1).toHaveBeenCalledTimes(1); + expect(onClose1).toHaveBeenCalledTimes(1); + expect(onError1).toHaveBeenCalledTimes(1); + expect(onMessage1).toHaveBeenCalledTimes(1); + + expect(onOpen2).toHaveBeenCalledTimes(1); + expect(onClose2).toHaveBeenCalledTimes(1); + expect(onError2).toHaveBeenCalledTimes(1); + expect(onMessage2).toHaveBeenCalledTimes(1); + }); + + it("Computes the socket protocol based on the browser location protocol", () => { + const oneWay1 = new OneWayWebSocket({ + apiRoute: dummyRoute, + websocketInit: (url, protocols) => { + const [socket] = createMockWebSocket(url, protocols); + return socket; + }, + location: { + protocol: "https:", + host: "www.cool.com", + }, + }); + const oneWay2 = new OneWayWebSocket({ + apiRoute: dummyRoute, + websocketInit: (url, protocols) => { + const [socket] = createMockWebSocket(url, protocols); + return socket; + }, + location: { + protocol: "http:", + host: "www.cool.com", + }, + }); + + expect(oneWay1.url).toMatch(/^wss:\/\//); + expect(oneWay2.url).toMatch(/^ws:\/\//); + }); + + it("Gives consumers pre-parsed versions of message events", () => { + let publisher!: MockPublisher; + const oneWay = new OneWayWebSocket({ + apiRoute: dummyRoute, + websocketInit: (url, protocols) => { + const [socket, pub] = createMockWebSocket(url, protocols); + publisher = pub; + return socket; + }, + }); + + const onMessage = jest.fn(); + oneWay.addEventListener("message", onMessage); + + const payload = { + value: 5, + cool: "yes", + }; + const event = new MessageEvent("message", { + data: JSON.stringify(payload), + }); + + publisher.publishMessage(event); + expect(onMessage).toHaveBeenCalledWith({ + sourceEvent: event, + parsedMessage: payload, + parseError: undefined, + }); + }); + + it("Exposes parsing error if message payload could not be parsed as JSON", () => { + let publisher!: MockPublisher; + const oneWay = new OneWayWebSocket({ + apiRoute: dummyRoute, + websocketInit: (url, protocols) => { + const [socket, pub] = createMockWebSocket(url, protocols); + publisher = pub; + return socket; + }, + }); + + const onMessage = jest.fn(); + oneWay.addEventListener("message", onMessage); + + const payload = "definitely not valid JSON"; + const event = new MessageEvent("message", { + data: payload, + }); + publisher.publishMessage(event); + + const arg: OneWayMessageEvent = onMessage.mock.lastCall[0]; + expect(arg.sourceEvent).toEqual(event); + expect(arg.parsedMessage).toEqual(undefined); + expect(arg.parseError).toBeInstanceOf(Error); + }); + + it("Passes all search param values through Websocket URL", () => { + const input1: Record = { + cool: "yeah", + yeah: "cool", + blah: "5", + }; + const oneWay1 = new OneWayWebSocket({ + apiRoute: dummyRoute, + websocketInit: (url, protocols) => { + const [socket] = createMockWebSocket(url, protocols); + return socket; + }, + searchParams: input1, + location: { + protocol: "https:", + host: "www.blah.com", + }, + }); + let [base, params] = oneWay1.url.split("?"); + expect(base).toBe("wss://www.blah.com/api/v2/blah"); + for (const [key, value] of Object.entries(input1)) { + expect(params).toContain(`${key}=${value}`); + } + + const input2 = new URLSearchParams(input1); + const oneWay2 = new OneWayWebSocket({ + apiRoute: dummyRoute, + websocketInit: (url, protocols) => { + const [socket] = createMockWebSocket(url, protocols); + return socket; + }, + searchParams: input2, + location: { + protocol: "https:", + host: "www.blah.com", + }, + }); + [base, params] = oneWay2.url.split("?"); + expect(base).toBe("wss://www.blah.com/api/v2/blah"); + for (const [key, value] of Object.entries(input2)) { + expect(params).toContain(`${key}=${value}`); + } + + const oneWay3 = new OneWayWebSocket({ + apiRoute: dummyRoute, + websocketInit: (url, protocols) => { + const [socket] = createMockWebSocket(url, protocols); + return socket; + }, + searchParams: undefined, + location: { + protocol: "https:", + host: "www.blah.com", + }, + }); + [base, params] = oneWay3.url.split("?"); + expect(base).toBe("wss://www.blah.com/api/v2/blah"); + expect(params).toBe(undefined); + }); +}); diff --git a/site/src/utils/OneWayWebSocket.ts b/site/src/utils/OneWayWebSocket.ts new file mode 100644 index 0000000000000..94ed1f1efc868 --- /dev/null +++ b/site/src/utils/OneWayWebSocket.ts @@ -0,0 +1,198 @@ +/** + * @file A wrapper over WebSockets that (1) enforces one-way communication, and + * (2) supports automatically parsing JSON messages as they come in. + * + * This should ALWAYS be favored in favor of using Server-Sent Events and the + * built-in EventSource class for doing one-way communication. SSEs have a hard + * limitation on HTTP/1.1 and below where there is a maximum number of 6 ports + * that can ever be used for a domain (sometimes less depending on the browser). + * Not only is this limit shared with short-lived REST requests, but it also + * applies across tabs and windows. So if a user opens Coder in multiple tabs, + * there is a very real possibility that parts of the app will start to lock up + * without it being clear why. + * + * WebSockets do not have this limitation, even on HTTP/1.1 – all modern + * browsers implement at least some degree of multiplexing for them. + */ + +// Not bothering with trying to borrow methods from the base WebSocket type +// because it's already a mess of inheritance and generics, and we're going to +// have to add a few more +export type WebSocketEventType = "close" | "error" | "message" | "open"; + +export type OneWayMessageEvent = Readonly< + | { + sourceEvent: MessageEvent; + parsedMessage: TData; + parseError: undefined; + } + | { + sourceEvent: MessageEvent; + parsedMessage: undefined; + parseError: Error; + } +>; + +type OneWayEventPayloadMap = { + close: CloseEvent; + error: Event; + message: OneWayMessageEvent; + open: Event; +}; + +type WebSocketMessageCallback = (payload: MessageEvent) => void; + +type OneWayEventCallback = ( + payload: OneWayEventPayloadMap[TEvent], +) => void; + +interface OneWayWebSocketApi { + get url(): string; + + addEventListener: ( + eventType: TEvent, + callback: OneWayEventCallback, + ) => void; + + removeEventListener: ( + eventType: TEvent, + callback: OneWayEventCallback, + ) => void; + + close: (closeCode?: number, reason?: string) => void; +} + +type OneWayWebSocketInit = Readonly<{ + apiRoute: string; + serverProtocols?: string | string[]; + searchParams?: Record | URLSearchParams; + binaryType?: BinaryType; + websocketInit?: (url: string, protocols?: string | string[]) => WebSocket; + location?: Readonly<{ + protocol: string; + host: string; + }>; +}>; + +function defaultInit(url: string, protocols?: string | string[]): WebSocket { + return new WebSocket(url, protocols); +} + +export class OneWayWebSocket + implements OneWayWebSocketApi +{ + readonly #socket: WebSocket; + readonly #messageCallbackWrappers = new Map< + OneWayEventCallback, + WebSocketMessageCallback + >(); + + constructor(init: OneWayWebSocketInit) { + const { + apiRoute, + searchParams, + serverProtocols, + binaryType = "blob", + location = window.location, + websocketInit = defaultInit, + } = init; + + if (!apiRoute.startsWith("/api/v2/")) { + throw new Error(`API route '${apiRoute}' does not begin with a slash`); + } + + const formattedParams = + searchParams instanceof URLSearchParams + ? searchParams + : new URLSearchParams(searchParams); + const paramsString = formattedParams.toString(); + const paramsSuffix = paramsString ? `?${paramsString}` : ""; + const wsProtocol = location.protocol === "https:" ? "wss:" : "ws:"; + const url = `${wsProtocol}//${location.host}${apiRoute}${paramsSuffix}`; + + this.#socket = websocketInit(url, serverProtocols); + this.#socket.binaryType = binaryType; + } + + get url(): string { + return this.#socket.url; + } + + addEventListener( + event: TEvent, + callback: OneWayEventCallback, + ): void { + // Not happy about all the type assertions, but there are some nasty + // type contravariance issues if you try to resolve the function types + // properly. This is actually the lesser of two evils + const looseCallback = callback as OneWayEventCallback< + TData, + WebSocketEventType + >; + + if (this.#messageCallbackWrappers.has(looseCallback)) { + return; + } + if (event !== "message") { + this.#socket.addEventListener(event, looseCallback); + return; + } + + const wrapped = (event: MessageEvent): void => { + const messageCallback = looseCallback as OneWayEventCallback< + TData, + "message" + >; + + try { + const message = JSON.parse(event.data) as TData; + messageCallback({ + sourceEvent: event, + parseError: undefined, + parsedMessage: message, + }); + } catch (err) { + messageCallback({ + sourceEvent: event, + parseError: err as Error, + parsedMessage: undefined, + }); + } + }; + + this.#socket.addEventListener(event as "message", wrapped); + this.#messageCallbackWrappers.set(looseCallback, wrapped); + } + + removeEventListener( + event: TEvent, + callback: OneWayEventCallback, + ): void { + const looseCallback = callback as OneWayEventCallback< + TData, + WebSocketEventType + >; + + if (event !== "message") { + this.#socket.removeEventListener(event, looseCallback); + return; + } + if (!this.#messageCallbackWrappers.has(looseCallback)) { + return; + } + + const wrapper = this.#messageCallbackWrappers.get(looseCallback); + if (wrapper === undefined) { + throw new Error( + `Cannot unregister callback for event ${event}. This is likely an issue with the browser itself.`, + ); + } + + this.#socket.removeEventListener(event as "message", wrapper); + this.#messageCallbackWrappers.delete(looseCallback); + } + + close(closeCode?: number, reason?: string): void { + this.#socket.close(closeCode, reason); + } +} pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy