Skip to content

Commit ded612d

Browse files
authored
fix: use authenticated urls for pubsub (#14261)
1 parent 6914862 commit ded612d

File tree

9 files changed

+290
-14
lines changed

9 files changed

+290
-14
lines changed

coderd/database/awsiamrds/awsiamrds.go

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,21 @@ import (
1010
"github.com/aws/aws-sdk-go-v2/aws"
1111
"github.com/aws/aws-sdk-go-v2/config"
1212
"github.com/aws/aws-sdk-go-v2/feature/rds/auth"
13+
"github.com/lib/pq"
1314
"golang.org/x/xerrors"
15+
16+
"github.com/coder/coder/v2/coderd/database"
1417
)
1518

1619
type awsIamRdsDriver struct {
1720
parent driver.Driver
1821
cfg aws.Config
1922
}
2023

21-
var _ driver.Driver = &awsIamRdsDriver{}
24+
var (
25+
_ driver.Driver = &awsIamRdsDriver{}
26+
_ database.ConnectorCreator = &awsIamRdsDriver{}
27+
)
2228

2329
// Register initializes and registers our aws iam rds wrapped database driver.
2430
func Register(ctx context.Context, parentName string) (string, error) {
@@ -65,6 +71,16 @@ func (d *awsIamRdsDriver) Open(name string) (driver.Conn, error) {
6571
return conn, nil
6672
}
6773

74+
// Connector returns a driver.Connector that fetches a new authentication token for each connection.
75+
func (d *awsIamRdsDriver) Connector(name string) (driver.Connector, error) {
76+
connector := &connector{
77+
url: name,
78+
cfg: d.cfg,
79+
}
80+
81+
return connector, nil
82+
}
83+
6884
func getAuthenticatedURL(cfg aws.Config, dbURL string) (string, error) {
6985
nURL, err := url.Parse(dbURL)
7086
if err != nil {
@@ -82,3 +98,37 @@ func getAuthenticatedURL(cfg aws.Config, dbURL string) (string, error) {
8298

8399
return nURL.String(), nil
84100
}
101+
102+
type connector struct {
103+
url string
104+
cfg aws.Config
105+
dialer pq.Dialer
106+
}
107+
108+
var _ database.DialerConnector = &connector{}
109+
110+
func (c *connector) Connect(ctx context.Context) (driver.Conn, error) {
111+
nURL, err := getAuthenticatedURL(c.cfg, c.url)
112+
if err != nil {
113+
return nil, xerrors.Errorf("assigning authentication token to url: %w", err)
114+
}
115+
116+
nc, err := pq.NewConnector(nURL)
117+
if err != nil {
118+
return nil, xerrors.Errorf("creating new connector: %w", err)
119+
}
120+
121+
if c.dialer != nil {
122+
nc.Dialer(c.dialer)
123+
}
124+
125+
return nc.Connect(ctx)
126+
}
127+
128+
func (*connector) Driver() driver.Driver {
129+
return &pq.Driver{}
130+
}
131+
132+
func (c *connector) Dialer(dialer pq.Dialer) {
133+
c.dialer = dialer
134+
}

coderd/database/awsiamrds/awsiamrds_test.go

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,11 @@ import (
77

88
"github.com/stretchr/testify/require"
99

10+
"cdr.dev/slog"
1011
"cdr.dev/slog/sloggers/slogtest"
11-
1212
"github.com/coder/coder/v2/cli"
13-
awsrdsiam "github.com/coder/coder/v2/coderd/database/awsiamrds"
13+
"github.com/coder/coder/v2/coderd/database/awsiamrds"
14+
"github.com/coder/coder/v2/coderd/database/pubsub"
1415
"github.com/coder/coder/v2/testutil"
1516
)
1617

@@ -22,13 +23,15 @@ func TestDriver(t *testing.T) {
2223
// export DBAWSIAMRDS_TEST_URL="postgres://user@host:5432/dbname";
2324
url := os.Getenv("DBAWSIAMRDS_TEST_URL")
2425
if url == "" {
26+
t.Log("skipping test; no DBAWSIAMRDS_TEST_URL set")
2527
t.Skip()
2628
}
2729

30+
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
2831
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
2932
defer cancel()
3033

31-
sqlDriver, err := awsrdsiam.Register(ctx, "postgres")
34+
sqlDriver, err := awsiamrds.Register(ctx, "postgres")
3235
require.NoError(t, err)
3336

3437
db, err := cli.ConnectToPostgres(ctx, slogtest.Make(t, nil), sqlDriver, url)
@@ -47,4 +50,23 @@ func TestDriver(t *testing.T) {
4750
var one int
4851
require.NoError(t, i.Scan(&one))
4952
require.Equal(t, 1, one)
53+
54+
ps, err := pubsub.New(ctx, logger, db, url)
55+
require.NoError(t, err)
56+
57+
gotChan := make(chan struct{})
58+
subCancel, err := ps.Subscribe("test", func(_ context.Context, _ []byte) {
59+
close(gotChan)
60+
})
61+
defer subCancel()
62+
require.NoError(t, err)
63+
64+
err = ps.Publish("test", []byte("hello"))
65+
require.NoError(t, err)
66+
67+
select {
68+
case <-gotChan:
69+
case <-ctx.Done():
70+
require.Fail(t, "timed out waiting for message")
71+
}
5072
}

coderd/database/connector.go

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
package database
2+
3+
import (
4+
"database/sql/driver"
5+
6+
"github.com/lib/pq"
7+
)
8+
9+
// ConnectorCreator is a driver.Driver that can create a driver.Connector.
10+
type ConnectorCreator interface {
11+
driver.Driver
12+
Connector(name string) (driver.Connector, error)
13+
}
14+
15+
// DialerConnector is a driver.Connector that can set a pq.Dialer.
16+
type DialerConnector interface {
17+
driver.Connector
18+
Dialer(dialer pq.Dialer)
19+
}

coderd/database/dbtestutil/driver.go

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
package dbtestutil
2+
3+
import (
4+
"context"
5+
"database/sql/driver"
6+
7+
"github.com/lib/pq"
8+
"golang.org/x/xerrors"
9+
10+
"github.com/coder/coder/v2/coderd/database"
11+
)
12+
13+
var _ database.DialerConnector = &Connector{}
14+
15+
type Connector struct {
16+
name string
17+
driver *Driver
18+
dialer pq.Dialer
19+
}
20+
21+
func (c *Connector) Connect(_ context.Context) (driver.Conn, error) {
22+
if c.dialer != nil {
23+
conn, err := pq.DialOpen(c.dialer, c.name)
24+
if err != nil {
25+
return nil, xerrors.Errorf("failed to dial open connection: %w", err)
26+
}
27+
28+
c.driver.Connections <- conn
29+
30+
return conn, nil
31+
}
32+
33+
conn, err := pq.Driver{}.Open(c.name)
34+
if err != nil {
35+
return nil, xerrors.Errorf("failed to open connection: %w", err)
36+
}
37+
38+
c.driver.Connections <- conn
39+
40+
return conn, nil
41+
}
42+
43+
func (c *Connector) Driver() driver.Driver {
44+
return c.driver
45+
}
46+
47+
func (c *Connector) Dialer(dialer pq.Dialer) {
48+
c.dialer = dialer
49+
}
50+
51+
type Driver struct {
52+
Connections chan driver.Conn
53+
}
54+
55+
func NewDriver() *Driver {
56+
return &Driver{
57+
Connections: make(chan driver.Conn, 1),
58+
}
59+
}
60+
61+
func (d *Driver) Connector(name string) (driver.Connector, error) {
62+
return &Connector{
63+
name: name,
64+
driver: d,
65+
}, nil
66+
}
67+
68+
func (d *Driver) Open(name string) (driver.Conn, error) {
69+
c, err := d.Connector(name)
70+
if err != nil {
71+
return nil, err
72+
}
73+
74+
return c.Connect(context.Background())
75+
}
76+
77+
func (d *Driver) Close() {
78+
close(d.Connections)
79+
}

coderd/database/pubsub/pubsub.go

Lines changed: 34 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package pubsub
33
import (
44
"context"
55
"database/sql"
6+
"database/sql/driver"
67
"errors"
78
"io"
89
"net"
@@ -15,6 +16,8 @@ import (
1516
"github.com/prometheus/client_golang/prometheus"
1617
"golang.org/x/xerrors"
1718

19+
"github.com/coder/coder/v2/coderd/database"
20+
1821
"cdr.dev/slog"
1922
)
2023

@@ -432,9 +435,35 @@ func (p *PGPubsub) startListener(ctx context.Context, connectURL string) error {
432435
// pq.defaultDialer uses a zero net.Dialer as well.
433436
d: net.Dialer{},
434437
}
438+
connector driver.Connector
439+
err error
435440
)
441+
442+
// Create a custom connector if the database driver supports it.
443+
connectorCreator, ok := p.db.Driver().(database.ConnectorCreator)
444+
if ok {
445+
connector, err = connectorCreator.Connector(connectURL)
446+
if err != nil {
447+
return xerrors.Errorf("create custom connector: %w", err)
448+
}
449+
} else {
450+
// use the default pq connector otherwise
451+
connector, err = pq.NewConnector(connectURL)
452+
if err != nil {
453+
return xerrors.Errorf("create pq connector: %w", err)
454+
}
455+
}
456+
457+
// Set the dialer if the connector supports it.
458+
dc, ok := connector.(database.DialerConnector)
459+
if !ok {
460+
p.logger.Critical(ctx, "connector does not support setting log dialer, database connection debug logs will be missing")
461+
} else {
462+
dc.Dialer(dialer)
463+
}
464+
436465
p.pgListener = pqListenerShim{
437-
Listener: pq.NewDialListener(dialer, connectURL, time.Second, time.Minute, func(t pq.ListenerEventType, err error) {
466+
Listener: pq.NewConnectorListener(connector, connectURL, time.Second, time.Minute, func(t pq.ListenerEventType, err error) {
438467
switch t {
439468
case pq.ListenerEventConnected:
440469
p.logger.Info(ctx, "pubsub connected to postgres")
@@ -583,8 +612,8 @@ func (p *PGPubsub) Collect(metrics chan<- prometheus.Metric) {
583612
}
584613

585614
// New creates a new Pubsub implementation using a PostgreSQL connection.
586-
func New(startCtx context.Context, logger slog.Logger, database *sql.DB, connectURL string) (*PGPubsub, error) {
587-
p := newWithoutListener(logger, database)
615+
func New(startCtx context.Context, logger slog.Logger, db *sql.DB, connectURL string) (*PGPubsub, error) {
616+
p := newWithoutListener(logger, db)
588617
if err := p.startListener(startCtx, connectURL); err != nil {
589618
return nil, err
590619
}
@@ -594,11 +623,11 @@ func New(startCtx context.Context, logger slog.Logger, database *sql.DB, connect
594623
}
595624

596625
// newWithoutListener creates a new PGPubsub without creating the pqListener.
597-
func newWithoutListener(logger slog.Logger, database *sql.DB) *PGPubsub {
626+
func newWithoutListener(logger slog.Logger, db *sql.DB) *PGPubsub {
598627
return &PGPubsub{
599628
logger: logger,
600629
listenDone: make(chan struct{}),
601-
db: database,
630+
db: db,
602631
queues: make(map[string]map[uuid.UUID]*msgQueue),
603632
latencyMeasurer: NewLatencyMeasurer(logger.Named("latency-measurer")),
604633

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy