diff --git a/enterprise/tailnet/pgcoord.go b/enterprise/tailnet/pgcoord.go index baccfe66a7fd7..857cdafe94e79 100644 --- a/enterprise/tailnet/pgcoord.go +++ b/enterprise/tailnet/pgcoord.go @@ -161,11 +161,12 @@ func newPGCoordInternal( closed: make(chan struct{}), } go func() { - // when the main context is canceled, or the coordinator closed, the binder and tunneler - // always eventually stop. Once they stop it's safe to cancel the querier context, which + // when the main context is canceled, or the coordinator closed, the binder, tunneler, and + // handshaker always eventually stop. Once they stop it's safe to cancel the querier context, which // has the effect of deleting the coordinator from the database and ceasing heartbeats. c.binder.workerWG.Wait() c.tunneler.workerWG.Wait() + c.handshaker.workerWG.Wait() querierCancel() }() logger.Info(ctx, "starting coordinator") @@ -231,6 +232,7 @@ func (c *pgCoord) Close() error { c.logger.Info(c.ctx, "closing coordinator") c.cancel() c.closeOnce.Do(func() { close(c.closed) }) + c.querier.wait() return nil } @@ -795,6 +797,8 @@ type querier struct { workQ *workQ[querierWorkKey] + wg sync.WaitGroup + heartbeats *heartbeats updates <-chan hbUpdate @@ -831,6 +835,7 @@ func newQuerier(ctx context.Context, } q.subscribe() + q.wg.Add(2 + numWorkers) go func() { <-firstHeartbeat go q.handleIncoming() @@ -842,7 +847,13 @@ func newQuerier(ctx context.Context, return q } +func (q *querier) wait() { + q.wg.Wait() + q.heartbeats.wg.Wait() +} + func (q *querier) handleIncoming() { + defer q.wg.Done() for { select { case <-q.ctx.Done(): @@ -919,6 +930,7 @@ func (q *querier) cleanupConn(c *connIO) { } func (q *querier) worker() { + defer q.wg.Done() eb := backoff.NewExponentialBackOff() eb.MaxElapsedTime = 0 // retry indefinitely eb.MaxInterval = dbMaxBackoff @@ -1204,6 +1216,7 @@ func (q *querier) resyncPeerMappings() { } func (q *querier) handleUpdates() { + defer q.wg.Done() for { select { case <-q.ctx.Done(): @@ -1451,6 +1464,8 @@ type heartbeats struct { coordinators map[uuid.UUID]time.Time timer *time.Timer + wg sync.WaitGroup + // overwritten in tests, but otherwise constant cleanupPeriod time.Duration } @@ -1472,6 +1487,7 @@ func newHeartbeats( coordinators: make(map[uuid.UUID]time.Time), cleanupPeriod: cleanupPeriod, } + h.wg.Add(3) go h.subscribe() go h.sendBeats() go h.cleanupLoop() @@ -1502,6 +1518,7 @@ func (h *heartbeats) filter(mappings []mapping) []mapping { } func (h *heartbeats) subscribe() { + defer h.wg.Done() eb := backoff.NewExponentialBackOff() eb.MaxElapsedTime = 0 // retry indefinitely eb.MaxInterval = dbMaxBackoff @@ -1611,6 +1628,7 @@ func (h *heartbeats) checkExpiry() { } func (h *heartbeats) sendBeats() { + defer h.wg.Done() // send an initial heartbeat so that other coordinators can start using our bindings right away. h.sendBeat() close(h.firstHeartbeat) // signal binder it can start writing @@ -1662,6 +1680,7 @@ func (h *heartbeats) sendDelete() { } func (h *heartbeats) cleanupLoop() { + defer h.wg.Done() h.cleanup() tkr := time.NewTicker(h.cleanupPeriod) defer tkr.Stop() diff --git a/enterprise/tailnet/pgcoord_internal_test.go b/enterprise/tailnet/pgcoord_internal_test.go index 53fd61d73f066..4607e6fb2ab2f 100644 --- a/enterprise/tailnet/pgcoord_internal_test.go +++ b/enterprise/tailnet/pgcoord_internal_test.go @@ -66,6 +66,7 @@ func TestHeartbeats_Cleanup(t *testing.T) { store: mStore, cleanupPeriod: time.Millisecond, } + uut.wg.Add(1) go uut.cleanupLoop() for i := 0; i < 6; i++ { diff --git a/enterprise/tailnet/pgcoord_test.go b/enterprise/tailnet/pgcoord_test.go index 5bd722533dc39..9c363ee700570 100644 --- a/enterprise/tailnet/pgcoord_test.go +++ b/enterprise/tailnet/pgcoord_test.go @@ -864,6 +864,53 @@ func TestPGCoordinator_Lost(t *testing.T) { agpltest.LostTest(ctx, t, coordinator) } +func TestPGCoordinator_DeleteOnClose(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong) + defer cancel() + ctrl := gomock.NewController(t) + mStore := dbmock.NewMockStore(ctrl) + ps := pubsub.NewInMemory() + logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug) + + upsertDone := make(chan struct{}) + deleteCalled := make(chan struct{}) + finishDelete := make(chan struct{}) + mStore.EXPECT().UpsertTailnetCoordinator(gomock.Any(), gomock.Any()). + MinTimes(1). + Do(func(_ context.Context, _ uuid.UUID) { close(upsertDone) }). + Return(database.TailnetCoordinator{}, nil) + mStore.EXPECT().DeleteCoordinator(gomock.Any(), gomock.Any()). + Times(1). + Do(func(_ context.Context, _ uuid.UUID) { + close(deleteCalled) + <-finishDelete + }). + Return(nil) + + // extra calls we don't particularly care about for this test + mStore.EXPECT().CleanTailnetCoordinators(gomock.Any()).AnyTimes().Return(nil) + mStore.EXPECT().CleanTailnetLostPeers(gomock.Any()).AnyTimes().Return(nil) + mStore.EXPECT().CleanTailnetTunnels(gomock.Any()).AnyTimes().Return(nil) + + uut, err := tailnet.NewPGCoord(ctx, logger, ps, mStore) + require.NoError(t, err) + testutil.RequireRecvCtx(ctx, t, upsertDone) + closeErr := make(chan error, 1) + go func() { + closeErr <- uut.Close() + }() + select { + case <-closeErr: + t.Fatal("close returned before DeleteCoordinator called") + case <-deleteCalled: + close(finishDelete) + err := testutil.RequireRecvCtx(ctx, t, closeErr) + require.NoError(t, err) + } +} + type testConn struct { ws, serverWS net.Conn nodeChan chan []*agpl.Node
Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.
Alternative Proxies: