Skip to content

Commit d7e8627

Browse files
authored
chore: add resume token controller (#15346)
Implements a controller for the Tailnet API resume token RPC, by refactoring from `workspacesdk`. chore re: #14729
1 parent d4131ba commit d7e8627

File tree

2 files changed

+312
-0
lines changed

2 files changed

+312
-0
lines changed

tailnet/controllers.go

Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"context"
55
"fmt"
66
"io"
7+
"math"
78
"strings"
89
"sync"
910
"time"
@@ -16,6 +17,7 @@ import (
1617

1718
"cdr.dev/slog"
1819
"github.com/coder/coder/v2/tailnet/proto"
20+
"github.com/coder/quartz"
1921
)
2022

2123
// A Controller connects to the tailnet control plane, and then uses the control protocols to
@@ -523,3 +525,142 @@ func sendTelemetry(
523525
}
524526
return false
525527
}
528+
529+
type basicResumeTokenController struct {
530+
logger slog.Logger
531+
532+
sync.Mutex
533+
token *proto.RefreshResumeTokenResponse
534+
refresher *basicResumeTokenRefresher
535+
536+
// for testing
537+
clock quartz.Clock
538+
}
539+
540+
func (b *basicResumeTokenController) New(client ResumeTokenClient) CloserWaiter {
541+
b.Lock()
542+
defer b.Unlock()
543+
if b.refresher != nil {
544+
cErr := b.refresher.Close(context.Background())
545+
if cErr != nil {
546+
b.logger.Debug(context.Background(), "closed previous refresher", slog.Error(cErr))
547+
}
548+
}
549+
b.refresher = newBasicResumeTokenRefresher(b.logger, b.clock, b, client)
550+
return b.refresher
551+
}
552+
553+
func (b *basicResumeTokenController) Token() (string, bool) {
554+
b.Lock()
555+
defer b.Unlock()
556+
if b.token == nil {
557+
return "", false
558+
}
559+
if b.token.ExpiresAt.AsTime().Before(b.clock.Now()) {
560+
return "", false
561+
}
562+
return b.token.Token, true
563+
}
564+
565+
func NewBasicResumeTokenController(logger slog.Logger, clock quartz.Clock) ResumeTokenController {
566+
return &basicResumeTokenController{
567+
logger: logger,
568+
clock: clock,
569+
}
570+
}
571+
572+
type basicResumeTokenRefresher struct {
573+
logger slog.Logger
574+
ctx context.Context
575+
cancel context.CancelFunc
576+
ctrl *basicResumeTokenController
577+
client ResumeTokenClient
578+
errCh chan error
579+
580+
sync.Mutex
581+
closed bool
582+
timer *quartz.Timer
583+
}
584+
585+
func (r *basicResumeTokenRefresher) Close(_ context.Context) error {
586+
r.cancel()
587+
r.Lock()
588+
defer r.Unlock()
589+
if r.closed {
590+
return nil
591+
}
592+
r.closed = true
593+
r.timer.Stop()
594+
select {
595+
case r.errCh <- nil:
596+
default: // already have an error
597+
}
598+
return nil
599+
}
600+
601+
func (r *basicResumeTokenRefresher) Wait() <-chan error {
602+
return r.errCh
603+
}
604+
605+
const never time.Duration = math.MaxInt64
606+
607+
func newBasicResumeTokenRefresher(
608+
logger slog.Logger, clock quartz.Clock,
609+
ctrl *basicResumeTokenController, client ResumeTokenClient,
610+
) *basicResumeTokenRefresher {
611+
r := &basicResumeTokenRefresher{
612+
logger: logger,
613+
ctrl: ctrl,
614+
client: client,
615+
errCh: make(chan error, 1),
616+
}
617+
r.ctx, r.cancel = context.WithCancel(context.Background())
618+
r.timer = clock.AfterFunc(never, r.refresh)
619+
go r.refresh()
620+
return r
621+
}
622+
623+
func (r *basicResumeTokenRefresher) refresh() {
624+
if r.ctx.Err() != nil {
625+
return // context done, no need to refresh
626+
}
627+
res, err := r.client.RefreshResumeToken(r.ctx, &proto.RefreshResumeTokenRequest{})
628+
if xerrors.Is(err, context.Canceled) || xerrors.Is(err, context.DeadlineExceeded) {
629+
// these can only come from being closed, no need to log
630+
select {
631+
case r.errCh <- nil:
632+
default: // already have an error
633+
}
634+
return
635+
}
636+
if err != nil {
637+
r.logger.Error(r.ctx, "error refreshing coordinator resume token", slog.Error(err))
638+
select {
639+
case r.errCh <- err:
640+
default: // already have an error
641+
}
642+
return
643+
}
644+
r.logger.Debug(r.ctx, "refreshed coordinator resume token",
645+
slog.F("expires_at", res.GetExpiresAt()),
646+
slog.F("refresh_in", res.GetRefreshIn()),
647+
)
648+
r.ctrl.Lock()
649+
if r.ctrl.refresher == r { // don't overwrite if we're not the current refresher
650+
r.ctrl.token = res
651+
} else {
652+
r.logger.Debug(context.Background(), "not writing token because we have a new client")
653+
}
654+
r.ctrl.Unlock()
655+
dur := res.RefreshIn.AsDuration()
656+
if dur <= 0 {
657+
// A sensible delay to refresh again.
658+
dur = 30 * time.Minute
659+
}
660+
r.Lock()
661+
defer r.Unlock()
662+
if r.closed {
663+
return
664+
}
665+
r.timer.Reset(dur, "basicResumeTokenRefresher", "refresh")
666+
}

tailnet/controllers_test.go

Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ import (
1313
"github.com/stretchr/testify/require"
1414
"go.uber.org/mock/gomock"
1515
"golang.org/x/xerrors"
16+
"google.golang.org/protobuf/types/known/durationpb"
17+
"google.golang.org/protobuf/types/known/timestamppb"
1618
"storj.io/drpc"
1719
"storj.io/drpc/drpcerr"
1820
"tailscale.com/tailcfg"
@@ -24,6 +26,7 @@ import (
2426
"github.com/coder/coder/v2/tailnet/proto"
2527
"github.com/coder/coder/v2/tailnet/tailnettest"
2628
"github.com/coder/coder/v2/testutil"
29+
"github.com/coder/quartz"
2730
)
2831

2932
func TestInMemoryCoordination(t *testing.T) {
@@ -507,3 +510,171 @@ type fakeTelemetryCall struct {
507510
req *proto.TelemetryRequest
508511
errCh chan error
509512
}
513+
514+
func TestBasicResumeTokenController_Mainline(t *testing.T) {
515+
t.Parallel()
516+
ctx := testutil.Context(t, testutil.WaitShort)
517+
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
518+
fr := newFakeResumeTokenClient(ctx)
519+
mClock := quartz.NewMock(t)
520+
trp := mClock.Trap().TimerReset("basicResumeTokenRefresher", "refresh")
521+
defer trp.Close()
522+
523+
uut := tailnet.NewBasicResumeTokenController(logger, mClock)
524+
_, ok := uut.Token()
525+
require.False(t, ok)
526+
527+
cwCh := make(chan tailnet.CloserWaiter, 1)
528+
go func() {
529+
cwCh <- uut.New(fr)
530+
}()
531+
call := testutil.RequireRecvCtx(ctx, t, fr.calls)
532+
testutil.RequireSendCtx(ctx, t, call.resp, &proto.RefreshResumeTokenResponse{
533+
Token: "test token 1",
534+
RefreshIn: durationpb.New(100 * time.Second),
535+
ExpiresAt: timestamppb.New(mClock.Now().Add(200 * time.Second)),
536+
})
537+
trp.MustWait(ctx).Release() // initial refresh done
538+
token, ok := uut.Token()
539+
require.True(t, ok)
540+
require.Equal(t, "test token 1", token)
541+
cw := testutil.RequireRecvCtx(ctx, t, cwCh)
542+
543+
w := mClock.Advance(100 * time.Second)
544+
call = testutil.RequireRecvCtx(ctx, t, fr.calls)
545+
testutil.RequireSendCtx(ctx, t, call.resp, &proto.RefreshResumeTokenResponse{
546+
Token: "test token 2",
547+
RefreshIn: durationpb.New(50 * time.Second),
548+
ExpiresAt: timestamppb.New(mClock.Now().Add(200 * time.Second)),
549+
})
550+
resetCall := trp.MustWait(ctx)
551+
require.Equal(t, resetCall.Duration, 50*time.Second)
552+
resetCall.Release()
553+
w.MustWait(ctx)
554+
token, ok = uut.Token()
555+
require.True(t, ok)
556+
require.Equal(t, "test token 2", token)
557+
558+
err := cw.Close(ctx)
559+
require.NoError(t, err)
560+
err = testutil.RequireRecvCtx(ctx, t, cw.Wait())
561+
require.NoError(t, err)
562+
563+
token, ok = uut.Token()
564+
require.True(t, ok)
565+
require.Equal(t, "test token 2", token)
566+
567+
mClock.Advance(201 * time.Second).MustWait(ctx)
568+
_, ok = uut.Token()
569+
require.False(t, ok)
570+
}
571+
572+
func TestBasicResumeTokenController_NewWhileRefreshing(t *testing.T) {
573+
t.Parallel()
574+
ctx := testutil.Context(t, testutil.WaitShort)
575+
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
576+
mClock := quartz.NewMock(t)
577+
trp := mClock.Trap().TimerReset("basicResumeTokenRefresher", "refresh")
578+
defer trp.Close()
579+
580+
uut := tailnet.NewBasicResumeTokenController(logger, mClock)
581+
_, ok := uut.Token()
582+
require.False(t, ok)
583+
584+
fr1 := newFakeResumeTokenClient(ctx)
585+
cwCh1 := make(chan tailnet.CloserWaiter, 1)
586+
go func() {
587+
cwCh1 <- uut.New(fr1)
588+
}()
589+
call1 := testutil.RequireRecvCtx(ctx, t, fr1.calls)
590+
591+
fr2 := newFakeResumeTokenClient(ctx)
592+
cwCh2 := make(chan tailnet.CloserWaiter, 1)
593+
go func() {
594+
cwCh2 <- uut.New(fr2)
595+
}()
596+
call2 := testutil.RequireRecvCtx(ctx, t, fr2.calls)
597+
598+
testutil.RequireSendCtx(ctx, t, call2.resp, &proto.RefreshResumeTokenResponse{
599+
Token: "test token 2.0",
600+
RefreshIn: durationpb.New(102 * time.Second),
601+
ExpiresAt: timestamppb.New(mClock.Now().Add(200 * time.Second)),
602+
})
603+
604+
cw2 := testutil.RequireRecvCtx(ctx, t, cwCh2) // this ensures Close was called on 1
605+
606+
testutil.RequireSendCtx(ctx, t, call1.resp, &proto.RefreshResumeTokenResponse{
607+
Token: "test token 1",
608+
RefreshIn: durationpb.New(101 * time.Second),
609+
ExpiresAt: timestamppb.New(mClock.Now().Add(200 * time.Second)),
610+
})
611+
612+
trp.MustWait(ctx).Release()
613+
614+
token, ok := uut.Token()
615+
require.True(t, ok)
616+
require.Equal(t, "test token 2.0", token)
617+
618+
// refresher 1 should already be closed.
619+
cw1 := testutil.RequireRecvCtx(ctx, t, cwCh1)
620+
err := testutil.RequireRecvCtx(ctx, t, cw1.Wait())
621+
require.NoError(t, err)
622+
623+
w := mClock.Advance(102 * time.Second)
624+
call := testutil.RequireRecvCtx(ctx, t, fr2.calls)
625+
testutil.RequireSendCtx(ctx, t, call.resp, &proto.RefreshResumeTokenResponse{
626+
Token: "test token 2.1",
627+
RefreshIn: durationpb.New(50 * time.Second),
628+
ExpiresAt: timestamppb.New(mClock.Now().Add(200 * time.Second)),
629+
})
630+
resetCall := trp.MustWait(ctx)
631+
require.Equal(t, resetCall.Duration, 50*time.Second)
632+
resetCall.Release()
633+
w.MustWait(ctx)
634+
token, ok = uut.Token()
635+
require.True(t, ok)
636+
require.Equal(t, "test token 2.1", token)
637+
638+
err = cw2.Close(ctx)
639+
require.NoError(t, err)
640+
err = testutil.RequireRecvCtx(ctx, t, cw2.Wait())
641+
require.NoError(t, err)
642+
}
643+
644+
func newFakeResumeTokenClient(ctx context.Context) *fakeResumeTokenClient {
645+
return &fakeResumeTokenClient{
646+
ctx: ctx,
647+
calls: make(chan *fakeResumeTokenCall),
648+
}
649+
}
650+
651+
type fakeResumeTokenClient struct {
652+
ctx context.Context
653+
calls chan *fakeResumeTokenCall
654+
}
655+
656+
func (f *fakeResumeTokenClient) RefreshResumeToken(_ context.Context, _ *proto.RefreshResumeTokenRequest) (*proto.RefreshResumeTokenResponse, error) {
657+
call := &fakeResumeTokenCall{
658+
resp: make(chan *proto.RefreshResumeTokenResponse),
659+
errCh: make(chan error),
660+
}
661+
select {
662+
case <-f.ctx.Done():
663+
return nil, f.ctx.Err()
664+
case f.calls <- call:
665+
// OK
666+
}
667+
select {
668+
case <-f.ctx.Done():
669+
return nil, f.ctx.Err()
670+
case err := <-call.errCh:
671+
return nil, err
672+
case resp := <-call.resp:
673+
return resp, nil
674+
}
675+
}
676+
677+
type fakeResumeTokenCall struct {
678+
resp chan *proto.RefreshResumeTokenResponse
679+
errCh chan error
680+
}

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy