Skip to content

Commit 0d455cd

Browse files
committed
fix: Use sync.WaitGroup to await hijacked HTTP connections
WebSockets hijack the HTTP connection from the server, causing server.Close() to not wait for these connections to fully cleanup. This adds a global wait-group to the coderd API, which ensures all WebSocket HTTP handlers have properly exited before returning.
1 parent 8f843d2 commit 0d455cd

File tree

4 files changed

+23
-13
lines changed

4 files changed

+23
-13
lines changed

coderd/cmd/root.go

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ func Root() *cobra.Command {
3333
Use: "coderd",
3434
RunE: func(cmd *cobra.Command, args []string) error {
3535
logger := slog.Make(sloghuman.Sink(os.Stderr))
36-
handler := coderd.New(&coderd.Options{
36+
handler, closeCoderd := coderd.New(&coderd.Options{
3737
Logger: logger,
3838
Database: databasefake.New(),
3939
Pubsub: database.NewPubsubInMemory(),
@@ -49,18 +49,19 @@ func Root() *cobra.Command {
4949
Scheme: "http",
5050
Host: address,
5151
})
52-
closer, err := newProvisionerDaemon(cmd.Context(), client, logger)
52+
daemonClose, err := newProvisionerDaemon(cmd.Context(), client, logger)
5353
if err != nil {
5454
return xerrors.Errorf("create provisioner daemon: %w", err)
5555
}
56-
defer closer.Close()
56+
defer daemonClose.Close()
5757

5858
errCh := make(chan error)
5959
go func() {
6060
defer close(errCh)
6161
errCh <- http.Serve(listener, handler)
6262
}()
6363

64+
closeCoderd()
6465
select {
6566
case <-cmd.Context().Done():
6667
return cmd.Context().Err()

coderd/coderd.go

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package coderd
22

33
import (
44
"net/http"
5+
"sync"
56

67
"github.com/go-chi/chi/v5"
78

@@ -20,11 +21,12 @@ type Options struct {
2021
}
2122

2223
// New constructs the Coder API into an HTTP handler.
23-
func New(options *Options) http.Handler {
24+
//
25+
// A wait function is returned to handle awaiting closure
26+
// of hijacked HTTP requests.
27+
func New(options *Options) (http.Handler, func()) {
2428
api := &api{
25-
Database: options.Database,
26-
Logger: options.Logger,
27-
Pubsub: options.Pubsub,
29+
Options: options,
2830
}
2931

3032
r := chi.NewRouter()
@@ -144,13 +146,13 @@ func New(options *Options) http.Handler {
144146
})
145147
})
146148
r.NotFound(site.Handler(options.Logger).ServeHTTP)
147-
return r
149+
return r, api.websocketWaitGroup.Wait
148150
}
149151

150152
// API contains all route handlers. Only HTTP handlers should
151153
// be added to this struct for code clarity.
152154
type api struct {
153-
Database database.Store
154-
Logger slog.Logger
155-
Pubsub database.Pubsub
155+
*Options
156+
157+
websocketWaitGroup sync.WaitGroup
156158
}

coderd/coderdtest/coderdtest.go

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ func New(t *testing.T) *codersdk.Client {
5555
})
5656
}
5757

58-
handler := coderd.New(&coderd.Options{
58+
handler, closeWait := coderd.New(&coderd.Options{
5959
Logger: slogtest.Make(t, nil).Leveled(slog.LevelDebug),
6060
Database: db,
6161
Pubsub: pubsub,
@@ -69,7 +69,10 @@ func New(t *testing.T) *codersdk.Client {
6969
srv.Start()
7070
serverURL, err := url.Parse(srv.URL)
7171
require.NoError(t, err)
72-
t.Cleanup(srv.Close)
72+
t.Cleanup(func() {
73+
srv.Close()
74+
closeWait()
75+
})
7376

7477
return codersdk.New(serverURL)
7578
}

coderd/provisionerdaemons.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,8 @@ func (api *api) provisionerDaemonsServe(rw http.ResponseWriter, r *http.Request)
6262
})
6363
return
6464
}
65+
api.websocketWaitGroup.Add(1)
66+
defer api.websocketWaitGroup.Done()
6567

6668
daemon, err := api.Database.InsertProvisionerDaemon(r.Context(), database.InsertProvisionerDaemonParams{
6769
ID: uuid.New(),
@@ -100,7 +102,9 @@ func (api *api) provisionerDaemonsServe(rw http.ResponseWriter, r *http.Request)
100102
err = server.Serve(r.Context(), session)
101103
if err != nil {
102104
_ = conn.Close(websocket.StatusInternalError, fmt.Sprintf("serve: %s", err))
105+
return
103106
}
107+
_ = conn.Close(websocket.StatusGoingAway, "")
104108
}
105109

106110
// The input for a "workspace_provision" job.

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy