Skip to content

Commit 384873a

Browse files
authored
feat: add wsproxy implementation for key fetching (#14917)
1 parent 5315656 commit 384873a

File tree

2 files changed

+709
-0
lines changed

2 files changed

+709
-0
lines changed

enterprise/wsproxy/keycache.go

Lines changed: 224 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,224 @@
1+
package wsproxy
2+
3+
import (
4+
"context"
5+
"sync"
6+
"time"
7+
8+
"golang.org/x/xerrors"
9+
10+
"cdr.dev/slog"
11+
12+
"github.com/coder/coder/v2/coderd/cryptokeys"
13+
"github.com/coder/coder/v2/codersdk"
14+
"github.com/coder/quartz"
15+
)
16+
17+
const (
18+
// latestSequence is a special sequence number that represents the latest key.
19+
latestSequence = -1
20+
// refreshInterval is the interval at which the key cache will refresh.
21+
refreshInterval = time.Minute * 10
22+
)
23+
24+
type Fetcher interface {
25+
Fetch(ctx context.Context) ([]codersdk.CryptoKey, error)
26+
}
27+
28+
type CryptoKeyCache struct {
29+
Clock quartz.Clock
30+
refreshCtx context.Context
31+
refreshCancel context.CancelFunc
32+
fetcher Fetcher
33+
logger slog.Logger
34+
35+
mu sync.Mutex
36+
keys map[int32]codersdk.CryptoKey
37+
lastFetch time.Time
38+
refresher *quartz.Timer
39+
fetching bool
40+
closed bool
41+
cond *sync.Cond
42+
}
43+
44+
func NewCryptoKeyCache(ctx context.Context, log slog.Logger, client Fetcher, opts ...func(*CryptoKeyCache)) (*CryptoKeyCache, error) {
45+
cache := &CryptoKeyCache{
46+
Clock: quartz.NewReal(),
47+
logger: log,
48+
fetcher: client,
49+
}
50+
51+
for _, opt := range opts {
52+
opt(cache)
53+
}
54+
55+
cache.cond = sync.NewCond(&cache.mu)
56+
cache.refreshCtx, cache.refreshCancel = context.WithCancel(ctx)
57+
cache.refresher = cache.Clock.AfterFunc(refreshInterval, cache.refresh)
58+
59+
keys, err := cache.cryptoKeys(ctx)
60+
if err != nil {
61+
cache.refreshCancel()
62+
return nil, xerrors.Errorf("initial fetch: %w", err)
63+
}
64+
cache.keys = keys
65+
66+
return cache, nil
67+
}
68+
69+
func (k *CryptoKeyCache) Signing(ctx context.Context) (codersdk.CryptoKey, error) {
70+
return k.cryptoKey(ctx, latestSequence)
71+
}
72+
73+
func (k *CryptoKeyCache) Verifying(ctx context.Context, sequence int32) (codersdk.CryptoKey, error) {
74+
return k.cryptoKey(ctx, sequence)
75+
}
76+
77+
func (k *CryptoKeyCache) cryptoKey(ctx context.Context, sequence int32) (codersdk.CryptoKey, error) {
78+
k.mu.Lock()
79+
defer k.mu.Unlock()
80+
81+
if k.closed {
82+
return codersdk.CryptoKey{}, cryptokeys.ErrClosed
83+
}
84+
85+
var key codersdk.CryptoKey
86+
var ok bool
87+
for key, ok = k.key(sequence); !ok && k.fetching && !k.closed; {
88+
k.cond.Wait()
89+
}
90+
91+
if k.closed {
92+
return codersdk.CryptoKey{}, cryptokeys.ErrClosed
93+
}
94+
95+
if ok {
96+
return checkKey(key, sequence, k.Clock.Now())
97+
}
98+
99+
k.fetching = true
100+
k.mu.Unlock()
101+
102+
keys, err := k.cryptoKeys(ctx)
103+
if err != nil {
104+
return codersdk.CryptoKey{}, xerrors.Errorf("get keys: %w", err)
105+
}
106+
107+
k.mu.Lock()
108+
k.lastFetch = k.Clock.Now()
109+
k.refresher.Reset(refreshInterval)
110+
k.keys = keys
111+
k.fetching = false
112+
k.cond.Broadcast()
113+
114+
key, ok = k.key(sequence)
115+
if !ok {
116+
return codersdk.CryptoKey{}, cryptokeys.ErrKeyNotFound
117+
}
118+
119+
return checkKey(key, sequence, k.Clock.Now())
120+
}
121+
122+
func (k *CryptoKeyCache) key(sequence int32) (codersdk.CryptoKey, bool) {
123+
if sequence == latestSequence {
124+
return k.keys[latestSequence], k.keys[latestSequence].CanSign(k.Clock.Now())
125+
}
126+
127+
key, ok := k.keys[sequence]
128+
return key, ok
129+
}
130+
131+
func checkKey(key codersdk.CryptoKey, sequence int32, now time.Time) (codersdk.CryptoKey, error) {
132+
if sequence == latestSequence {
133+
if !key.CanSign(now) {
134+
return codersdk.CryptoKey{}, cryptokeys.ErrKeyInvalid
135+
}
136+
return key, nil
137+
}
138+
139+
if !key.CanVerify(now) {
140+
return codersdk.CryptoKey{}, cryptokeys.ErrKeyInvalid
141+
}
142+
143+
return key, nil
144+
}
145+
146+
// refresh fetches the keys from the control plane and updates the cache.
147+
func (k *CryptoKeyCache) refresh() {
148+
now := k.Clock.Now("CryptoKeyCache", "refresh")
149+
k.mu.Lock()
150+
151+
if k.closed {
152+
k.mu.Unlock()
153+
return
154+
}
155+
156+
// If something's already fetching, we don't need to do anything.
157+
if k.fetching {
158+
k.mu.Unlock()
159+
return
160+
}
161+
162+
// There's a window we must account for where the timer fires while a fetch
163+
// is ongoing but prior to the timer getting reset. In this case we want to
164+
// avoid double fetching.
165+
if now.Sub(k.lastFetch) < refreshInterval {
166+
k.mu.Unlock()
167+
return
168+
}
169+
170+
k.fetching = true
171+
172+
k.mu.Unlock()
173+
keys, err := k.cryptoKeys(k.refreshCtx)
174+
if err != nil {
175+
k.logger.Error(k.refreshCtx, "fetch crypto keys", slog.Error(err))
176+
return
177+
}
178+
179+
k.mu.Lock()
180+
defer k.mu.Unlock()
181+
182+
k.lastFetch = k.Clock.Now()
183+
k.refresher.Reset(refreshInterval)
184+
k.keys = keys
185+
k.fetching = false
186+
k.cond.Broadcast()
187+
}
188+
189+
// cryptoKeys queries the control plane for the crypto keys.
190+
// Outside of initialization, this should only be called by fetch.
191+
func (k *CryptoKeyCache) cryptoKeys(ctx context.Context) (map[int32]codersdk.CryptoKey, error) {
192+
keys, err := k.fetcher.Fetch(ctx)
193+
if err != nil {
194+
return nil, xerrors.Errorf("crypto keys: %w", err)
195+
}
196+
cache := toKeyMap(keys, k.Clock.Now())
197+
return cache, nil
198+
}
199+
200+
func toKeyMap(keys []codersdk.CryptoKey, now time.Time) map[int32]codersdk.CryptoKey {
201+
m := make(map[int32]codersdk.CryptoKey)
202+
var latest codersdk.CryptoKey
203+
for _, key := range keys {
204+
m[key.Sequence] = key
205+
if key.Sequence > latest.Sequence && key.CanSign(now) {
206+
m[latestSequence] = key
207+
}
208+
}
209+
return m
210+
}
211+
212+
func (k *CryptoKeyCache) Close() {
213+
k.mu.Lock()
214+
defer k.mu.Unlock()
215+
216+
if k.closed {
217+
return
218+
}
219+
220+
k.closed = true
221+
k.refreshCancel()
222+
k.refresher.Stop()
223+
k.cond.Broadcast()
224+
}

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy