diff --git a/coderd/workspaceagents.go b/coderd/workspaceagents.go index cfcdea0404683..90396235993cb 100644 --- a/coderd/workspaceagents.go +++ b/coderd/workspaceagents.go @@ -143,16 +143,49 @@ func (api *API) workspaceAgentListen(rw http.ResponseWriter, r *http.Request) { defer api.websocketWaitGroup.Done() workspaceAgent := httpmw.WorkspaceAgent(r) - conn, err := websocket.Accept(rw, r, &websocket.AcceptOptions{ - CompressionMode: websocket.CompressionDisabled, - }) + resource, err := api.Database.GetWorkspaceResourceByID(r.Context(), workspaceAgent.ResourceID) if err != nil { httpapi.Write(rw, http.StatusBadRequest, httpapi.Response{ - Message: fmt.Sprintf("accept websocket: %s", err), + Message: fmt.Sprintf("get workspace resource: %s", err), }) return } - resource, err := api.Database.GetWorkspaceResourceByID(r.Context(), workspaceAgent.ResourceID) + + build, err := api.Database.GetWorkspaceBuildByJobID(r.Context(), resource.JobID) + if err != nil { + httpapi.Write(rw, http.StatusBadRequest, httpapi.Response{ + Message: fmt.Sprintf("get workspace build job: %s", err), + }) + return + } + // Ensure the resource is still valid! + // We only accept agents for resources on the latest build. + ensureLatestBuild := func() error { + latestBuild, err := api.Database.GetLatestWorkspaceBuildByWorkspaceID(r.Context(), build.WorkspaceID) + if err != nil { + return err + } + if build.ID != latestBuild.ID { + return xerrors.New("build is outdated") + } + return nil + } + + err = ensureLatestBuild() + if err != nil { + api.Logger.Debug(r.Context(), "agent tried to connect from non-latest built", + slog.F("resource", resource), + slog.F("agent", workspaceAgent), + ) + httpapi.Write(rw, http.StatusForbidden, httpapi.Response{ + Message: fmt.Sprintf("ensure latest build: %s", err), + }) + return + } + + conn, err := websocket.Accept(rw, r, &websocket.AcceptOptions{ + CompressionMode: websocket.CompressionDisabled, + }) if err != nil { httpapi.Write(rw, http.StatusBadRequest, httpapi.Response{ Message: fmt.Sprintf("accept websocket: %s", err), @@ -163,6 +196,7 @@ func (api *API) workspaceAgentListen(rw http.ResponseWriter, r *http.Request) { defer func() { _ = conn.Close(websocket.StatusNormalClosure, "") }() + config := yamux.DefaultConfig() config.LogOutput = io.Discard session, err := yamux.Server(websocket.NetConn(r.Context(), conn, websocket.MessageBinary), config) @@ -170,6 +204,7 @@ func (api *API) workspaceAgentListen(rw http.ResponseWriter, r *http.Request) { _ = conn.Close(websocket.StatusAbnormalClosure, err.Error()) return } + closer, err := peerbroker.ProxyDial(proto.NewDRPCPeerBrokerClient(provisionersdk.Conn(session)), peerbroker.ProxyOptions{ ChannelID: workspaceAgent.ID.String(), Pubsub: api.Pubsub, @@ -180,6 +215,7 @@ func (api *API) workspaceAgentListen(rw http.ResponseWriter, r *http.Request) { return } defer closer.Close() + firstConnectedAt := workspaceAgent.FirstConnectedAt if !firstConnectedAt.Valid { firstConnectedAt = sql.NullTime{ @@ -204,23 +240,6 @@ func (api *API) workspaceAgentListen(rw http.ResponseWriter, r *http.Request) { } return nil } - build, err := api.Database.GetWorkspaceBuildByJobID(r.Context(), resource.JobID) - if err != nil { - _ = conn.Close(websocket.StatusAbnormalClosure, err.Error()) - return - } - // Ensure the resource is still valid! - // We only accept agents for resources on the latest build. - ensureLatestBuild := func() error { - latestBuild, err := api.Database.GetLatestWorkspaceBuildByWorkspaceID(r.Context(), build.WorkspaceID) - if err != nil { - return err - } - if build.ID != latestBuild.ID { - return xerrors.New("build is outdated") - } - return nil - } defer func() { disconnectedAt = sql.NullTime{ @@ -230,11 +249,6 @@ func (api *API) workspaceAgentListen(rw http.ResponseWriter, r *http.Request) { _ = updateConnectionTimes() }() - err = ensureLatestBuild() - if err != nil { - _ = conn.Close(websocket.StatusGoingAway, "") - return - } err = updateConnectionTimes() if err != nil { _ = conn.Close(websocket.StatusAbnormalClosure, err.Error()) diff --git a/coderd/workspaceagents_test.go b/coderd/workspaceagents_test.go index b14ac43bac4ce..360abb3431156 100644 --- a/coderd/workspaceagents_test.go +++ b/coderd/workspaceagents_test.go @@ -68,52 +68,130 @@ func TestWorkspaceAgent(t *testing.T) { func TestWorkspaceAgentListen(t *testing.T) { t.Parallel() - client, coderAPI := coderdtest.NewWithAPI(t, nil) - user := coderdtest.CreateFirstUser(t, client) - daemonCloser := coderdtest.NewProvisionerDaemon(t, coderAPI) - authToken := uuid.NewString() - version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, &echo.Responses{ - Parse: echo.ParseComplete, - ProvisionDryRun: echo.ProvisionComplete, - Provision: []*proto.Provision_Response{{ - Type: &proto.Provision_Response_Complete{ - Complete: &proto.Provision_Complete{ - Resources: []*proto.Resource{{ - Name: "example", - Type: "aws_instance", - Agents: []*proto.Agent{{ - Id: uuid.NewString(), - Auth: &proto.Agent_Token{ - Token: authToken, - }, + + t.Run("Connect", func(t *testing.T) { + t.Parallel() + + client, coderAPI := coderdtest.NewWithAPI(t, nil) + user := coderdtest.CreateFirstUser(t, client) + daemonCloser := coderdtest.NewProvisionerDaemon(t, coderAPI) + authToken := uuid.NewString() + version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, &echo.Responses{ + Parse: echo.ParseComplete, + ProvisionDryRun: echo.ProvisionComplete, + Provision: []*proto.Provision_Response{{ + Type: &proto.Provision_Response_Complete{ + Complete: &proto.Provision_Complete{ + Resources: []*proto.Resource{{ + Name: "example", + Type: "aws_instance", + Agents: []*proto.Agent{{ + Id: uuid.NewString(), + Auth: &proto.Agent_Token{ + Token: authToken, + }, + }}, }}, - }}, + }, }, - }, - }}, - }) - template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID) - coderdtest.AwaitTemplateVersionJob(t, client, version.ID) - workspace := coderdtest.CreateWorkspace(t, client, user.OrganizationID, template.ID) - coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID) - daemonCloser.Close() + }}, + }) + template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID) + coderdtest.AwaitTemplateVersionJob(t, client, version.ID) + workspace := coderdtest.CreateWorkspace(t, client, user.OrganizationID, template.ID) + coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID) + daemonCloser.Close() - agentClient := codersdk.New(client.URL) - agentClient.SessionToken = authToken - agentCloser := agent.New(agentClient.ListenWorkspaceAgent, &agent.Options{ - Logger: slogtest.Make(t, nil).Named("agent").Leveled(slog.LevelDebug), - }) - t.Cleanup(func() { - _ = agentCloser.Close() + agentClient := codersdk.New(client.URL) + agentClient.SessionToken = authToken + agentCloser := agent.New(agentClient.ListenWorkspaceAgent, &agent.Options{ + Logger: slogtest.Make(t, nil).Named("agent").Leveled(slog.LevelDebug), + }) + t.Cleanup(func() { + _ = agentCloser.Close() + }) + resources := coderdtest.AwaitWorkspaceAgents(t, client, workspace.LatestBuild.ID) + conn, err := client.DialWorkspaceAgent(context.Background(), resources[0].Agents[0].ID, nil) + require.NoError(t, err) + t.Cleanup(func() { + _ = conn.Close() + }) + _, err = conn.Ping() + require.NoError(t, err) }) - resources := coderdtest.AwaitWorkspaceAgents(t, client, workspace.LatestBuild.ID) - conn, err := client.DialWorkspaceAgent(context.Background(), resources[0].Agents[0].ID, nil) - require.NoError(t, err) - t.Cleanup(func() { - _ = conn.Close() + + t.Run("FailNonLatestBuild", func(t *testing.T) { + t.Parallel() + + ctx := context.Background() + client, coderAPI := coderdtest.NewWithAPI(t, nil) + user := coderdtest.CreateFirstUser(t, client) + daemonCloser := coderdtest.NewProvisionerDaemon(t, coderAPI) + defer daemonCloser.Close() + + authToken := uuid.NewString() + version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, &echo.Responses{ + Parse: echo.ParseComplete, + ProvisionDryRun: echo.ProvisionComplete, + Provision: []*proto.Provision_Response{{ + Type: &proto.Provision_Response_Complete{ + Complete: &proto.Provision_Complete{ + Resources: []*proto.Resource{{ + Name: "example", + Type: "aws_instance", + Agents: []*proto.Agent{{ + Id: uuid.NewString(), + Auth: &proto.Agent_Token{ + Token: authToken, + }, + }}, + }}, + }, + }, + }}, + }) + + template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID) + coderdtest.AwaitTemplateVersionJob(t, client, version.ID) + workspace := coderdtest.CreateWorkspace(t, client, user.OrganizationID, template.ID) + coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID) + + version = coderdtest.UpdateTemplateVersion(t, client, user.OrganizationID, &echo.Responses{ + Parse: echo.ParseComplete, + ProvisionDryRun: echo.ProvisionComplete, + Provision: []*proto.Provision_Response{{ + Type: &proto.Provision_Response_Complete{ + Complete: &proto.Provision_Complete{ + Resources: []*proto.Resource{{ + Name: "example", + Type: "aws_instance", + Agents: []*proto.Agent{{ + Id: uuid.NewString(), + Auth: &proto.Agent_Token{ + Token: uuid.NewString(), + }, + }}, + }}, + }, + }, + }}, + }, template.ID) + coderdtest.AwaitTemplateVersionJob(t, client, version.ID) + + stopBuild, err := client.CreateWorkspaceBuild(context.Background(), workspace.ID, codersdk.CreateWorkspaceBuildRequest{ + TemplateVersionID: version.ID, + Transition: codersdk.WorkspaceTransitionStop, + }) + require.NoError(t, err) + coderdtest.AwaitWorkspaceBuildJob(t, client, stopBuild.ID) + + agentClient := codersdk.New(client.URL) + agentClient.SessionToken = authToken + + _, _, err = agentClient.ListenWorkspaceAgent(ctx, slogtest.Make(t, nil)) + require.Error(t, err) + require.ErrorContains(t, err, "build is outdated") }) - _, err = conn.Ping() - require.NoError(t, err) } func TestWorkspaceAgentTURN(t *testing.T) {
Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.
Alternative Proxies: