diff --git a/.vscode/settings.json b/.vscode/settings.json index 8b92ff2228df0..9771a27a0de3e 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -19,6 +19,7 @@ "derphttp", "derpmap", "devel", + "dflags", "drpc", "drpcconn", "drpcmux", @@ -86,8 +87,10 @@ "ptytest", "quickstart", "reconfig", + "replicasync", "retrier", "rpty", + "SCIM", "sdkproto", "sdktrace", "Signup", diff --git a/agent/agent.go b/agent/agent.go index 6d0a9a952f44b..f7c5598b7b710 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -170,6 +170,7 @@ func (a *agent) runTailnet(ctx context.Context, derpMap *tailcfg.DERPMap) { if a.isClosed() { return } + a.logger.Debug(ctx, "running tailnet with derpmap", slog.F("derpmap", derpMap)) if a.network != nil { a.network.SetDERPMap(derpMap) return diff --git a/agent/agent_test.go b/agent/agent_test.go index 06a33598b755f..e10eee7f111a0 100644 --- a/agent/agent_test.go +++ b/agent/agent_test.go @@ -465,7 +465,7 @@ func TestAgent(t *testing.T) { conn, _ := setupAgent(t, codersdk.WorkspaceAgentMetadata{}, 0) require.Eventually(t, func() bool { - _, err := conn.Ping() + _, err := conn.Ping(context.Background()) return err == nil }, testutil.WaitMedium, testutil.IntervalFast) conn1, err := conn.DialContext(context.Background(), l.Addr().Network(), l.Addr().String()) @@ -483,9 +483,7 @@ func TestAgent(t *testing.T) { t.Run("Speedtest", func(t *testing.T) { t.Parallel() - if testing.Short() { - t.Skip("The minimum duration for a speedtest is hardcoded in Tailscale to 5s!") - } + t.Skip("This test is relatively flakey because of Tailscale's speedtest code...") derpMap := tailnettest.RunDERPAndSTUN(t) conn, _ := setupAgent(t, codersdk.WorkspaceAgentMetadata{ DERPMap: derpMap, diff --git a/cli/agent_test.go b/cli/agent_test.go index dd0cb1d789349..f487ebfc005ed 100644 --- a/cli/agent_test.go +++ b/cli/agent_test.go @@ -7,8 +7,6 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "cdr.dev/slog" - "github.com/coder/coder/cli/clitest" "github.com/coder/coder/coderd/coderdtest" "github.com/coder/coder/provisioner/echo" @@ -67,11 +65,11 @@ func TestWorkspaceAgent(t *testing.T) { if assert.NotEmpty(t, workspace.LatestBuild.Resources) && assert.NotEmpty(t, resources[0].Agents) { assert.NotEmpty(t, resources[0].Agents[0].Version) } - dialer, err := client.DialWorkspaceAgentTailnet(ctx, slog.Logger{}, resources[0].Agents[0].ID) + dialer, err := client.DialWorkspaceAgent(ctx, resources[0].Agents[0].ID, nil) require.NoError(t, err) defer dialer.Close() require.Eventually(t, func() bool { - _, err := dialer.Ping() + _, err := dialer.Ping(ctx) return err == nil }, testutil.WaitMedium, testutil.IntervalFast) cancelFunc() @@ -128,11 +126,11 @@ func TestWorkspaceAgent(t *testing.T) { if assert.NotEmpty(t, resources) && assert.NotEmpty(t, resources[0].Agents) { assert.NotEmpty(t, resources[0].Agents[0].Version) } - dialer, err := client.DialWorkspaceAgentTailnet(ctx, slog.Logger{}, resources[0].Agents[0].ID) + dialer, err := client.DialWorkspaceAgent(ctx, resources[0].Agents[0].ID, nil) require.NoError(t, err) defer dialer.Close() require.Eventually(t, func() bool { - _, err := dialer.Ping() + _, err := dialer.Ping(ctx) return err == nil }, testutil.WaitMedium, testutil.IntervalFast) cancelFunc() @@ -189,11 +187,11 @@ func TestWorkspaceAgent(t *testing.T) { if assert.NotEmpty(t, resources) && assert.NotEmpty(t, resources[0].Agents) { assert.NotEmpty(t, resources[0].Agents[0].Version) } - dialer, err := client.DialWorkspaceAgentTailnet(ctx, slog.Logger{}, resources[0].Agents[0].ID) + dialer, err := client.DialWorkspaceAgent(ctx, resources[0].Agents[0].ID, nil) require.NoError(t, err) defer dialer.Close() require.Eventually(t, func() bool { - _, err := dialer.Ping() + _, err := dialer.Ping(ctx) return err == nil }, testutil.WaitMedium, testutil.IntervalFast) cancelFunc() diff --git a/cli/config/file.go b/cli/config/file.go index a98237afed22b..388ce0881f304 100644 --- a/cli/config/file.go +++ b/cli/config/file.go @@ -13,6 +13,11 @@ func (r Root) Session() File { return File(filepath.Join(string(r), "session")) } +// ReplicaID is a unique identifier for the Coder server. +func (r Root) ReplicaID() File { + return File(filepath.Join(string(r), "replica_id")) +} + func (r Root) URL() File { return File(filepath.Join(string(r), "url")) } diff --git a/cli/configssh_test.go b/cli/configssh_test.go index 3e1512a0c3471..4553cbe431221 100644 --- a/cli/configssh_test.go +++ b/cli/configssh_test.go @@ -19,7 +19,6 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "cdr.dev/slog" "cdr.dev/slog/sloggers/slogtest" "github.com/coder/coder/agent" @@ -115,7 +114,7 @@ func TestConfigSSH(t *testing.T) { _ = agentCloser.Close() }() resources := coderdtest.AwaitWorkspaceAgents(t, client, workspace.ID) - agentConn, err := client.DialWorkspaceAgentTailnet(context.Background(), slog.Logger{}, resources[0].Agents[0].ID) + agentConn, err := client.DialWorkspaceAgent(context.Background(), resources[0].Agents[0].ID, nil) require.NoError(t, err) defer agentConn.Close() diff --git a/cli/deployment/flags.go b/cli/deployment/flags.go index df18e950270f2..714365cc8e897 100644 --- a/cli/deployment/flags.go +++ b/cli/deployment/flags.go @@ -85,6 +85,13 @@ func Flags() *codersdk.DeploymentFlags { Description: "Addresses for STUN servers to establish P2P connections. Set empty to disable P2P connections.", Default: []string{"stun.l.google.com:19302"}, }, + DerpServerRelayAddress: &codersdk.StringFlag{ + Name: "DERP Server Relay Address", + Flag: "derp-server-relay-address", + EnvVar: "CODER_DERP_SERVER_RELAY_ADDRESS", + Description: "An HTTP address that is accessible by other replicas to relay DERP traffic. Required for high availability.", + Enterprise: true, + }, DerpConfigURL: &codersdk.StringFlag{ Name: "DERP Config URL", Flag: "derp-config-url", diff --git a/cli/portforward.go b/cli/portforward.go index 476809d601558..5a6f4391dd897 100644 --- a/cli/portforward.go +++ b/cli/portforward.go @@ -16,7 +16,6 @@ import ( "github.com/spf13/cobra" "golang.org/x/xerrors" - "cdr.dev/slog" "github.com/coder/coder/agent" "github.com/coder/coder/cli/cliflag" "github.com/coder/coder/cli/cliui" @@ -96,7 +95,7 @@ func portForward() *cobra.Command { return xerrors.Errorf("await agent: %w", err) } - conn, err := client.DialWorkspaceAgentTailnet(ctx, slog.Logger{}, workspaceAgent.ID) + conn, err := client.DialWorkspaceAgent(ctx, workspaceAgent.ID, nil) if err != nil { return err } @@ -156,7 +155,7 @@ func portForward() *cobra.Command { case <-ticker.C: } - _, err = conn.Ping() + _, err = conn.Ping(ctx) if err != nil { continue } diff --git a/cli/root.go b/cli/root.go index e7104e64284eb..91d4551916cc0 100644 --- a/cli/root.go +++ b/cli/root.go @@ -4,6 +4,7 @@ import ( "context" "flag" "fmt" + "io" "net/http" "net/url" "os" @@ -100,8 +101,9 @@ func Core() []*cobra.Command { } func AGPL() []*cobra.Command { - all := append(Core(), Server(deployment.Flags(), func(_ context.Context, o *coderd.Options) (*coderd.API, error) { - return coderd.New(o), nil + all := append(Core(), Server(deployment.Flags(), func(_ context.Context, o *coderd.Options) (*coderd.API, io.Closer, error) { + api := coderd.New(o) + return api, api, nil })) return all } diff --git a/cli/server.go b/cli/server.go index 9d828abbea606..c2dbeac07e8ab 100644 --- a/cli/server.go +++ b/cli/server.go @@ -69,7 +69,7 @@ import ( ) // nolint:gocyclo -func Server(dflags *codersdk.DeploymentFlags, newAPI func(context.Context, *coderd.Options) (*coderd.API, error)) *cobra.Command { +func Server(dflags *codersdk.DeploymentFlags, newAPI func(context.Context, *coderd.Options) (*coderd.API, io.Closer, error)) *cobra.Command { root := &cobra.Command{ Use: "server", Short: "Start a Coder server", @@ -167,9 +167,10 @@ func Server(dflags *codersdk.DeploymentFlags, newAPI func(context.Context, *code } defer listener.Close() + var tlsConfig *tls.Config if dflags.TLSEnable.Value { - listener, err = configureServerTLS( - listener, dflags.TLSMinVersion.Value, + tlsConfig, err = configureTLS( + dflags.TLSMinVersion.Value, dflags.TLSClientAuth.Value, dflags.TLSCertFiles.Value, dflags.TLSKeyFiles.Value, @@ -178,6 +179,7 @@ func Server(dflags *codersdk.DeploymentFlags, newAPI func(context.Context, *code if err != nil { return xerrors.Errorf("configure tls: %w", err) } + listener = tls.NewListener(listener, tlsConfig) } tcpAddr, valid := listener.Addr().(*net.TCPAddr) @@ -328,6 +330,9 @@ func Server(dflags *codersdk.DeploymentFlags, newAPI func(context.Context, *code Experimental: ExperimentalEnabled(cmd), DeploymentFlags: dflags, } + if tlsConfig != nil { + options.TLSCertificates = tlsConfig.Certificates + } if dflags.OAuth2GithubClientSecret.Value != "" { options.GithubOAuth2Config, err = configureGithubOAuth2(accessURLParsed, @@ -471,11 +476,14 @@ func Server(dflags *codersdk.DeploymentFlags, newAPI func(context.Context, *code ), dflags.PromAddress.Value, "prometheus")() } - coderAPI, err := newAPI(ctx, options) + // We use a separate closer so the Enterprise API + // can have it's own close functions. This is cleaner + // than abstracting the Coder API itself. + coderAPI, closer, err := newAPI(ctx, options) if err != nil { return err } - defer coderAPI.Close() + defer closer.Close() client := codersdk.New(localURL) if dflags.TLSEnable.Value { @@ -893,7 +901,7 @@ func loadCertificates(tlsCertFiles, tlsKeyFiles []string) ([]tls.Certificate, er return certs, nil } -func configureServerTLS(listener net.Listener, tlsMinVersion, tlsClientAuth string, tlsCertFiles, tlsKeyFiles []string, tlsClientCAFile string) (net.Listener, error) { +func configureTLS(tlsMinVersion, tlsClientAuth string, tlsCertFiles, tlsKeyFiles []string, tlsClientCAFile string) (*tls.Config, error) { tlsConfig := &tls.Config{ MinVersion: tls.VersionTLS12, } @@ -929,6 +937,7 @@ func configureServerTLS(listener net.Listener, tlsMinVersion, tlsClientAuth stri if err != nil { return nil, xerrors.Errorf("load certificates: %w", err) } + tlsConfig.Certificates = certs tlsConfig.GetCertificate = func(hi *tls.ClientHelloInfo) (*tls.Certificate, error) { // If there's only one certificate, return it. if len(certs) == 1 { @@ -963,7 +972,7 @@ func configureServerTLS(listener net.Listener, tlsMinVersion, tlsClientAuth stri tlsConfig.ClientCAs = caPool } - return tls.NewListener(listener, tlsConfig), nil + return tlsConfig, nil } func configureGithubOAuth2(accessURL *url.URL, clientID, clientSecret string, allowSignups bool, allowOrgs []string, rawTeams []string, enterpriseBaseURL string) (*coderd.GithubOAuth2Config, error) { diff --git a/cli/speedtest.go b/cli/speedtest.go index 357048f63ea34..f6c06641ec26f 100644 --- a/cli/speedtest.go +++ b/cli/speedtest.go @@ -55,7 +55,9 @@ func speedtest() *cobra.Command { if cliflag.IsSetBool(cmd, varVerbose) { logger = logger.Leveled(slog.LevelDebug) } - conn, err := client.DialWorkspaceAgentTailnet(ctx, logger, workspaceAgent.ID) + conn, err := client.DialWorkspaceAgent(ctx, workspaceAgent.ID, &codersdk.DialWorkspaceAgentOptions{ + Logger: logger, + }) if err != nil { return err } @@ -68,7 +70,7 @@ func speedtest() *cobra.Command { return ctx.Err() case <-ticker.C: } - dur, err := conn.Ping() + dur, err := conn.Ping(ctx) if err != nil { continue } diff --git a/cli/ssh.go b/cli/ssh.go index ef8538764e3ac..b4d4f6420da78 100644 --- a/cli/ssh.go +++ b/cli/ssh.go @@ -20,8 +20,6 @@ import ( "golang.org/x/term" "golang.org/x/xerrors" - "cdr.dev/slog" - "github.com/coder/coder/cli/cliflag" "github.com/coder/coder/cli/cliui" "github.com/coder/coder/coderd/autobuild/notify" @@ -86,7 +84,7 @@ func ssh() *cobra.Command { return xerrors.Errorf("await agent: %w", err) } - conn, err := client.DialWorkspaceAgentTailnet(ctx, slog.Logger{}, workspaceAgent.ID) + conn, err := client.DialWorkspaceAgent(ctx, workspaceAgent.ID, nil) if err != nil { return err } diff --git a/coderd/activitybump_test.go b/coderd/activitybump_test.go index cd43b774d5dea..e498b98fa0c80 100644 --- a/coderd/activitybump_test.go +++ b/coderd/activitybump_test.go @@ -72,7 +72,7 @@ func TestWorkspaceActivityBump(t *testing.T) { "deadline %v never updated", firstDeadline, ) - require.WithinDuration(t, database.Now().Add(time.Hour), workspace.LatestBuild.Deadline.Time, time.Second) + require.WithinDuration(t, database.Now().Add(time.Hour), workspace.LatestBuild.Deadline.Time, 3*time.Second) } } @@ -82,7 +82,9 @@ func TestWorkspaceActivityBump(t *testing.T) { client, workspace, assertBumped := setupActivityTest(t) resources := coderdtest.AwaitWorkspaceAgents(t, client, workspace.ID) - conn, err := client.DialWorkspaceAgentTailnet(ctx, slogtest.Make(t, nil), resources[0].Agents[0].ID) + conn, err := client.DialWorkspaceAgent(ctx, resources[0].Agents[0].ID, &codersdk.DialWorkspaceAgentOptions{ + Logger: slogtest.Make(t, nil), + }) require.NoError(t, err) defer conn.Close() diff --git a/coderd/coderd.go b/coderd/coderd.go index 992ae6c7f5ca5..cf8a20d3734cd 100644 --- a/coderd/coderd.go +++ b/coderd/coderd.go @@ -1,6 +1,7 @@ package coderd import ( + "crypto/tls" "crypto/x509" "io" "net/http" @@ -82,7 +83,10 @@ type Options struct { TracerProvider trace.TracerProvider AutoImportTemplates []AutoImportTemplate - TailnetCoordinator *tailnet.Coordinator + // TLSCertificates is used to mesh DERP servers securely. + TLSCertificates []tls.Certificate + TailnetCoordinator tailnet.Coordinator + DERPServer *derp.Server DERPMap *tailcfg.DERPMap MetricsCacheRefreshInterval time.Duration @@ -130,6 +134,9 @@ func New(options *Options) *API { if options.TailnetCoordinator == nil { options.TailnetCoordinator = tailnet.NewCoordinator() } + if options.DERPServer == nil { + options.DERPServer = derp.NewServer(key.NewNode(), tailnet.Logger(options.Logger.Named("derp"))) + } if options.Auditor == nil { options.Auditor = audit.NewNop() } @@ -168,7 +175,7 @@ func New(options *Options) *API { api.Auditor.Store(&options.Auditor) api.WorkspaceQuotaEnforcer.Store(&options.WorkspaceQuotaEnforcer) api.workspaceAgentCache = wsconncache.New(api.dialWorkspaceAgentTailnet, 0) - api.derpServer = derp.NewServer(key.NewNode(), tailnet.Logger(options.Logger)) + api.TailnetCoordinator.Store(&options.TailnetCoordinator) oauthConfigs := &httpmw.OAuth2Configs{ Github: options.GithubOAuth2Config, OIDC: options.OIDCConfig, @@ -246,7 +253,7 @@ func New(options *Options) *API { r.Route("/%40{user}/{workspace_and_agent}/apps/{workspaceapp}", apps) r.Route("/@{user}/{workspace_and_agent}/apps/{workspaceapp}", apps) r.Route("/derp", func(r chi.Router) { - r.Get("/", derphttp.Handler(api.derpServer).ServeHTTP) + r.Get("/", derphttp.Handler(api.DERPServer).ServeHTTP) // This is used when UDP is blocked, and latency must be checked via HTTP(s). r.Get("/latency-check", func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) @@ -550,6 +557,7 @@ type API struct { Auditor atomic.Pointer[audit.Auditor] WorkspaceClientCoordinateOverride atomic.Pointer[func(rw http.ResponseWriter) bool] WorkspaceQuotaEnforcer atomic.Pointer[workspacequota.Enforcer] + TailnetCoordinator atomic.Pointer[tailnet.Coordinator] HTTPAuth *HTTPAuthorizer // APIHandler serves "/api/v2" @@ -557,7 +565,6 @@ type API struct { // RootHandler serves "/" RootHandler chi.Router - derpServer *derp.Server metricsCache *metricscache.Cache siteHandler http.Handler websocketWaitMutex sync.Mutex @@ -572,7 +579,10 @@ func (api *API) Close() error { api.websocketWaitMutex.Unlock() api.metricsCache.Close() - + coordinator := api.TailnetCoordinator.Load() + if coordinator != nil { + _ = (*coordinator).Close() + } return api.workspaceAgentCache.Close() } diff --git a/coderd/coderdtest/coderdtest.go b/coderd/coderdtest/coderdtest.go index f8695deb04df5..5cf307d842e90 100644 --- a/coderd/coderdtest/coderdtest.go +++ b/coderd/coderdtest/coderdtest.go @@ -7,6 +7,7 @@ import ( "crypto/rand" "crypto/rsa" "crypto/sha256" + "crypto/tls" "crypto/x509" "crypto/x509/pkix" "encoding/base64" @@ -23,6 +24,7 @@ import ( "regexp" "strconv" "strings" + "sync" "testing" "time" @@ -37,8 +39,10 @@ import ( "golang.org/x/xerrors" "google.golang.org/api/idtoken" "google.golang.org/api/option" + "tailscale.com/derp" "tailscale.com/net/stun/stuntest" "tailscale.com/tailcfg" + "tailscale.com/types/key" "tailscale.com/types/nettype" "cdr.dev/slog" @@ -60,6 +64,7 @@ import ( "github.com/coder/coder/provisionerd" "github.com/coder/coder/provisionersdk" "github.com/coder/coder/provisionersdk/proto" + "github.com/coder/coder/tailnet" "github.com/coder/coder/testutil" ) @@ -77,12 +82,19 @@ type Options struct { AutobuildTicker <-chan time.Time AutobuildStats chan<- executor.Stats Auditor audit.Auditor + TLSCertificates []tls.Certificate // IncludeProvisionerDaemon when true means to start an in-memory provisionerD IncludeProvisionerDaemon bool MetricsCacheRefreshInterval time.Duration AgentStatsRefreshInterval time.Duration DeploymentFlags *codersdk.DeploymentFlags + + // Overriding the database is heavily discouraged. + // It should only be used in cases where multiple Coder + // test instances are running against the same database. + Database database.Store + Pubsub database.Pubsub } // New constructs a codersdk client connected to an in-memory API instance. @@ -116,7 +128,7 @@ func newWithCloser(t *testing.T, options *Options) (*codersdk.Client, io.Closer) return client, closer } -func NewOptions(t *testing.T, options *Options) (*httptest.Server, context.CancelFunc, *coderd.Options) { +func NewOptions(t *testing.T, options *Options) (func(http.Handler), context.CancelFunc, *coderd.Options) { if options == nil { options = &Options{} } @@ -137,23 +149,40 @@ func NewOptions(t *testing.T, options *Options) (*httptest.Server, context.Cance close(options.AutobuildStats) }) } - - db, pubsub := dbtestutil.NewDB(t) + if options.Database == nil { + options.Database, options.Pubsub = dbtestutil.NewDB(t) + } ctx, cancelFunc := context.WithCancel(context.Background()) lifecycleExecutor := executor.New( ctx, - db, + options.Database, slogtest.Make(t, nil).Named("autobuild.executor").Leveled(slog.LevelDebug), options.AutobuildTicker, ).WithStatsChannel(options.AutobuildStats) lifecycleExecutor.Run() - srv := httptest.NewUnstartedServer(nil) + var mutex sync.RWMutex + var handler http.Handler + srv := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + mutex.RLock() + defer mutex.RUnlock() + if handler != nil { + handler.ServeHTTP(w, r) + } + })) srv.Config.BaseContext = func(_ net.Listener) context.Context { return ctx } - srv.Start() + if options.TLSCertificates != nil { + srv.TLS = &tls.Config{ + Certificates: options.TLSCertificates, + MinVersion: tls.VersionTLS12, + } + srv.StartTLS() + } else { + srv.Start() + } t.Cleanup(srv.Close) tcpAddr, ok := srv.Listener.Addr().(*net.TCPAddr) @@ -169,6 +198,9 @@ func NewOptions(t *testing.T, options *Options) (*httptest.Server, context.Cance stunAddr, stunCleanup := stuntest.ServeWithPacketListener(t, nettype.Std{}) t.Cleanup(stunCleanup) + derpServer := derp.NewServer(key.NewNode(), tailnet.Logger(slogtest.Make(t, nil).Named("derp"))) + derpServer.SetMeshKey("test-key") + // match default with cli default if options.SSHKeygenAlgorithm == "" { options.SSHKeygenAlgorithm = gitsshkey.AlgorithmEd25519 @@ -181,53 +213,59 @@ func NewOptions(t *testing.T, options *Options) (*httptest.Server, context.Cance require.NoError(t, err) } - return srv, cancelFunc, &coderd.Options{ - AgentConnectionUpdateFrequency: 150 * time.Millisecond, - // Force a long disconnection timeout to ensure - // agents are not marked as disconnected during slow tests. - AgentInactiveDisconnectTimeout: testutil.WaitShort, - AccessURL: serverURL, - AppHostname: options.AppHostname, - AppHostnameRegex: appHostnameRegex, - Logger: slogtest.Make(t, nil).Leveled(slog.LevelDebug), - CacheDir: t.TempDir(), - Database: db, - Pubsub: pubsub, - - Auditor: options.Auditor, - AWSCertificates: options.AWSCertificates, - AzureCertificates: options.AzureCertificates, - GithubOAuth2Config: options.GithubOAuth2Config, - OIDCConfig: options.OIDCConfig, - GoogleTokenValidator: options.GoogleTokenValidator, - SSHKeygenAlgorithm: options.SSHKeygenAlgorithm, - APIRateLimit: options.APIRateLimit, - Authorizer: options.Authorizer, - Telemetry: telemetry.NewNoop(), - DERPMap: &tailcfg.DERPMap{ - Regions: map[int]*tailcfg.DERPRegion{ - 1: { - EmbeddedRelay: true, - RegionID: 1, - RegionCode: "coder", - RegionName: "Coder", - Nodes: []*tailcfg.DERPNode{{ - Name: "1a", - RegionID: 1, - IPv4: "127.0.0.1", - DERPPort: derpPort, - STUNPort: stunAddr.Port, - InsecureForTests: true, - ForceHTTP: true, - }}, + return func(h http.Handler) { + mutex.Lock() + defer mutex.Unlock() + handler = h + }, cancelFunc, &coderd.Options{ + AgentConnectionUpdateFrequency: 150 * time.Millisecond, + // Force a long disconnection timeout to ensure + // agents are not marked as disconnected during slow tests. + AgentInactiveDisconnectTimeout: testutil.WaitShort, + AccessURL: serverURL, + AppHostname: options.AppHostname, + AppHostnameRegex: appHostnameRegex, + Logger: slogtest.Make(t, nil).Leveled(slog.LevelDebug), + CacheDir: t.TempDir(), + Database: options.Database, + Pubsub: options.Pubsub, + + Auditor: options.Auditor, + AWSCertificates: options.AWSCertificates, + AzureCertificates: options.AzureCertificates, + GithubOAuth2Config: options.GithubOAuth2Config, + OIDCConfig: options.OIDCConfig, + GoogleTokenValidator: options.GoogleTokenValidator, + SSHKeygenAlgorithm: options.SSHKeygenAlgorithm, + DERPServer: derpServer, + APIRateLimit: options.APIRateLimit, + Authorizer: options.Authorizer, + Telemetry: telemetry.NewNoop(), + TLSCertificates: options.TLSCertificates, + DERPMap: &tailcfg.DERPMap{ + Regions: map[int]*tailcfg.DERPRegion{ + 1: { + EmbeddedRelay: true, + RegionID: 1, + RegionCode: "coder", + RegionName: "Coder", + Nodes: []*tailcfg.DERPNode{{ + Name: "1a", + RegionID: 1, + IPv4: "127.0.0.1", + DERPPort: derpPort, + STUNPort: stunAddr.Port, + InsecureForTests: true, + ForceHTTP: options.TLSCertificates == nil, + }}, + }, }, }, - }, - AutoImportTemplates: options.AutoImportTemplates, - MetricsCacheRefreshInterval: options.MetricsCacheRefreshInterval, - AgentStatsRefreshInterval: options.AgentStatsRefreshInterval, - DeploymentFlags: options.DeploymentFlags, - } + AutoImportTemplates: options.AutoImportTemplates, + MetricsCacheRefreshInterval: options.MetricsCacheRefreshInterval, + AgentStatsRefreshInterval: options.AgentStatsRefreshInterval, + DeploymentFlags: options.DeploymentFlags, + } } // NewWithAPI constructs an in-memory API instance and returns a client to talk to it. @@ -237,10 +275,10 @@ func NewWithAPI(t *testing.T, options *Options) (*codersdk.Client, io.Closer, *c if options == nil { options = &Options{} } - srv, cancelFunc, newOptions := NewOptions(t, options) + setHandler, cancelFunc, newOptions := NewOptions(t, options) // We set the handler after server creation for the access URL. coderAPI := coderd.New(newOptions) - srv.Config.Handler = coderAPI.RootHandler + setHandler(coderAPI.RootHandler) var provisionerCloser io.Closer = nopcloser{} if options.IncludeProvisionerDaemon { provisionerCloser = NewProvisionerDaemon(t, coderAPI) @@ -459,7 +497,7 @@ func AwaitTemplateVersionJob(t *testing.T, client *codersdk.Client, version uuid var err error templateVersion, err = client.TemplateVersion(context.Background(), version) return assert.NoError(t, err) && templateVersion.Job.CompletedAt != nil - }, testutil.WaitShort, testutil.IntervalFast) + }, testutil.WaitMedium, testutil.IntervalFast) return templateVersion } diff --git a/coderd/database/databasefake/databasefake.go b/coderd/database/databasefake/databasefake.go index 63239bdf4dfd3..65043d2412302 100644 --- a/coderd/database/databasefake/databasefake.go +++ b/coderd/database/databasefake/databasefake.go @@ -107,11 +107,17 @@ type data struct { workspaceApps []database.WorkspaceApp workspaces []database.Workspace licenses []database.License + replicas []database.Replica deploymentID string + derpMeshKey string lastLicenseID int32 } +func (*fakeQuerier) Ping(_ context.Context) (time.Duration, error) { + return 0, nil +} + // InTx doesn't rollback data properly for in-memory yet. func (q *fakeQuerier) InTx(fn func(database.Store) error) error { q.mutex.Lock() @@ -2931,6 +2937,21 @@ func (q *fakeQuerier) GetDeploymentID(_ context.Context) (string, error) { return q.deploymentID, nil } +func (q *fakeQuerier) InsertDERPMeshKey(_ context.Context, id string) error { + q.mutex.Lock() + defer q.mutex.Unlock() + + q.derpMeshKey = id + return nil +} + +func (q *fakeQuerier) GetDERPMeshKey(_ context.Context) (string, error) { + q.mutex.RLock() + defer q.mutex.RUnlock() + + return q.derpMeshKey, nil +} + func (q *fakeQuerier) InsertLicense( _ context.Context, arg database.InsertLicenseParams, ) (database.License, error) { @@ -3196,3 +3217,70 @@ func (q *fakeQuerier) DeleteGroupByID(_ context.Context, id uuid.UUID) error { return sql.ErrNoRows } + +func (q *fakeQuerier) DeleteReplicasUpdatedBefore(_ context.Context, before time.Time) error { + q.mutex.Lock() + defer q.mutex.Unlock() + + for i, replica := range q.replicas { + if replica.UpdatedAt.Before(before) { + q.replicas = append(q.replicas[:i], q.replicas[i+1:]...) + } + } + + return nil +} + +func (q *fakeQuerier) InsertReplica(_ context.Context, arg database.InsertReplicaParams) (database.Replica, error) { + q.mutex.Lock() + defer q.mutex.Unlock() + + replica := database.Replica{ + ID: arg.ID, + CreatedAt: arg.CreatedAt, + StartedAt: arg.StartedAt, + UpdatedAt: arg.UpdatedAt, + Hostname: arg.Hostname, + RegionID: arg.RegionID, + RelayAddress: arg.RelayAddress, + Version: arg.Version, + DatabaseLatency: arg.DatabaseLatency, + } + q.replicas = append(q.replicas, replica) + return replica, nil +} + +func (q *fakeQuerier) UpdateReplica(_ context.Context, arg database.UpdateReplicaParams) (database.Replica, error) { + q.mutex.Lock() + defer q.mutex.Unlock() + + for index, replica := range q.replicas { + if replica.ID != arg.ID { + continue + } + replica.Hostname = arg.Hostname + replica.StartedAt = arg.StartedAt + replica.StoppedAt = arg.StoppedAt + replica.UpdatedAt = arg.UpdatedAt + replica.RelayAddress = arg.RelayAddress + replica.RegionID = arg.RegionID + replica.Version = arg.Version + replica.Error = arg.Error + replica.DatabaseLatency = arg.DatabaseLatency + q.replicas[index] = replica + return replica, nil + } + return database.Replica{}, sql.ErrNoRows +} + +func (q *fakeQuerier) GetReplicasUpdatedAfter(_ context.Context, updatedAt time.Time) ([]database.Replica, error) { + q.mutex.RLock() + defer q.mutex.RUnlock() + replicas := make([]database.Replica, 0) + for _, replica := range q.replicas { + if replica.UpdatedAt.After(updatedAt) && !replica.StoppedAt.Valid { + replicas = append(replicas, replica) + } + } + return replicas, nil +} diff --git a/coderd/database/db.go b/coderd/database/db.go index 4cbbdb399f193..020000888f8eb 100644 --- a/coderd/database/db.go +++ b/coderd/database/db.go @@ -12,6 +12,7 @@ import ( "context" "database/sql" "errors" + "time" "github.com/jmoiron/sqlx" "golang.org/x/xerrors" @@ -24,6 +25,7 @@ type Store interface { // customQuerier contains custom queries that are not generated. customQuerier + Ping(ctx context.Context) (time.Duration, error) InTx(func(Store) error) error } @@ -58,6 +60,13 @@ type sqlQuerier struct { db DBTX } +// Ping returns the time it takes to ping the database. +func (q *sqlQuerier) Ping(ctx context.Context) (time.Duration, error) { + start := time.Now() + err := q.sdb.PingContext(ctx) + return time.Since(start), err +} + // InTx performs database operations inside a transaction. func (q *sqlQuerier) InTx(function func(Store) error) error { if _, ok := q.db.(*sqlx.Tx); ok { diff --git a/coderd/database/dump.sql b/coderd/database/dump.sql index de2d352a6a073..b946a1130e0c8 100644 --- a/coderd/database/dump.sql +++ b/coderd/database/dump.sql @@ -256,7 +256,8 @@ CREATE TABLE provisioner_daemons ( created_at timestamp with time zone NOT NULL, updated_at timestamp with time zone, name character varying(64) NOT NULL, - provisioners provisioner_type[] NOT NULL + provisioners provisioner_type[] NOT NULL, + replica_id uuid ); CREATE TABLE provisioner_job_logs ( @@ -287,6 +288,20 @@ CREATE TABLE provisioner_jobs ( file_id uuid NOT NULL ); +CREATE TABLE replicas ( + id uuid NOT NULL, + created_at timestamp with time zone NOT NULL, + started_at timestamp with time zone NOT NULL, + stopped_at timestamp with time zone, + updated_at timestamp with time zone NOT NULL, + hostname text NOT NULL, + region_id integer NOT NULL, + relay_address text NOT NULL, + database_latency integer NOT NULL, + version text NOT NULL, + error text DEFAULT ''::text NOT NULL +); + CREATE TABLE site_configs ( key character varying(256) NOT NULL, value character varying(8192) NOT NULL diff --git a/coderd/database/migrations/000061_replicas.down.sql b/coderd/database/migrations/000061_replicas.down.sql new file mode 100644 index 0000000000000..4cca6615d4213 --- /dev/null +++ b/coderd/database/migrations/000061_replicas.down.sql @@ -0,0 +1,2 @@ +DROP TABLE replicas; +ALTER TABLE provisioner_daemons DROP COLUMN replica_id; diff --git a/coderd/database/migrations/000061_replicas.up.sql b/coderd/database/migrations/000061_replicas.up.sql new file mode 100644 index 0000000000000..1400662e30582 --- /dev/null +++ b/coderd/database/migrations/000061_replicas.up.sql @@ -0,0 +1,28 @@ +CREATE TABLE IF NOT EXISTS replicas ( + -- A unique identifier for the replica that is stored on disk. + -- For persistent replicas, this will be reused. + -- For ephemeral replicas, this will be a new UUID for each one. + id uuid NOT NULL, + created_at timestamp with time zone NOT NULL, + -- The time the replica was created. + started_at timestamp with time zone NOT NULL, + -- The time the replica was last seen. + stopped_at timestamp with time zone, + -- Updated periodically to ensure the replica is still alive. + updated_at timestamp with time zone NOT NULL, + -- Hostname is the hostname of the replica. + hostname text NOT NULL, + -- Region is the region the replica is in. + -- We only DERP mesh to the same region ID of a running replica. + region_id integer NOT NULL, + -- An address that should be accessible to other replicas. + relay_address text NOT NULL, + -- The latency of the replica to the database in microseconds. + database_latency int NOT NULL, + -- Version is the Coder version of the replica. + version text NOT NULL, + error text NOT NULL DEFAULT '' +); + +-- Associates a provisioner daemon with a replica. +ALTER TABLE provisioner_daemons ADD COLUMN replica_id uuid; diff --git a/coderd/database/models.go b/coderd/database/models.go index e30615244e299..53e074984ac11 100644 --- a/coderd/database/models.go +++ b/coderd/database/models.go @@ -508,6 +508,7 @@ type ProvisionerDaemon struct { UpdatedAt sql.NullTime `db:"updated_at" json:"updated_at"` Name string `db:"name" json:"name"` Provisioners []ProvisionerType `db:"provisioners" json:"provisioners"` + ReplicaID uuid.NullUUID `db:"replica_id" json:"replica_id"` } type ProvisionerJob struct { @@ -538,6 +539,20 @@ type ProvisionerJobLog struct { Output string `db:"output" json:"output"` } +type Replica struct { + ID uuid.UUID `db:"id" json:"id"` + CreatedAt time.Time `db:"created_at" json:"created_at"` + StartedAt time.Time `db:"started_at" json:"started_at"` + StoppedAt sql.NullTime `db:"stopped_at" json:"stopped_at"` + UpdatedAt time.Time `db:"updated_at" json:"updated_at"` + Hostname string `db:"hostname" json:"hostname"` + RegionID int32 `db:"region_id" json:"region_id"` + RelayAddress string `db:"relay_address" json:"relay_address"` + DatabaseLatency int32 `db:"database_latency" json:"database_latency"` + Version string `db:"version" json:"version"` + Error string `db:"error" json:"error"` +} + type SiteConfig struct { Key string `db:"key" json:"key"` Value string `db:"value" json:"value"` diff --git a/coderd/database/pubsub_memory.go b/coderd/database/pubsub_memory.go index 148d2f57b129f..de5a940414d6c 100644 --- a/coderd/database/pubsub_memory.go +++ b/coderd/database/pubsub_memory.go @@ -47,8 +47,9 @@ func (m *memoryPubsub) Publish(event string, message []byte) error { return nil } for _, listener := range listeners { - listener(context.Background(), message) + go listener(context.Background(), message) } + return nil } diff --git a/coderd/database/querier.go b/coderd/database/querier.go index ad26413873e04..393ab81fdd347 100644 --- a/coderd/database/querier.go +++ b/coderd/database/querier.go @@ -26,6 +26,7 @@ type sqlcQuerier interface { DeleteLicense(ctx context.Context, id int32) (int32, error) DeleteOldAgentStats(ctx context.Context) error DeleteParameterValueByID(ctx context.Context, id uuid.UUID) error + DeleteReplicasUpdatedBefore(ctx context.Context, updatedAt time.Time) error GetAPIKeyByID(ctx context.Context, id string) (APIKey, error) GetAPIKeysByLoginType(ctx context.Context, loginType LoginType) ([]APIKey, error) GetAPIKeysLastUsedAfter(ctx context.Context, lastUsed time.Time) ([]APIKey, error) @@ -38,6 +39,7 @@ type sqlcQuerier interface { // This function returns roles for authorization purposes. Implied member roles // are included. GetAuthorizationUserRoles(ctx context.Context, userID uuid.UUID) (GetAuthorizationUserRolesRow, error) + GetDERPMeshKey(ctx context.Context) (string, error) GetDeploymentID(ctx context.Context) (string, error) GetFileByHashAndCreator(ctx context.Context, arg GetFileByHashAndCreatorParams) (File, error) GetFileByID(ctx context.Context, id uuid.UUID) (File, error) @@ -67,6 +69,7 @@ type sqlcQuerier interface { GetProvisionerJobsByIDs(ctx context.Context, ids []uuid.UUID) ([]ProvisionerJob, error) GetProvisionerJobsCreatedAfter(ctx context.Context, createdAt time.Time) ([]ProvisionerJob, error) GetProvisionerLogsByIDBetween(ctx context.Context, arg GetProvisionerLogsByIDBetweenParams) ([]ProvisionerJobLog, error) + GetReplicasUpdatedAfter(ctx context.Context, updatedAt time.Time) ([]Replica, error) GetTemplateAverageBuildTime(ctx context.Context, arg GetTemplateAverageBuildTimeParams) (GetTemplateAverageBuildTimeRow, error) GetTemplateByID(ctx context.Context, id uuid.UUID) (Template, error) GetTemplateByOrganizationAndName(ctx context.Context, arg GetTemplateByOrganizationAndNameParams) (Template, error) @@ -123,6 +126,7 @@ type sqlcQuerier interface { // every member of the org. InsertAllUsersGroup(ctx context.Context, organizationID uuid.UUID) (Group, error) InsertAuditLog(ctx context.Context, arg InsertAuditLogParams) (AuditLog, error) + InsertDERPMeshKey(ctx context.Context, value string) error InsertDeploymentID(ctx context.Context, value string) error InsertFile(ctx context.Context, arg InsertFileParams) (File, error) InsertGitSSHKey(ctx context.Context, arg InsertGitSSHKeyParams) (GitSSHKey, error) @@ -136,6 +140,7 @@ type sqlcQuerier interface { InsertProvisionerDaemon(ctx context.Context, arg InsertProvisionerDaemonParams) (ProvisionerDaemon, error) InsertProvisionerJob(ctx context.Context, arg InsertProvisionerJobParams) (ProvisionerJob, error) InsertProvisionerJobLogs(ctx context.Context, arg InsertProvisionerJobLogsParams) ([]ProvisionerJobLog, error) + InsertReplica(ctx context.Context, arg InsertReplicaParams) (Replica, error) InsertTemplate(ctx context.Context, arg InsertTemplateParams) (Template, error) InsertTemplateVersion(ctx context.Context, arg InsertTemplateVersionParams) (TemplateVersion, error) InsertUser(ctx context.Context, arg InsertUserParams) (User, error) @@ -156,6 +161,7 @@ type sqlcQuerier interface { UpdateProvisionerJobByID(ctx context.Context, arg UpdateProvisionerJobByIDParams) error UpdateProvisionerJobWithCancelByID(ctx context.Context, arg UpdateProvisionerJobWithCancelByIDParams) error UpdateProvisionerJobWithCompleteByID(ctx context.Context, arg UpdateProvisionerJobWithCompleteByIDParams) error + UpdateReplica(ctx context.Context, arg UpdateReplicaParams) (Replica, error) UpdateTemplateActiveVersionByID(ctx context.Context, arg UpdateTemplateActiveVersionByIDParams) error UpdateTemplateDeletedByID(ctx context.Context, arg UpdateTemplateDeletedByIDParams) error UpdateTemplateMetaByID(ctx context.Context, arg UpdateTemplateMetaByIDParams) (Template, error) diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index 41eb029b59a83..3621050bc0096 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -2031,7 +2031,7 @@ func (q *sqlQuerier) ParameterValues(ctx context.Context, arg ParameterValuesPar const getProvisionerDaemonByID = `-- name: GetProvisionerDaemonByID :one SELECT - id, created_at, updated_at, name, provisioners + id, created_at, updated_at, name, provisioners, replica_id FROM provisioner_daemons WHERE @@ -2047,13 +2047,14 @@ func (q *sqlQuerier) GetProvisionerDaemonByID(ctx context.Context, id uuid.UUID) &i.UpdatedAt, &i.Name, pq.Array(&i.Provisioners), + &i.ReplicaID, ) return i, err } const getProvisionerDaemons = `-- name: GetProvisionerDaemons :many SELECT - id, created_at, updated_at, name, provisioners + id, created_at, updated_at, name, provisioners, replica_id FROM provisioner_daemons ` @@ -2073,6 +2074,7 @@ func (q *sqlQuerier) GetProvisionerDaemons(ctx context.Context) ([]ProvisionerDa &i.UpdatedAt, &i.Name, pq.Array(&i.Provisioners), + &i.ReplicaID, ); err != nil { return nil, err } @@ -2096,7 +2098,7 @@ INSERT INTO provisioners ) VALUES - ($1, $2, $3, $4) RETURNING id, created_at, updated_at, name, provisioners + ($1, $2, $3, $4) RETURNING id, created_at, updated_at, name, provisioners, replica_id ` type InsertProvisionerDaemonParams struct { @@ -2120,6 +2122,7 @@ func (q *sqlQuerier) InsertProvisionerDaemon(ctx context.Context, arg InsertProv &i.UpdatedAt, &i.Name, pq.Array(&i.Provisioners), + &i.ReplicaID, ) return i, err } @@ -2577,6 +2580,177 @@ func (q *sqlQuerier) UpdateProvisionerJobWithCompleteByID(ctx context.Context, a return err } +const deleteReplicasUpdatedBefore = `-- name: DeleteReplicasUpdatedBefore :exec +DELETE FROM replicas WHERE updated_at < $1 +` + +func (q *sqlQuerier) DeleteReplicasUpdatedBefore(ctx context.Context, updatedAt time.Time) error { + _, err := q.db.ExecContext(ctx, deleteReplicasUpdatedBefore, updatedAt) + return err +} + +const getReplicasUpdatedAfter = `-- name: GetReplicasUpdatedAfter :many +SELECT id, created_at, started_at, stopped_at, updated_at, hostname, region_id, relay_address, database_latency, version, error FROM replicas WHERE updated_at > $1 AND stopped_at IS NULL +` + +func (q *sqlQuerier) GetReplicasUpdatedAfter(ctx context.Context, updatedAt time.Time) ([]Replica, error) { + rows, err := q.db.QueryContext(ctx, getReplicasUpdatedAfter, updatedAt) + if err != nil { + return nil, err + } + defer rows.Close() + var items []Replica + for rows.Next() { + var i Replica + if err := rows.Scan( + &i.ID, + &i.CreatedAt, + &i.StartedAt, + &i.StoppedAt, + &i.UpdatedAt, + &i.Hostname, + &i.RegionID, + &i.RelayAddress, + &i.DatabaseLatency, + &i.Version, + &i.Error, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const insertReplica = `-- name: InsertReplica :one +INSERT INTO replicas ( + id, + created_at, + started_at, + updated_at, + hostname, + region_id, + relay_address, + version, + database_latency +) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9) RETURNING id, created_at, started_at, stopped_at, updated_at, hostname, region_id, relay_address, database_latency, version, error +` + +type InsertReplicaParams struct { + ID uuid.UUID `db:"id" json:"id"` + CreatedAt time.Time `db:"created_at" json:"created_at"` + StartedAt time.Time `db:"started_at" json:"started_at"` + UpdatedAt time.Time `db:"updated_at" json:"updated_at"` + Hostname string `db:"hostname" json:"hostname"` + RegionID int32 `db:"region_id" json:"region_id"` + RelayAddress string `db:"relay_address" json:"relay_address"` + Version string `db:"version" json:"version"` + DatabaseLatency int32 `db:"database_latency" json:"database_latency"` +} + +func (q *sqlQuerier) InsertReplica(ctx context.Context, arg InsertReplicaParams) (Replica, error) { + row := q.db.QueryRowContext(ctx, insertReplica, + arg.ID, + arg.CreatedAt, + arg.StartedAt, + arg.UpdatedAt, + arg.Hostname, + arg.RegionID, + arg.RelayAddress, + arg.Version, + arg.DatabaseLatency, + ) + var i Replica + err := row.Scan( + &i.ID, + &i.CreatedAt, + &i.StartedAt, + &i.StoppedAt, + &i.UpdatedAt, + &i.Hostname, + &i.RegionID, + &i.RelayAddress, + &i.DatabaseLatency, + &i.Version, + &i.Error, + ) + return i, err +} + +const updateReplica = `-- name: UpdateReplica :one +UPDATE replicas SET + updated_at = $2, + started_at = $3, + stopped_at = $4, + relay_address = $5, + region_id = $6, + hostname = $7, + version = $8, + error = $9, + database_latency = $10 +WHERE id = $1 RETURNING id, created_at, started_at, stopped_at, updated_at, hostname, region_id, relay_address, database_latency, version, error +` + +type UpdateReplicaParams struct { + ID uuid.UUID `db:"id" json:"id"` + UpdatedAt time.Time `db:"updated_at" json:"updated_at"` + StartedAt time.Time `db:"started_at" json:"started_at"` + StoppedAt sql.NullTime `db:"stopped_at" json:"stopped_at"` + RelayAddress string `db:"relay_address" json:"relay_address"` + RegionID int32 `db:"region_id" json:"region_id"` + Hostname string `db:"hostname" json:"hostname"` + Version string `db:"version" json:"version"` + Error string `db:"error" json:"error"` + DatabaseLatency int32 `db:"database_latency" json:"database_latency"` +} + +func (q *sqlQuerier) UpdateReplica(ctx context.Context, arg UpdateReplicaParams) (Replica, error) { + row := q.db.QueryRowContext(ctx, updateReplica, + arg.ID, + arg.UpdatedAt, + arg.StartedAt, + arg.StoppedAt, + arg.RelayAddress, + arg.RegionID, + arg.Hostname, + arg.Version, + arg.Error, + arg.DatabaseLatency, + ) + var i Replica + err := row.Scan( + &i.ID, + &i.CreatedAt, + &i.StartedAt, + &i.StoppedAt, + &i.UpdatedAt, + &i.Hostname, + &i.RegionID, + &i.RelayAddress, + &i.DatabaseLatency, + &i.Version, + &i.Error, + ) + return i, err +} + +const getDERPMeshKey = `-- name: GetDERPMeshKey :one +SELECT value FROM site_configs WHERE key = 'derp_mesh_key' +` + +func (q *sqlQuerier) GetDERPMeshKey(ctx context.Context) (string, error) { + row := q.db.QueryRowContext(ctx, getDERPMeshKey) + var value string + err := row.Scan(&value) + return value, err +} + const getDeploymentID = `-- name: GetDeploymentID :one SELECT value FROM site_configs WHERE key = 'deployment_id' ` @@ -2588,6 +2762,15 @@ func (q *sqlQuerier) GetDeploymentID(ctx context.Context) (string, error) { return value, err } +const insertDERPMeshKey = `-- name: InsertDERPMeshKey :exec +INSERT INTO site_configs (key, value) VALUES ('derp_mesh_key', $1) +` + +func (q *sqlQuerier) InsertDERPMeshKey(ctx context.Context, value string) error { + _, err := q.db.ExecContext(ctx, insertDERPMeshKey, value) + return err +} + const insertDeploymentID = `-- name: InsertDeploymentID :exec INSERT INTO site_configs (key, value) VALUES ('deployment_id', $1) ` diff --git a/coderd/database/queries/replicas.sql b/coderd/database/queries/replicas.sql new file mode 100644 index 0000000000000..e87c1f46432f2 --- /dev/null +++ b/coderd/database/queries/replicas.sql @@ -0,0 +1,31 @@ +-- name: GetReplicasUpdatedAfter :many +SELECT * FROM replicas WHERE updated_at > $1 AND stopped_at IS NULL; + +-- name: InsertReplica :one +INSERT INTO replicas ( + id, + created_at, + started_at, + updated_at, + hostname, + region_id, + relay_address, + version, + database_latency +) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9) RETURNING *; + +-- name: UpdateReplica :one +UPDATE replicas SET + updated_at = $2, + started_at = $3, + stopped_at = $4, + relay_address = $5, + region_id = $6, + hostname = $7, + version = $8, + error = $9, + database_latency = $10 +WHERE id = $1 RETURNING *; + +-- name: DeleteReplicasUpdatedBefore :exec +DELETE FROM replicas WHERE updated_at < $1; diff --git a/coderd/database/queries/siteconfig.sql b/coderd/database/queries/siteconfig.sql index 9d3936e23886d..b975d2f68cc3c 100644 --- a/coderd/database/queries/siteconfig.sql +++ b/coderd/database/queries/siteconfig.sql @@ -3,3 +3,9 @@ INSERT INTO site_configs (key, value) VALUES ('deployment_id', $1); -- name: GetDeploymentID :one SELECT value FROM site_configs WHERE key = 'deployment_id'; + +-- name: InsertDERPMeshKey :exec +INSERT INTO site_configs (key, value) VALUES ('derp_mesh_key', $1); + +-- name: GetDERPMeshKey :one +SELECT value FROM site_configs WHERE key = 'derp_mesh_key'; diff --git a/coderd/provisionerjobs.go b/coderd/provisionerjobs.go index 294b013e00280..04f050f0c5218 100644 --- a/coderd/provisionerjobs.go +++ b/coderd/provisionerjobs.go @@ -270,7 +270,7 @@ func (api *API) provisionerJobResources(rw http.ResponseWriter, r *http.Request, } } - apiAgent, err := convertWorkspaceAgent(api.DERPMap, api.TailnetCoordinator, agent, convertApps(dbApps), api.AgentInactiveDisconnectTimeout) + apiAgent, err := convertWorkspaceAgent(api.DERPMap, *api.TailnetCoordinator.Load(), agent, convertApps(dbApps), api.AgentInactiveDisconnectTimeout) if err != nil { httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ Message: "Internal error reading job agent.", diff --git a/coderd/rbac/object.go b/coderd/rbac/object.go index 5492e4397d5f7..1a8861c984ce9 100644 --- a/coderd/rbac/object.go +++ b/coderd/rbac/object.go @@ -146,6 +146,10 @@ var ( ResourceDeploymentFlags = Object{ Type: "deployment_flags", } + + ResourceReplicas = Object{ + Type: "replicas", + } ) // Object is used to create objects for authz checks when you have none in diff --git a/coderd/templates_test.go b/coderd/templates_test.go index 637ced633cdc1..f6aacba8a5547 100644 --- a/coderd/templates_test.go +++ b/coderd/templates_test.go @@ -627,7 +627,9 @@ func TestTemplateMetrics(t *testing.T) { require.NoError(t, err) assert.Zero(t, workspaces[0].LastUsedAt) - conn, err := client.DialWorkspaceAgentTailnet(ctx, slogtest.Make(t, nil).Named("tailnet"), resources[0].Agents[0].ID) + conn, err := client.DialWorkspaceAgent(ctx, resources[0].Agents[0].ID, &codersdk.DialWorkspaceAgentOptions{ + Logger: slogtest.Make(t, nil).Named("tailnet"), + }) require.NoError(t, err) defer func() { _ = conn.Close() diff --git a/coderd/workspaceagents.go b/coderd/workspaceagents.go index 295beff0d2b7e..fb7f765cc7519 100644 --- a/coderd/workspaceagents.go +++ b/coderd/workspaceagents.go @@ -49,7 +49,7 @@ func (api *API) workspaceAgent(rw http.ResponseWriter, r *http.Request) { }) return } - apiAgent, err := convertWorkspaceAgent(api.DERPMap, api.TailnetCoordinator, workspaceAgent, convertApps(dbApps), api.AgentInactiveDisconnectTimeout) + apiAgent, err := convertWorkspaceAgent(api.DERPMap, *api.TailnetCoordinator.Load(), workspaceAgent, convertApps(dbApps), api.AgentInactiveDisconnectTimeout) if err != nil { httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ Message: "Internal error reading workspace agent.", @@ -78,7 +78,7 @@ func (api *API) workspaceAgentApps(rw http.ResponseWriter, r *http.Request) { func (api *API) workspaceAgentMetadata(rw http.ResponseWriter, r *http.Request) { ctx := r.Context() workspaceAgent := httpmw.WorkspaceAgent(r) - apiAgent, err := convertWorkspaceAgent(api.DERPMap, api.TailnetCoordinator, workspaceAgent, nil, api.AgentInactiveDisconnectTimeout) + apiAgent, err := convertWorkspaceAgent(api.DERPMap, *api.TailnetCoordinator.Load(), workspaceAgent, nil, api.AgentInactiveDisconnectTimeout) if err != nil { httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ Message: "Internal error reading workspace agent.", @@ -98,7 +98,7 @@ func (api *API) workspaceAgentMetadata(rw http.ResponseWriter, r *http.Request) func (api *API) postWorkspaceAgentVersion(rw http.ResponseWriter, r *http.Request) { ctx := r.Context() workspaceAgent := httpmw.WorkspaceAgent(r) - apiAgent, err := convertWorkspaceAgent(api.DERPMap, api.TailnetCoordinator, workspaceAgent, nil, api.AgentInactiveDisconnectTimeout) + apiAgent, err := convertWorkspaceAgent(api.DERPMap, *api.TailnetCoordinator.Load(), workspaceAgent, nil, api.AgentInactiveDisconnectTimeout) if err != nil { httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ Message: "Internal error reading workspace agent.", @@ -152,7 +152,7 @@ func (api *API) workspaceAgentPTY(rw http.ResponseWriter, r *http.Request) { httpapi.ResourceNotFound(rw) return } - apiAgent, err := convertWorkspaceAgent(api.DERPMap, api.TailnetCoordinator, workspaceAgent, nil, api.AgentInactiveDisconnectTimeout) + apiAgent, err := convertWorkspaceAgent(api.DERPMap, *api.TailnetCoordinator.Load(), workspaceAgent, nil, api.AgentInactiveDisconnectTimeout) if err != nil { httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ Message: "Internal error reading workspace agent.", @@ -229,7 +229,7 @@ func (api *API) workspaceAgentListeningPorts(rw http.ResponseWriter, r *http.Req return } - apiAgent, err := convertWorkspaceAgent(api.DERPMap, api.TailnetCoordinator, workspaceAgent, nil, api.AgentInactiveDisconnectTimeout) + apiAgent, err := convertWorkspaceAgent(api.DERPMap, *api.TailnetCoordinator.Load(), workspaceAgent, nil, api.AgentInactiveDisconnectTimeout) if err != nil { httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ Message: "Internal error reading workspace agent.", @@ -376,8 +376,9 @@ func (api *API) dialWorkspaceAgentTailnet(r *http.Request, agentID uuid.UUID) (* }) conn.SetNodeCallback(sendNodes) go func() { - err := api.TailnetCoordinator.ServeClient(serverConn, uuid.New(), agentID) + err := (*api.TailnetCoordinator.Load()).ServeClient(serverConn, uuid.New(), agentID) if err != nil { + api.Logger.Warn(r.Context(), "tailnet coordinator client error", slog.Error(err)) _ = conn.Close() } }() @@ -514,8 +515,9 @@ func (api *API) workspaceAgentCoordinate(rw http.ResponseWriter, r *http.Request closeChan := make(chan struct{}) go func() { defer close(closeChan) - err := api.TailnetCoordinator.ServeAgent(wsNetConn, workspaceAgent.ID) + err := (*api.TailnetCoordinator.Load()).ServeAgent(wsNetConn, workspaceAgent.ID) if err != nil { + api.Logger.Warn(ctx, "tailnet coordinator agent error", slog.Error(err)) _ = conn.Close(websocket.StatusInternalError, err.Error()) return } @@ -583,7 +585,7 @@ func (api *API) workspaceAgentClientCoordinate(rw http.ResponseWriter, r *http.R go httpapi.Heartbeat(ctx, conn) defer conn.Close(websocket.StatusNormalClosure, "") - err = api.TailnetCoordinator.ServeClient(websocket.NetConn(ctx, conn, websocket.MessageBinary), uuid.New(), workspaceAgent.ID) + err = (*api.TailnetCoordinator.Load()).ServeClient(websocket.NetConn(ctx, conn, websocket.MessageBinary), uuid.New(), workspaceAgent.ID) if err != nil { _ = conn.Close(websocket.StatusInternalError, err.Error()) return @@ -611,7 +613,7 @@ func convertApps(dbApps []database.WorkspaceApp) []codersdk.WorkspaceApp { return apps } -func convertWorkspaceAgent(derpMap *tailcfg.DERPMap, coordinator *tailnet.Coordinator, dbAgent database.WorkspaceAgent, apps []codersdk.WorkspaceApp, agentInactiveDisconnectTimeout time.Duration) (codersdk.WorkspaceAgent, error) { +func convertWorkspaceAgent(derpMap *tailcfg.DERPMap, coordinator tailnet.Coordinator, dbAgent database.WorkspaceAgent, apps []codersdk.WorkspaceApp, agentInactiveDisconnectTimeout time.Duration) (codersdk.WorkspaceAgent, error) { var envs map[string]string if dbAgent.EnvironmentVariables.Valid { err := json.Unmarshal(dbAgent.EnvironmentVariables.RawMessage, &envs) diff --git a/coderd/workspaceagents_test.go b/coderd/workspaceagents_test.go index 6bd569dde9f71..e8dd772095736 100644 --- a/coderd/workspaceagents_test.go +++ b/coderd/workspaceagents_test.go @@ -123,13 +123,13 @@ func TestWorkspaceAgentListen(t *testing.T) { defer cancel() resources := coderdtest.AwaitWorkspaceAgents(t, client, workspace.ID) - conn, err := client.DialWorkspaceAgentTailnet(ctx, slog.Logger{}, resources[0].Agents[0].ID) + conn, err := client.DialWorkspaceAgent(ctx, resources[0].Agents[0].ID, nil) require.NoError(t, err) defer func() { _ = conn.Close() }() require.Eventually(t, func() bool { - _, err := conn.Ping() + _, err := conn.Ping(ctx) return err == nil }, testutil.WaitLong, testutil.IntervalFast) }) @@ -253,7 +253,9 @@ func TestWorkspaceAgentTailnet(t *testing.T) { ctx, cancelFunc := context.WithCancel(context.Background()) defer cancelFunc() - conn, err := client.DialWorkspaceAgentTailnet(ctx, slogtest.Make(t, nil).Named("client").Leveled(slog.LevelDebug), resources[0].Agents[0].ID) + conn, err := client.DialWorkspaceAgent(ctx, resources[0].Agents[0].ID, &codersdk.DialWorkspaceAgentOptions{ + Logger: slogtest.Make(t, nil).Named("client").Leveled(slog.LevelDebug), + }) require.NoError(t, err) defer conn.Close() sshClient, err := conn.SSHClient() diff --git a/coderd/workspacebuilds.go b/coderd/workspacebuilds.go index ed136f372b3d5..dc89f576b5484 100644 --- a/coderd/workspacebuilds.go +++ b/coderd/workspacebuilds.go @@ -861,7 +861,7 @@ func (api *API) convertWorkspaceBuild( apiAgents := make([]codersdk.WorkspaceAgent, 0) for _, agent := range agents { apps := appsByAgentID[agent.ID] - apiAgent, err := convertWorkspaceAgent(api.DERPMap, api.TailnetCoordinator, agent, convertApps(apps), api.AgentInactiveDisconnectTimeout) + apiAgent, err := convertWorkspaceAgent(api.DERPMap, *api.TailnetCoordinator.Load(), agent, convertApps(apps), api.AgentInactiveDisconnectTimeout) if err != nil { return codersdk.WorkspaceBuild{}, xerrors.Errorf("converting workspace agent: %w", err) } diff --git a/coderd/wsconncache/wsconncache_test.go b/coderd/wsconncache/wsconncache_test.go index 003d3cddb8b7a..d4345ce9d5f05 100644 --- a/coderd/wsconncache/wsconncache_test.go +++ b/coderd/wsconncache/wsconncache_test.go @@ -128,7 +128,9 @@ func TestCache(t *testing.T) { return } defer release() - proxy.Transport = conn.HTTPTransport() + transport := conn.HTTPTransport() + defer transport.CloseIdleConnections() + proxy.Transport = transport res := httptest.NewRecorder() proxy.ServeHTTP(res, req) resp := res.Result() diff --git a/codersdk/agentconn.go b/codersdk/agentconn.go index b11c440ce3a65..ddfb9541a186a 100644 --- a/codersdk/agentconn.go +++ b/codersdk/agentconn.go @@ -132,10 +132,10 @@ type AgentConn struct { CloseFunc func() } -func (c *AgentConn) Ping() (time.Duration, error) { +func (c *AgentConn) Ping(ctx context.Context) (time.Duration, error) { errCh := make(chan error, 1) durCh := make(chan time.Duration, 1) - c.Conn.Ping(TailnetIP, tailcfg.PingDisco, func(pr *ipnstate.PingResult) { + go c.Conn.Ping(TailnetIP, tailcfg.PingDisco, func(pr *ipnstate.PingResult) { if pr.Err != "" { errCh <- xerrors.New(pr.Err) return @@ -145,6 +145,8 @@ func (c *AgentConn) Ping() (time.Duration, error) { select { case err := <-errCh: return 0, err + case <-ctx.Done(): + return 0, ctx.Err() case dur := <-durCh: return dur, nil } diff --git a/codersdk/features.go b/codersdk/features.go index 291b5575a7e6b..862411de62872 100644 --- a/codersdk/features.go +++ b/codersdk/features.go @@ -15,12 +15,13 @@ const ( ) const ( - FeatureUserLimit = "user_limit" - FeatureAuditLog = "audit_log" - FeatureBrowserOnly = "browser_only" - FeatureSCIM = "scim" - FeatureWorkspaceQuota = "workspace_quota" - FeatureTemplateRBAC = "template_rbac" + FeatureUserLimit = "user_limit" + FeatureAuditLog = "audit_log" + FeatureBrowserOnly = "browser_only" + FeatureSCIM = "scim" + FeatureWorkspaceQuota = "workspace_quota" + FeatureTemplateRBAC = "template_rbac" + FeatureHighAvailability = "high_availability" ) var FeatureNames = []string{ @@ -30,6 +31,7 @@ var FeatureNames = []string{ FeatureSCIM, FeatureWorkspaceQuota, FeatureTemplateRBAC, + FeatureHighAvailability, } type Feature struct { @@ -42,6 +44,7 @@ type Feature struct { type Entitlements struct { Features map[string]Feature `json:"features"` Warnings []string `json:"warnings"` + Errors []string `json:"errors"` HasLicense bool `json:"has_license"` Experimental bool `json:"experimental"` Trial bool `json:"trial"` diff --git a/codersdk/flags.go b/codersdk/flags.go index 92f02941a57f8..09ca65b1ea813 100644 --- a/codersdk/flags.go +++ b/codersdk/flags.go @@ -19,6 +19,7 @@ type DeploymentFlags struct { DerpServerRegionCode *StringFlag `json:"derp_server_region_code" typescript:",notnull"` DerpServerRegionName *StringFlag `json:"derp_server_region_name" typescript:",notnull"` DerpServerSTUNAddresses *StringArrayFlag `json:"derp_server_stun_address" typescript:",notnull"` + DerpServerRelayAddress *StringFlag `json:"derp_server_relay_address" typescript:",notnull"` DerpConfigURL *StringFlag `json:"derp_config_url" typescript:",notnull"` DerpConfigPath *StringFlag `json:"derp_config_path" typescript:",notnull"` PromEnabled *BoolFlag `json:"prom_enabled" typescript:",notnull"` diff --git a/codersdk/replicas.go b/codersdk/replicas.go new file mode 100644 index 0000000000000..e74af021ee9a3 --- /dev/null +++ b/codersdk/replicas.go @@ -0,0 +1,44 @@ +package codersdk + +import ( + "context" + "encoding/json" + "net/http" + "time" + + "github.com/google/uuid" + "golang.org/x/xerrors" +) + +type Replica struct { + // ID is the unique identifier for the replica. + ID uuid.UUID `json:"id"` + // Hostname is the hostname of the replica. + Hostname string `json:"hostname"` + // CreatedAt is when the replica was first seen. + CreatedAt time.Time `json:"created_at"` + // RelayAddress is the accessible address to relay DERP connections. + RelayAddress string `json:"relay_address"` + // RegionID is the region of the replica. + RegionID int32 `json:"region_id"` + // Error is the error. + Error string `json:"error"` + // DatabaseLatency is the latency in microseconds to the database. + DatabaseLatency int32 `json:"database_latency"` +} + +// Replicas fetches the list of replicas. +func (c *Client) Replicas(ctx context.Context) ([]Replica, error) { + res, err := c.Request(ctx, http.MethodGet, "/api/v2/replicas", nil) + if err != nil { + return nil, xerrors.Errorf("execute request: %w", err) + } + defer res.Body.Close() + + if res.StatusCode != http.StatusOK { + return nil, readBodyAsError(res) + } + + var replicas []Replica + return replicas, json.NewDecoder(res.Body).Decode(&replicas) +} diff --git a/codersdk/workspaceagents.go b/codersdk/workspaceagents.go index 253e8713fdb4f..c86944ae2b629 100644 --- a/codersdk/workspaceagents.go +++ b/codersdk/workspaceagents.go @@ -21,7 +21,6 @@ import ( "tailscale.com/tailcfg" "cdr.dev/slog" - "github.com/coder/coder/tailnet" "github.com/coder/retry" ) @@ -316,7 +315,8 @@ func (c *Client) ListenWorkspaceAgentTailnet(ctx context.Context) (net.Conn, err Value: c.SessionToken, }}) httpClient := &http.Client{ - Jar: jar, + Jar: jar, + Transport: c.HTTPClient.Transport, } // nolint:bodyclose conn, res, err := websocket.Dial(ctx, coordinateURL.String(), &websocket.DialOptions{ @@ -332,7 +332,17 @@ func (c *Client) ListenWorkspaceAgentTailnet(ctx context.Context) (net.Conn, err return websocket.NetConn(ctx, conn, websocket.MessageBinary), nil } -func (c *Client) DialWorkspaceAgentTailnet(ctx context.Context, logger slog.Logger, agentID uuid.UUID) (*AgentConn, error) { +// @typescript-ignore DialWorkspaceAgentOptions +type DialWorkspaceAgentOptions struct { + Logger slog.Logger + // BlockEndpoints forced a direct connection through DERP. + BlockEndpoints bool +} + +func (c *Client) DialWorkspaceAgent(ctx context.Context, agentID uuid.UUID, options *DialWorkspaceAgentOptions) (*AgentConn, error) { + if options == nil { + options = &DialWorkspaceAgentOptions{} + } res, err := c.Request(ctx, http.MethodGet, fmt.Sprintf("/api/v2/workspaceagents/%s/connection", agentID), nil) if err != nil { return nil, err @@ -349,9 +359,10 @@ func (c *Client) DialWorkspaceAgentTailnet(ctx context.Context, logger slog.Logg ip := tailnet.IP() conn, err := tailnet.NewConn(&tailnet.Options{ - Addresses: []netip.Prefix{netip.PrefixFrom(ip, 128)}, - DERPMap: connInfo.DERPMap, - Logger: logger, + Addresses: []netip.Prefix{netip.PrefixFrom(ip, 128)}, + DERPMap: connInfo.DERPMap, + Logger: options.Logger, + BlockEndpoints: options.BlockEndpoints, }) if err != nil { return nil, xerrors.Errorf("create tailnet: %w", err) @@ -370,7 +381,8 @@ func (c *Client) DialWorkspaceAgentTailnet(ctx context.Context, logger slog.Logg Value: c.SessionToken, }}) httpClient := &http.Client{ - Jar: jar, + Jar: jar, + Transport: c.HTTPClient.Transport, } ctx, cancelFunc := context.WithCancel(ctx) closed := make(chan struct{}) @@ -379,7 +391,7 @@ func (c *Client) DialWorkspaceAgentTailnet(ctx context.Context, logger slog.Logg defer close(closed) isFirst := true for retrier := retry.New(50*time.Millisecond, 10*time.Second); retrier.Wait(ctx); { - logger.Debug(ctx, "connecting") + options.Logger.Debug(ctx, "connecting") // nolint:bodyclose ws, res, err := websocket.Dial(ctx, coordinateURL.String(), &websocket.DialOptions{ HTTPClient: httpClient, @@ -398,21 +410,21 @@ func (c *Client) DialWorkspaceAgentTailnet(ctx context.Context, logger slog.Logg if errors.Is(err, context.Canceled) { return } - logger.Debug(ctx, "failed to dial", slog.Error(err)) + options.Logger.Debug(ctx, "failed to dial", slog.Error(err)) continue } sendNode, errChan := tailnet.ServeCoordinator(websocket.NetConn(ctx, ws, websocket.MessageBinary), func(node []*tailnet.Node) error { return conn.UpdateNodes(node) }) conn.SetNodeCallback(sendNode) - logger.Debug(ctx, "serving coordinator") + options.Logger.Debug(ctx, "serving coordinator") err = <-errChan if errors.Is(err, context.Canceled) { _ = ws.Close(websocket.StatusGoingAway, "") return } if err != nil { - logger.Debug(ctx, "error serving coordinator", slog.Error(err)) + options.Logger.Debug(ctx, "error serving coordinator", slog.Error(err)) _ = ws.Close(websocket.StatusGoingAway, "") continue } diff --git a/enterprise/cli/features_test.go b/enterprise/cli/features_test.go index 215809c1736c9..78b94a6509526 100644 --- a/enterprise/cli/features_test.go +++ b/enterprise/cli/features_test.go @@ -57,7 +57,7 @@ func TestFeaturesList(t *testing.T) { var entitlements codersdk.Entitlements err := json.Unmarshal(buf.Bytes(), &entitlements) require.NoError(t, err, "unmarshal JSON output") - assert.Len(t, entitlements.Features, 6) + assert.Len(t, entitlements.Features, 7) assert.Empty(t, entitlements.Warnings) assert.Equal(t, codersdk.EntitlementNotEntitled, entitlements.Features[codersdk.FeatureUserLimit].Entitlement) @@ -71,6 +71,8 @@ func TestFeaturesList(t *testing.T) { entitlements.Features[codersdk.FeatureTemplateRBAC].Entitlement) assert.Equal(t, codersdk.EntitlementNotEntitled, entitlements.Features[codersdk.FeatureSCIM].Entitlement) + assert.Equal(t, codersdk.EntitlementNotEntitled, + entitlements.Features[codersdk.FeatureHighAvailability].Entitlement) assert.False(t, entitlements.HasLicense) assert.False(t, entitlements.Experimental) }) diff --git a/enterprise/cli/server.go b/enterprise/cli/server.go index 62af6f2888373..a65b8e8faa6e0 100644 --- a/enterprise/cli/server.go +++ b/enterprise/cli/server.go @@ -2,11 +2,20 @@ package cli import ( "context" + "database/sql" + "errors" + "io" + "net/url" "github.com/spf13/cobra" + "golang.org/x/xerrors" + "tailscale.com/derp" + "tailscale.com/types/key" "github.com/coder/coder/cli/deployment" + "github.com/coder/coder/cryptorand" "github.com/coder/coder/enterprise/coderd" + "github.com/coder/coder/tailnet" agpl "github.com/coder/coder/cli" agplcoderd "github.com/coder/coder/coderd" @@ -14,23 +23,49 @@ import ( func server() *cobra.Command { dflags := deployment.Flags() - cmd := agpl.Server(dflags, func(ctx context.Context, options *agplcoderd.Options) (*agplcoderd.API, error) { + cmd := agpl.Server(dflags, func(ctx context.Context, options *agplcoderd.Options) (*agplcoderd.API, io.Closer, error) { + if dflags.DerpServerRelayAddress.Value != "" { + _, err := url.Parse(dflags.DerpServerRelayAddress.Value) + if err != nil { + return nil, nil, xerrors.Errorf("derp-server-relay-address must be a valid HTTP URL: %w", err) + } + } + + options.DERPServer = derp.NewServer(key.NewNode(), tailnet.Logger(options.Logger.Named("derp"))) + meshKey, err := options.Database.GetDERPMeshKey(ctx) + if err != nil { + if !errors.Is(err, sql.ErrNoRows) { + return nil, nil, xerrors.Errorf("get mesh key: %w", err) + } + meshKey, err = cryptorand.String(32) + if err != nil { + return nil, nil, xerrors.Errorf("generate mesh key: %w", err) + } + err = options.Database.InsertDERPMeshKey(ctx, meshKey) + if err != nil { + return nil, nil, xerrors.Errorf("insert mesh key: %w", err) + } + } + options.DERPServer.SetMeshKey(meshKey) + o := &coderd.Options{ - AuditLogging: dflags.AuditLogging.Value, - BrowserOnly: dflags.BrowserOnly.Value, - SCIMAPIKey: []byte(dflags.SCIMAuthHeader.Value), - UserWorkspaceQuota: dflags.UserWorkspaceQuota.Value, - RBACEnabled: true, - Options: options, + AuditLogging: dflags.AuditLogging.Value, + BrowserOnly: dflags.BrowserOnly.Value, + SCIMAPIKey: []byte(dflags.SCIMAuthHeader.Value), + UserWorkspaceQuota: dflags.UserWorkspaceQuota.Value, + RBAC: true, + DERPServerRelayAddress: dflags.DerpServerRelayAddress.Value, + DERPServerRegionID: dflags.DerpServerRegionID.Value, + + Options: options, } api, err := coderd.New(ctx, o) if err != nil { - return nil, err + return nil, nil, err } - return api.AGPL, nil + return api.AGPL, api, nil }) deployment.AttachFlags(cmd.Flags(), dflags, true) - return cmd } diff --git a/enterprise/coderd/authorize_test.go b/enterprise/coderd/authorize_test.go index 72cc4c5f3861b..9195387632a67 100644 --- a/enterprise/coderd/authorize_test.go +++ b/enterprise/coderd/authorize_test.go @@ -28,7 +28,7 @@ func TestCheckACLPermissions(t *testing.T) { // Create adminClient, member, and org adminClient adminUser := coderdtest.CreateFirstUser(t, adminClient) _ = coderdenttest.AddLicense(t, adminClient, coderdenttest.LicenseOptions{ - TemplateRBACEnabled: true, + TemplateRBAC: true, }) memberClient := coderdtest.CreateAnotherUser(t, adminClient, adminUser.OrganizationID) diff --git a/enterprise/coderd/coderd.go b/enterprise/coderd/coderd.go index 2c341dd13a224..1250e6ae129da 100644 --- a/enterprise/coderd/coderd.go +++ b/enterprise/coderd/coderd.go @@ -3,6 +3,8 @@ package coderd import ( "context" "crypto/ed25519" + "crypto/tls" + "crypto/x509" "net/http" "sync" "time" @@ -23,6 +25,10 @@ import ( "github.com/coder/coder/enterprise/audit" "github.com/coder/coder/enterprise/audit/backends" "github.com/coder/coder/enterprise/coderd/license" + "github.com/coder/coder/enterprise/derpmesh" + "github.com/coder/coder/enterprise/replicasync" + "github.com/coder/coder/enterprise/tailnet" + agpltailnet "github.com/coder/coder/tailnet" ) // New constructs an Enterprise coderd API instance. @@ -47,6 +53,7 @@ func New(ctx context.Context, options *Options) (*API, error) { Options: options, cancelEntitlementsLoop: cancelFunc, } + oauthConfigs := &httpmw.OAuth2Configs{ Github: options.GithubOAuth2Config, OIDC: options.OIDCConfig, @@ -59,6 +66,10 @@ func New(ctx context.Context, options *Options) (*API, error) { api.AGPL.APIHandler.Group(func(r chi.Router) { r.Get("/entitlements", api.serveEntitlements) + r.Route("/replicas", func(r chi.Router) { + r.Use(apiKeyMiddleware) + r.Get("/", api.replicas) + }) r.Route("/licenses", func(r chi.Router) { r.Use(apiKeyMiddleware) r.Post("/", api.postLicense) @@ -117,7 +128,40 @@ func New(ctx context.Context, options *Options) (*API, error) { }) } - err := api.updateEntitlements(ctx) + meshRootCA := x509.NewCertPool() + for _, certificate := range options.TLSCertificates { + for _, certificatePart := range certificate.Certificate { + certificate, err := x509.ParseCertificate(certificatePart) + if err != nil { + return nil, xerrors.Errorf("parse certificate %s: %w", certificate.Subject.CommonName, err) + } + meshRootCA.AddCert(certificate) + } + } + // This TLS configuration spoofs access from the access URL hostname + // assuming that the certificates provided will cover that hostname. + // + // Replica sync and DERP meshing require accessing replicas via their + // internal IP addresses, and if TLS is configured we use the same + // certificates. + meshTLSConfig := &tls.Config{ + MinVersion: tls.VersionTLS12, + Certificates: options.TLSCertificates, + RootCAs: meshRootCA, + ServerName: options.AccessURL.Hostname(), + } + var err error + api.replicaManager, err = replicasync.New(ctx, options.Logger, options.Database, options.Pubsub, &replicasync.Options{ + RelayAddress: options.DERPServerRelayAddress, + RegionID: int32(options.DERPServerRegionID), + TLSConfig: meshTLSConfig, + }) + if err != nil { + return nil, xerrors.Errorf("initialize replica: %w", err) + } + api.derpMesh = derpmesh.New(options.Logger.Named("derpmesh"), api.DERPServer, meshTLSConfig) + + err = api.updateEntitlements(ctx) if err != nil { return nil, xerrors.Errorf("update entitlements: %w", err) } @@ -129,13 +173,17 @@ func New(ctx context.Context, options *Options) (*API, error) { type Options struct { *coderd.Options - RBACEnabled bool + RBAC bool AuditLogging bool // Whether to block non-browser connections. BrowserOnly bool SCIMAPIKey []byte UserWorkspaceQuota int + // Used for high availability. + DERPServerRelayAddress string + DERPServerRegionID int + EntitlementsUpdateInterval time.Duration Keys map[string]ed25519.PublicKey } @@ -144,6 +192,11 @@ type API struct { AGPL *coderd.API *Options + // Detects multiple Coder replicas running at the same time. + replicaManager *replicasync.Manager + // Meshes DERP connections from multiple replicas. + derpMesh *derpmesh.Mesh + cancelEntitlementsLoop func() entitlementsMu sync.RWMutex entitlements codersdk.Entitlements @@ -151,6 +204,8 @@ type API struct { func (api *API) Close() error { api.cancelEntitlementsLoop() + _ = api.replicaManager.Close() + _ = api.derpMesh.Close() return api.AGPL.Close() } @@ -158,12 +213,13 @@ func (api *API) updateEntitlements(ctx context.Context) error { api.entitlementsMu.Lock() defer api.entitlementsMu.Unlock() - entitlements, err := license.Entitlements(ctx, api.Database, api.Logger, api.Keys, map[string]bool{ - codersdk.FeatureAuditLog: api.AuditLogging, - codersdk.FeatureBrowserOnly: api.BrowserOnly, - codersdk.FeatureSCIM: len(api.SCIMAPIKey) != 0, - codersdk.FeatureWorkspaceQuota: api.UserWorkspaceQuota != 0, - codersdk.FeatureTemplateRBAC: api.RBACEnabled, + entitlements, err := license.Entitlements(ctx, api.Database, api.Logger, len(api.replicaManager.All()), api.Keys, map[string]bool{ + codersdk.FeatureAuditLog: api.AuditLogging, + codersdk.FeatureBrowserOnly: api.BrowserOnly, + codersdk.FeatureSCIM: len(api.SCIMAPIKey) != 0, + codersdk.FeatureWorkspaceQuota: api.UserWorkspaceQuota != 0, + codersdk.FeatureHighAvailability: api.DERPServerRelayAddress != "", + codersdk.FeatureTemplateRBAC: api.RBAC, }) if err != nil { return err @@ -209,6 +265,46 @@ func (api *API) updateEntitlements(ctx context.Context) error { api.AGPL.WorkspaceQuotaEnforcer.Store(&enforcer) } + if changed, enabled := featureChanged(codersdk.FeatureHighAvailability); changed { + coordinator := agpltailnet.NewCoordinator() + if enabled { + haCoordinator, err := tailnet.NewCoordinator(api.Logger, api.Pubsub) + if err != nil { + api.Logger.Error(ctx, "unable to set up high availability coordinator", slog.Error(err)) + // If we try to setup the HA coordinator and it fails, nothing + // is actually changing. + changed = false + } else { + coordinator = haCoordinator + } + + api.replicaManager.SetCallback(func() { + addresses := make([]string, 0) + for _, replica := range api.replicaManager.Regional() { + addresses = append(addresses, replica.RelayAddress) + } + api.derpMesh.SetAddresses(addresses, false) + _ = api.updateEntitlements(ctx) + }) + } else { + api.derpMesh.SetAddresses([]string{}, false) + api.replicaManager.SetCallback(func() { + // If the amount of replicas change, so should our entitlements. + // This is to display a warning in the UI if the user is unlicensed. + _ = api.updateEntitlements(ctx) + }) + } + + // Recheck changed in case the HA coordinator failed to set up. + if changed { + oldCoordinator := *api.AGPL.TailnetCoordinator.Swap(&coordinator) + err := oldCoordinator.Close() + if err != nil { + api.Logger.Error(ctx, "close old tailnet coordinator", slog.Error(err)) + } + } + } + api.entitlements = entitlements return nil diff --git a/enterprise/coderd/coderd_test.go b/enterprise/coderd/coderd_test.go index 050cad5f9b87d..7b51845ff3986 100644 --- a/enterprise/coderd/coderd_test.go +++ b/enterprise/coderd/coderd_test.go @@ -41,9 +41,9 @@ func TestEntitlements(t *testing.T) { }) _ = coderdtest.CreateFirstUser(t, client) coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{ - UserLimit: 100, - AuditLog: true, - TemplateRBACEnabled: true, + UserLimit: 100, + AuditLog: true, + TemplateRBAC: true, }) res, err := client.Entitlements(context.Background()) require.NoError(t, err) @@ -85,7 +85,7 @@ func TestEntitlements(t *testing.T) { assert.False(t, res.HasLicense) al = res.Features[codersdk.FeatureAuditLog] assert.Equal(t, codersdk.EntitlementNotEntitled, al.Entitlement) - assert.True(t, al.Enabled) + assert.False(t, al.Enabled) }) t.Run("Pubsub", func(t *testing.T) { t.Parallel() diff --git a/enterprise/coderd/coderdenttest/coderdenttest.go b/enterprise/coderd/coderdenttest/coderdenttest.go index 75760b3d4f2eb..a8595b5bc6ede 100644 --- a/enterprise/coderd/coderdenttest/coderdenttest.go +++ b/enterprise/coderd/coderdenttest/coderdenttest.go @@ -4,7 +4,9 @@ import ( "context" "crypto/ed25519" "crypto/rand" + "crypto/tls" "io" + "net/http" "testing" "time" @@ -60,19 +62,21 @@ func NewWithAPI(t *testing.T, options *Options) (*codersdk.Client, io.Closer, *c if options.Options == nil { options.Options = &coderdtest.Options{} } - srv, cancelFunc, oop := coderdtest.NewOptions(t, options.Options) + setHandler, cancelFunc, oop := coderdtest.NewOptions(t, options.Options) coderAPI, err := coderd.New(context.Background(), &coderd.Options{ - RBACEnabled: true, + RBAC: true, AuditLogging: options.AuditLogging, BrowserOnly: options.BrowserOnly, SCIMAPIKey: options.SCIMAPIKey, + DERPServerRelayAddress: oop.AccessURL.String(), + DERPServerRegionID: oop.DERPMap.RegionIDs()[0], UserWorkspaceQuota: options.UserWorkspaceQuota, Options: oop, EntitlementsUpdateInterval: options.EntitlementsUpdateInterval, Keys: Keys, }) assert.NoError(t, err) - srv.Config.Handler = coderAPI.AGPL.RootHandler + setHandler(coderAPI.AGPL.RootHandler) var provisionerCloser io.Closer = nopcloser{} if options.IncludeProvisionerDaemon { provisionerCloser = coderdtest.NewProvisionerDaemon(t, coderAPI.AGPL) @@ -83,22 +87,32 @@ func NewWithAPI(t *testing.T, options *Options) (*codersdk.Client, io.Closer, *c _ = provisionerCloser.Close() _ = coderAPI.Close() }) - return codersdk.New(coderAPI.AccessURL), provisionerCloser, coderAPI + client := codersdk.New(coderAPI.AccessURL) + client.HTTPClient = &http.Client{ + Transport: &http.Transport{ + TLSClientConfig: &tls.Config{ + //nolint:gosec + InsecureSkipVerify: true, + }, + }, + } + return client, provisionerCloser, coderAPI } type LicenseOptions struct { - AccountType string - AccountID string - Trial bool - AllFeatures bool - GraceAt time.Time - ExpiresAt time.Time - UserLimit int64 - AuditLog bool - BrowserOnly bool - SCIM bool - WorkspaceQuota bool - TemplateRBACEnabled bool + AccountType string + AccountID string + Trial bool + AllFeatures bool + GraceAt time.Time + ExpiresAt time.Time + UserLimit int64 + AuditLog bool + BrowserOnly bool + SCIM bool + WorkspaceQuota bool + TemplateRBAC bool + HighAvailability bool } // AddLicense generates a new license with the options provided and inserts it. @@ -134,9 +148,13 @@ func GenerateLicense(t *testing.T, options LicenseOptions) string { if options.WorkspaceQuota { workspaceQuota = 1 } + highAvailability := int64(0) + if options.HighAvailability { + highAvailability = 1 + } rbacEnabled := int64(0) - if options.TemplateRBACEnabled { + if options.TemplateRBAC { rbacEnabled = 1 } @@ -154,12 +172,13 @@ func GenerateLicense(t *testing.T, options LicenseOptions) string { Version: license.CurrentVersion, AllFeatures: options.AllFeatures, Features: license.Features{ - UserLimit: options.UserLimit, - AuditLog: auditLog, - BrowserOnly: browserOnly, - SCIM: scim, - WorkspaceQuota: workspaceQuota, - TemplateRBAC: rbacEnabled, + UserLimit: options.UserLimit, + AuditLog: auditLog, + BrowserOnly: browserOnly, + SCIM: scim, + WorkspaceQuota: workspaceQuota, + HighAvailability: highAvailability, + TemplateRBAC: rbacEnabled, }, } tok := jwt.NewWithClaims(jwt.SigningMethodEdDSA, c) diff --git a/enterprise/coderd/coderdenttest/coderdenttest_test.go b/enterprise/coderd/coderdenttest/coderdenttest_test.go index d526f6927bc00..e8ad88cd02805 100644 --- a/enterprise/coderd/coderdenttest/coderdenttest_test.go +++ b/enterprise/coderd/coderdenttest/coderdenttest_test.go @@ -33,7 +33,7 @@ func TestAuthorizeAllEndpoints(t *testing.T) { ctx, _ := testutil.Context(t) admin := coderdtest.CreateFirstUser(t, client) license := coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{ - TemplateRBACEnabled: true, + TemplateRBAC: true, }) group, err := client.CreateGroup(ctx, admin.OrganizationID, codersdk.CreateGroupRequest{ Name: "testgroup", @@ -58,6 +58,10 @@ func TestAuthorizeAllEndpoints(t *testing.T) { AssertAction: rbac.ActionRead, AssertObject: rbac.ResourceLicense, } + assertRoute["GET:/api/v2/replicas"] = coderdtest.RouteCheck{ + AssertAction: rbac.ActionRead, + AssertObject: rbac.ResourceReplicas, + } assertRoute["DELETE:/api/v2/licenses/{id}"] = coderdtest.RouteCheck{ AssertAction: rbac.ActionDelete, AssertObject: rbac.ResourceLicense, diff --git a/enterprise/coderd/groups_test.go b/enterprise/coderd/groups_test.go index 2661da6bcc29f..eae51b0dfdc3f 100644 --- a/enterprise/coderd/groups_test.go +++ b/enterprise/coderd/groups_test.go @@ -24,7 +24,7 @@ func TestCreateGroup(t *testing.T) { user := coderdtest.CreateFirstUser(t, client) _ = coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{ - TemplateRBACEnabled: true, + TemplateRBAC: true, }) ctx, _ := testutil.Context(t) group, err := client.CreateGroup(ctx, user.OrganizationID, codersdk.CreateGroupRequest{ @@ -43,7 +43,7 @@ func TestCreateGroup(t *testing.T) { user := coderdtest.CreateFirstUser(t, client) _ = coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{ - TemplateRBACEnabled: true, + TemplateRBAC: true, }) ctx, _ := testutil.Context(t) _, err := client.CreateGroup(ctx, user.OrganizationID, codersdk.CreateGroupRequest{ @@ -67,7 +67,7 @@ func TestCreateGroup(t *testing.T) { user := coderdtest.CreateFirstUser(t, client) _ = coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{ - TemplateRBACEnabled: true, + TemplateRBAC: true, }) ctx, _ := testutil.Context(t) _, err := client.CreateGroup(ctx, user.OrganizationID, codersdk.CreateGroupRequest{ @@ -90,7 +90,7 @@ func TestPatchGroup(t *testing.T) { user := coderdtest.CreateFirstUser(t, client) _ = coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{ - TemplateRBACEnabled: true, + TemplateRBAC: true, }) ctx, _ := testutil.Context(t) group, err := client.CreateGroup(ctx, user.OrganizationID, codersdk.CreateGroupRequest{ @@ -112,7 +112,7 @@ func TestPatchGroup(t *testing.T) { user := coderdtest.CreateFirstUser(t, client) _ = coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{ - TemplateRBACEnabled: true, + TemplateRBAC: true, }) _, user2 := coderdtest.CreateAnotherUserWithUser(t, client, user.OrganizationID) _, user3 := coderdtest.CreateAnotherUserWithUser(t, client, user.OrganizationID) @@ -138,7 +138,7 @@ func TestPatchGroup(t *testing.T) { user := coderdtest.CreateFirstUser(t, client) _ = coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{ - TemplateRBACEnabled: true, + TemplateRBAC: true, }) _, user2 := coderdtest.CreateAnotherUserWithUser(t, client, user.OrganizationID) _, user3 := coderdtest.CreateAnotherUserWithUser(t, client, user.OrganizationID) @@ -173,7 +173,7 @@ func TestPatchGroup(t *testing.T) { user := coderdtest.CreateFirstUser(t, client) _ = coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{ - TemplateRBACEnabled: true, + TemplateRBAC: true, }) ctx, _ := testutil.Context(t) group, err := client.CreateGroup(ctx, user.OrganizationID, codersdk.CreateGroupRequest{ @@ -197,7 +197,7 @@ func TestPatchGroup(t *testing.T) { user := coderdtest.CreateFirstUser(t, client) _ = coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{ - TemplateRBACEnabled: true, + TemplateRBAC: true, }) ctx, _ := testutil.Context(t) group, err := client.CreateGroup(ctx, user.OrganizationID, codersdk.CreateGroupRequest{ @@ -221,7 +221,7 @@ func TestPatchGroup(t *testing.T) { user := coderdtest.CreateFirstUser(t, client) _ = coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{ - TemplateRBACEnabled: true, + TemplateRBAC: true, }) _, user2 := coderdtest.CreateAnotherUserWithUser(t, client, user.OrganizationID) ctx, _ := testutil.Context(t) @@ -247,7 +247,7 @@ func TestPatchGroup(t *testing.T) { user := coderdtest.CreateFirstUser(t, client) _ = coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{ - TemplateRBACEnabled: true, + TemplateRBAC: true, }) ctx, _ := testutil.Context(t) group, err := client.CreateGroup(ctx, user.OrganizationID, codersdk.CreateGroupRequest{ @@ -276,7 +276,7 @@ func TestGroup(t *testing.T) { user := coderdtest.CreateFirstUser(t, client) _ = coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{ - TemplateRBACEnabled: true, + TemplateRBAC: true, }) ctx, _ := testutil.Context(t) group, err := client.CreateGroup(ctx, user.OrganizationID, codersdk.CreateGroupRequest{ @@ -296,7 +296,7 @@ func TestGroup(t *testing.T) { user := coderdtest.CreateFirstUser(t, client) _ = coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{ - TemplateRBACEnabled: true, + TemplateRBAC: true, }) _, user2 := coderdtest.CreateAnotherUserWithUser(t, client, user.OrganizationID) _, user3 := coderdtest.CreateAnotherUserWithUser(t, client, user.OrganizationID) @@ -326,7 +326,7 @@ func TestGroup(t *testing.T) { user := coderdtest.CreateFirstUser(t, client) _ = coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{ - TemplateRBACEnabled: true, + TemplateRBAC: true, }) client1, _ := coderdtest.CreateAnotherUserWithUser(t, client, user.OrganizationID) @@ -347,7 +347,7 @@ func TestGroup(t *testing.T) { client := coderdenttest.New(t, nil) user := coderdtest.CreateFirstUser(t, client) _ = coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{ - TemplateRBACEnabled: true, + TemplateRBAC: true, }) _, user1 := coderdtest.CreateAnotherUserWithUser(t, client, user.OrganizationID) @@ -380,7 +380,7 @@ func TestGroup(t *testing.T) { client := coderdenttest.New(t, nil) user := coderdtest.CreateFirstUser(t, client) _ = coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{ - TemplateRBACEnabled: true, + TemplateRBAC: true, }) _, user1 := coderdtest.CreateAnotherUserWithUser(t, client, user.OrganizationID) @@ -421,7 +421,7 @@ func TestGroups(t *testing.T) { client := coderdenttest.New(t, nil) user := coderdtest.CreateFirstUser(t, client) _ = coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{ - TemplateRBACEnabled: true, + TemplateRBAC: true, }) _, user2 := coderdtest.CreateAnotherUserWithUser(t, client, user.OrganizationID) _, user3 := coderdtest.CreateAnotherUserWithUser(t, client, user.OrganizationID) @@ -467,7 +467,7 @@ func TestDeleteGroup(t *testing.T) { user := coderdtest.CreateFirstUser(t, client) _ = coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{ - TemplateRBACEnabled: true, + TemplateRBAC: true, }) ctx, _ := testutil.Context(t) group1, err := client.CreateGroup(ctx, user.OrganizationID, codersdk.CreateGroupRequest{ @@ -492,7 +492,7 @@ func TestDeleteGroup(t *testing.T) { user := coderdtest.CreateFirstUser(t, client) _ = coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{ - TemplateRBACEnabled: true, + TemplateRBAC: true, }) ctx, _ := testutil.Context(t) err := client.DeleteGroup(ctx, user.OrganizationID) diff --git a/enterprise/coderd/license/license.go b/enterprise/coderd/license/license.go index ce9e5d1d590b1..c5bb689db65a9 100644 --- a/enterprise/coderd/license/license.go +++ b/enterprise/coderd/license/license.go @@ -17,12 +17,20 @@ import ( ) // Entitlements processes licenses to return whether features are enabled or not. -func Entitlements(ctx context.Context, db database.Store, logger slog.Logger, keys map[string]ed25519.PublicKey, enablements map[string]bool) (codersdk.Entitlements, error) { +func Entitlements( + ctx context.Context, + db database.Store, + logger slog.Logger, + replicaCount int, + keys map[string]ed25519.PublicKey, + enablements map[string]bool, +) (codersdk.Entitlements, error) { now := time.Now() // Default all entitlements to be disabled. entitlements := codersdk.Entitlements{ Features: map[string]codersdk.Feature{}, Warnings: []string{}, + Errors: []string{}, } for _, featureName := range codersdk.FeatureNames { entitlements.Features[featureName] = codersdk.Feature{ @@ -96,6 +104,12 @@ func Entitlements(ctx context.Context, db database.Store, logger slog.Logger, ke Enabled: enablements[codersdk.FeatureWorkspaceQuota], } } + if claims.Features.HighAvailability > 0 { + entitlements.Features[codersdk.FeatureHighAvailability] = codersdk.Feature{ + Entitlement: entitlement, + Enabled: enablements[codersdk.FeatureHighAvailability], + } + } if claims.Features.TemplateRBAC > 0 { entitlements.Features[codersdk.FeatureTemplateRBAC] = codersdk.Feature{ Entitlement: entitlement, @@ -132,6 +146,10 @@ func Entitlements(ctx context.Context, db database.Store, logger slog.Logger, ke if featureName == codersdk.FeatureUserLimit { continue } + // High availability has it's own warnings based on replica count! + if featureName == codersdk.FeatureHighAvailability { + continue + } feature := entitlements.Features[featureName] if !feature.Enabled { continue @@ -141,9 +159,6 @@ func Entitlements(ctx context.Context, db database.Store, logger slog.Logger, ke case codersdk.EntitlementNotEntitled: entitlements.Warnings = append(entitlements.Warnings, fmt.Sprintf("%s is enabled but your license is not entitled to this feature.", niceName)) - // Disable the feature and add a warning... - feature.Enabled = false - entitlements.Features[featureName] = feature case codersdk.EntitlementGracePeriod: entitlements.Warnings = append(entitlements.Warnings, fmt.Sprintf("%s is enabled but your license for this feature is expired.", niceName)) @@ -152,6 +167,32 @@ func Entitlements(ctx context.Context, db database.Store, logger slog.Logger, ke } } + if replicaCount > 1 { + feature := entitlements.Features[codersdk.FeatureHighAvailability] + + switch feature.Entitlement { + case codersdk.EntitlementNotEntitled: + if entitlements.HasLicense { + entitlements.Errors = append(entitlements.Warnings, + "You have multiple replicas but your license is not entitled to high availability. You will be unable to connect to workspaces.") + } else { + entitlements.Errors = append(entitlements.Warnings, + "You have multiple replicas but high availability is an Enterprise feature. You will be unable to connect to workspaces.") + } + case codersdk.EntitlementGracePeriod: + entitlements.Warnings = append(entitlements.Warnings, + "You have multiple replicas but your license for high availability is expired. Reduce to one replica or workspace connections will stop working.") + } + } + + for _, featureName := range codersdk.FeatureNames { + feature := entitlements.Features[featureName] + if feature.Entitlement == codersdk.EntitlementNotEntitled { + feature.Enabled = false + entitlements.Features[featureName] = feature + } + } + return entitlements, nil } @@ -171,12 +212,13 @@ var ( ) type Features struct { - UserLimit int64 `json:"user_limit"` - AuditLog int64 `json:"audit_log"` - BrowserOnly int64 `json:"browser_only"` - SCIM int64 `json:"scim"` - WorkspaceQuota int64 `json:"workspace_quota"` - TemplateRBAC int64 `json:"template_rbac"` + UserLimit int64 `json:"user_limit"` + AuditLog int64 `json:"audit_log"` + BrowserOnly int64 `json:"browser_only"` + SCIM int64 `json:"scim"` + WorkspaceQuota int64 `json:"workspace_quota"` + TemplateRBAC int64 `json:"template_rbac"` + HighAvailability int64 `json:"high_availability"` } type Claims struct { diff --git a/enterprise/coderd/license/license_test.go b/enterprise/coderd/license/license_test.go index 8f15c5c009cdf..6def291e3e24c 100644 --- a/enterprise/coderd/license/license_test.go +++ b/enterprise/coderd/license/license_test.go @@ -20,17 +20,18 @@ import ( func TestEntitlements(t *testing.T) { t.Parallel() all := map[string]bool{ - codersdk.FeatureAuditLog: true, - codersdk.FeatureBrowserOnly: true, - codersdk.FeatureSCIM: true, - codersdk.FeatureWorkspaceQuota: true, - codersdk.FeatureTemplateRBAC: true, + codersdk.FeatureAuditLog: true, + codersdk.FeatureBrowserOnly: true, + codersdk.FeatureSCIM: true, + codersdk.FeatureWorkspaceQuota: true, + codersdk.FeatureHighAvailability: true, + codersdk.FeatureTemplateRBAC: true, } t.Run("Defaults", func(t *testing.T) { t.Parallel() db := databasefake.New() - entitlements, err := license.Entitlements(context.Background(), db, slog.Logger{}, coderdenttest.Keys, map[string]bool{}) + entitlements, err := license.Entitlements(context.Background(), db, slog.Logger{}, 1, coderdenttest.Keys, all) require.NoError(t, err) require.False(t, entitlements.HasLicense) require.False(t, entitlements.Trial) @@ -46,7 +47,7 @@ func TestEntitlements(t *testing.T) { JWT: coderdenttest.GenerateLicense(t, coderdenttest.LicenseOptions{}), Exp: time.Now().Add(time.Hour), }) - entitlements, err := license.Entitlements(context.Background(), db, slog.Logger{}, coderdenttest.Keys, map[string]bool{}) + entitlements, err := license.Entitlements(context.Background(), db, slog.Logger{}, 1, coderdenttest.Keys, map[string]bool{}) require.NoError(t, err) require.True(t, entitlements.HasLicense) require.False(t, entitlements.Trial) @@ -60,16 +61,17 @@ func TestEntitlements(t *testing.T) { db := databasefake.New() db.InsertLicense(context.Background(), database.InsertLicenseParams{ JWT: coderdenttest.GenerateLicense(t, coderdenttest.LicenseOptions{ - UserLimit: 100, - AuditLog: true, - BrowserOnly: true, - SCIM: true, - WorkspaceQuota: true, - TemplateRBACEnabled: true, + UserLimit: 100, + AuditLog: true, + BrowserOnly: true, + SCIM: true, + WorkspaceQuota: true, + HighAvailability: true, + TemplateRBAC: true, }), Exp: time.Now().Add(time.Hour), }) - entitlements, err := license.Entitlements(context.Background(), db, slog.Logger{}, coderdenttest.Keys, map[string]bool{}) + entitlements, err := license.Entitlements(context.Background(), db, slog.Logger{}, 1, coderdenttest.Keys, map[string]bool{}) require.NoError(t, err) require.True(t, entitlements.HasLicense) require.False(t, entitlements.Trial) @@ -82,18 +84,19 @@ func TestEntitlements(t *testing.T) { db := databasefake.New() db.InsertLicense(context.Background(), database.InsertLicenseParams{ JWT: coderdenttest.GenerateLicense(t, coderdenttest.LicenseOptions{ - UserLimit: 100, - AuditLog: true, - BrowserOnly: true, - SCIM: true, - WorkspaceQuota: true, - TemplateRBACEnabled: true, - GraceAt: time.Now().Add(-time.Hour), - ExpiresAt: time.Now().Add(time.Hour), + UserLimit: 100, + AuditLog: true, + BrowserOnly: true, + SCIM: true, + WorkspaceQuota: true, + HighAvailability: true, + TemplateRBAC: true, + GraceAt: time.Now().Add(-time.Hour), + ExpiresAt: time.Now().Add(time.Hour), }), Exp: time.Now().Add(time.Hour), }) - entitlements, err := license.Entitlements(context.Background(), db, slog.Logger{}, coderdenttest.Keys, all) + entitlements, err := license.Entitlements(context.Background(), db, slog.Logger{}, 1, coderdenttest.Keys, all) require.NoError(t, err) require.True(t, entitlements.HasLicense) require.False(t, entitlements.Trial) @@ -101,6 +104,9 @@ func TestEntitlements(t *testing.T) { if featureName == codersdk.FeatureUserLimit { continue } + if featureName == codersdk.FeatureHighAvailability { + continue + } niceName := strings.Title(strings.ReplaceAll(featureName, "_", " ")) require.Equal(t, codersdk.EntitlementGracePeriod, entitlements.Features[featureName].Entitlement) require.Contains(t, entitlements.Warnings, fmt.Sprintf("%s is enabled but your license for this feature is expired.", niceName)) @@ -113,7 +119,7 @@ func TestEntitlements(t *testing.T) { JWT: coderdenttest.GenerateLicense(t, coderdenttest.LicenseOptions{}), Exp: time.Now().Add(time.Hour), }) - entitlements, err := license.Entitlements(context.Background(), db, slog.Logger{}, coderdenttest.Keys, all) + entitlements, err := license.Entitlements(context.Background(), db, slog.Logger{}, 1, coderdenttest.Keys, all) require.NoError(t, err) require.True(t, entitlements.HasLicense) require.False(t, entitlements.Trial) @@ -121,6 +127,9 @@ func TestEntitlements(t *testing.T) { if featureName == codersdk.FeatureUserLimit { continue } + if featureName == codersdk.FeatureHighAvailability { + continue + } niceName := strings.Title(strings.ReplaceAll(featureName, "_", " ")) // Ensures features that are not entitled are properly disabled. require.False(t, entitlements.Features[featureName].Enabled) @@ -139,7 +148,7 @@ func TestEntitlements(t *testing.T) { }), Exp: time.Now().Add(time.Hour), }) - entitlements, err := license.Entitlements(context.Background(), db, slog.Logger{}, coderdenttest.Keys, map[string]bool{}) + entitlements, err := license.Entitlements(context.Background(), db, slog.Logger{}, 1, coderdenttest.Keys, map[string]bool{}) require.NoError(t, err) require.True(t, entitlements.HasLicense) require.Contains(t, entitlements.Warnings, "Your deployment has 2 active users but is only licensed for 1.") @@ -161,7 +170,7 @@ func TestEntitlements(t *testing.T) { }), Exp: time.Now().Add(time.Hour), }) - entitlements, err := license.Entitlements(context.Background(), db, slog.Logger{}, coderdenttest.Keys, map[string]bool{}) + entitlements, err := license.Entitlements(context.Background(), db, slog.Logger{}, 1, coderdenttest.Keys, map[string]bool{}) require.NoError(t, err) require.True(t, entitlements.HasLicense) require.Empty(t, entitlements.Warnings) @@ -184,7 +193,7 @@ func TestEntitlements(t *testing.T) { }), }) - entitlements, err := license.Entitlements(context.Background(), db, slog.Logger{}, coderdenttest.Keys, map[string]bool{}) + entitlements, err := license.Entitlements(context.Background(), db, slog.Logger{}, 1, coderdenttest.Keys, map[string]bool{}) require.NoError(t, err) require.True(t, entitlements.HasLicense) require.False(t, entitlements.Trial) @@ -199,7 +208,7 @@ func TestEntitlements(t *testing.T) { AllFeatures: true, }), }) - entitlements, err := license.Entitlements(context.Background(), db, slog.Logger{}, coderdenttest.Keys, all) + entitlements, err := license.Entitlements(context.Background(), db, slog.Logger{}, 1, coderdenttest.Keys, all) require.NoError(t, err) require.True(t, entitlements.HasLicense) require.False(t, entitlements.Trial) @@ -211,4 +220,52 @@ func TestEntitlements(t *testing.T) { require.Equal(t, codersdk.EntitlementEntitled, entitlements.Features[featureName].Entitlement) } }) + + t.Run("MultipleReplicasNoLicense", func(t *testing.T) { + t.Parallel() + db := databasefake.New() + entitlements, err := license.Entitlements(context.Background(), db, slog.Logger{}, 2, coderdenttest.Keys, all) + require.NoError(t, err) + require.False(t, entitlements.HasLicense) + require.Len(t, entitlements.Errors, 1) + require.Equal(t, "You have multiple replicas but high availability is an Enterprise feature. You will be unable to connect to workspaces.", entitlements.Errors[0]) + }) + + t.Run("MultipleReplicasNotEntitled", func(t *testing.T) { + t.Parallel() + db := databasefake.New() + db.InsertLicense(context.Background(), database.InsertLicenseParams{ + Exp: time.Now().Add(time.Hour), + JWT: coderdenttest.GenerateLicense(t, coderdenttest.LicenseOptions{ + AuditLog: true, + }), + }) + entitlements, err := license.Entitlements(context.Background(), db, slog.Logger{}, 2, coderdenttest.Keys, map[string]bool{ + codersdk.FeatureHighAvailability: true, + }) + require.NoError(t, err) + require.True(t, entitlements.HasLicense) + require.Len(t, entitlements.Errors, 1) + require.Equal(t, "You have multiple replicas but your license is not entitled to high availability. You will be unable to connect to workspaces.", entitlements.Errors[0]) + }) + + t.Run("MultipleReplicasGrace", func(t *testing.T) { + t.Parallel() + db := databasefake.New() + db.InsertLicense(context.Background(), database.InsertLicenseParams{ + JWT: coderdenttest.GenerateLicense(t, coderdenttest.LicenseOptions{ + HighAvailability: true, + GraceAt: time.Now().Add(-time.Hour), + ExpiresAt: time.Now().Add(time.Hour), + }), + Exp: time.Now().Add(time.Hour), + }) + entitlements, err := license.Entitlements(context.Background(), db, slog.Logger{}, 2, coderdenttest.Keys, map[string]bool{ + codersdk.FeatureHighAvailability: true, + }) + require.NoError(t, err) + require.True(t, entitlements.HasLicense) + require.Len(t, entitlements.Warnings, 1) + require.Equal(t, "You have multiple replicas but your license for high availability is expired. Reduce to one replica or workspace connections will stop working.", entitlements.Warnings[0]) + }) } diff --git a/enterprise/coderd/licenses_test.go b/enterprise/coderd/licenses_test.go index f7c1c639997cb..aa4dddf1fd5f1 100644 --- a/enterprise/coderd/licenses_test.go +++ b/enterprise/coderd/licenses_test.go @@ -78,21 +78,21 @@ func TestGetLicense(t *testing.T) { defer cancel() coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{ - AccountID: "testing", - AuditLog: true, - SCIM: true, - BrowserOnly: true, - TemplateRBACEnabled: true, + AccountID: "testing", + AuditLog: true, + SCIM: true, + BrowserOnly: true, + TemplateRBAC: true, }) coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{ - AccountID: "testing2", - AuditLog: true, - SCIM: true, - BrowserOnly: true, - Trial: true, - UserLimit: 200, - TemplateRBACEnabled: false, + AccountID: "testing2", + AuditLog: true, + SCIM: true, + BrowserOnly: true, + Trial: true, + UserLimit: 200, + TemplateRBAC: false, }) licenses, err := client.Licenses(ctx) @@ -101,23 +101,25 @@ func TestGetLicense(t *testing.T) { assert.Equal(t, int32(1), licenses[0].ID) assert.Equal(t, "testing", licenses[0].Claims["account_id"]) assert.Equal(t, map[string]interface{}{ - codersdk.FeatureUserLimit: json.Number("0"), - codersdk.FeatureAuditLog: json.Number("1"), - codersdk.FeatureSCIM: json.Number("1"), - codersdk.FeatureBrowserOnly: json.Number("1"), - codersdk.FeatureWorkspaceQuota: json.Number("0"), - codersdk.FeatureTemplateRBAC: json.Number("1"), + codersdk.FeatureUserLimit: json.Number("0"), + codersdk.FeatureAuditLog: json.Number("1"), + codersdk.FeatureSCIM: json.Number("1"), + codersdk.FeatureBrowserOnly: json.Number("1"), + codersdk.FeatureWorkspaceQuota: json.Number("0"), + codersdk.FeatureHighAvailability: json.Number("0"), + codersdk.FeatureTemplateRBAC: json.Number("1"), }, licenses[0].Claims["features"]) assert.Equal(t, int32(2), licenses[1].ID) assert.Equal(t, "testing2", licenses[1].Claims["account_id"]) assert.Equal(t, true, licenses[1].Claims["trial"]) assert.Equal(t, map[string]interface{}{ - codersdk.FeatureUserLimit: json.Number("200"), - codersdk.FeatureAuditLog: json.Number("1"), - codersdk.FeatureSCIM: json.Number("1"), - codersdk.FeatureBrowserOnly: json.Number("1"), - codersdk.FeatureWorkspaceQuota: json.Number("0"), - codersdk.FeatureTemplateRBAC: json.Number("0"), + codersdk.FeatureUserLimit: json.Number("200"), + codersdk.FeatureAuditLog: json.Number("1"), + codersdk.FeatureSCIM: json.Number("1"), + codersdk.FeatureBrowserOnly: json.Number("1"), + codersdk.FeatureWorkspaceQuota: json.Number("0"), + codersdk.FeatureHighAvailability: json.Number("0"), + codersdk.FeatureTemplateRBAC: json.Number("0"), }, licenses[1].Claims["features"]) }) } diff --git a/enterprise/coderd/replicas.go b/enterprise/coderd/replicas.go new file mode 100644 index 0000000000000..906597f257f04 --- /dev/null +++ b/enterprise/coderd/replicas.go @@ -0,0 +1,37 @@ +package coderd + +import ( + "net/http" + + "github.com/coder/coder/coderd/database" + "github.com/coder/coder/coderd/httpapi" + "github.com/coder/coder/coderd/rbac" + "github.com/coder/coder/codersdk" +) + +// replicas returns the number of replicas that are active in Coder. +func (api *API) replicas(rw http.ResponseWriter, r *http.Request) { + if !api.AGPL.Authorize(r, rbac.ActionRead, rbac.ResourceReplicas) { + httpapi.ResourceNotFound(rw) + return + } + + replicas := api.replicaManager.All() + res := make([]codersdk.Replica, 0, len(replicas)) + for _, replica := range replicas { + res = append(res, convertReplica(replica)) + } + httpapi.Write(r.Context(), rw, http.StatusOK, res) +} + +func convertReplica(replica database.Replica) codersdk.Replica { + return codersdk.Replica{ + ID: replica.ID, + Hostname: replica.Hostname, + CreatedAt: replica.CreatedAt, + RelayAddress: replica.RelayAddress, + RegionID: replica.RegionID, + Error: replica.Error, + DatabaseLatency: replica.DatabaseLatency, + } +} diff --git a/enterprise/coderd/replicas_test.go b/enterprise/coderd/replicas_test.go new file mode 100644 index 0000000000000..7a3e130cf7770 --- /dev/null +++ b/enterprise/coderd/replicas_test.go @@ -0,0 +1,138 @@ +package coderd_test + +import ( + "context" + "crypto/tls" + "testing" + + "github.com/stretchr/testify/require" + + "cdr.dev/slog" + "cdr.dev/slog/sloggers/slogtest" + + "github.com/coder/coder/coderd/coderdtest" + "github.com/coder/coder/coderd/database/dbtestutil" + "github.com/coder/coder/codersdk" + "github.com/coder/coder/enterprise/coderd/coderdenttest" + "github.com/coder/coder/testutil" +) + +func TestReplicas(t *testing.T) { + t.Parallel() + t.Run("ErrorWithoutLicense", func(t *testing.T) { + t.Parallel() + db, pubsub := dbtestutil.NewDB(t) + firstClient := coderdenttest.New(t, &coderdenttest.Options{ + Options: &coderdtest.Options{ + IncludeProvisionerDaemon: true, + Database: db, + Pubsub: pubsub, + }, + }) + _ = coderdtest.CreateFirstUser(t, firstClient) + secondClient, _, secondAPI := coderdenttest.NewWithAPI(t, &coderdenttest.Options{ + Options: &coderdtest.Options{ + Database: db, + Pubsub: pubsub, + }, + }) + secondClient.SessionToken = firstClient.SessionToken + ents, err := secondClient.Entitlements(context.Background()) + require.NoError(t, err) + require.Len(t, ents.Errors, 1) + _ = secondAPI.Close() + + ents, err = firstClient.Entitlements(context.Background()) + require.NoError(t, err) + require.Len(t, ents.Warnings, 0) + }) + t.Run("ConnectAcrossMultiple", func(t *testing.T) { + t.Parallel() + db, pubsub := dbtestutil.NewDB(t) + firstClient := coderdenttest.New(t, &coderdenttest.Options{ + Options: &coderdtest.Options{ + IncludeProvisionerDaemon: true, + Database: db, + Pubsub: pubsub, + }, + }) + firstUser := coderdtest.CreateFirstUser(t, firstClient) + coderdenttest.AddLicense(t, firstClient, coderdenttest.LicenseOptions{ + HighAvailability: true, + }) + + secondClient := coderdenttest.New(t, &coderdenttest.Options{ + Options: &coderdtest.Options{ + Database: db, + Pubsub: pubsub, + }, + }) + secondClient.SessionToken = firstClient.SessionToken + replicas, err := secondClient.Replicas(context.Background()) + require.NoError(t, err) + require.Len(t, replicas, 2) + + _, agent := setupWorkspaceAgent(t, firstClient, firstUser, 0) + conn, err := secondClient.DialWorkspaceAgent(context.Background(), agent.ID, &codersdk.DialWorkspaceAgentOptions{ + BlockEndpoints: true, + Logger: slogtest.Make(t, nil).Leveled(slog.LevelDebug), + }) + require.NoError(t, err) + require.Eventually(t, func() bool { + ctx, cancelFunc := context.WithTimeout(context.Background(), testutil.WaitShort) + defer cancelFunc() + _, err = conn.Ping(ctx) + return err == nil + }, testutil.WaitLong, testutil.IntervalFast) + _ = conn.Close() + }) + t.Run("ConnectAcrossMultipleTLS", func(t *testing.T) { + t.Parallel() + db, pubsub := dbtestutil.NewDB(t) + certificates := []tls.Certificate{testutil.GenerateTLSCertificate(t, "localhost")} + firstClient := coderdenttest.New(t, &coderdenttest.Options{ + Options: &coderdtest.Options{ + IncludeProvisionerDaemon: true, + Database: db, + Pubsub: pubsub, + TLSCertificates: certificates, + }, + }) + firstUser := coderdtest.CreateFirstUser(t, firstClient) + coderdenttest.AddLicense(t, firstClient, coderdenttest.LicenseOptions{ + HighAvailability: true, + }) + + secondClient := coderdenttest.New(t, &coderdenttest.Options{ + Options: &coderdtest.Options{ + Database: db, + Pubsub: pubsub, + TLSCertificates: certificates, + }, + }) + secondClient.SessionToken = firstClient.SessionToken + replicas, err := secondClient.Replicas(context.Background()) + require.NoError(t, err) + require.Len(t, replicas, 2) + + _, agent := setupWorkspaceAgent(t, firstClient, firstUser, 0) + conn, err := secondClient.DialWorkspaceAgent(context.Background(), agent.ID, &codersdk.DialWorkspaceAgentOptions{ + BlockEndpoints: true, + Logger: slogtest.Make(t, nil).Named("client").Leveled(slog.LevelDebug), + }) + require.NoError(t, err) + require.Eventually(t, func() bool { + ctx, cancelFunc := context.WithTimeout(context.Background(), testutil.IntervalSlow) + defer cancelFunc() + _, err = conn.Ping(ctx) + return err == nil + }, testutil.WaitLong, testutil.IntervalFast) + _ = conn.Close() + replicas, err = secondClient.Replicas(context.Background()) + require.NoError(t, err) + require.Len(t, replicas, 2) + for _, replica := range replicas { + require.Empty(t, replica.Error) + } + }) +} diff --git a/enterprise/coderd/templates_test.go b/enterprise/coderd/templates_test.go index fe6dd6f687f8c..87aa5a4ca83d8 100644 --- a/enterprise/coderd/templates_test.go +++ b/enterprise/coderd/templates_test.go @@ -23,7 +23,7 @@ func TestTemplateACL(t *testing.T) { client := coderdenttest.New(t, nil) user := coderdtest.CreateFirstUser(t, client) _ = coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{ - TemplateRBACEnabled: true, + TemplateRBAC: true, }) _, user2 := coderdtest.CreateAnotherUserWithUser(t, client, user.OrganizationID) @@ -64,7 +64,7 @@ func TestTemplateACL(t *testing.T) { client := coderdenttest.New(t, nil) user := coderdtest.CreateFirstUser(t, client) _ = coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{ - TemplateRBACEnabled: true, + TemplateRBAC: true, }) _, user1 := coderdtest.CreateAnotherUserWithUser(t, client, user.OrganizationID) @@ -88,7 +88,7 @@ func TestTemplateACL(t *testing.T) { client := coderdenttest.New(t, nil) user := coderdtest.CreateFirstUser(t, client) _ = coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{ - TemplateRBACEnabled: true, + TemplateRBAC: true, }) client1, _ := coderdtest.CreateAnotherUserWithUser(t, client, user.OrganizationID) @@ -138,7 +138,7 @@ func TestTemplateACL(t *testing.T) { client := coderdenttest.New(t, nil) user := coderdtest.CreateFirstUser(t, client) _ = coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{ - TemplateRBACEnabled: true, + TemplateRBAC: true, }) _, user1 := coderdtest.CreateAnotherUserWithUser(t, client, user.OrganizationID) @@ -176,7 +176,7 @@ func TestTemplateACL(t *testing.T) { client := coderdenttest.New(t, nil) user := coderdtest.CreateFirstUser(t, client) _ = coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{ - TemplateRBACEnabled: true, + TemplateRBAC: true, }) _, user1 := coderdtest.CreateAnotherUserWithUser(t, client, user.OrganizationID) @@ -214,7 +214,7 @@ func TestTemplateACL(t *testing.T) { client := coderdenttest.New(t, nil) user := coderdtest.CreateFirstUser(t, client) _ = coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{ - TemplateRBACEnabled: true, + TemplateRBAC: true, }) version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, nil) @@ -262,7 +262,7 @@ func TestTemplateACL(t *testing.T) { client := coderdenttest.New(t, nil) user := coderdtest.CreateFirstUser(t, client) _ = coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{ - TemplateRBACEnabled: true, + TemplateRBAC: true, }) client1, user1 := coderdtest.CreateAnotherUserWithUser(t, client, user.OrganizationID) @@ -318,7 +318,7 @@ func TestUpdateTemplateACL(t *testing.T) { client := coderdenttest.New(t, nil) user := coderdtest.CreateFirstUser(t, client) _ = coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{ - TemplateRBACEnabled: true, + TemplateRBAC: true, }) _, user2 := coderdtest.CreateAnotherUserWithUser(t, client, user.OrganizationID) @@ -361,7 +361,7 @@ func TestUpdateTemplateACL(t *testing.T) { client := coderdenttest.New(t, nil) user := coderdtest.CreateFirstUser(t, client) _ = coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{ - TemplateRBACEnabled: true, + TemplateRBAC: true, }) _, user2 := coderdtest.CreateAnotherUserWithUser(t, client, user.OrganizationID) @@ -422,7 +422,7 @@ func TestUpdateTemplateACL(t *testing.T) { client := coderdenttest.New(t, nil) user := coderdtest.CreateFirstUser(t, client) _ = coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{ - TemplateRBACEnabled: true, + TemplateRBAC: true, }) version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, nil) @@ -447,7 +447,7 @@ func TestUpdateTemplateACL(t *testing.T) { client := coderdenttest.New(t, nil) user := coderdtest.CreateFirstUser(t, client) _ = coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{ - TemplateRBACEnabled: true, + TemplateRBAC: true, }) version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, nil) @@ -472,7 +472,7 @@ func TestUpdateTemplateACL(t *testing.T) { client := coderdenttest.New(t, nil) user := coderdtest.CreateFirstUser(t, client) _ = coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{ - TemplateRBACEnabled: true, + TemplateRBAC: true, }) _, user2 := coderdtest.CreateAnotherUserWithUser(t, client, user.OrganizationID) @@ -498,7 +498,7 @@ func TestUpdateTemplateACL(t *testing.T) { client := coderdenttest.New(t, nil) user := coderdtest.CreateFirstUser(t, client) _ = coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{ - TemplateRBACEnabled: true, + TemplateRBAC: true, }) client2, user2 := coderdtest.CreateAnotherUserWithUser(t, client, user.OrganizationID) @@ -533,7 +533,7 @@ func TestUpdateTemplateACL(t *testing.T) { client := coderdenttest.New(t, nil) user := coderdtest.CreateFirstUser(t, client) _ = coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{ - TemplateRBACEnabled: true, + TemplateRBAC: true, }) client2, user2 := coderdtest.CreateAnotherUserWithUser(t, client, user.OrganizationID) @@ -575,7 +575,7 @@ func TestUpdateTemplateACL(t *testing.T) { client := coderdenttest.New(t, nil) user := coderdtest.CreateFirstUser(t, client) _ = coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{ - TemplateRBACEnabled: true, + TemplateRBAC: true, }) version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, nil) @@ -597,7 +597,7 @@ func TestUpdateTemplateACL(t *testing.T) { client := coderdenttest.New(t, nil) user := coderdtest.CreateFirstUser(t, client) _ = coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{ - TemplateRBACEnabled: true, + TemplateRBAC: true, }) client1, user1 := coderdtest.CreateAnotherUserWithUser(t, client, user.OrganizationID) @@ -662,7 +662,7 @@ func TestUpdateTemplateACL(t *testing.T) { client := coderdenttest.New(t, nil) user := coderdtest.CreateFirstUser(t, client) _ = coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{ - TemplateRBACEnabled: true, + TemplateRBAC: true, }) client1, _ := coderdtest.CreateAnotherUserWithUser(t, client, user.OrganizationID) diff --git a/enterprise/coderd/workspaceagents_test.go b/enterprise/coderd/workspaceagents_test.go index 9fe3cfeaa3064..18285bcb94317 100644 --- a/enterprise/coderd/workspaceagents_test.go +++ b/enterprise/coderd/workspaceagents_test.go @@ -2,6 +2,7 @@ package coderd_test import ( "context" + "crypto/tls" "fmt" "net/http" "testing" @@ -9,7 +10,6 @@ import ( "github.com/google/uuid" "github.com/stretchr/testify/require" - "cdr.dev/slog" "cdr.dev/slog/sloggers/slogtest" "github.com/coder/coder/agent" "github.com/coder/coder/coderd/coderdtest" @@ -42,7 +42,7 @@ func TestBlockNonBrowser(t *testing.T) { BrowserOnly: true, }) _, agent := setupWorkspaceAgent(t, client, user, 0) - _, err := client.DialWorkspaceAgentTailnet(context.Background(), slog.Logger{}, agent.ID) + _, err := client.DialWorkspaceAgent(context.Background(), agent.ID, nil) var apiErr *codersdk.Error require.ErrorAs(t, err, &apiErr) require.Equal(t, http.StatusConflict, apiErr.StatusCode()) @@ -59,7 +59,7 @@ func TestBlockNonBrowser(t *testing.T) { BrowserOnly: false, }) _, agent := setupWorkspaceAgent(t, client, user, 0) - conn, err := client.DialWorkspaceAgentTailnet(context.Background(), slog.Logger{}, agent.ID) + conn, err := client.DialWorkspaceAgent(context.Background(), agent.ID, nil) require.NoError(t, err) _ = conn.Close() }) @@ -109,6 +109,14 @@ func setupWorkspaceAgent(t *testing.T, client *codersdk.Client, user codersdk.Cr workspace := coderdtest.CreateWorkspace(t, client, user.OrganizationID, template.ID) coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID) agentClient := codersdk.New(client.URL) + agentClient.HTTPClient = &http.Client{ + Transport: &http.Transport{ + TLSClientConfig: &tls.Config{ + //nolint:gosec + InsecureSkipVerify: true, + }, + }, + } agentClient.SessionToken = authToken agentCloser := agent.New(agent.Options{ FetchMetadata: agentClient.WorkspaceAgentMetadata, diff --git a/enterprise/coderd/workspaces_test.go b/enterprise/coderd/workspaces_test.go index 33984e970d2af..824b3febb191c 100644 --- a/enterprise/coderd/workspaces_test.go +++ b/enterprise/coderd/workspaces_test.go @@ -26,7 +26,7 @@ func TestCreateWorkspace(t *testing.T) { client := coderdenttest.New(t, nil) user := coderdtest.CreateFirstUser(t, client) _ = coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{ - TemplateRBACEnabled: true, + TemplateRBAC: true, }) version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, nil) diff --git a/enterprise/derpmesh/derpmesh.go b/enterprise/derpmesh/derpmesh.go new file mode 100644 index 0000000000000..3982542167073 --- /dev/null +++ b/enterprise/derpmesh/derpmesh.go @@ -0,0 +1,165 @@ +package derpmesh + +import ( + "context" + "crypto/tls" + "net" + "net/url" + "sync" + + "golang.org/x/xerrors" + "tailscale.com/derp" + "tailscale.com/derp/derphttp" + "tailscale.com/types/key" + + "github.com/coder/coder/tailnet" + + "cdr.dev/slog" +) + +// New constructs a new mesh for DERP servers. +func New(logger slog.Logger, server *derp.Server, tlsConfig *tls.Config) *Mesh { + return &Mesh{ + logger: logger, + server: server, + tlsConfig: tlsConfig, + ctx: context.Background(), + closed: make(chan struct{}), + active: make(map[string]context.CancelFunc), + } +} + +type Mesh struct { + logger slog.Logger + server *derp.Server + ctx context.Context + tlsConfig *tls.Config + + mutex sync.Mutex + closed chan struct{} + active map[string]context.CancelFunc +} + +// SetAddresses performs a diff of the incoming addresses and adds +// or removes DERP clients from the mesh. +// +// Connect is only used for testing to ensure DERPs are meshed before +// exchanging messages. +// nolint:revive +func (m *Mesh) SetAddresses(addresses []string, connect bool) { + total := make(map[string]struct{}, 0) + for _, address := range addresses { + addressURL, err := url.Parse(address) + if err != nil { + m.logger.Error(m.ctx, "invalid address", slog.F("address", err), slog.Error(err)) + continue + } + derpURL, err := addressURL.Parse("/derp") + if err != nil { + m.logger.Error(m.ctx, "parse derp", slog.F("address", err), slog.Error(err)) + continue + } + address = derpURL.String() + + total[address] = struct{}{} + added, err := m.addAddress(address, connect) + if err != nil { + m.logger.Error(m.ctx, "failed to add address", slog.F("address", address), slog.Error(err)) + continue + } + if added { + m.logger.Debug(m.ctx, "added mesh address", slog.F("address", address)) + } + } + + m.mutex.Lock() + for address := range m.active { + _, found := total[address] + if found { + continue + } + removed := m.removeAddress(address) + if removed { + m.logger.Debug(m.ctx, "removed mesh address", slog.F("address", address)) + } + } + m.mutex.Unlock() +} + +// addAddress begins meshing with a new address. It returns false if the address is already being meshed with. +// It's expected that this is a full HTTP address with a path. +// e.g. http://127.0.0.1:8080/derp +// nolint:revive +func (m *Mesh) addAddress(address string, connect bool) (bool, error) { + m.mutex.Lock() + defer m.mutex.Unlock() + if m.isClosed() { + return false, nil + } + _, isActive := m.active[address] + if isActive { + return false, nil + } + client, err := derphttp.NewClient(m.server.PrivateKey(), address, tailnet.Logger(m.logger.Named("client"))) + if err != nil { + return false, xerrors.Errorf("create derp client: %w", err) + } + client.TLSConfig = m.tlsConfig + client.MeshKey = m.server.MeshKey() + client.SetURLDialer(func(ctx context.Context, network, addr string) (net.Conn, error) { + var dialer net.Dialer + return dialer.DialContext(ctx, network, addr) + }) + if connect { + _ = client.Connect(m.ctx) + } + ctx, cancelFunc := context.WithCancel(m.ctx) + closed := make(chan struct{}) + closeFunc := func() { + cancelFunc() + _ = client.Close() + <-closed + } + m.active[address] = closeFunc + go func() { + defer close(closed) + client.RunWatchConnectionLoop(ctx, m.server.PublicKey(), tailnet.Logger(m.logger.Named("loop")), func(np key.NodePublic) { + m.server.AddPacketForwarder(np, client) + }, func(np key.NodePublic) { + m.server.RemovePacketForwarder(np, client) + }) + }() + return true, nil +} + +// removeAddress stops meshing with a given address. +func (m *Mesh) removeAddress(address string) bool { + cancelFunc, isActive := m.active[address] + if isActive { + cancelFunc() + } + return isActive +} + +// Close ends all active meshes with the DERP server. +func (m *Mesh) Close() error { + m.mutex.Lock() + defer m.mutex.Unlock() + if m.isClosed() { + return nil + } + close(m.closed) + for _, cancelFunc := range m.active { + cancelFunc() + } + return nil +} + +func (m *Mesh) isClosed() bool { + select { + case <-m.closed: + return true + default: + } + return false +} diff --git a/enterprise/derpmesh/derpmesh_test.go b/enterprise/derpmesh/derpmesh_test.go new file mode 100644 index 0000000000000..7fad141238442 --- /dev/null +++ b/enterprise/derpmesh/derpmesh_test.go @@ -0,0 +1,219 @@ +package derpmesh_test + +import ( + "context" + "crypto/tls" + "crypto/x509" + "errors" + "io" + "net/http/httptest" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.uber.org/goleak" + "tailscale.com/derp" + "tailscale.com/derp/derphttp" + "tailscale.com/types/key" + + "cdr.dev/slog" + "cdr.dev/slog/sloggers/slogtest" + "github.com/coder/coder/enterprise/derpmesh" + "github.com/coder/coder/tailnet" + "github.com/coder/coder/testutil" +) + +func TestMain(m *testing.M) { + goleak.VerifyTestMain(m) +} + +func TestDERPMesh(t *testing.T) { + t.Parallel() + commonName := "something.org" + rawCert := testutil.GenerateTLSCertificate(t, commonName) + certificate, err := x509.ParseCertificate(rawCert.Certificate[0]) + require.NoError(t, err) + pool := x509.NewCertPool() + pool.AddCert(certificate) + tlsConfig := &tls.Config{ + MinVersion: tls.VersionTLS12, + ServerName: commonName, + RootCAs: pool, + Certificates: []tls.Certificate{rawCert}, + } + + t.Run("ExchangeMessages", func(t *testing.T) { + // This tests messages passing through multiple DERP servers. + t.Parallel() + firstServer, firstServerURL := startDERP(t, tlsConfig) + defer firstServer.Close() + secondServer, secondServerURL := startDERP(t, tlsConfig) + firstMesh := derpmesh.New(slogtest.Make(t, nil).Named("first").Leveled(slog.LevelDebug), firstServer, tlsConfig) + firstMesh.SetAddresses([]string{secondServerURL}, true) + secondMesh := derpmesh.New(slogtest.Make(t, nil).Named("second").Leveled(slog.LevelDebug), secondServer, tlsConfig) + secondMesh.SetAddresses([]string{firstServerURL}, true) + defer firstMesh.Close() + defer secondMesh.Close() + + first := key.NewNode() + second := key.NewNode() + firstClient, err := derphttp.NewClient(first, secondServerURL, tailnet.Logger(slogtest.Make(t, nil))) + require.NoError(t, err) + firstClient.TLSConfig = tlsConfig + secondClient, err := derphttp.NewClient(second, firstServerURL, tailnet.Logger(slogtest.Make(t, nil))) + require.NoError(t, err) + secondClient.TLSConfig = tlsConfig + err = secondClient.Connect(context.Background()) + require.NoError(t, err) + + closed := make(chan struct{}) + ctx, cancelFunc := context.WithCancel(context.Background()) + defer cancelFunc() + sent := []byte("hello world") + go func() { + defer close(closed) + ticker := time.NewTicker(50 * time.Millisecond) + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + } + err = firstClient.Send(second.Public(), sent) + require.NoError(t, err) + } + }() + + got := recvData(t, secondClient) + require.Equal(t, sent, got) + cancelFunc() + <-closed + }) + t.Run("RemoveAddress", func(t *testing.T) { + // This tests messages passing through multiple DERP servers. + t.Parallel() + server, serverURL := startDERP(t, tlsConfig) + mesh := derpmesh.New(slogtest.Make(t, nil).Named("first").Leveled(slog.LevelDebug), server, tlsConfig) + mesh.SetAddresses([]string{"http://fake.com"}, false) + // This should trigger a removal... + mesh.SetAddresses([]string{}, false) + defer mesh.Close() + + first := key.NewNode() + second := key.NewNode() + firstClient, err := derphttp.NewClient(first, serverURL, tailnet.Logger(slogtest.Make(t, nil))) + require.NoError(t, err) + firstClient.TLSConfig = tlsConfig + secondClient, err := derphttp.NewClient(second, serverURL, tailnet.Logger(slogtest.Make(t, nil))) + require.NoError(t, err) + secondClient.TLSConfig = tlsConfig + err = secondClient.Connect(context.Background()) + require.NoError(t, err) + + closed := make(chan struct{}) + ctx, cancelFunc := context.WithCancel(context.Background()) + defer cancelFunc() + sent := []byte("hello world") + go func() { + defer close(closed) + ticker := time.NewTicker(50 * time.Millisecond) + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + } + err = firstClient.Send(second.Public(), sent) + require.NoError(t, err) + } + }() + got := recvData(t, secondClient) + require.Equal(t, sent, got) + cancelFunc() + <-closed + }) + t.Run("TwentyMeshes", func(t *testing.T) { + t.Parallel() + meshes := make([]*derpmesh.Mesh, 0, 20) + serverURLs := make([]string, 0, 20) + for i := 0; i < 20; i++ { + server, url := startDERP(t, tlsConfig) + mesh := derpmesh.New(slogtest.Make(t, nil).Named("mesh").Leveled(slog.LevelDebug), server, tlsConfig) + t.Cleanup(func() { + _ = server.Close() + _ = mesh.Close() + }) + serverURLs = append(serverURLs, url) + meshes = append(meshes, mesh) + } + for _, mesh := range meshes { + mesh.SetAddresses(serverURLs, true) + } + + first := key.NewNode() + second := key.NewNode() + firstClient, err := derphttp.NewClient(first, serverURLs[9], tailnet.Logger(slogtest.Make(t, nil))) + require.NoError(t, err) + firstClient.TLSConfig = tlsConfig + secondClient, err := derphttp.NewClient(second, serverURLs[16], tailnet.Logger(slogtest.Make(t, nil))) + require.NoError(t, err) + secondClient.TLSConfig = tlsConfig + err = secondClient.Connect(context.Background()) + require.NoError(t, err) + + closed := make(chan struct{}) + ctx, cancelFunc := context.WithCancel(context.Background()) + defer cancelFunc() + sent := []byte("hello world") + go func() { + defer close(closed) + ticker := time.NewTicker(50 * time.Millisecond) + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + } + err = firstClient.Send(second.Public(), sent) + require.NoError(t, err) + } + }() + + got := recvData(t, secondClient) + require.Equal(t, sent, got) + cancelFunc() + <-closed + }) +} + +func recvData(t *testing.T, client *derphttp.Client) []byte { + for { + msg, err := client.Recv() + if errors.Is(err, io.EOF) { + return nil + } + assert.NoError(t, err) + t.Logf("derp: %T", msg) + switch msg := msg.(type) { + case derp.ReceivedPacket: + return msg.Data + default: + // Drop all others! + } + } +} + +func startDERP(t *testing.T, tlsConfig *tls.Config) (*derp.Server, string) { + logf := tailnet.Logger(slogtest.Make(t, nil)) + d := derp.NewServer(key.NewNode(), logf) + d.SetMeshKey("some-key") + server := httptest.NewUnstartedServer(derphttp.Handler(d)) + server.TLS = tlsConfig + server.StartTLS() + t.Cleanup(func() { + _ = d.Close() + }) + t.Cleanup(server.Close) + return d, server.URL +} diff --git a/enterprise/replicasync/replicasync.go b/enterprise/replicasync/replicasync.go new file mode 100644 index 0000000000000..0534c55246824 --- /dev/null +++ b/enterprise/replicasync/replicasync.go @@ -0,0 +1,391 @@ +package replicasync + +import ( + "context" + "crypto/tls" + "database/sql" + "errors" + "fmt" + "net/http" + "os" + "strings" + "sync" + "time" + + "github.com/google/uuid" + "golang.org/x/xerrors" + + "cdr.dev/slog" + + "github.com/coder/coder/buildinfo" + "github.com/coder/coder/coderd/database" +) + +var ( + PubsubEvent = "replica" +) + +type Options struct { + CleanupInterval time.Duration + UpdateInterval time.Duration + PeerTimeout time.Duration + RelayAddress string + RegionID int32 + TLSConfig *tls.Config +} + +// New registers the replica with the database and periodically updates to ensure +// it's healthy. It contacts all other alive replicas to ensure they are reachable. +func New(ctx context.Context, logger slog.Logger, db database.Store, pubsub database.Pubsub, options *Options) (*Manager, error) { + if options == nil { + options = &Options{} + } + if options.PeerTimeout == 0 { + options.PeerTimeout = 3 * time.Second + } + if options.UpdateInterval == 0 { + options.UpdateInterval = 5 * time.Second + } + if options.CleanupInterval == 0 { + // The cleanup interval can be quite long, because it's + // primary purpose is to clean up dead replicas. + options.CleanupInterval = 30 * time.Minute + } + hostname, err := os.Hostname() + if err != nil { + return nil, xerrors.Errorf("get hostname: %w", err) + } + databaseLatency, err := db.Ping(ctx) + if err != nil { + return nil, xerrors.Errorf("ping database: %w", err) + } + id := uuid.New() + replica, err := db.InsertReplica(ctx, database.InsertReplicaParams{ + ID: id, + CreatedAt: database.Now(), + StartedAt: database.Now(), + UpdatedAt: database.Now(), + Hostname: hostname, + RegionID: options.RegionID, + RelayAddress: options.RelayAddress, + Version: buildinfo.Version(), + DatabaseLatency: int32(databaseLatency.Microseconds()), + }) + if err != nil { + return nil, xerrors.Errorf("insert replica: %w", err) + } + err = pubsub.Publish(PubsubEvent, []byte(id.String())) + if err != nil { + return nil, xerrors.Errorf("publish new replica: %w", err) + } + ctx, cancelFunc := context.WithCancel(ctx) + manager := &Manager{ + id: id, + options: options, + db: db, + pubsub: pubsub, + self: replica, + logger: logger, + closed: make(chan struct{}), + closeCancel: cancelFunc, + } + err = manager.syncReplicas(ctx) + if err != nil { + return nil, xerrors.Errorf("run replica: %w", err) + } + peers := manager.Regional() + if len(peers) > 0 { + self := manager.Self() + if self.RelayAddress == "" { + return nil, xerrors.Errorf("a relay address must be specified when running multiple replicas in the same region") + } + } + + err = manager.subscribe(ctx) + if err != nil { + return nil, xerrors.Errorf("subscribe: %w", err) + } + manager.closeWait.Add(1) + go manager.loop(ctx) + return manager, nil +} + +// Manager keeps the replica up to date and in sync with other replicas. +type Manager struct { + id uuid.UUID + options *Options + db database.Store + pubsub database.Pubsub + logger slog.Logger + + closeWait sync.WaitGroup + closeMutex sync.Mutex + closed chan (struct{}) + closeCancel context.CancelFunc + + self database.Replica + mutex sync.Mutex + peers []database.Replica + callback func() +} + +// updateInterval is used to determine a replicas state. +// If the replica was updated > the time, it's considered healthy. +// If the replica was updated < the time, it's considered stale. +func (m *Manager) updateInterval() time.Time { + return database.Now().Add(-3 * m.options.UpdateInterval) +} + +// loop runs the replica update sequence on an update interval. +func (m *Manager) loop(ctx context.Context) { + defer m.closeWait.Done() + updateTicker := time.NewTicker(m.options.UpdateInterval) + defer updateTicker.Stop() + deleteTicker := time.NewTicker(m.options.CleanupInterval) + defer deleteTicker.Stop() + for { + select { + case <-ctx.Done(): + return + case <-deleteTicker.C: + err := m.db.DeleteReplicasUpdatedBefore(ctx, m.updateInterval()) + if err != nil { + m.logger.Warn(ctx, "delete old replicas", slog.Error(err)) + } + continue + case <-updateTicker.C: + } + err := m.syncReplicas(ctx) + if err != nil && !errors.Is(err, context.Canceled) { + m.logger.Warn(ctx, "run replica update loop", slog.Error(err)) + } + } +} + +// subscribe listens for new replica information! +func (m *Manager) subscribe(ctx context.Context) error { + var ( + needsUpdate = false + updating = false + updateMutex = sync.Mutex{} + ) + + // This loop will continually update nodes as updates are processed. + // The intent is to always be up to date without spamming the run + // function, so if a new update comes in while one is being processed, + // it will reprocess afterwards. + var update func() + update = func() { + err := m.syncReplicas(ctx) + if err != nil && !errors.Is(err, context.Canceled) { + m.logger.Warn(ctx, "run replica from subscribe", slog.Error(err)) + } + updateMutex.Lock() + if needsUpdate { + needsUpdate = false + updateMutex.Unlock() + update() + return + } + updating = false + updateMutex.Unlock() + } + cancelFunc, err := m.pubsub.Subscribe(PubsubEvent, func(ctx context.Context, message []byte) { + updateMutex.Lock() + defer updateMutex.Unlock() + id, err := uuid.Parse(string(message)) + if err != nil { + return + } + // Don't process updates for ourself! + if id == m.id { + return + } + if updating { + needsUpdate = true + return + } + updating = true + go update() + }) + if err != nil { + return err + } + go func() { + <-ctx.Done() + cancelFunc() + }() + return nil +} + +func (m *Manager) syncReplicas(ctx context.Context) error { + m.closeMutex.Lock() + m.closeWait.Add(1) + m.closeMutex.Unlock() + defer m.closeWait.Done() + // Expect replicas to update once every three times the interval... + // If they don't, assume death! + replicas, err := m.db.GetReplicasUpdatedAfter(ctx, m.updateInterval()) + if err != nil { + return xerrors.Errorf("get replicas: %w", err) + } + + m.mutex.Lock() + m.peers = make([]database.Replica, 0, len(replicas)) + for _, replica := range replicas { + if replica.ID == m.id { + continue + } + m.peers = append(m.peers, replica) + } + m.mutex.Unlock() + + client := http.Client{ + Timeout: m.options.PeerTimeout, + Transport: &http.Transport{ + TLSClientConfig: m.options.TLSConfig, + }, + } + defer client.CloseIdleConnections() + var wg sync.WaitGroup + var mu sync.Mutex + failed := make([]string, 0) + for _, peer := range m.Regional() { + wg.Add(1) + go func(peer database.Replica) { + defer wg.Done() + req, err := http.NewRequestWithContext(ctx, http.MethodGet, peer.RelayAddress, nil) + if err != nil { + m.logger.Warn(ctx, "create http request for relay probe", + slog.F("relay_address", peer.RelayAddress), slog.Error(err)) + return + } + res, err := client.Do(req) + if err != nil { + mu.Lock() + failed = append(failed, fmt.Sprintf("relay %s (%s): %s", peer.Hostname, peer.RelayAddress, err)) + mu.Unlock() + return + } + _ = res.Body.Close() + }(peer) + } + wg.Wait() + replicaError := "" + if len(failed) > 0 { + replicaError = fmt.Sprintf("Failed to dial peers: %s", strings.Join(failed, ", ")) + } + + databaseLatency, err := m.db.Ping(ctx) + if err != nil { + return xerrors.Errorf("ping database: %w", err) + } + + replica, err := m.db.UpdateReplica(ctx, database.UpdateReplicaParams{ + ID: m.self.ID, + UpdatedAt: database.Now(), + StartedAt: m.self.StartedAt, + StoppedAt: m.self.StoppedAt, + RelayAddress: m.self.RelayAddress, + RegionID: m.self.RegionID, + Hostname: m.self.Hostname, + Version: m.self.Version, + Error: replicaError, + DatabaseLatency: int32(databaseLatency.Microseconds()), + }) + if err != nil { + return xerrors.Errorf("update replica: %w", err) + } + m.mutex.Lock() + defer m.mutex.Unlock() + if m.self.Error != replica.Error { + // Publish an update occurred! + err = m.pubsub.Publish(PubsubEvent, []byte(m.self.ID.String())) + if err != nil { + return xerrors.Errorf("publish replica update: %w", err) + } + } + m.self = replica + if m.callback != nil { + go m.callback() + } + return nil +} + +// Self represents the current replica. +func (m *Manager) Self() database.Replica { + m.mutex.Lock() + defer m.mutex.Unlock() + return m.self +} + +// All returns every replica, including itself. +func (m *Manager) All() []database.Replica { + m.mutex.Lock() + defer m.mutex.Unlock() + return append(m.peers[:], m.self) +} + +// Regional returns all replicas in the same region excluding itself. +func (m *Manager) Regional() []database.Replica { + m.mutex.Lock() + defer m.mutex.Unlock() + replicas := make([]database.Replica, 0) + for _, replica := range m.peers { + if replica.RegionID != m.self.RegionID { + continue + } + replicas = append(replicas, replica) + } + return replicas +} + +// SetCallback sets a function to execute whenever new peers +// are refreshed or updated. +func (m *Manager) SetCallback(callback func()) { + m.mutex.Lock() + defer m.mutex.Unlock() + m.callback = callback + // Instantly call the callback to inform replicas! + go callback() +} + +func (m *Manager) Close() error { + m.closeMutex.Lock() + select { + case <-m.closed: + m.closeMutex.Unlock() + return nil + default: + } + close(m.closed) + m.closeCancel() + m.closeWait.Wait() + m.closeMutex.Unlock() + m.mutex.Lock() + defer m.mutex.Unlock() + ctx, cancelFunc := context.WithTimeout(context.Background(), 5*time.Second) + defer cancelFunc() + _, err := m.db.UpdateReplica(ctx, database.UpdateReplicaParams{ + ID: m.self.ID, + UpdatedAt: database.Now(), + StartedAt: m.self.StartedAt, + StoppedAt: sql.NullTime{ + Time: database.Now(), + Valid: true, + }, + RelayAddress: m.self.RelayAddress, + RegionID: m.self.RegionID, + Hostname: m.self.Hostname, + Version: m.self.Version, + Error: m.self.Error, + }) + if err != nil { + return xerrors.Errorf("update replica: %w", err) + } + err = m.pubsub.Publish(PubsubEvent, []byte(m.self.ID.String())) + if err != nil { + return xerrors.Errorf("publish replica update: %w", err) + } + return nil +} diff --git a/enterprise/replicasync/replicasync_test.go b/enterprise/replicasync/replicasync_test.go new file mode 100644 index 0000000000000..b7709c1f6f814 --- /dev/null +++ b/enterprise/replicasync/replicasync_test.go @@ -0,0 +1,239 @@ +package replicasync_test + +import ( + "context" + "crypto/tls" + "crypto/x509" + "net/http" + "net/http/httptest" + "sync" + "testing" + "time" + + "github.com/google/uuid" + "github.com/stretchr/testify/require" + "go.uber.org/goleak" + + "cdr.dev/slog/sloggers/slogtest" + "github.com/coder/coder/coderd/database" + "github.com/coder/coder/coderd/database/databasefake" + "github.com/coder/coder/coderd/database/dbtestutil" + "github.com/coder/coder/enterprise/replicasync" + "github.com/coder/coder/testutil" +) + +func TestMain(m *testing.M) { + goleak.VerifyTestMain(m) +} + +func TestReplica(t *testing.T) { + t.Parallel() + t.Run("CreateOnNew", func(t *testing.T) { + // This ensures that a new replica is created on New. + t.Parallel() + db, pubsub := dbtestutil.NewDB(t) + closeChan := make(chan struct{}, 1) + cancel, err := pubsub.Subscribe(replicasync.PubsubEvent, func(ctx context.Context, message []byte) { + closeChan <- struct{}{} + }) + require.NoError(t, err) + defer cancel() + server, err := replicasync.New(context.Background(), slogtest.Make(t, nil), db, pubsub, nil) + require.NoError(t, err) + <-closeChan + _ = server.Close() + require.NoError(t, err) + }) + t.Run("ErrorsWithoutRelayAddress", func(t *testing.T) { + // Ensures that the replica reports a successful status for + // accessing all of its peers. + t.Parallel() + db, pubsub := dbtestutil.NewDB(t) + _, err := db.InsertReplica(context.Background(), database.InsertReplicaParams{ + ID: uuid.New(), + CreatedAt: database.Now(), + StartedAt: database.Now(), + UpdatedAt: database.Now(), + Hostname: "something", + }) + require.NoError(t, err) + _, err = replicasync.New(context.Background(), slogtest.Make(t, nil), db, pubsub, nil) + require.Error(t, err) + require.Equal(t, "a relay address must be specified when running multiple replicas in the same region", err.Error()) + }) + t.Run("ConnectsToPeerReplica", func(t *testing.T) { + // Ensures that the replica reports a successful status for + // accessing all of its peers. + t.Parallel() + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + defer srv.Close() + db, pubsub := dbtestutil.NewDB(t) + peer, err := db.InsertReplica(context.Background(), database.InsertReplicaParams{ + ID: uuid.New(), + CreatedAt: database.Now(), + StartedAt: database.Now(), + UpdatedAt: database.Now(), + Hostname: "something", + RelayAddress: srv.URL, + }) + require.NoError(t, err) + server, err := replicasync.New(context.Background(), slogtest.Make(t, nil), db, pubsub, &replicasync.Options{ + RelayAddress: "http://169.254.169.254", + }) + require.NoError(t, err) + require.Len(t, server.Regional(), 1) + require.Equal(t, peer.ID, server.Regional()[0].ID) + require.Empty(t, server.Self().Error) + _ = server.Close() + }) + t.Run("ConnectsToPeerReplicaTLS", func(t *testing.T) { + // Ensures that the replica reports a successful status for + // accessing all of its peers. + t.Parallel() + rawCert := testutil.GenerateTLSCertificate(t, "hello.org") + certificate, err := x509.ParseCertificate(rawCert.Certificate[0]) + require.NoError(t, err) + pool := x509.NewCertPool() + pool.AddCert(certificate) + // nolint:gosec + tlsConfig := &tls.Config{ + Certificates: []tls.Certificate{rawCert}, + ServerName: "hello.org", + RootCAs: pool, + } + srv := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + srv.TLS = tlsConfig + srv.StartTLS() + defer srv.Close() + db, pubsub := dbtestutil.NewDB(t) + peer, err := db.InsertReplica(context.Background(), database.InsertReplicaParams{ + ID: uuid.New(), + CreatedAt: database.Now(), + StartedAt: database.Now(), + UpdatedAt: database.Now(), + Hostname: "something", + RelayAddress: srv.URL, + }) + require.NoError(t, err) + server, err := replicasync.New(context.Background(), slogtest.Make(t, nil), db, pubsub, &replicasync.Options{ + RelayAddress: "http://169.254.169.254", + TLSConfig: tlsConfig, + }) + require.NoError(t, err) + require.Len(t, server.Regional(), 1) + require.Equal(t, peer.ID, server.Regional()[0].ID) + require.Empty(t, server.Self().Error) + _ = server.Close() + }) + t.Run("ConnectsToFakePeerWithError", func(t *testing.T) { + t.Parallel() + db, pubsub := dbtestutil.NewDB(t) + peer, err := db.InsertReplica(context.Background(), database.InsertReplicaParams{ + ID: uuid.New(), + CreatedAt: database.Now().Add(time.Minute), + StartedAt: database.Now().Add(time.Minute), + UpdatedAt: database.Now().Add(time.Minute), + Hostname: "something", + // Fake address to dial! + RelayAddress: "http://127.0.0.1:1", + }) + require.NoError(t, err) + server, err := replicasync.New(context.Background(), slogtest.Make(t, nil), db, pubsub, &replicasync.Options{ + PeerTimeout: 1 * time.Millisecond, + RelayAddress: "http://127.0.0.1:1", + }) + require.NoError(t, err) + require.Len(t, server.Regional(), 1) + require.Equal(t, peer.ID, server.Regional()[0].ID) + require.NotEmpty(t, server.Self().Error) + require.Contains(t, server.Self().Error, "Failed to dial peers") + _ = server.Close() + }) + t.Run("RefreshOnPublish", func(t *testing.T) { + // Refresh when a new replica appears! + t.Parallel() + db, pubsub := dbtestutil.NewDB(t) + server, err := replicasync.New(context.Background(), slogtest.Make(t, nil), db, pubsub, nil) + require.NoError(t, err) + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + defer srv.Close() + peer, err := db.InsertReplica(context.Background(), database.InsertReplicaParams{ + ID: uuid.New(), + RelayAddress: srv.URL, + UpdatedAt: database.Now(), + }) + require.NoError(t, err) + // Publish multiple times to ensure it can handle that case. + err = pubsub.Publish(replicasync.PubsubEvent, []byte(peer.ID.String())) + require.NoError(t, err) + err = pubsub.Publish(replicasync.PubsubEvent, []byte(peer.ID.String())) + require.NoError(t, err) + require.Eventually(t, func() bool { + return len(server.Regional()) == 1 + }, testutil.WaitShort, testutil.IntervalFast) + _ = server.Close() + }) + t.Run("DeletesOld", func(t *testing.T) { + t.Parallel() + db, pubsub := dbtestutil.NewDB(t) + _, err := db.InsertReplica(context.Background(), database.InsertReplicaParams{ + ID: uuid.New(), + UpdatedAt: database.Now().Add(-time.Hour), + }) + require.NoError(t, err) + server, err := replicasync.New(context.Background(), slogtest.Make(t, nil), db, pubsub, &replicasync.Options{ + RelayAddress: "google.com", + CleanupInterval: time.Millisecond, + }) + require.NoError(t, err) + defer server.Close() + require.Eventually(t, func() bool { + return len(server.Regional()) == 0 + }, testutil.WaitShort, testutil.IntervalFast) + }) + t.Run("TwentyConcurrent", func(t *testing.T) { + // Ensures that twenty concurrent replicas can spawn and all + // discover each other in parallel! + t.Parallel() + // This doesn't use the database fake because creating + // this many PostgreSQL connections takes some + // configuration tweaking. + db := databasefake.New() + pubsub := database.NewPubsubInMemory() + logger := slogtest.Make(t, nil) + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + defer srv.Close() + var wg sync.WaitGroup + count := 20 + wg.Add(count) + for i := 0; i < count; i++ { + server, err := replicasync.New(context.Background(), logger, db, pubsub, &replicasync.Options{ + RelayAddress: srv.URL, + }) + require.NoError(t, err) + t.Cleanup(func() { + _ = server.Close() + }) + done := false + server.SetCallback(func() { + if len(server.All()) != count { + return + } + if done { + return + } + done = true + wg.Done() + }) + } + wg.Wait() + }) +} diff --git a/enterprise/tailnet/coordinator.go b/enterprise/tailnet/coordinator.go new file mode 100644 index 0000000000000..5749d9ef47c7a --- /dev/null +++ b/enterprise/tailnet/coordinator.go @@ -0,0 +1,575 @@ +package tailnet + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "io" + "net" + "sync" + "time" + + "github.com/google/uuid" + "golang.org/x/xerrors" + + "cdr.dev/slog" + "github.com/coder/coder/coderd/database" + agpl "github.com/coder/coder/tailnet" +) + +// NewCoordinator creates a new high availability coordinator +// that uses PostgreSQL pubsub to exchange handshakes. +func NewCoordinator(logger slog.Logger, pubsub database.Pubsub) (agpl.Coordinator, error) { + ctx, cancelFunc := context.WithCancel(context.Background()) + coord := &haCoordinator{ + id: uuid.New(), + log: logger, + pubsub: pubsub, + closeFunc: cancelFunc, + close: make(chan struct{}), + nodes: map[uuid.UUID]*agpl.Node{}, + agentSockets: map[uuid.UUID]net.Conn{}, + agentToConnectionSockets: map[uuid.UUID]map[uuid.UUID]net.Conn{}, + } + + if err := coord.runPubsub(ctx); err != nil { + return nil, xerrors.Errorf("run coordinator pubsub: %w", err) + } + + return coord, nil +} + +type haCoordinator struct { + id uuid.UUID + log slog.Logger + mutex sync.RWMutex + pubsub database.Pubsub + close chan struct{} + closeFunc context.CancelFunc + + // nodes maps agent and connection IDs their respective node. + nodes map[uuid.UUID]*agpl.Node + // agentSockets maps agent IDs to their open websocket. + agentSockets map[uuid.UUID]net.Conn + // agentToConnectionSockets maps agent IDs to connection IDs of conns that + // are subscribed to updates for that agent. + agentToConnectionSockets map[uuid.UUID]map[uuid.UUID]net.Conn +} + +// Node returns an in-memory node by ID. +func (c *haCoordinator) Node(id uuid.UUID) *agpl.Node { + c.mutex.Lock() + defer c.mutex.Unlock() + node := c.nodes[id] + return node +} + +// ServeClient accepts a WebSocket connection that wants to connect to an agent +// with the specified ID. +func (c *haCoordinator) ServeClient(conn net.Conn, id uuid.UUID, agent uuid.UUID) error { + c.mutex.Lock() + // When a new connection is requested, we update it with the latest + // node of the agent. This allows the connection to establish. + node, ok := c.nodes[agent] + c.mutex.Unlock() + if ok { + data, err := json.Marshal([]*agpl.Node{node}) + if err != nil { + return xerrors.Errorf("marshal node: %w", err) + } + _, err = conn.Write(data) + if err != nil { + return xerrors.Errorf("write nodes: %w", err) + } + } else { + err := c.publishClientHello(agent) + if err != nil { + return xerrors.Errorf("publish client hello: %w", err) + } + } + + c.mutex.Lock() + connectionSockets, ok := c.agentToConnectionSockets[agent] + if !ok { + connectionSockets = map[uuid.UUID]net.Conn{} + c.agentToConnectionSockets[agent] = connectionSockets + } + + // Insert this connection into a map so the agent can publish node updates. + connectionSockets[id] = conn + c.mutex.Unlock() + + defer func() { + c.mutex.Lock() + defer c.mutex.Unlock() + // Clean all traces of this connection from the map. + delete(c.nodes, id) + connectionSockets, ok := c.agentToConnectionSockets[agent] + if !ok { + return + } + delete(connectionSockets, id) + if len(connectionSockets) != 0 { + return + } + delete(c.agentToConnectionSockets, agent) + }() + + decoder := json.NewDecoder(conn) + // Indefinitely handle messages from the client websocket. + for { + err := c.handleNextClientMessage(id, agent, decoder) + if err != nil { + if errors.Is(err, io.EOF) || errors.Is(err, io.ErrClosedPipe) { + return nil + } + return xerrors.Errorf("handle next client message: %w", err) + } + } +} + +func (c *haCoordinator) handleNextClientMessage(id, agent uuid.UUID, decoder *json.Decoder) error { + var node agpl.Node + err := decoder.Decode(&node) + if err != nil { + return xerrors.Errorf("read json: %w", err) + } + + c.mutex.Lock() + // Update the node of this client in our in-memory map. If an agent entirely + // shuts down and reconnects, it needs to be aware of all clients attempting + // to establish connections. + c.nodes[id] = &node + // Write the new node from this client to the actively connected agent. + agentSocket, ok := c.agentSockets[agent] + c.mutex.Unlock() + if !ok { + // If we don't own the agent locally, send it over pubsub to a node that + // owns the agent. + err := c.publishNodesToAgent(agent, []*agpl.Node{&node}) + if err != nil { + return xerrors.Errorf("publish node to agent") + } + return nil + } + + // Write the new node from this client to the actively + // connected agent. + data, err := json.Marshal([]*agpl.Node{&node}) + if err != nil { + return xerrors.Errorf("marshal nodes: %w", err) + } + + _, err = agentSocket.Write(data) + if err != nil { + if errors.Is(err, io.EOF) || errors.Is(err, io.ErrClosedPipe) { + return nil + } + return xerrors.Errorf("write json: %w", err) + } + + return nil +} + +// ServeAgent accepts a WebSocket connection to an agent that listens to +// incoming connections and publishes node updates. +func (c *haCoordinator) ServeAgent(conn net.Conn, id uuid.UUID) error { + // Tell clients on other instances to send a callmemaybe to us. + err := c.publishAgentHello(id) + if err != nil { + return xerrors.Errorf("publish agent hello: %w", err) + } + + // Publish all nodes on this instance that want to connect to this agent. + nodes := c.nodesSubscribedToAgent(id) + if len(nodes) > 0 { + data, err := json.Marshal(nodes) + if err != nil { + return xerrors.Errorf("marshal json: %w", err) + } + _, err = conn.Write(data) + if err != nil { + return xerrors.Errorf("write nodes: %w", err) + } + } + + // If an old agent socket is connected, we close it + // to avoid any leaks. This shouldn't ever occur because + // we expect one agent to be running. + c.mutex.Lock() + oldAgentSocket, ok := c.agentSockets[id] + if ok { + _ = oldAgentSocket.Close() + } + c.agentSockets[id] = conn + c.mutex.Unlock() + defer func() { + c.mutex.Lock() + defer c.mutex.Unlock() + delete(c.agentSockets, id) + delete(c.nodes, id) + }() + + decoder := json.NewDecoder(conn) + for { + node, err := c.handleAgentUpdate(id, decoder) + if err != nil { + if errors.Is(err, io.EOF) || errors.Is(err, io.ErrClosedPipe) { + return nil + } + return xerrors.Errorf("handle next agent message: %w", err) + } + + err = c.publishAgentToNodes(id, node) + if err != nil { + return xerrors.Errorf("publish agent to nodes: %w", err) + } + } +} + +func (c *haCoordinator) nodesSubscribedToAgent(agentID uuid.UUID) []*agpl.Node { + c.mutex.Lock() + defer c.mutex.Unlock() + sockets, ok := c.agentToConnectionSockets[agentID] + if !ok { + return nil + } + + nodes := make([]*agpl.Node, 0, len(sockets)) + for targetID := range sockets { + node, ok := c.nodes[targetID] + if !ok { + continue + } + nodes = append(nodes, node) + } + + return nodes +} + +func (c *haCoordinator) handleClientHello(id uuid.UUID) error { + c.mutex.Lock() + node, ok := c.nodes[id] + c.mutex.Unlock() + if !ok { + return nil + } + return c.publishAgentToNodes(id, node) +} + +func (c *haCoordinator) handleAgentUpdate(id uuid.UUID, decoder *json.Decoder) (*agpl.Node, error) { + var node agpl.Node + err := decoder.Decode(&node) + if err != nil { + return nil, xerrors.Errorf("read json: %w", err) + } + + c.mutex.Lock() + oldNode := c.nodes[id] + if oldNode != nil { + if oldNode.AsOf.After(node.AsOf) { + c.mutex.Unlock() + return oldNode, nil + } + } + c.nodes[id] = &node + connectionSockets, ok := c.agentToConnectionSockets[id] + if !ok { + c.mutex.Unlock() + return &node, nil + } + + data, err := json.Marshal([]*agpl.Node{&node}) + if err != nil { + c.mutex.Unlock() + return nil, xerrors.Errorf("marshal nodes: %w", err) + } + + // Publish the new node to every listening socket. + var wg sync.WaitGroup + wg.Add(len(connectionSockets)) + for _, connectionSocket := range connectionSockets { + connectionSocket := connectionSocket + go func() { + defer wg.Done() + _ = connectionSocket.SetWriteDeadline(time.Now().Add(5 * time.Second)) + _, _ = connectionSocket.Write(data) + }() + } + c.mutex.Unlock() + wg.Wait() + return &node, nil +} + +// Close closes all of the open connections in the coordinator and stops the +// coordinator from accepting new connections. +func (c *haCoordinator) Close() error { + c.mutex.Lock() + defer c.mutex.Unlock() + select { + case <-c.close: + return nil + default: + } + close(c.close) + c.closeFunc() + + wg := sync.WaitGroup{} + + wg.Add(len(c.agentSockets)) + for _, socket := range c.agentSockets { + socket := socket + go func() { + _ = socket.Close() + wg.Done() + }() + } + + for _, connMap := range c.agentToConnectionSockets { + wg.Add(len(connMap)) + for _, socket := range connMap { + socket := socket + go func() { + _ = socket.Close() + wg.Done() + }() + } + } + + wg.Wait() + return nil +} + +func (c *haCoordinator) publishNodesToAgent(recipient uuid.UUID, nodes []*agpl.Node) error { + msg, err := c.formatCallMeMaybe(recipient, nodes) + if err != nil { + return xerrors.Errorf("format publish message: %w", err) + } + + err = c.pubsub.Publish("wireguard_peers", msg) + if err != nil { + return xerrors.Errorf("publish message: %w", err) + } + + return nil +} + +func (c *haCoordinator) publishAgentHello(id uuid.UUID) error { + msg, err := c.formatAgentHello(id) + if err != nil { + return xerrors.Errorf("format publish message: %w", err) + } + + err = c.pubsub.Publish("wireguard_peers", msg) + if err != nil { + return xerrors.Errorf("publish message: %w", err) + } + + return nil +} + +func (c *haCoordinator) publishClientHello(id uuid.UUID) error { + msg, err := c.formatClientHello(id) + if err != nil { + return xerrors.Errorf("format client hello: %w", err) + } + err = c.pubsub.Publish("wireguard_peers", msg) + if err != nil { + return xerrors.Errorf("publish client hello: %w", err) + } + return nil +} + +func (c *haCoordinator) publishAgentToNodes(id uuid.UUID, node *agpl.Node) error { + msg, err := c.formatAgentUpdate(id, node) + if err != nil { + return xerrors.Errorf("format publish message: %w", err) + } + + err = c.pubsub.Publish("wireguard_peers", msg) + if err != nil { + return xerrors.Errorf("publish message: %w", err) + } + + return nil +} + +func (c *haCoordinator) runPubsub(ctx context.Context) error { + messageQueue := make(chan []byte, 64) + cancelSub, err := c.pubsub.Subscribe("wireguard_peers", func(ctx context.Context, message []byte) { + select { + case messageQueue <- message: + case <-ctx.Done(): + return + } + }) + if err != nil { + return xerrors.Errorf("subscribe wireguard peers") + } + go func() { + for { + var message []byte + select { + case <-ctx.Done(): + return + case message = <-messageQueue: + } + c.handlePubsubMessage(ctx, message) + } + }() + + go func() { + defer cancelSub() + <-c.close + }() + + return nil +} + +func (c *haCoordinator) handlePubsubMessage(ctx context.Context, message []byte) { + sp := bytes.Split(message, []byte("|")) + if len(sp) != 4 { + c.log.Error(ctx, "invalid wireguard peer message", slog.F("msg", string(message))) + return + } + + var ( + coordinatorID = sp[0] + eventType = sp[1] + agentID = sp[2] + nodeJSON = sp[3] + ) + + sender, err := uuid.ParseBytes(coordinatorID) + if err != nil { + c.log.Error(ctx, "invalid sender id", slog.F("id", string(coordinatorID)), slog.F("msg", string(message))) + return + } + + // We sent this message! + if sender == c.id { + return + } + + switch string(eventType) { + case "callmemaybe": + agentUUID, err := uuid.ParseBytes(agentID) + if err != nil { + c.log.Error(ctx, "invalid agent id", slog.F("id", string(agentID))) + return + } + + c.mutex.Lock() + agentSocket, ok := c.agentSockets[agentUUID] + if !ok { + c.mutex.Unlock() + return + } + c.mutex.Unlock() + + // We get a single node over pubsub, so turn into an array. + _, err = agentSocket.Write(nodeJSON) + if err != nil { + if errors.Is(err, io.EOF) || errors.Is(err, io.ErrClosedPipe) { + return + } + c.log.Error(ctx, "send callmemaybe to agent", slog.Error(err)) + return + } + case "clienthello": + agentUUID, err := uuid.ParseBytes(agentID) + if err != nil { + c.log.Error(ctx, "invalid agent id", slog.F("id", string(agentID))) + return + } + + err = c.handleClientHello(agentUUID) + if err != nil { + c.log.Error(ctx, "handle agent request node", slog.Error(err)) + return + } + case "agenthello": + agentUUID, err := uuid.ParseBytes(agentID) + if err != nil { + c.log.Error(ctx, "invalid agent id", slog.F("id", string(agentID))) + return + } + + nodes := c.nodesSubscribedToAgent(agentUUID) + if len(nodes) > 0 { + err := c.publishNodesToAgent(agentUUID, nodes) + if err != nil { + c.log.Error(ctx, "publish nodes to agent", slog.Error(err)) + return + } + } + case "agentupdate": + agentUUID, err := uuid.ParseBytes(agentID) + if err != nil { + c.log.Error(ctx, "invalid agent id", slog.F("id", string(agentID))) + return + } + + decoder := json.NewDecoder(bytes.NewReader(nodeJSON)) + _, err = c.handleAgentUpdate(agentUUID, decoder) + if err != nil { + c.log.Error(ctx, "handle agent update", slog.Error(err)) + return + } + default: + c.log.Error(ctx, "unknown peer event", slog.F("name", string(eventType))) + } +} + +// format: |callmemaybe|| +func (c *haCoordinator) formatCallMeMaybe(recipient uuid.UUID, nodes []*agpl.Node) ([]byte, error) { + buf := bytes.Buffer{} + + buf.WriteString(c.id.String() + "|") + buf.WriteString("callmemaybe|") + buf.WriteString(recipient.String() + "|") + err := json.NewEncoder(&buf).Encode(nodes) + if err != nil { + return nil, xerrors.Errorf("encode node: %w", err) + } + + return buf.Bytes(), nil +} + +// format: |agenthello|| +func (c *haCoordinator) formatAgentHello(id uuid.UUID) ([]byte, error) { + buf := bytes.Buffer{} + + buf.WriteString(c.id.String() + "|") + buf.WriteString("agenthello|") + buf.WriteString(id.String() + "|") + + return buf.Bytes(), nil +} + +// format: |clienthello|| +func (c *haCoordinator) formatClientHello(id uuid.UUID) ([]byte, error) { + buf := bytes.Buffer{} + + buf.WriteString(c.id.String() + "|") + buf.WriteString("clienthello|") + buf.WriteString(id.String() + "|") + + return buf.Bytes(), nil +} + +// format: |agentupdate|| +func (c *haCoordinator) formatAgentUpdate(id uuid.UUID, node *agpl.Node) ([]byte, error) { + buf := bytes.Buffer{} + + buf.WriteString(c.id.String() + "|") + buf.WriteString("agentupdate|") + buf.WriteString(id.String() + "|") + err := json.NewEncoder(&buf).Encode(node) + if err != nil { + return nil, xerrors.Errorf("encode node: %w", err) + } + + return buf.Bytes(), nil +} diff --git a/enterprise/tailnet/coordinator_test.go b/enterprise/tailnet/coordinator_test.go new file mode 100644 index 0000000000000..86cee94dbdf5b --- /dev/null +++ b/enterprise/tailnet/coordinator_test.go @@ -0,0 +1,261 @@ +package tailnet_test + +import ( + "net" + "testing" + + "github.com/google/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "cdr.dev/slog/sloggers/slogtest" + + "github.com/coder/coder/coderd/database" + "github.com/coder/coder/coderd/database/dbtestutil" + "github.com/coder/coder/enterprise/tailnet" + agpl "github.com/coder/coder/tailnet" + "github.com/coder/coder/testutil" +) + +func TestCoordinatorSingle(t *testing.T) { + t.Parallel() + t.Run("ClientWithoutAgent", func(t *testing.T) { + t.Parallel() + coordinator, err := tailnet.NewCoordinator(slogtest.Make(t, nil), database.NewPubsubInMemory()) + require.NoError(t, err) + defer coordinator.Close() + + client, server := net.Pipe() + sendNode, errChan := agpl.ServeCoordinator(client, func(node []*agpl.Node) error { + return nil + }) + id := uuid.New() + closeChan := make(chan struct{}) + go func() { + err := coordinator.ServeClient(server, id, uuid.New()) + assert.NoError(t, err) + close(closeChan) + }() + sendNode(&agpl.Node{}) + require.Eventually(t, func() bool { + return coordinator.Node(id) != nil + }, testutil.WaitShort, testutil.IntervalFast) + + err = client.Close() + require.NoError(t, err) + <-errChan + <-closeChan + }) + + t.Run("AgentWithoutClients", func(t *testing.T) { + t.Parallel() + coordinator, err := tailnet.NewCoordinator(slogtest.Make(t, nil), database.NewPubsubInMemory()) + require.NoError(t, err) + defer coordinator.Close() + + client, server := net.Pipe() + sendNode, errChan := agpl.ServeCoordinator(client, func(node []*agpl.Node) error { + return nil + }) + id := uuid.New() + closeChan := make(chan struct{}) + go func() { + err := coordinator.ServeAgent(server, id) + assert.NoError(t, err) + close(closeChan) + }() + sendNode(&agpl.Node{}) + require.Eventually(t, func() bool { + return coordinator.Node(id) != nil + }, testutil.WaitShort, testutil.IntervalFast) + err = client.Close() + require.NoError(t, err) + <-errChan + <-closeChan + }) + + t.Run("AgentWithClient", func(t *testing.T) { + t.Parallel() + + coordinator, err := tailnet.NewCoordinator(slogtest.Make(t, nil), database.NewPubsubInMemory()) + require.NoError(t, err) + defer coordinator.Close() + + agentWS, agentServerWS := net.Pipe() + defer agentWS.Close() + agentNodeChan := make(chan []*agpl.Node) + sendAgentNode, agentErrChan := agpl.ServeCoordinator(agentWS, func(nodes []*agpl.Node) error { + agentNodeChan <- nodes + return nil + }) + agentID := uuid.New() + closeAgentChan := make(chan struct{}) + go func() { + err := coordinator.ServeAgent(agentServerWS, agentID) + assert.NoError(t, err) + close(closeAgentChan) + }() + sendAgentNode(&agpl.Node{}) + require.Eventually(t, func() bool { + return coordinator.Node(agentID) != nil + }, testutil.WaitShort, testutil.IntervalFast) + + clientWS, clientServerWS := net.Pipe() + defer clientWS.Close() + defer clientServerWS.Close() + clientNodeChan := make(chan []*agpl.Node) + sendClientNode, clientErrChan := agpl.ServeCoordinator(clientWS, func(nodes []*agpl.Node) error { + clientNodeChan <- nodes + return nil + }) + clientID := uuid.New() + closeClientChan := make(chan struct{}) + go func() { + err := coordinator.ServeClient(clientServerWS, clientID, agentID) + assert.NoError(t, err) + close(closeClientChan) + }() + agentNodes := <-clientNodeChan + require.Len(t, agentNodes, 1) + sendClientNode(&agpl.Node{}) + clientNodes := <-agentNodeChan + require.Len(t, clientNodes, 1) + + // Ensure an update to the agent node reaches the client! + sendAgentNode(&agpl.Node{}) + agentNodes = <-clientNodeChan + require.Len(t, agentNodes, 1) + + // Close the agent WebSocket so a new one can connect. + err = agentWS.Close() + require.NoError(t, err) + <-agentErrChan + <-closeAgentChan + + // Create a new agent connection. This is to simulate a reconnect! + agentWS, agentServerWS = net.Pipe() + defer agentWS.Close() + agentNodeChan = make(chan []*agpl.Node) + _, agentErrChan = agpl.ServeCoordinator(agentWS, func(nodes []*agpl.Node) error { + agentNodeChan <- nodes + return nil + }) + closeAgentChan = make(chan struct{}) + go func() { + err := coordinator.ServeAgent(agentServerWS, agentID) + assert.NoError(t, err) + close(closeAgentChan) + }() + // Ensure the existing listening client sends it's node immediately! + clientNodes = <-agentNodeChan + require.Len(t, clientNodes, 1) + + err = agentWS.Close() + require.NoError(t, err) + <-agentErrChan + <-closeAgentChan + + err = clientWS.Close() + require.NoError(t, err) + <-clientErrChan + <-closeClientChan + }) +} + +func TestCoordinatorHA(t *testing.T) { + t.Parallel() + + t.Run("AgentWithClient", func(t *testing.T) { + t.Parallel() + + _, pubsub := dbtestutil.NewDB(t) + + coordinator1, err := tailnet.NewCoordinator(slogtest.Make(t, nil), pubsub) + require.NoError(t, err) + defer coordinator1.Close() + + agentWS, agentServerWS := net.Pipe() + defer agentWS.Close() + agentNodeChan := make(chan []*agpl.Node) + sendAgentNode, agentErrChan := agpl.ServeCoordinator(agentWS, func(nodes []*agpl.Node) error { + agentNodeChan <- nodes + return nil + }) + agentID := uuid.New() + closeAgentChan := make(chan struct{}) + go func() { + err := coordinator1.ServeAgent(agentServerWS, agentID) + assert.NoError(t, err) + close(closeAgentChan) + }() + sendAgentNode(&agpl.Node{}) + require.Eventually(t, func() bool { + return coordinator1.Node(agentID) != nil + }, testutil.WaitShort, testutil.IntervalFast) + + coordinator2, err := tailnet.NewCoordinator(slogtest.Make(t, nil), pubsub) + require.NoError(t, err) + defer coordinator2.Close() + + clientWS, clientServerWS := net.Pipe() + defer clientWS.Close() + defer clientServerWS.Close() + clientNodeChan := make(chan []*agpl.Node) + sendClientNode, clientErrChan := agpl.ServeCoordinator(clientWS, func(nodes []*agpl.Node) error { + clientNodeChan <- nodes + return nil + }) + clientID := uuid.New() + closeClientChan := make(chan struct{}) + go func() { + err := coordinator2.ServeClient(clientServerWS, clientID, agentID) + assert.NoError(t, err) + close(closeClientChan) + }() + agentNodes := <-clientNodeChan + require.Len(t, agentNodes, 1) + sendClientNode(&agpl.Node{}) + _ = sendClientNode + clientNodes := <-agentNodeChan + require.Len(t, clientNodes, 1) + + // Ensure an update to the agent node reaches the client! + sendAgentNode(&agpl.Node{}) + agentNodes = <-clientNodeChan + require.Len(t, agentNodes, 1) + + // Close the agent WebSocket so a new one can connect. + require.NoError(t, agentWS.Close()) + require.NoError(t, agentServerWS.Close()) + <-agentErrChan + <-closeAgentChan + + // Create a new agent connection. This is to simulate a reconnect! + agentWS, agentServerWS = net.Pipe() + defer agentWS.Close() + agentNodeChan = make(chan []*agpl.Node) + _, agentErrChan = agpl.ServeCoordinator(agentWS, func(nodes []*agpl.Node) error { + agentNodeChan <- nodes + return nil + }) + closeAgentChan = make(chan struct{}) + go func() { + err := coordinator1.ServeAgent(agentServerWS, agentID) + assert.NoError(t, err) + close(closeAgentChan) + }() + // Ensure the existing listening client sends it's node immediately! + clientNodes = <-agentNodeChan + require.Len(t, clientNodes, 1) + + err = agentWS.Close() + require.NoError(t, err) + <-agentErrChan + <-closeAgentChan + + err = clientWS.Close() + require.NoError(t, err) + <-clientErrChan + <-closeClientChan + }) +} diff --git a/go.mod b/go.mod index 9834e27e5f39c..195a09ae2b8fd 100644 --- a/go.mod +++ b/go.mod @@ -40,7 +40,7 @@ replace github.com/tcnksm/go-httpstat => github.com/kylecarbs/go-httpstat v0.0.0 // There are a few minor changes we make to Tailscale that we're slowly upstreaming. Compare here: // https://github.com/tailscale/tailscale/compare/main...coder:tailscale:main -replace tailscale.com => github.com/coder/tailscale v1.1.1-0.20220926024748-50f068456c6c +replace tailscale.com => github.com/coder/tailscale v1.1.1-0.20221015033036-5861cbbf7bf5 // Switch to our fork that imports fixes from http://github.com/tailscale/ssh. // See: https://github.com/coder/coder/issues/3371 diff --git a/go.sum b/go.sum index 13fdc5724f6b6..b80c0d4173a5f 100644 --- a/go.sum +++ b/go.sum @@ -351,8 +351,8 @@ github.com/coder/retry v1.3.0 h1:5lAAwt/2Cm6lVmnfBY7sOMXcBOwcwJhmV5QGSELIVWY= github.com/coder/retry v1.3.0/go.mod h1:tXuRgZgWjUnU5LZPT4lJh4ew2elUhexhlnXzrJWdyFY= github.com/coder/ssh v0.0.0-20220811105153-fcea99919338 h1:tN5GKFT68YLVzJoA8AHuiMNJ0qlhoD3pGN3JY9gxSko= github.com/coder/ssh v0.0.0-20220811105153-fcea99919338/go.mod h1:ZSS+CUoKHDrqVakTfTWUlKSr9MtMFkC4UvtQKD7O914= -github.com/coder/tailscale v1.1.1-0.20220926024748-50f068456c6c h1:xa6lr5Pj87Is26tgpzwBsEGKL7aVz7/fRGgY9QIbf3E= -github.com/coder/tailscale v1.1.1-0.20220926024748-50f068456c6c/go.mod h1:5amxy08qijEa8bcTW2SeIy4MIqcmd7LMsuOxqOlj2Ak= +github.com/coder/tailscale v1.1.1-0.20221015033036-5861cbbf7bf5 h1:WVH6e/qK3Wpl0wbmpORD2oQ1qLJborF3fsFHyO1ps0Y= +github.com/coder/tailscale v1.1.1-0.20221015033036-5861cbbf7bf5/go.mod h1:5amxy08qijEa8bcTW2SeIy4MIqcmd7LMsuOxqOlj2Ak= github.com/containerd/aufs v0.0.0-20200908144142-dab0cbea06f4/go.mod h1:nukgQABAEopAHvB6j7cnP5zJ+/3aVcE7hCYqvIwAHyE= github.com/containerd/aufs v0.0.0-20201003224125-76a6863f2989/go.mod h1:AkGGQs9NM2vtYHaUen+NljV0/baGCAPELGm2q9ZXpWU= github.com/containerd/aufs v0.0.0-20210316121734-20793ff83c97/go.mod h1:kL5kd6KM5TzQjR79jljyi4olc1Vrx6XBlcyj3gNv2PU= diff --git a/helm/templates/coder.yaml b/helm/templates/coder.yaml index 45f3f6e29a32e..1165251fc885b 100644 --- a/helm/templates/coder.yaml +++ b/helm/templates/coder.yaml @@ -14,10 +14,7 @@ metadata: {{- include "coder.labels" . | nindent 4 }} annotations: {{ toYaml .Values.coder.annotations | nindent 4}} spec: - # NOTE: this is currently not used as coder v2 does not support high - # availability yet. - # replicas: {{ .Values.coder.replicaCount }} - replicas: 1 + replicas: {{ .Values.coder.replicaCount }} selector: matchLabels: {{- include "coder.selectorLabels" . | nindent 6 }} @@ -38,6 +35,13 @@ spec: env: - name: CODER_ADDRESS value: "0.0.0.0:{{ include "coder.port" . }}" + # Used for inter-pod communication with high-availability. + - name: KUBE_POD_IP + valueFrom: + fieldRef: + fieldPath: status.podIP + - name: CODER_DERP_SERVER_RELAY_ADDRESS + value: "{{ include "coder.portName" . }}://$(KUBE_POD_IP):{{ include "coder.port" . }}" {{- include "coder.tlsEnv" . | nindent 12 }} {{- with .Values.coder.env -}} {{ toYaml . | nindent 12 }} diff --git a/helm/templates/service.yaml b/helm/templates/service.yaml index 28fe0e9f9aa8c..b9a7e9a2f0886 100644 --- a/helm/templates/service.yaml +++ b/helm/templates/service.yaml @@ -10,6 +10,7 @@ metadata: {{- toYaml .Values.coder.service.annotations | nindent 4 }} spec: type: {{ .Values.coder.service.type }} + sessionAffinity: ClientIP ports: - name: {{ include "coder.portName" . | quote }} port: {{ include "coder.servicePort" . }} diff --git a/helm/values.yaml b/helm/values.yaml index 30a21a8985d23..392a53c187492 100644 --- a/helm/values.yaml +++ b/helm/values.yaml @@ -1,9 +1,9 @@ # coder -- Primary configuration for `coder server`. coder: - # NOTE: this is currently not used as coder v2 does not support high - # availability yet. - # # coder.replicaCount -- The number of Kubernetes deployment replicas. - # replicaCount: 1 + # coder.replicaCount -- The number of Kubernetes deployment replicas. + # This should only be increased if High Availability is enabled. + # This is an Enterprise feature. Contact sales@coder.com. + replicaCount: 1 # coder.image -- The image to use for Coder. image: diff --git a/site/src/api/api.ts b/site/src/api/api.ts index 2e60a88b8469c..fb12571fd91ae 100644 --- a/site/src/api/api.ts +++ b/site/src/api/api.ts @@ -28,6 +28,7 @@ export const defaultEntitlements = (): TypesGen.Entitlements => { return { features: features, has_license: false, + errors: [], warnings: [], experimental: false, trial: false, diff --git a/site/src/api/typesGenerated.ts b/site/src/api/typesGenerated.ts index 5347613e77e86..a4b2cf83a9581 100644 --- a/site/src/api/typesGenerated.ts +++ b/site/src/api/typesGenerated.ts @@ -274,6 +274,7 @@ export interface DeploymentFlags { readonly derp_server_region_code: StringFlag readonly derp_server_region_name: StringFlag readonly derp_server_stun_address: StringArrayFlag + readonly derp_server_relay_address: StringFlag readonly derp_config_url: StringFlag readonly derp_config_path: StringFlag readonly prom_enabled: BoolFlag @@ -337,6 +338,7 @@ export interface DurationFlag { export interface Entitlements { readonly features: Record readonly warnings: string[] + readonly errors: string[] readonly has_license: boolean readonly experimental: boolean readonly trial: boolean @@ -528,6 +530,17 @@ export interface PutExtendWorkspaceRequest { readonly deadline: string } +// From codersdk/replicas.go +export interface Replica { + readonly id: string + readonly hostname: string + readonly created_at: string + readonly relay_address: string + readonly region_id: number + readonly error: string + readonly database_latency: number +} + // From codersdk/error.go export interface Response { readonly message: string diff --git a/site/src/components/LicenseBanner/LicenseBanner.tsx b/site/src/components/LicenseBanner/LicenseBanner.tsx index 8532bfca2ecbe..7ecfc2a2a2fac 100644 --- a/site/src/components/LicenseBanner/LicenseBanner.tsx +++ b/site/src/components/LicenseBanner/LicenseBanner.tsx @@ -8,15 +8,15 @@ export const LicenseBanner: React.FC = () => { const [entitlementsState, entitlementsSend] = useActor( xServices.entitlementsXService, ) - const { warnings } = entitlementsState.context.entitlements + const { errors, warnings } = entitlementsState.context.entitlements /** Gets license data on app mount because LicenseBanner is mounted in App */ useEffect(() => { entitlementsSend("GET_ENTITLEMENTS") }, [entitlementsSend]) - if (warnings.length > 0) { - return + if (errors.length > 0 || warnings.length > 0) { + return } else { return null } diff --git a/site/src/components/LicenseBanner/LicenseBannerView.stories.tsx b/site/src/components/LicenseBanner/LicenseBannerView.stories.tsx index c37653eff7bd5..c7ee69c261e38 100644 --- a/site/src/components/LicenseBanner/LicenseBannerView.stories.tsx +++ b/site/src/components/LicenseBanner/LicenseBannerView.stories.tsx @@ -12,13 +12,23 @@ const Template: Story = (args) => ( export const OneWarning = Template.bind({}) OneWarning.args = { + errors: [], warnings: ["You have exceeded the number of seats in your license."], } export const TwoWarnings = Template.bind({}) TwoWarnings.args = { + errors: [], warnings: [ "You have exceeded the number of seats in your license.", "You are flying too close to the sun.", ], } + +export const OneError = Template.bind({}) +OneError.args = { + errors: [ + "You have multiple replicas but high availability is an Enterprise feature. You will be unable to connect to workspaces.", + ], + warnings: [], +} diff --git a/site/src/components/LicenseBanner/LicenseBannerView.tsx b/site/src/components/LicenseBanner/LicenseBannerView.tsx index 49276b1f0d5ed..792bc191a0a2a 100644 --- a/site/src/components/LicenseBanner/LicenseBannerView.tsx +++ b/site/src/components/LicenseBanner/LicenseBannerView.tsx @@ -2,47 +2,56 @@ import { makeStyles } from "@material-ui/core/styles" import { Expander } from "components/Expander/Expander" import { Pill } from "components/Pill/Pill" import { useState } from "react" +import { colors } from "theme/colors" export const Language = { licenseIssue: "License Issue", licenseIssues: (num: number): string => `${num} License Issues`, - upgrade: "Contact us to upgrade your license.", + upgrade: "Contact sales@coder.com.", exceeded: "It looks like you've exceeded some limits of your license.", lessDetails: "Less", moreDetails: "More", } export interface LicenseBannerViewProps { + errors: string[] warnings: string[] } export const LicenseBannerView: React.FC = ({ + errors, warnings, }) => { const styles = useStyles() const [showDetails, setShowDetails] = useState(false) - if (warnings.length === 1) { + const isError = errors.length > 0 + const messages = [...errors, ...warnings] + const type = isError ? "error" : "warning" + + if (messages.length === 1) { return ( -
- - {warnings[0]} -   - - {Language.upgrade} - +
+ +
+ {messages[0]} +   + + {Language.upgrade} + +
) } else { return ( -
-
-
- - {Language.exceeded} +
+ +
+
    - {warnings.map((warning) => ( -
  • - {warning} + {messages.map((message) => ( +
  • + {message}
  • ))}
@@ -67,14 +76,18 @@ const useStyles = makeStyles((theme) => ({ container: { padding: theme.spacing(1.5), backgroundColor: theme.palette.warning.main, + display: "flex", + alignItems: "center", + + "&.error": { + backgroundColor: colors.red[12], + }, }, flex: { - display: "flex", + display: "column", }, leftContent: { marginRight: theme.spacing(1), - }, - text: { marginLeft: theme.spacing(1), }, link: { @@ -83,9 +96,10 @@ const useStyles = makeStyles((theme) => ({ fontWeight: "bold", }, list: { - margin: theme.spacing(1.5), + padding: theme.spacing(1), + margin: 0, }, listItem: { - margin: theme.spacing(1), + margin: theme.spacing(0.5), }, })) diff --git a/site/src/testHelpers/entities.ts b/site/src/testHelpers/entities.ts index 59abb4a913222..8d0358bc5826e 100644 --- a/site/src/testHelpers/entities.ts +++ b/site/src/testHelpers/entities.ts @@ -821,6 +821,7 @@ export const makeMockApiError = ({ }) export const MockEntitlements: TypesGen.Entitlements = { + errors: [], warnings: [], has_license: false, features: {}, @@ -829,6 +830,7 @@ export const MockEntitlements: TypesGen.Entitlements = { } export const MockEntitlementsWithWarnings: TypesGen.Entitlements = { + errors: [], warnings: ["You are over your active user limit.", "And another thing."], has_license: true, experimental: false, @@ -852,6 +854,7 @@ export const MockEntitlementsWithWarnings: TypesGen.Entitlements = { } export const MockEntitlementsWithAuditLog: TypesGen.Entitlements = { + errors: [], warnings: [], has_license: true, experimental: false, diff --git a/site/src/xServices/entitlements/entitlementsXService.ts b/site/src/xServices/entitlements/entitlementsXService.ts index 83ed44d12052d..a1e8bb0d9b895 100644 --- a/site/src/xServices/entitlements/entitlementsXService.ts +++ b/site/src/xServices/entitlements/entitlementsXService.ts @@ -20,6 +20,7 @@ export type EntitlementsEvent = | { type: "HIDE_MOCK_BANNER" } const emptyEntitlements = { + errors: [], warnings: [], features: {}, has_license: false, diff --git a/tailnet/conn.go b/tailnet/conn.go index 1b454d6346b97..e3af3786ec92f 100644 --- a/tailnet/conn.go +++ b/tailnet/conn.go @@ -48,7 +48,10 @@ type Options struct { Addresses []netip.Prefix DERPMap *tailcfg.DERPMap - Logger slog.Logger + // BlockEndpoints specifies whether P2P endpoints are blocked. + // If so, only DERPs can establish connections. + BlockEndpoints bool + Logger slog.Logger } // NewConn constructs a new Wireguard server that will accept connections from the addresses provided. @@ -175,6 +178,7 @@ func NewConn(options *Options) (*Conn, error) { wireguardEngine.SetFilter(filter.New(netMap.PacketFilter, localIPs, logIPs, nil, Logger(options.Logger.Named("packet-filter")))) dialContext, dialCancel := context.WithCancel(context.Background()) server := &Conn{ + blockEndpoints: options.BlockEndpoints, dialContext: dialContext, dialCancel: dialCancel, closed: make(chan struct{}), @@ -240,11 +244,12 @@ func IP() netip.Addr { // Conn is an actively listening Wireguard connection. type Conn struct { - dialContext context.Context - dialCancel context.CancelFunc - mutex sync.Mutex - closed chan struct{} - logger slog.Logger + dialContext context.Context + dialCancel context.CancelFunc + mutex sync.Mutex + closed chan struct{} + logger slog.Logger + blockEndpoints bool dialer *tsdial.Dialer tunDevice *tstun.Wrapper @@ -323,6 +328,8 @@ func (c *Conn) UpdateNodes(nodes []*Node) error { delete(c.peerMap, peer.ID) } for _, node := range nodes { + c.logger.Debug(context.Background(), "adding node", slog.F("node", node)) + peerStatus, ok := status.Peer[node.Key] peerNode := &tailcfg.Node{ ID: node.ID, @@ -339,6 +346,13 @@ func (c *Conn) UpdateNodes(nodes []*Node) error { // reason. TODO: @kylecarbs debug this! KeepAlive: ok && peerStatus.Active, } + // If no preferred DERP is provided, don't set an IP! + if node.PreferredDERP == 0 { + peerNode.DERP = "" + } + if c.blockEndpoints { + peerNode.Endpoints = nil + } c.peerMap[node.ID] = peerNode } c.netMap.Peers = make([]*tailcfg.Node, 0, len(c.peerMap)) @@ -421,6 +435,7 @@ func (c *Conn) sendNode() { } node := &Node{ ID: c.netMap.SelfNode.ID, + AsOf: c.lastStatus, Key: c.netMap.SelfNode.Key, Addresses: c.netMap.SelfNode.Addresses, AllowedIPs: c.netMap.SelfNode.AllowedIPs, @@ -429,6 +444,9 @@ func (c *Conn) sendNode() { PreferredDERP: c.lastPreferredDERP, DERPLatency: c.lastDERPLatency, } + if c.blockEndpoints { + node.Endpoints = nil + } nodeCallback := c.nodeCallback if nodeCallback == nil { return diff --git a/tailnet/coordinator.go b/tailnet/coordinator.go index ee696b0925e3c..4216bbc624d48 100644 --- a/tailnet/coordinator.go +++ b/tailnet/coordinator.go @@ -7,6 +7,7 @@ import ( "net" "net/netip" "sync" + "time" "github.com/google/uuid" "golang.org/x/xerrors" @@ -14,10 +15,30 @@ import ( "tailscale.com/types/key" ) +// Coordinator exchanges nodes with agents to establish connections. +// ┌──────────────────┐ ┌────────────────────┐ ┌───────────────────┐ ┌──────────────────┐ +// │tailnet.Coordinate├──►│tailnet.AcceptClient│◄─►│tailnet.AcceptAgent│◄──┤tailnet.Coordinate│ +// └──────────────────┘ └────────────────────┘ └───────────────────┘ └──────────────────┘ +// Coordinators have different guarantees for HA support. +type Coordinator interface { + // Node returns an in-memory node by ID. + Node(id uuid.UUID) *Node + // ServeClient accepts a WebSocket connection that wants to connect to an agent + // with the specified ID. + ServeClient(conn net.Conn, id uuid.UUID, agent uuid.UUID) error + // ServeAgent accepts a WebSocket connection to an agent that listens to + // incoming connections and publishes node updates. + ServeAgent(conn net.Conn, id uuid.UUID) error + // Close closes the coordinator. + Close() error +} + // Node represents a node in the network. type Node struct { // ID is used to identify the connection. ID tailcfg.NodeID `json:"id"` + // AsOf is the time the node was created. + AsOf time.Time `json:"as_of"` // Key is the Wireguard public key of the node. Key key.NodePublic `json:"key"` // DiscoKey is used for discovery messages over DERP to establish peer-to-peer connections. @@ -75,48 +96,59 @@ func ServeCoordinator(conn net.Conn, updateNodes func(node []*Node) error) (func }, errChan } -// NewCoordinator constructs a new in-memory connection coordinator. -func NewCoordinator() *Coordinator { - return &Coordinator{ +// NewCoordinator constructs a new in-memory connection coordinator. This +// coordinator is incompatible with multiple Coder replicas as all node data is +// in-memory. +func NewCoordinator() Coordinator { + return &coordinator{ + closed: false, nodes: map[uuid.UUID]*Node{}, agentSockets: map[uuid.UUID]net.Conn{}, agentToConnectionSockets: map[uuid.UUID]map[uuid.UUID]net.Conn{}, } } -// Coordinator exchanges nodes with agents to establish connections. +// coordinator exchanges nodes with agents to establish connections entirely in-memory. +// The Enterprise implementation provides this for high-availability. // ┌──────────────────┐ ┌────────────────────┐ ┌───────────────────┐ ┌──────────────────┐ // │tailnet.Coordinate├──►│tailnet.AcceptClient│◄─►│tailnet.AcceptAgent│◄──┤tailnet.Coordinate│ // └──────────────────┘ └────────────────────┘ └───────────────────┘ └──────────────────┘ // This coordinator is incompatible with multiple Coder // replicas as all node data is in-memory. -type Coordinator struct { - mutex sync.Mutex +type coordinator struct { + mutex sync.Mutex + closed bool - // Maps agent and connection IDs to a node. + // nodes maps agent and connection IDs their respective node. nodes map[uuid.UUID]*Node - // Maps agent ID to an open socket. + // agentSockets maps agent IDs to their open websocket. agentSockets map[uuid.UUID]net.Conn - // Maps agent ID to connection ID for sending - // new node data as it comes in! + // agentToConnectionSockets maps agent IDs to connection IDs of conns that + // are subscribed to updates for that agent. agentToConnectionSockets map[uuid.UUID]map[uuid.UUID]net.Conn } // Node returns an in-memory node by ID. -func (c *Coordinator) Node(id uuid.UUID) *Node { +// If the node does not exist, nil is returned. +func (c *coordinator) Node(id uuid.UUID) *Node { c.mutex.Lock() defer c.mutex.Unlock() - node := c.nodes[id] - return node + return c.nodes[id] } -// ServeClient accepts a WebSocket connection that wants to -// connect to an agent with the specified ID. -func (c *Coordinator) ServeClient(conn net.Conn, id uuid.UUID, agent uuid.UUID) error { +// ServeClient accepts a WebSocket connection that wants to connect to an agent +// with the specified ID. +func (c *coordinator) ServeClient(conn net.Conn, id uuid.UUID, agent uuid.UUID) error { c.mutex.Lock() + if c.closed { + c.mutex.Unlock() + return xerrors.New("coordinator is closed") + } + // When a new connection is requested, we update it with the latest // node of the agent. This allows the connection to establish. node, ok := c.nodes[agent] + c.mutex.Unlock() if ok { data, err := json.Marshal([]*Node{node}) if err != nil { @@ -129,6 +161,7 @@ func (c *Coordinator) ServeClient(conn net.Conn, id uuid.UUID, agent uuid.UUID) return xerrors.Errorf("write nodes: %w", err) } } + c.mutex.Lock() connectionSockets, ok := c.agentToConnectionSockets[agent] if !ok { connectionSockets = map[uuid.UUID]net.Conn{} @@ -156,47 +189,62 @@ func (c *Coordinator) ServeClient(conn net.Conn, id uuid.UUID, agent uuid.UUID) decoder := json.NewDecoder(conn) for { - var node Node - err := decoder.Decode(&node) - if errors.Is(err, io.EOF) { - return nil - } + err := c.handleNextClientMessage(id, agent, decoder) if err != nil { - return xerrors.Errorf("read json: %w", err) - } - c.mutex.Lock() - // Update the node of this client in our in-memory map. - // If an agent entirely shuts down and reconnects, it - // needs to be aware of all clients attempting to - // establish connections. - c.nodes[id] = &node - agentSocket, ok := c.agentSockets[agent] - if !ok { - c.mutex.Unlock() - continue + if errors.Is(err, io.EOF) { + return nil + } + return xerrors.Errorf("handle next client message: %w", err) } + } +} + +func (c *coordinator) handleNextClientMessage(id, agent uuid.UUID, decoder *json.Decoder) error { + var node Node + err := decoder.Decode(&node) + if err != nil { + return xerrors.Errorf("read json: %w", err) + } + + c.mutex.Lock() + // Update the node of this client in our in-memory map. If an agent entirely + // shuts down and reconnects, it needs to be aware of all clients attempting + // to establish connections. + c.nodes[id] = &node + + agentSocket, ok := c.agentSockets[agent] + if !ok { c.mutex.Unlock() - // Write the new node from this client to the actively - // connected agent. - data, err := json.Marshal([]*Node{&node}) - if err != nil { - c.mutex.Unlock() - return xerrors.Errorf("marshal nodes: %w", err) - } - _, err = agentSocket.Write(data) + return nil + } + c.mutex.Unlock() + + // Write the new node from this client to the actively connected agent. + data, err := json.Marshal([]*Node{&node}) + if err != nil { + return xerrors.Errorf("marshal nodes: %w", err) + } + + _, err = agentSocket.Write(data) + if err != nil { if errors.Is(err, io.EOF) { return nil } - if err != nil { - return xerrors.Errorf("write json: %w", err) - } + return xerrors.Errorf("write json: %w", err) } + + return nil } // ServeAgent accepts a WebSocket connection to an agent that // listens to incoming connections and publishes node updates. -func (c *Coordinator) ServeAgent(conn net.Conn, id uuid.UUID) error { +func (c *coordinator) ServeAgent(conn net.Conn, id uuid.UUID) error { c.mutex.Lock() + if c.closed { + c.mutex.Unlock() + return xerrors.New("coordinator is closed") + } + sockets, ok := c.agentToConnectionSockets[id] if ok { // Publish all nodes that want to connect to the @@ -209,16 +257,16 @@ func (c *Coordinator) ServeAgent(conn net.Conn, id uuid.UUID) error { } nodes = append(nodes, node) } + c.mutex.Unlock() data, err := json.Marshal(nodes) if err != nil { - c.mutex.Unlock() return xerrors.Errorf("marshal json: %w", err) } _, err = conn.Write(data) if err != nil { - c.mutex.Unlock() return xerrors.Errorf("write nodes: %w", err) } + c.mutex.Lock() } // If an old agent socket is connected, we close it @@ -239,36 +287,84 @@ func (c *Coordinator) ServeAgent(conn net.Conn, id uuid.UUID) error { decoder := json.NewDecoder(conn) for { - var node Node - err := decoder.Decode(&node) - if errors.Is(err, io.EOF) { - return nil - } - if err != nil { - return xerrors.Errorf("read json: %w", err) - } - c.mutex.Lock() - c.nodes[id] = &node - connectionSockets, ok := c.agentToConnectionSockets[id] - if !ok { - c.mutex.Unlock() - continue - } - data, err := json.Marshal([]*Node{&node}) + err := c.handleNextAgentMessage(id, decoder) if err != nil { - return xerrors.Errorf("marshal nodes: %w", err) + if errors.Is(err, io.EOF) { + return nil + } + return xerrors.Errorf("handle next agent message: %w", err) } - // Publish the new node to every listening socket. - var wg sync.WaitGroup - wg.Add(len(connectionSockets)) - for _, connectionSocket := range connectionSockets { - connectionSocket := connectionSocket + } +} + +func (c *coordinator) handleNextAgentMessage(id uuid.UUID, decoder *json.Decoder) error { + var node Node + err := decoder.Decode(&node) + if err != nil { + return xerrors.Errorf("read json: %w", err) + } + + c.mutex.Lock() + c.nodes[id] = &node + connectionSockets, ok := c.agentToConnectionSockets[id] + if !ok { + c.mutex.Unlock() + return nil + } + data, err := json.Marshal([]*Node{&node}) + if err != nil { + return xerrors.Errorf("marshal nodes: %w", err) + } + + // Publish the new node to every listening socket. + var wg sync.WaitGroup + wg.Add(len(connectionSockets)) + for _, connectionSocket := range connectionSockets { + connectionSocket := connectionSocket + go func() { + _ = connectionSocket.SetWriteDeadline(time.Now().Add(5 * time.Second)) + _, _ = connectionSocket.Write(data) + wg.Done() + }() + } + + c.mutex.Unlock() + wg.Wait() + return nil +} + +// Close closes all of the open connections in the coordinator and stops the +// coordinator from accepting new connections. +func (c *coordinator) Close() error { + c.mutex.Lock() + if c.closed { + return nil + } + c.closed = true + c.mutex.Unlock() + + wg := sync.WaitGroup{} + + wg.Add(len(c.agentSockets)) + for _, socket := range c.agentSockets { + socket := socket + go func() { + _ = socket.Close() + wg.Done() + }() + } + + for _, connMap := range c.agentToConnectionSockets { + wg.Add(len(connMap)) + for _, socket := range connMap { + socket := socket go func() { - _, _ = connectionSocket.Write(data) + _ = socket.Close() wg.Done() }() } - c.mutex.Unlock() - wg.Wait() } + + wg.Wait() + return nil } diff --git a/tailnet/coordinator_test.go b/tailnet/coordinator_test.go index f3fdab88d5ef8..a4a020deadf93 100644 --- a/tailnet/coordinator_test.go +++ b/tailnet/coordinator_test.go @@ -32,8 +32,8 @@ func TestCoordinator(t *testing.T) { require.Eventually(t, func() bool { return coordinator.Node(id) != nil }, testutil.WaitShort, testutil.IntervalFast) - err := client.Close() - require.NoError(t, err) + require.NoError(t, client.Close()) + require.NoError(t, server.Close()) <-errChan <-closeChan }) diff --git a/testutil/certificate.go b/testutil/certificate.go new file mode 100644 index 0000000000000..1edc975746958 --- /dev/null +++ b/testutil/certificate.go @@ -0,0 +1,53 @@ +package testutil + +import ( + "bytes" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "math/big" + "net" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func GenerateTLSCertificate(t testing.TB, commonName string) tls.Certificate { + privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + require.NoError(t, err) + template := x509.Certificate{ + SerialNumber: big.NewInt(1), + Subject: pkix.Name{ + Organization: []string{"Acme Co"}, + CommonName: commonName, + }, + DNSNames: []string{commonName}, + NotBefore: time.Now(), + NotAfter: time.Now().Add(time.Hour * 24 * 180), + + KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + BasicConstraintsValid: true, + IPAddresses: []net.IP{net.ParseIP("127.0.0.1")}, + } + + derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &privateKey.PublicKey, privateKey) + require.NoError(t, err) + var certFile bytes.Buffer + require.NoError(t, err) + _, err = certFile.Write(pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: derBytes})) + require.NoError(t, err) + privateKeyBytes, err := x509.MarshalPKCS8PrivateKey(privateKey) + require.NoError(t, err) + var keyFile bytes.Buffer + err = pem.Encode(&keyFile, &pem.Block{Type: "PRIVATE KEY", Bytes: privateKeyBytes}) + require.NoError(t, err) + cert, err := tls.X509KeyPair(certFile.Bytes(), keyFile.Bytes()) + require.NoError(t, err) + return cert +} pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy