From 0d455cd0e18f1204655251c3e5b7f6306306b4d9 Mon Sep 17 00:00:00 2001 From: Kyle Carberry Date: Sun, 20 Feb 2022 19:50:57 +0000 Subject: [PATCH] fix: Use sync.WaitGroup to await hijacked HTTP connections WebSockets hijack the HTTP connection from the server, causing server.Close() to not wait for these connections to fully cleanup. This adds a global wait-group to the coderd API, which ensures all WebSocket HTTP handlers have properly exited before returning. --- coderd/cmd/root.go | 7 ++++--- coderd/coderd.go | 18 ++++++++++-------- coderd/coderdtest/coderdtest.go | 7 +++++-- coderd/provisionerdaemons.go | 4 ++++ 4 files changed, 23 insertions(+), 13 deletions(-) diff --git a/coderd/cmd/root.go b/coderd/cmd/root.go index 9087ac435b715..2778cb320a17e 100644 --- a/coderd/cmd/root.go +++ b/coderd/cmd/root.go @@ -33,7 +33,7 @@ func Root() *cobra.Command { Use: "coderd", RunE: func(cmd *cobra.Command, args []string) error { logger := slog.Make(sloghuman.Sink(os.Stderr)) - handler := coderd.New(&coderd.Options{ + handler, closeCoderd := coderd.New(&coderd.Options{ Logger: logger, Database: databasefake.New(), Pubsub: database.NewPubsubInMemory(), @@ -49,11 +49,11 @@ func Root() *cobra.Command { Scheme: "http", Host: address, }) - closer, err := newProvisionerDaemon(cmd.Context(), client, logger) + daemonClose, err := newProvisionerDaemon(cmd.Context(), client, logger) if err != nil { return xerrors.Errorf("create provisioner daemon: %w", err) } - defer closer.Close() + defer daemonClose.Close() errCh := make(chan error) go func() { @@ -61,6 +61,7 @@ func Root() *cobra.Command { errCh <- http.Serve(listener, handler) }() + closeCoderd() select { case <-cmd.Context().Done(): return cmd.Context().Err() diff --git a/coderd/coderd.go b/coderd/coderd.go index 1213b04aa0a86..765e4b0f1951c 100644 --- a/coderd/coderd.go +++ b/coderd/coderd.go @@ -2,6 +2,7 @@ package coderd import ( "net/http" + "sync" "github.com/go-chi/chi/v5" @@ -20,11 +21,12 @@ type Options struct { } // New constructs the Coder API into an HTTP handler. -func New(options *Options) http.Handler { +// +// A wait function is returned to handle awaiting closure +// of hijacked HTTP requests. +func New(options *Options) (http.Handler, func()) { api := &api{ - Database: options.Database, - Logger: options.Logger, - Pubsub: options.Pubsub, + Options: options, } r := chi.NewRouter() @@ -144,13 +146,13 @@ func New(options *Options) http.Handler { }) }) r.NotFound(site.Handler(options.Logger).ServeHTTP) - return r + return r, api.websocketWaitGroup.Wait } // API contains all route handlers. Only HTTP handlers should // be added to this struct for code clarity. type api struct { - Database database.Store - Logger slog.Logger - Pubsub database.Pubsub + *Options + + websocketWaitGroup sync.WaitGroup } diff --git a/coderd/coderdtest/coderdtest.go b/coderd/coderdtest/coderdtest.go index 40eba2fb53942..ff05abb63a78a 100644 --- a/coderd/coderdtest/coderdtest.go +++ b/coderd/coderdtest/coderdtest.go @@ -55,7 +55,7 @@ func New(t *testing.T) *codersdk.Client { }) } - handler := coderd.New(&coderd.Options{ + handler, closeWait := coderd.New(&coderd.Options{ Logger: slogtest.Make(t, nil).Leveled(slog.LevelDebug), Database: db, Pubsub: pubsub, @@ -69,7 +69,10 @@ func New(t *testing.T) *codersdk.Client { srv.Start() serverURL, err := url.Parse(srv.URL) require.NoError(t, err) - t.Cleanup(srv.Close) + t.Cleanup(func() { + srv.Close() + closeWait() + }) return codersdk.New(serverURL) } diff --git a/coderd/provisionerdaemons.go b/coderd/provisionerdaemons.go index a75730dc0a876..d95e6e62c0f9d 100644 --- a/coderd/provisionerdaemons.go +++ b/coderd/provisionerdaemons.go @@ -62,6 +62,8 @@ func (api *api) provisionerDaemonsServe(rw http.ResponseWriter, r *http.Request) }) return } + api.websocketWaitGroup.Add(1) + defer api.websocketWaitGroup.Done() daemon, err := api.Database.InsertProvisionerDaemon(r.Context(), database.InsertProvisionerDaemonParams{ ID: uuid.New(), @@ -100,7 +102,9 @@ func (api *api) provisionerDaemonsServe(rw http.ResponseWriter, r *http.Request) err = server.Serve(r.Context(), session) if err != nil { _ = conn.Close(websocket.StatusInternalError, fmt.Sprintf("serve: %s", err)) + return } + _ = conn.Close(websocket.StatusGoingAway, "") } // The input for a "workspace_provision" job. pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy