Skip to content

Commit a0962ba

Browse files
authored
fix: wait for PGCoordinator to clean up db state (#13351)
c.f. #13192 (comment) We need to wait for PGCoordinator to finish its work before returning on `Close()`, so that we delete database state (best effort -- if this fails others will filter it out based on heartbeats).
1 parent e5bb0a7 commit a0962ba

File tree

3 files changed

+69
-2
lines changed

3 files changed

+69
-2
lines changed

enterprise/tailnet/pgcoord.go

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -161,11 +161,12 @@ func newPGCoordInternal(
161161
closed: make(chan struct{}),
162162
}
163163
go func() {
164-
// when the main context is canceled, or the coordinator closed, the binder and tunneler
165-
// always eventually stop. Once they stop it's safe to cancel the querier context, which
164+
// when the main context is canceled, or the coordinator closed, the binder, tunneler, and
165+
// handshaker always eventually stop. Once they stop it's safe to cancel the querier context, which
166166
// has the effect of deleting the coordinator from the database and ceasing heartbeats.
167167
c.binder.workerWG.Wait()
168168
c.tunneler.workerWG.Wait()
169+
c.handshaker.workerWG.Wait()
169170
querierCancel()
170171
}()
171172
logger.Info(ctx, "starting coordinator")
@@ -231,6 +232,7 @@ func (c *pgCoord) Close() error {
231232
c.logger.Info(c.ctx, "closing coordinator")
232233
c.cancel()
233234
c.closeOnce.Do(func() { close(c.closed) })
235+
c.querier.wait()
234236
return nil
235237
}
236238

@@ -795,6 +797,8 @@ type querier struct {
795797

796798
workQ *workQ[querierWorkKey]
797799

800+
wg sync.WaitGroup
801+
798802
heartbeats *heartbeats
799803
updates <-chan hbUpdate
800804

@@ -831,6 +835,7 @@ func newQuerier(ctx context.Context,
831835
}
832836
q.subscribe()
833837

838+
q.wg.Add(2 + numWorkers)
834839
go func() {
835840
<-firstHeartbeat
836841
go q.handleIncoming()
@@ -842,7 +847,13 @@ func newQuerier(ctx context.Context,
842847
return q
843848
}
844849

850+
func (q *querier) wait() {
851+
q.wg.Wait()
852+
q.heartbeats.wg.Wait()
853+
}
854+
845855
func (q *querier) handleIncoming() {
856+
defer q.wg.Done()
846857
for {
847858
select {
848859
case <-q.ctx.Done():
@@ -919,6 +930,7 @@ func (q *querier) cleanupConn(c *connIO) {
919930
}
920931

921932
func (q *querier) worker() {
933+
defer q.wg.Done()
922934
eb := backoff.NewExponentialBackOff()
923935
eb.MaxElapsedTime = 0 // retry indefinitely
924936
eb.MaxInterval = dbMaxBackoff
@@ -1204,6 +1216,7 @@ func (q *querier) resyncPeerMappings() {
12041216
}
12051217

12061218
func (q *querier) handleUpdates() {
1219+
defer q.wg.Done()
12071220
for {
12081221
select {
12091222
case <-q.ctx.Done():
@@ -1451,6 +1464,8 @@ type heartbeats struct {
14511464
coordinators map[uuid.UUID]time.Time
14521465
timer *time.Timer
14531466

1467+
wg sync.WaitGroup
1468+
14541469
// overwritten in tests, but otherwise constant
14551470
cleanupPeriod time.Duration
14561471
}
@@ -1472,6 +1487,7 @@ func newHeartbeats(
14721487
coordinators: make(map[uuid.UUID]time.Time),
14731488
cleanupPeriod: cleanupPeriod,
14741489
}
1490+
h.wg.Add(3)
14751491
go h.subscribe()
14761492
go h.sendBeats()
14771493
go h.cleanupLoop()
@@ -1502,6 +1518,7 @@ func (h *heartbeats) filter(mappings []mapping) []mapping {
15021518
}
15031519

15041520
func (h *heartbeats) subscribe() {
1521+
defer h.wg.Done()
15051522
eb := backoff.NewExponentialBackOff()
15061523
eb.MaxElapsedTime = 0 // retry indefinitely
15071524
eb.MaxInterval = dbMaxBackoff
@@ -1611,6 +1628,7 @@ func (h *heartbeats) checkExpiry() {
16111628
}
16121629

16131630
func (h *heartbeats) sendBeats() {
1631+
defer h.wg.Done()
16141632
// send an initial heartbeat so that other coordinators can start using our bindings right away.
16151633
h.sendBeat()
16161634
close(h.firstHeartbeat) // signal binder it can start writing
@@ -1662,6 +1680,7 @@ func (h *heartbeats) sendDelete() {
16621680
}
16631681

16641682
func (h *heartbeats) cleanupLoop() {
1683+
defer h.wg.Done()
16651684
h.cleanup()
16661685
tkr := time.NewTicker(h.cleanupPeriod)
16671686
defer tkr.Stop()

enterprise/tailnet/pgcoord_internal_test.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ func TestHeartbeats_Cleanup(t *testing.T) {
6666
store: mStore,
6767
cleanupPeriod: time.Millisecond,
6868
}
69+
uut.wg.Add(1)
6970
go uut.cleanupLoop()
7071

7172
for i := 0; i < 6; i++ {

enterprise/tailnet/pgcoord_test.go

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -864,6 +864,53 @@ func TestPGCoordinator_Lost(t *testing.T) {
864864
agpltest.LostTest(ctx, t, coordinator)
865865
}
866866

867+
func TestPGCoordinator_DeleteOnClose(t *testing.T) {
868+
t.Parallel()
869+
870+
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong)
871+
defer cancel()
872+
ctrl := gomock.NewController(t)
873+
mStore := dbmock.NewMockStore(ctrl)
874+
ps := pubsub.NewInMemory()
875+
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
876+
877+
upsertDone := make(chan struct{})
878+
deleteCalled := make(chan struct{})
879+
finishDelete := make(chan struct{})
880+
mStore.EXPECT().UpsertTailnetCoordinator(gomock.Any(), gomock.Any()).
881+
MinTimes(1).
882+
Do(func(_ context.Context, _ uuid.UUID) { close(upsertDone) }).
883+
Return(database.TailnetCoordinator{}, nil)
884+
mStore.EXPECT().DeleteCoordinator(gomock.Any(), gomock.Any()).
885+
Times(1).
886+
Do(func(_ context.Context, _ uuid.UUID) {
887+
close(deleteCalled)
888+
<-finishDelete
889+
}).
890+
Return(nil)
891+
892+
// extra calls we don't particularly care about for this test
893+
mStore.EXPECT().CleanTailnetCoordinators(gomock.Any()).AnyTimes().Return(nil)
894+
mStore.EXPECT().CleanTailnetLostPeers(gomock.Any()).AnyTimes().Return(nil)
895+
mStore.EXPECT().CleanTailnetTunnels(gomock.Any()).AnyTimes().Return(nil)
896+
897+
uut, err := tailnet.NewPGCoord(ctx, logger, ps, mStore)
898+
require.NoError(t, err)
899+
testutil.RequireRecvCtx(ctx, t, upsertDone)
900+
closeErr := make(chan error, 1)
901+
go func() {
902+
closeErr <- uut.Close()
903+
}()
904+
select {
905+
case <-closeErr:
906+
t.Fatal("close returned before DeleteCoordinator called")
907+
case <-deleteCalled:
908+
close(finishDelete)
909+
err := testutil.RequireRecvCtx(ctx, t, closeErr)
910+
require.NoError(t, err)
911+
}
912+
}
913+
867914
type testConn struct {
868915
ws, serverWS net.Conn
869916
nodeChan chan []*agpl.Node

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy