diff --git a/cli/agent.go b/cli/agent.go index af54bfc969bce..b900db48d0c2d 100644 --- a/cli/agent.go +++ b/cli/agent.go @@ -125,7 +125,7 @@ func (r *RootCmd) workspaceAgent() *clibase.Cmd { args := append(os.Args, "--no-reap") err := reaper.ForkReap( reaper.WithExecArgs(args...), - reaper.WithCatchSignals(InterruptSignals...), + reaper.WithCatchSignals(StopSignals...), ) if err != nil { logger.Error(ctx, "agent process reaper unable to fork", slog.Error(err)) @@ -144,7 +144,7 @@ func (r *RootCmd) workspaceAgent() *clibase.Cmd { // Note that we don't want to handle these signals in the // process that runs as PID 1, that's why we do this after // the reaper forked. - ctx, stopNotify := inv.SignalNotifyContext(ctx, InterruptSignals...) + ctx, stopNotify := inv.SignalNotifyContext(ctx, StopSignals...) defer stopNotify() // DumpHandler does signal handling, so we call it after the diff --git a/cli/exp_scaletest.go b/cli/exp_scaletest.go index fc8f503f3a30a..fe93fe26cd8fb 100644 --- a/cli/exp_scaletest.go +++ b/cli/exp_scaletest.go @@ -890,7 +890,7 @@ func (r *RootCmd) scaletestWorkspaceTraffic() *clibase.Cmd { Handler: func(inv *clibase.Invocation) (err error) { ctx := inv.Context() - notifyCtx, stop := signal.NotifyContext(ctx, InterruptSignals...) // Checked later. + notifyCtx, stop := signal.NotifyContext(ctx, StopSignals...) // Checked later. defer stop() ctx = notifyCtx diff --git a/cli/externalauth.go b/cli/externalauth.go index 675d795491346..52b897b64b971 100644 --- a/cli/externalauth.go +++ b/cli/externalauth.go @@ -65,7 +65,7 @@ fi Handler: func(inv *clibase.Invocation) error { ctx := inv.Context() - ctx, stop := inv.SignalNotifyContext(ctx, InterruptSignals...) + ctx, stop := inv.SignalNotifyContext(ctx, StopSignals...) defer stop() client, err := r.createAgentClient() diff --git a/cli/gitaskpass.go b/cli/gitaskpass.go index ddfd05af9d1f9..30e82ab90aced 100644 --- a/cli/gitaskpass.go +++ b/cli/gitaskpass.go @@ -25,7 +25,7 @@ func (r *RootCmd) gitAskpass() *clibase.Cmd { Handler: func(inv *clibase.Invocation) error { ctx := inv.Context() - ctx, stop := inv.SignalNotifyContext(ctx, InterruptSignals...) + ctx, stop := inv.SignalNotifyContext(ctx, StopSignals...) defer stop() user, host, err := gitauth.ParseAskpass(inv.Args[0]) diff --git a/cli/gitssh.go b/cli/gitssh.go index b627b3911b820..479ec094f0c5a 100644 --- a/cli/gitssh.go +++ b/cli/gitssh.go @@ -29,7 +29,7 @@ func (r *RootCmd) gitssh() *clibase.Cmd { // Catch interrupt signals to ensure the temporary private // key file is cleaned up on most cases. - ctx, stop := inv.SignalNotifyContext(ctx, InterruptSignals...) + ctx, stop := inv.SignalNotifyContext(ctx, StopSignals...) defer stop() // Early check so errors are reported immediately. diff --git a/cli/server.go b/cli/server.go index e02a891022e02..937f290aff13e 100644 --- a/cli/server.go +++ b/cli/server.go @@ -337,7 +337,7 @@ func (r *RootCmd) Server(newAPI func(context.Context, *coderd.Options) (*coderd. // Register signals early on so that graceful shutdown can't // be interrupted by additional signals. Note that we avoid - // shadowing cancel() (from above) here because notifyStop() + // shadowing cancel() (from above) here because stopCancel() // restores default behavior for the signals. This protects // the shutdown sequence from abruptly terminating things // like: database migrations, provisioner work, workspace @@ -345,8 +345,10 @@ func (r *RootCmd) Server(newAPI func(context.Context, *coderd.Options) (*coderd. // // To get out of a graceful shutdown, the user can send // SIGQUIT with ctrl+\ or SIGKILL with `kill -9`. - notifyCtx, notifyStop := inv.SignalNotifyContext(ctx, InterruptSignals...) - defer notifyStop() + stopCtx, stopCancel := signalNotifyContext(ctx, inv, StopSignalsNoInterrupt...) + defer stopCancel() + interruptCtx, interruptCancel := signalNotifyContext(ctx, inv, InterruptSignals...) + defer interruptCancel() cacheDir := vals.CacheDir.String() err = os.MkdirAll(cacheDir, 0o700) @@ -1028,13 +1030,18 @@ func (r *RootCmd) Server(newAPI func(context.Context, *coderd.Options) (*coderd. hangDetector.Start() defer hangDetector.Close() + waitForProvisionerJobs := false // Currently there is no way to ask the server to shut // itself down, so any exit signal will result in a non-zero // exit of the server. var exitErr error select { - case <-notifyCtx.Done(): - exitErr = notifyCtx.Err() + case <-stopCtx.Done(): + exitErr = stopCtx.Err() + waitForProvisionerJobs = true + _, _ = io.WriteString(inv.Stdout, cliui.Bold("Stop caught, waiting for provisioner jobs to complete and gracefully exiting. Use ctrl+\\ to force quit")) + case <-interruptCtx.Done(): + exitErr = interruptCtx.Err() _, _ = io.WriteString(inv.Stdout, cliui.Bold("Interrupt caught, gracefully exiting. Use ctrl+\\ to force quit")) case <-tunnelDone: exitErr = xerrors.New("dev tunnel closed unexpectedly") @@ -1082,7 +1089,16 @@ func (r *RootCmd) Server(newAPI func(context.Context, *coderd.Options) (*coderd. defer wg.Done() r.Verbosef(inv, "Shutting down provisioner daemon %d...", id) - err := shutdownWithTimeout(provisionerDaemon.Shutdown, 5*time.Second) + timeout := 5 * time.Second + if waitForProvisionerJobs { + // It can last for a long time... + timeout = 30 * time.Minute + } + + err := shutdownWithTimeout(func(ctx context.Context) error { + // We only want to cancel active jobs if we aren't exiting gracefully. + return provisionerDaemon.Shutdown(ctx, !waitForProvisionerJobs) + }, timeout) if err != nil { cliui.Errorf(inv.Stderr, "Failed to shut down provisioner daemon %d: %s\n", id, err) return @@ -2512,3 +2528,12 @@ func escapePostgresURLUserInfo(v string) (string, error) { return v, nil } + +func signalNotifyContext(ctx context.Context, inv *clibase.Invocation, sig ...os.Signal) (context.Context, context.CancelFunc) { + // On Windows, some of our signal functions lack support. + // If we pass in no signals, we should just return the context as-is. + if len(sig) == 0 { + return context.WithCancel(ctx) + } + return inv.SignalNotifyContext(ctx, sig...) +} diff --git a/cli/server_createadminuser.go b/cli/server_createadminuser.go index 7491afac3c3f8..43f78ea784fc8 100644 --- a/cli/server_createadminuser.go +++ b/cli/server_createadminuser.go @@ -47,7 +47,7 @@ func (r *RootCmd) newCreateAdminUserCommand() *clibase.Cmd { logger = logger.Leveled(slog.LevelDebug) } - ctx, cancel := inv.SignalNotifyContext(ctx, InterruptSignals...) + ctx, cancel := inv.SignalNotifyContext(ctx, StopSignals...) defer cancel() if newUserDBURL == "" { diff --git a/cli/server_test.go b/cli/server_test.go index 9699e8a48e8f7..4ce4d2b5f583c 100644 --- a/cli/server_test.go +++ b/cli/server_test.go @@ -21,6 +21,7 @@ import ( "net/url" "os" "path/filepath" + "reflect" "runtime" "strconv" "strings" @@ -1605,7 +1606,7 @@ func TestServer_Production(t *testing.T) { } //nolint:tparallel,paralleltest // This test cannot be run in parallel due to signal handling. -func TestServer_Shutdown(t *testing.T) { +func TestServer_InterruptShutdown(t *testing.T) { t.Skip("This test issues an interrupt signal which will propagate to the test runner.") if runtime.GOOS == "windows" { @@ -1638,6 +1639,46 @@ func TestServer_Shutdown(t *testing.T) { require.NoError(t, err) } +func TestServer_GracefulShutdown(t *testing.T) { + t.Parallel() + if runtime.GOOS == "windows" { + // Sending interrupt signal isn't supported on Windows! + t.SkipNow() + } + ctx, cancelFunc := context.WithCancel(context.Background()) + defer cancelFunc() + + root, cfg := clitest.New(t, + "server", + "--in-memory", + "--http-address", ":0", + "--access-url", "http://example.com", + "--provisioner-daemons", "1", + "--cache-dir", t.TempDir(), + ) + var stopFunc context.CancelFunc + root = root.WithTestSignalNotifyContext(t, func(parent context.Context, signals ...os.Signal) (context.Context, context.CancelFunc) { + if !reflect.DeepEqual(cli.StopSignalsNoInterrupt, signals) { + return context.WithCancel(ctx) + } + var ctx context.Context + ctx, stopFunc = context.WithCancel(parent) + return ctx, stopFunc + }) + serverErr := make(chan error, 1) + pty := ptytest.New(t).Attach(root) + go func() { + serverErr <- root.WithContext(ctx).Run() + }() + _ = waitAccessURL(t, cfg) + // It's fair to assume `stopFunc` isn't nil here, because the server + // has started and access URL is propagated. + stopFunc() + pty.ExpectMatch("waiting for provisioner jobs to complete") + err := <-serverErr + require.NoError(t, err) +} + func BenchmarkServerHelp(b *testing.B) { // server --help is a good proxy for measuring the // constant overhead of each command. diff --git a/cli/signal_unix.go b/cli/signal_unix.go index 05d619c0232e4..9cb6f3f899954 100644 --- a/cli/signal_unix.go +++ b/cli/signal_unix.go @@ -7,8 +7,23 @@ import ( "syscall" ) -var InterruptSignals = []os.Signal{ +// StopSignals is the list of signals that are used for handling +// shutdown behavior. +var StopSignals = []os.Signal{ os.Interrupt, syscall.SIGTERM, syscall.SIGHUP, } + +// StopSignals is the list of signals that are used for handling +// graceful shutdown behavior. +var StopSignalsNoInterrupt = []os.Signal{ + syscall.SIGTERM, + syscall.SIGHUP, +} + +// InterruptSignals is the list of signals that are used for handling +// immediate shutdown behavior. +var InterruptSignals = []os.Signal{ + os.Interrupt, +} diff --git a/cli/signal_windows.go b/cli/signal_windows.go index 3624415a6452f..8d9b8518e615e 100644 --- a/cli/signal_windows.go +++ b/cli/signal_windows.go @@ -6,4 +6,12 @@ import ( "os" ) -var InterruptSignals = []os.Signal{os.Interrupt} +var StopSignals = []os.Signal{ + os.Interrupt, +} + +var StopSignalsNoInterrupt = []os.Signal{} + +var InterruptSignals = []os.Signal{ + os.Interrupt, +} diff --git a/cli/ssh.go b/cli/ssh.go index 21437ee6aea14..023c5307da01a 100644 --- a/cli/ssh.go +++ b/cli/ssh.go @@ -73,7 +73,7 @@ func (r *RootCmd) ssh() *clibase.Cmd { // session can persist for up to 72 hours, since we set a long // timeout on the Agent side of the connection. In particular, // OpenSSH sends SIGHUP to terminate a proxy command. - ctx, stop := inv.SignalNotifyContext(inv.Context(), InterruptSignals...) + ctx, stop := inv.SignalNotifyContext(inv.Context(), StopSignals...) defer stop() ctx, cancel := context.WithCancel(ctx) defer cancel() diff --git a/cli/templatepull_test.go b/cli/templatepull_test.go index ec7beb619606e..1b1d51b0ccd02 100644 --- a/cli/templatepull_test.go +++ b/cli/templatepull_test.go @@ -328,7 +328,7 @@ func TestTemplatePull_ToDir(t *testing.T) { require.NoError(t, inv.Run()) - // Validate behaviour of choosing template name in the absence of an output path argument. + // Validate behavior of choosing template name in the absence of an output path argument. destPath := actualDest if destPath == "" { destPath = template.Name diff --git a/coderd/coderdtest/coderdtest.go b/coderd/coderdtest/coderdtest.go index 85d92a5ef6627..4d315c3e2b058 100644 --- a/coderd/coderdtest/coderdtest.go +++ b/coderd/coderdtest/coderdtest.go @@ -498,7 +498,7 @@ func (c *provisionerdCloser) Close() error { c.closed = true ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) defer cancel() - shutdownErr := c.d.Shutdown(ctx) + shutdownErr := c.d.Shutdown(ctx, true) closeErr := c.d.Close() if shutdownErr != nil { return shutdownErr diff --git a/enterprise/cli/provisionerdaemons.go b/enterprise/cli/provisionerdaemons.go index 5943758a7d743..6f356e541d531 100644 --- a/enterprise/cli/provisionerdaemons.go +++ b/enterprise/cli/provisionerdaemons.go @@ -88,8 +88,10 @@ func (r *RootCmd) provisionerDaemonStart() *clibase.Cmd { ctx, cancel := context.WithCancel(inv.Context()) defer cancel() - notifyCtx, notifyStop := inv.SignalNotifyContext(ctx, agpl.InterruptSignals...) - defer notifyStop() + stopCtx, stopCancel := inv.SignalNotifyContext(ctx, agpl.StopSignalsNoInterrupt...) + defer stopCancel() + interruptCtx, interruptCancel := inv.SignalNotifyContext(ctx, agpl.InterruptSignals...) + defer interruptCancel() tags, err := agpl.ParseProvisionerTags(rawTags) if err != nil { @@ -212,10 +214,17 @@ func (r *RootCmd) provisionerDaemonStart() *clibase.Cmd { Metrics: metrics, }) + waitForProvisionerJobs := false var exitErr error select { - case <-notifyCtx.Done(): - exitErr = notifyCtx.Err() + case <-stopCtx.Done(): + exitErr = stopCtx.Err() + _, _ = fmt.Fprintln(inv.Stdout, cliui.Bold( + "Stop caught, waiting for provisioner jobs to complete and gracefully exiting. Use ctrl+\\ to force quit", + )) + waitForProvisionerJobs = true + case <-interruptCtx.Done(): + exitErr = interruptCtx.Err() _, _ = fmt.Fprintln(inv.Stdout, cliui.Bold( "Interrupt caught, gracefully exiting. Use ctrl+\\ to force quit", )) @@ -225,7 +234,7 @@ func (r *RootCmd) provisionerDaemonStart() *clibase.Cmd { cliui.Errorf(inv.Stderr, "Unexpected error, shutting down server: %s\n", exitErr) } - err = srv.Shutdown(ctx) + err = srv.Shutdown(ctx, waitForProvisionerJobs) if err != nil { return xerrors.Errorf("shutdown: %w", err) } diff --git a/enterprise/cli/proxyserver.go b/enterprise/cli/proxyserver.go index a31d2fe829b80..68ec04b9666e4 100644 --- a/enterprise/cli/proxyserver.go +++ b/enterprise/cli/proxyserver.go @@ -142,7 +142,7 @@ func (r *RootCmd) proxyServer() *clibase.Cmd { // // To get out of a graceful shutdown, the user can send // SIGQUIT with ctrl+\ or SIGKILL with `kill -9`. - notifyCtx, notifyStop := inv.SignalNotifyContext(ctx, cli.InterruptSignals...) + notifyCtx, notifyStop := inv.SignalNotifyContext(ctx, cli.StopSignals...) defer notifyStop() // Clean up idle connections at the end, e.g. diff --git a/enterprise/coderd/provisionerdaemons_test.go b/enterprise/coderd/provisionerdaemons_test.go index caa65c885077b..c62a91593d80f 100644 --- a/enterprise/coderd/provisionerdaemons_test.go +++ b/enterprise/coderd/provisionerdaemons_test.go @@ -441,7 +441,7 @@ func TestProvisionerDaemonServe(t *testing.T) { build := coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, workspace.LatestBuild.ID) require.Equal(t, codersdk.WorkspaceStatusRunning, build.Status) - err = pd.Shutdown(ctx) + err = pd.Shutdown(ctx, false) require.NoError(t, err) err = terraformServer.Close() require.NoError(t, err) diff --git a/provisionerd/provisionerd.go b/provisionerd/provisionerd.go index 52414db4afade..3e49648700f2f 100644 --- a/provisionerd/provisionerd.go +++ b/provisionerd/provisionerd.go @@ -474,15 +474,18 @@ func (p *Server) isClosed() bool { } } -// Shutdown triggers a graceful exit of each registered provisioner. -func (p *Server) Shutdown(ctx context.Context) error { +// Shutdown gracefully exists with the option to cancel the active job. +// If false, it will wait for the job to complete. +// +//nolint:revive +func (p *Server) Shutdown(ctx context.Context, cancelActiveJob bool) error { p.mutex.Lock() p.opts.Logger.Info(ctx, "attempting graceful shutdown") if !p.shuttingDownB { close(p.shuttingDownCh) p.shuttingDownB = true } - if p.activeJob != nil { + if cancelActiveJob && p.activeJob != nil { p.activeJob.Cancel() } p.mutex.Unlock() diff --git a/provisionerd/provisionerd_test.go b/provisionerd/provisionerd_test.go index a04196e6b4a65..2031fa6c3939e 100644 --- a/provisionerd/provisionerd_test.go +++ b/provisionerd/provisionerd_test.go @@ -671,7 +671,7 @@ func TestProvisionerd(t *testing.T) { }), }) require.Condition(t, closedWithin(updateChan, testutil.WaitShort)) - err := server.Shutdown(context.Background()) + err := server.Shutdown(context.Background(), true) require.NoError(t, err) require.Condition(t, closedWithin(completeChan, testutil.WaitShort)) require.NoError(t, server.Close()) @@ -762,7 +762,7 @@ func TestProvisionerd(t *testing.T) { require.Condition(t, closedWithin(completeChan, testutil.WaitShort)) ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) defer cancel() - require.NoError(t, server.Shutdown(ctx)) + require.NoError(t, server.Shutdown(ctx, true)) require.NoError(t, server.Close()) }) @@ -853,7 +853,7 @@ func TestProvisionerd(t *testing.T) { require.Condition(t, closedWithin(completeChan, testutil.WaitShort)) ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) defer cancel() - require.NoError(t, server.Shutdown(ctx)) + require.NoError(t, server.Shutdown(ctx, true)) require.NoError(t, server.Close()) }) @@ -944,7 +944,7 @@ func TestProvisionerd(t *testing.T) { t.Log("completeChan closed") ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) defer cancel() - require.NoError(t, server.Shutdown(ctx)) + require.NoError(t, server.Shutdown(ctx, true)) require.NoError(t, server.Close()) }) @@ -1039,7 +1039,7 @@ func TestProvisionerd(t *testing.T) { require.Condition(t, closedWithin(completeChan, testutil.WaitShort)) ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) defer cancel() - require.NoError(t, server.Shutdown(ctx)) + require.NoError(t, server.Shutdown(ctx, true)) require.NoError(t, server.Close()) assert.Equal(t, ops[len(ops)-1], "CompleteJob") assert.Contains(t, ops[0:len(ops)-1], "Log: Cleaning Up | ") @@ -1076,7 +1076,7 @@ func createProvisionerd(t *testing.T, dialer provisionerd.Dialer, connector prov t.Cleanup(func() { ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) defer cancel() - _ = server.Shutdown(ctx) + _ = server.Shutdown(ctx, true) _ = server.Close() }) return server
Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.
Alternative Proxies: