diff --git a/cli/resetpassword.go b/cli/resetpassword.go index 2aacc8a6e6c44..f77ed81d14db4 100644 --- a/cli/resetpassword.go +++ b/cli/resetpassword.go @@ -3,22 +3,27 @@ package cli import ( - "database/sql" "fmt" "golang.org/x/xerrors" + "cdr.dev/slog" + "cdr.dev/slog/sloggers/sloghuman" + "github.com/coder/coder/v2/coderd/database/awsiamrds" + "github.com/coder/coder/v2/codersdk" "github.com/coder/pretty" "github.com/coder/serpent" "github.com/coder/coder/v2/cli/cliui" "github.com/coder/coder/v2/coderd/database" - "github.com/coder/coder/v2/coderd/database/migrations" "github.com/coder/coder/v2/coderd/userpassword" ) func (*RootCmd) resetPassword() *serpent.Command { - var postgresURL string + var ( + postgresURL string + postgresAuth string + ) root := &serpent.Command{ Use: "reset-password ", @@ -27,20 +32,26 @@ func (*RootCmd) resetPassword() *serpent.Command { Handler: func(inv *serpent.Invocation) error { username := inv.Args[0] - sqlDB, err := sql.Open("postgres", postgresURL) - if err != nil { - return xerrors.Errorf("dial postgres: %w", err) + logger := slog.Make(sloghuman.Sink(inv.Stdout)) + if ok, _ := inv.ParsedFlags().GetBool("verbose"); ok { + logger = logger.Leveled(slog.LevelDebug) } - defer sqlDB.Close() - err = sqlDB.Ping() - if err != nil { - return xerrors.Errorf("ping postgres: %w", err) + + sqlDriver := "postgres" + if codersdk.PostgresAuth(postgresAuth) == codersdk.PostgresAuthAWSIAMRDS { + var err error + sqlDriver, err = awsiamrds.Register(inv.Context(), sqlDriver) + if err != nil { + return xerrors.Errorf("register aws rds iam auth: %w", err) + } } - err = migrations.EnsureClean(sqlDB) + sqlDB, err := ConnectToPostgres(inv.Context(), logger, sqlDriver, postgresURL, nil) if err != nil { - return xerrors.Errorf("database needs migration: %w", err) + return xerrors.Errorf("dial postgres: %w", err) } + defer sqlDB.Close() + db := database.New(sqlDB) user, err := db.GetUserByEmailOrUsername(inv.Context(), database.GetUserByEmailOrUsernameParams{ @@ -97,6 +108,14 @@ func (*RootCmd) resetPassword() *serpent.Command { Env: "CODER_PG_CONNECTION_URL", Value: serpent.StringOf(&postgresURL), }, + serpent.Option{ + Name: "Postgres Connection Auth", + Description: "Type of auth to use when connecting to postgres.", + Flag: "postgres-connection-auth", + Env: "CODER_PG_CONNECTION_AUTH", + Default: "password", + Value: serpent.EnumOf(&postgresAuth, codersdk.PostgresAuthDrivers...), + }, } return root diff --git a/cli/server.go b/cli/server.go index ff8b2963e0eb4..9bb4cfb0a72f2 100644 --- a/cli/server.go +++ b/cli/server.go @@ -697,7 +697,7 @@ func (r *RootCmd) Server(newAPI func(context.Context, *coderd.Options) (*coderd. options.Database = dbmem.New() options.Pubsub = pubsub.NewInMemory() } else { - sqlDB, dbURL, err := getPostgresDB(ctx, logger, vals.PostgresURL.String(), codersdk.PostgresAuth(vals.PostgresAuth), sqlDriver) + sqlDB, dbURL, err := getAndMigratePostgresDB(ctx, logger, vals.PostgresURL.String(), codersdk.PostgresAuth(vals.PostgresAuth), sqlDriver) if err != nil { return xerrors.Errorf("connect to postgres: %w", err) } @@ -2090,9 +2090,18 @@ func IsLocalhost(host string) bool { return host == "localhost" || host == "127.0.0.1" || host == "::1" } -func ConnectToPostgres(ctx context.Context, logger slog.Logger, driver string, dbURL string) (sqlDB *sql.DB, err error) { +// ConnectToPostgres takes in the migration command to run on the database once +// it connects. To avoid running migrations, pass in `nil` or a no-op function. +// Regardless of the passed in migration function, if the database is not fully +// migrated, an error will be returned. This can happen if the database is on a +// future or past migration version. +// +// If no error is returned, the database is fully migrated and up to date. +func ConnectToPostgres(ctx context.Context, logger slog.Logger, driver string, dbURL string, migrate func(db *sql.DB) error) (*sql.DB, error) { logger.Debug(ctx, "connecting to postgresql") + var err error + var sqlDB *sql.DB // Try to connect for 30 seconds. ctx, cancel := context.WithTimeout(ctx, 30*time.Second) defer cancel() @@ -2155,9 +2164,16 @@ func ConnectToPostgres(ctx context.Context, logger slog.Logger, driver string, d } logger.Debug(ctx, "connected to postgresql", slog.F("version", versionNum)) - err = migrations.Up(sqlDB) + if migrate != nil { + err = migrate(sqlDB) + if err != nil { + return nil, xerrors.Errorf("migrate up: %w", err) + } + } + + err = migrations.EnsureClean(sqlDB) if err != nil { - return nil, xerrors.Errorf("migrate up: %w", err) + return nil, xerrors.Errorf("migrations in database: %w", err) } // The default is 0 but the request will fail with a 500 if the DB // cannot accept new connections, so we try to limit that here. @@ -2561,7 +2577,7 @@ func signalNotifyContext(ctx context.Context, inv *serpent.Invocation, sig ...os return inv.SignalNotifyContext(ctx, sig...) } -func getPostgresDB(ctx context.Context, logger slog.Logger, postgresURL string, auth codersdk.PostgresAuth, sqlDriver string) (*sql.DB, string, error) { +func getAndMigratePostgresDB(ctx context.Context, logger slog.Logger, postgresURL string, auth codersdk.PostgresAuth, sqlDriver string) (*sql.DB, string, error) { dbURL, err := escapePostgresURLUserInfo(postgresURL) if err != nil { return nil, "", xerrors.Errorf("escaping postgres URL: %w", err) @@ -2574,7 +2590,7 @@ func getPostgresDB(ctx context.Context, logger slog.Logger, postgresURL string, } } - sqlDB, err := ConnectToPostgres(ctx, logger, sqlDriver, dbURL) + sqlDB, err := ConnectToPostgres(ctx, logger, sqlDriver, dbURL, migrations.Up) if err != nil { return nil, "", xerrors.Errorf("connect to postgres: %w", err) } diff --git a/cli/server_createadminuser.go b/cli/server_createadminuser.go index 7ef95e7e093e6..ed9c7b9bcc921 100644 --- a/cli/server_createadminuser.go +++ b/cli/server_createadminuser.go @@ -72,7 +72,7 @@ func (r *RootCmd) newCreateAdminUserCommand() *serpent.Command { } } - sqlDB, err := ConnectToPostgres(ctx, logger, sqlDriver, newUserDBURL) + sqlDB, err := ConnectToPostgres(ctx, logger, sqlDriver, newUserDBURL, nil) if err != nil { return xerrors.Errorf("connect to postgres: %w", err) } diff --git a/cli/server_test.go b/cli/server_test.go index 9ba963d484548..0dba63e7c2fe3 100644 --- a/cli/server_test.go +++ b/cli/server_test.go @@ -38,11 +38,13 @@ import ( "tailscale.com/derp/derphttp" "tailscale.com/types/key" + "cdr.dev/slog/sloggers/slogtest" "github.com/coder/coder/v2/cli" "github.com/coder/coder/v2/cli/clitest" "github.com/coder/coder/v2/cli/config" "github.com/coder/coder/v2/coderd/coderdtest" "github.com/coder/coder/v2/coderd/database/dbtestutil" + "github.com/coder/coder/v2/coderd/database/migrations" "github.com/coder/coder/v2/coderd/httpapi" "github.com/coder/coder/v2/coderd/telemetry" "github.com/coder/coder/v2/codersdk" @@ -1828,20 +1830,51 @@ func TestConnectToPostgres(t *testing.T) { if !dbtestutil.WillUsePostgres() { t.Skip("this test does not make sense without postgres") } - ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) - t.Cleanup(cancel) - log := testutil.Logger(t) + t.Run("Migrate", func(t *testing.T) { + t.Parallel() - dbURL, err := dbtestutil.Open(t) - require.NoError(t, err) + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) + t.Cleanup(cancel) - sqlDB, err := cli.ConnectToPostgres(ctx, log, "postgres", dbURL) - require.NoError(t, err) - t.Cleanup(func() { - _ = sqlDB.Close() + log := testutil.Logger(t) + + dbURL, err := dbtestutil.Open(t) + require.NoError(t, err) + + sqlDB, err := cli.ConnectToPostgres(ctx, log, "postgres", dbURL, migrations.Up) + require.NoError(t, err) + t.Cleanup(func() { + _ = sqlDB.Close() + }) + require.NoError(t, sqlDB.PingContext(ctx)) + }) + + t.Run("NoMigrate", func(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) + t.Cleanup(cancel) + + log := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}) + + dbURL, err := dbtestutil.Open(t) + require.NoError(t, err) + + okDB, err := cli.ConnectToPostgres(ctx, log, "postgres", dbURL, nil) + require.NoError(t, err) + defer okDB.Close() + + // Set the migration number forward + _, err = okDB.Exec(`UPDATE schema_migrations SET version = version + 1`) + require.NoError(t, err) + + _, err = cli.ConnectToPostgres(ctx, log, "postgres", dbURL, nil) + require.Error(t, err) + require.ErrorContains(t, err, "database needs migration") + + require.NoError(t, okDB.PingContext(ctx)) }) - require.NoError(t, sqlDB.PingContext(ctx)) } func TestServer_InvalidDERP(t *testing.T) { diff --git a/cli/testdata/coder_reset-password_--help.golden b/cli/testdata/coder_reset-password_--help.golden index a7d53df12ad90..ccefb412d8fb7 100644 --- a/cli/testdata/coder_reset-password_--help.golden +++ b/cli/testdata/coder_reset-password_--help.golden @@ -6,6 +6,9 @@ USAGE: Directly connect to the database to reset a user's password OPTIONS: + --postgres-connection-auth password|awsiamrds, $CODER_PG_CONNECTION_AUTH (default: password) + Type of auth to use when connecting to postgres. + --postgres-url string, $CODER_PG_CONNECTION_URL URL of a PostgreSQL database to connect to. diff --git a/coderd/database/awsiamrds/awsiamrds_test.go b/coderd/database/awsiamrds/awsiamrds_test.go index 844b85b119850..d52da4aab7bfe 100644 --- a/coderd/database/awsiamrds/awsiamrds_test.go +++ b/coderd/database/awsiamrds/awsiamrds_test.go @@ -9,6 +9,7 @@ import ( "github.com/coder/coder/v2/cli" "github.com/coder/coder/v2/coderd/database/awsiamrds" + "github.com/coder/coder/v2/coderd/database/migrations" "github.com/coder/coder/v2/coderd/database/pubsub" "github.com/coder/coder/v2/testutil" ) @@ -32,7 +33,7 @@ func TestDriver(t *testing.T) { sqlDriver, err := awsiamrds.Register(ctx, "postgres") require.NoError(t, err) - db, err := cli.ConnectToPostgres(ctx, testutil.Logger(t), sqlDriver, url) + db, err := cli.ConnectToPostgres(ctx, testutil.Logger(t), sqlDriver, url, migrations.Up) require.NoError(t, err) defer func() { _ = db.Close() diff --git a/docs/reference/cli/reset-password.md b/docs/reference/cli/reset-password.md index 75e94821cdb31..ada9ad7e7db3e 100644 --- a/docs/reference/cli/reset-password.md +++ b/docs/reference/cli/reset-password.md @@ -19,3 +19,13 @@ coder reset-password [flags] | Environment | $CODER_PG_CONNECTION_URL | URL of a PostgreSQL database to connect to. + +### --postgres-connection-auth + +| | | +|-------------|----------------------------------------| +| Type | password\|awsiamrds | +| Environment | $CODER_PG_CONNECTION_AUTH | +| Default | password | + +Type of auth to use when connecting to postgres. diff --git a/enterprise/cli/server_dbcrypt.go b/enterprise/cli/server_dbcrypt.go index 148303f85402d..72ac6cc6e82b0 100644 --- a/enterprise/cli/server_dbcrypt.go +++ b/enterprise/cli/server_dbcrypt.go @@ -98,7 +98,7 @@ func (*RootCmd) dbcryptRotateCmd() *serpent.Command { } } - sqlDB, err := cli.ConnectToPostgres(inv.Context(), logger, sqlDriver, flags.PostgresURL) + sqlDB, err := cli.ConnectToPostgres(inv.Context(), logger, sqlDriver, flags.PostgresURL, nil) if err != nil { return xerrors.Errorf("connect to postgres: %w", err) } @@ -163,7 +163,7 @@ func (*RootCmd) dbcryptDecryptCmd() *serpent.Command { } } - sqlDB, err := cli.ConnectToPostgres(inv.Context(), logger, sqlDriver, flags.PostgresURL) + sqlDB, err := cli.ConnectToPostgres(inv.Context(), logger, sqlDriver, flags.PostgresURL, nil) if err != nil { return xerrors.Errorf("connect to postgres: %w", err) } @@ -219,7 +219,7 @@ Are you sure you want to continue?` } } - sqlDB, err := cli.ConnectToPostgres(inv.Context(), logger, sqlDriver, flags.PostgresURL) + sqlDB, err := cli.ConnectToPostgres(inv.Context(), logger, sqlDriver, flags.PostgresURL, nil) if err != nil { return xerrors.Errorf("connect to postgres: %w", err) } pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy