Skip to content

Commit a7ed977

Browse files
authored
chore: prevent db migrations from running on all cli commands (#15980)
1 parent 813270d commit a7ed977

File tree

8 files changed

+115
-33
lines changed

8 files changed

+115
-33
lines changed

cli/resetpassword.go

Lines changed: 31 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,22 +3,27 @@
33
package cli
44

55
import (
6-
"database/sql"
76
"fmt"
87

98
"golang.org/x/xerrors"
109

10+
"cdr.dev/slog"
11+
"cdr.dev/slog/sloggers/sloghuman"
12+
"github.com/coder/coder/v2/coderd/database/awsiamrds"
13+
"github.com/coder/coder/v2/codersdk"
1114
"github.com/coder/pretty"
1215
"github.com/coder/serpent"
1316

1417
"github.com/coder/coder/v2/cli/cliui"
1518
"github.com/coder/coder/v2/coderd/database"
16-
"github.com/coder/coder/v2/coderd/database/migrations"
1719
"github.com/coder/coder/v2/coderd/userpassword"
1820
)
1921

2022
func (*RootCmd) resetPassword() *serpent.Command {
21-
var postgresURL string
23+
var (
24+
postgresURL string
25+
postgresAuth string
26+
)
2227

2328
root := &serpent.Command{
2429
Use: "reset-password <username>",
@@ -27,20 +32,26 @@ func (*RootCmd) resetPassword() *serpent.Command {
2732
Handler: func(inv *serpent.Invocation) error {
2833
username := inv.Args[0]
2934

30-
sqlDB, err := sql.Open("postgres", postgresURL)
31-
if err != nil {
32-
return xerrors.Errorf("dial postgres: %w", err)
35+
logger := slog.Make(sloghuman.Sink(inv.Stdout))
36+
if ok, _ := inv.ParsedFlags().GetBool("verbose"); ok {
37+
logger = logger.Leveled(slog.LevelDebug)
3338
}
34-
defer sqlDB.Close()
35-
err = sqlDB.Ping()
36-
if err != nil {
37-
return xerrors.Errorf("ping postgres: %w", err)
39+
40+
sqlDriver := "postgres"
41+
if codersdk.PostgresAuth(postgresAuth) == codersdk.PostgresAuthAWSIAMRDS {
42+
var err error
43+
sqlDriver, err = awsiamrds.Register(inv.Context(), sqlDriver)
44+
if err != nil {
45+
return xerrors.Errorf("register aws rds iam auth: %w", err)
46+
}
3847
}
3948

40-
err = migrations.EnsureClean(sqlDB)
49+
sqlDB, err := ConnectToPostgres(inv.Context(), logger, sqlDriver, postgresURL, nil)
4150
if err != nil {
42-
return xerrors.Errorf("database needs migration: %w", err)
51+
return xerrors.Errorf("dial postgres: %w", err)
4352
}
53+
defer sqlDB.Close()
54+
4455
db := database.New(sqlDB)
4556

4657
user, err := db.GetUserByEmailOrUsername(inv.Context(), database.GetUserByEmailOrUsernameParams{
@@ -97,6 +108,14 @@ func (*RootCmd) resetPassword() *serpent.Command {
97108
Env: "CODER_PG_CONNECTION_URL",
98109
Value: serpent.StringOf(&postgresURL),
99110
},
111+
serpent.Option{
112+
Name: "Postgres Connection Auth",
113+
Description: "Type of auth to use when connecting to postgres.",
114+
Flag: "postgres-connection-auth",
115+
Env: "CODER_PG_CONNECTION_AUTH",
116+
Default: "password",
117+
Value: serpent.EnumOf(&postgresAuth, codersdk.PostgresAuthDrivers...),
118+
},
100119
}
101120

102121
return root

cli/server.go

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -697,7 +697,7 @@ func (r *RootCmd) Server(newAPI func(context.Context, *coderd.Options) (*coderd.
697697
options.Database = dbmem.New()
698698
options.Pubsub = pubsub.NewInMemory()
699699
} else {
700-
sqlDB, dbURL, err := getPostgresDB(ctx, logger, vals.PostgresURL.String(), codersdk.PostgresAuth(vals.PostgresAuth), sqlDriver)
700+
sqlDB, dbURL, err := getAndMigratePostgresDB(ctx, logger, vals.PostgresURL.String(), codersdk.PostgresAuth(vals.PostgresAuth), sqlDriver)
701701
if err != nil {
702702
return xerrors.Errorf("connect to postgres: %w", err)
703703
}
@@ -2090,9 +2090,18 @@ func IsLocalhost(host string) bool {
20902090
return host == "localhost" || host == "127.0.0.1" || host == "::1"
20912091
}
20922092

2093-
func ConnectToPostgres(ctx context.Context, logger slog.Logger, driver string, dbURL string) (sqlDB *sql.DB, err error) {
2093+
// ConnectToPostgres takes in the migration command to run on the database once
2094+
// it connects. To avoid running migrations, pass in `nil` or a no-op function.
2095+
// Regardless of the passed in migration function, if the database is not fully
2096+
// migrated, an error will be returned. This can happen if the database is on a
2097+
// future or past migration version.
2098+
//
2099+
// If no error is returned, the database is fully migrated and up to date.
2100+
func ConnectToPostgres(ctx context.Context, logger slog.Logger, driver string, dbURL string, migrate func(db *sql.DB) error) (*sql.DB, error) {
20942101
logger.Debug(ctx, "connecting to postgresql")
20952102

2103+
var err error
2104+
var sqlDB *sql.DB
20962105
// Try to connect for 30 seconds.
20972106
ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
20982107
defer cancel()
@@ -2155,9 +2164,16 @@ func ConnectToPostgres(ctx context.Context, logger slog.Logger, driver string, d
21552164
}
21562165
logger.Debug(ctx, "connected to postgresql", slog.F("version", versionNum))
21572166

2158-
err = migrations.Up(sqlDB)
2167+
if migrate != nil {
2168+
err = migrate(sqlDB)
2169+
if err != nil {
2170+
return nil, xerrors.Errorf("migrate up: %w", err)
2171+
}
2172+
}
2173+
2174+
err = migrations.EnsureClean(sqlDB)
21592175
if err != nil {
2160-
return nil, xerrors.Errorf("migrate up: %w", err)
2176+
return nil, xerrors.Errorf("migrations in database: %w", err)
21612177
}
21622178
// The default is 0 but the request will fail with a 500 if the DB
21632179
// cannot accept new connections, so we try to limit that here.
@@ -2561,7 +2577,7 @@ func signalNotifyContext(ctx context.Context, inv *serpent.Invocation, sig ...os
25612577
return inv.SignalNotifyContext(ctx, sig...)
25622578
}
25632579

2564-
func getPostgresDB(ctx context.Context, logger slog.Logger, postgresURL string, auth codersdk.PostgresAuth, sqlDriver string) (*sql.DB, string, error) {
2580+
func getAndMigratePostgresDB(ctx context.Context, logger slog.Logger, postgresURL string, auth codersdk.PostgresAuth, sqlDriver string) (*sql.DB, string, error) {
25652581
dbURL, err := escapePostgresURLUserInfo(postgresURL)
25662582
if err != nil {
25672583
return nil, "", xerrors.Errorf("escaping postgres URL: %w", err)
@@ -2574,7 +2590,7 @@ func getPostgresDB(ctx context.Context, logger slog.Logger, postgresURL string,
25742590
}
25752591
}
25762592

2577-
sqlDB, err := ConnectToPostgres(ctx, logger, sqlDriver, dbURL)
2593+
sqlDB, err := ConnectToPostgres(ctx, logger, sqlDriver, dbURL, migrations.Up)
25782594
if err != nil {
25792595
return nil, "", xerrors.Errorf("connect to postgres: %w", err)
25802596
}

cli/server_createadminuser.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ func (r *RootCmd) newCreateAdminUserCommand() *serpent.Command {
7272
}
7373
}
7474

75-
sqlDB, err := ConnectToPostgres(ctx, logger, sqlDriver, newUserDBURL)
75+
sqlDB, err := ConnectToPostgres(ctx, logger, sqlDriver, newUserDBURL, nil)
7676
if err != nil {
7777
return xerrors.Errorf("connect to postgres: %w", err)
7878
}

cli/server_test.go

Lines changed: 43 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -38,11 +38,13 @@ import (
3838
"tailscale.com/derp/derphttp"
3939
"tailscale.com/types/key"
4040

41+
"cdr.dev/slog/sloggers/slogtest"
4142
"github.com/coder/coder/v2/cli"
4243
"github.com/coder/coder/v2/cli/clitest"
4344
"github.com/coder/coder/v2/cli/config"
4445
"github.com/coder/coder/v2/coderd/coderdtest"
4546
"github.com/coder/coder/v2/coderd/database/dbtestutil"
47+
"github.com/coder/coder/v2/coderd/database/migrations"
4648
"github.com/coder/coder/v2/coderd/httpapi"
4749
"github.com/coder/coder/v2/coderd/telemetry"
4850
"github.com/coder/coder/v2/codersdk"
@@ -1828,20 +1830,51 @@ func TestConnectToPostgres(t *testing.T) {
18281830
if !dbtestutil.WillUsePostgres() {
18291831
t.Skip("this test does not make sense without postgres")
18301832
}
1831-
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
1832-
t.Cleanup(cancel)
18331833

1834-
log := testutil.Logger(t)
1834+
t.Run("Migrate", func(t *testing.T) {
1835+
t.Parallel()
18351836

1836-
dbURL, err := dbtestutil.Open(t)
1837-
require.NoError(t, err)
1837+
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
1838+
t.Cleanup(cancel)
18381839

1839-
sqlDB, err := cli.ConnectToPostgres(ctx, log, "postgres", dbURL)
1840-
require.NoError(t, err)
1841-
t.Cleanup(func() {
1842-
_ = sqlDB.Close()
1840+
log := testutil.Logger(t)
1841+
1842+
dbURL, err := dbtestutil.Open(t)
1843+
require.NoError(t, err)
1844+
1845+
sqlDB, err := cli.ConnectToPostgres(ctx, log, "postgres", dbURL, migrations.Up)
1846+
require.NoError(t, err)
1847+
t.Cleanup(func() {
1848+
_ = sqlDB.Close()
1849+
})
1850+
require.NoError(t, sqlDB.PingContext(ctx))
1851+
})
1852+
1853+
t.Run("NoMigrate", func(t *testing.T) {
1854+
t.Parallel()
1855+
1856+
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
1857+
t.Cleanup(cancel)
1858+
1859+
log := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
1860+
1861+
dbURL, err := dbtestutil.Open(t)
1862+
require.NoError(t, err)
1863+
1864+
okDB, err := cli.ConnectToPostgres(ctx, log, "postgres", dbURL, nil)
1865+
require.NoError(t, err)
1866+
defer okDB.Close()
1867+
1868+
// Set the migration number forward
1869+
_, err = okDB.Exec(`UPDATE schema_migrations SET version = version + 1`)
1870+
require.NoError(t, err)
1871+
1872+
_, err = cli.ConnectToPostgres(ctx, log, "postgres", dbURL, nil)
1873+
require.Error(t, err)
1874+
require.ErrorContains(t, err, "database needs migration")
1875+
1876+
require.NoError(t, okDB.PingContext(ctx))
18431877
})
1844-
require.NoError(t, sqlDB.PingContext(ctx))
18451878
}
18461879

18471880
func TestServer_InvalidDERP(t *testing.T) {

cli/testdata/coder_reset-password_--help.golden

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@ USAGE:
66
Directly connect to the database to reset a user's password
77

88
OPTIONS:
9+
--postgres-connection-auth password|awsiamrds, $CODER_PG_CONNECTION_AUTH (default: password)
10+
Type of auth to use when connecting to postgres.
11+
912
--postgres-url string, $CODER_PG_CONNECTION_URL
1013
URL of a PostgreSQL database to connect to.
1114

coderd/database/awsiamrds/awsiamrds_test.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99

1010
"github.com/coder/coder/v2/cli"
1111
"github.com/coder/coder/v2/coderd/database/awsiamrds"
12+
"github.com/coder/coder/v2/coderd/database/migrations"
1213
"github.com/coder/coder/v2/coderd/database/pubsub"
1314
"github.com/coder/coder/v2/testutil"
1415
)
@@ -32,7 +33,7 @@ func TestDriver(t *testing.T) {
3233
sqlDriver, err := awsiamrds.Register(ctx, "postgres")
3334
require.NoError(t, err)
3435

35-
db, err := cli.ConnectToPostgres(ctx, testutil.Logger(t), sqlDriver, url)
36+
db, err := cli.ConnectToPostgres(ctx, testutil.Logger(t), sqlDriver, url, migrations.Up)
3637
require.NoError(t, err)
3738
defer func() {
3839
_ = db.Close()

docs/reference/cli/reset-password.md

Lines changed: 10 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

enterprise/cli/server_dbcrypt.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ func (*RootCmd) dbcryptRotateCmd() *serpent.Command {
9898
}
9999
}
100100

101-
sqlDB, err := cli.ConnectToPostgres(inv.Context(), logger, sqlDriver, flags.PostgresURL)
101+
sqlDB, err := cli.ConnectToPostgres(inv.Context(), logger, sqlDriver, flags.PostgresURL, nil)
102102
if err != nil {
103103
return xerrors.Errorf("connect to postgres: %w", err)
104104
}
@@ -163,7 +163,7 @@ func (*RootCmd) dbcryptDecryptCmd() *serpent.Command {
163163
}
164164
}
165165

166-
sqlDB, err := cli.ConnectToPostgres(inv.Context(), logger, sqlDriver, flags.PostgresURL)
166+
sqlDB, err := cli.ConnectToPostgres(inv.Context(), logger, sqlDriver, flags.PostgresURL, nil)
167167
if err != nil {
168168
return xerrors.Errorf("connect to postgres: %w", err)
169169
}
@@ -219,7 +219,7 @@ Are you sure you want to continue?`
219219
}
220220
}
221221

222-
sqlDB, err := cli.ConnectToPostgres(inv.Context(), logger, sqlDriver, flags.PostgresURL)
222+
sqlDB, err := cli.ConnectToPostgres(inv.Context(), logger, sqlDriver, flags.PostgresURL, nil)
223223
if err != nil {
224224
return xerrors.Errorf("connect to postgres: %w", err)
225225
}

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy