Skip to content

Commit f349edc

Browse files
refactor: create tasks in coderd instead of frontend (#19280)
Instead of creating tasks with a specialized call to `CreateWorkspace` on the frontend, we instead lift this to the backend and allow the frontend to simply call `CreateAITask`.
1 parent cda1a3a commit f349edc

File tree

14 files changed

+362
-4
lines changed

14 files changed

+362
-4
lines changed

coderd/aitasks.go

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,20 @@
11
package coderd
22

33
import (
4+
"database/sql"
5+
"errors"
46
"fmt"
57
"net/http"
8+
"slices"
69
"strings"
710

811
"github.com/google/uuid"
912

13+
"github.com/coder/coder/v2/coderd/audit"
14+
"github.com/coder/coder/v2/coderd/database"
1015
"github.com/coder/coder/v2/coderd/httpapi"
16+
"github.com/coder/coder/v2/coderd/httpmw"
17+
"github.com/coder/coder/v2/coderd/rbac"
1118
"github.com/coder/coder/v2/codersdk"
1219
)
1320

@@ -61,3 +68,106 @@ func (api *API) aiTasksPrompts(rw http.ResponseWriter, r *http.Request) {
6168
Prompts: promptsByBuildID,
6269
})
6370
}
71+
72+
// This endpoint is experimental and not guaranteed to be stable, so we're not
73+
// generating public-facing documentation for it.
74+
func (api *API) tasksCreate(rw http.ResponseWriter, r *http.Request) {
75+
var (
76+
ctx = r.Context()
77+
apiKey = httpmw.APIKey(r)
78+
auditor = api.Auditor.Load()
79+
mems = httpmw.OrganizationMembersParam(r)
80+
)
81+
82+
var req codersdk.CreateTaskRequest
83+
if !httpapi.Read(ctx, rw, r, &req) {
84+
return
85+
}
86+
87+
hasAITask, err := api.Database.GetTemplateVersionHasAITask(ctx, req.TemplateVersionID)
88+
if err != nil {
89+
if errors.Is(err, sql.ErrNoRows) || rbac.IsUnauthorizedError(err) {
90+
httpapi.ResourceNotFound(rw)
91+
return
92+
}
93+
94+
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
95+
Message: "Internal error fetching whether the template version has an AI task.",
96+
Detail: err.Error(),
97+
})
98+
return
99+
}
100+
if !hasAITask {
101+
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
102+
Message: fmt.Sprintf(`Template does not have required parameter %q`, codersdk.AITaskPromptParameterName),
103+
})
104+
return
105+
}
106+
107+
createReq := codersdk.CreateWorkspaceRequest{
108+
Name: req.Name,
109+
TemplateVersionID: req.TemplateVersionID,
110+
TemplateVersionPresetID: req.TemplateVersionPresetID,
111+
RichParameterValues: []codersdk.WorkspaceBuildParameter{
112+
{Name: codersdk.AITaskPromptParameterName, Value: req.Prompt},
113+
},
114+
}
115+
116+
var owner workspaceOwner
117+
if mems.User != nil {
118+
// This user fetch is an optimization path for the most common case of creating a
119+
// task for 'Me'.
120+
//
121+
// This is also required to allow `owners` to create workspaces for users
122+
// that are not in an organization.
123+
owner = workspaceOwner{
124+
ID: mems.User.ID,
125+
Username: mems.User.Username,
126+
AvatarURL: mems.User.AvatarURL,
127+
}
128+
} else {
129+
// A task can still be created if the caller can read the organization
130+
// member. The organization is required, which can be sourced from the
131+
// template.
132+
//
133+
// TODO: This code gets called twice for each workspace build request.
134+
// This is inefficient and costs at most 2 extra RTTs to the DB.
135+
// This can be optimized. It exists as it is now for code simplicity.
136+
// The most common case is to create a workspace for 'Me'. Which does
137+
// not enter this code branch.
138+
template, ok := requestTemplate(ctx, rw, createReq, api.Database)
139+
if !ok {
140+
return
141+
}
142+
143+
// If the caller can find the organization membership in the same org
144+
// as the template, then they can continue.
145+
orgIndex := slices.IndexFunc(mems.Memberships, func(mem httpmw.OrganizationMember) bool {
146+
return mem.OrganizationID == template.OrganizationID
147+
})
148+
if orgIndex == -1 {
149+
httpapi.ResourceNotFound(rw)
150+
return
151+
}
152+
153+
member := mems.Memberships[orgIndex]
154+
owner = workspaceOwner{
155+
ID: member.UserID,
156+
Username: member.Username,
157+
AvatarURL: member.AvatarURL,
158+
}
159+
}
160+
161+
aReq, commitAudit := audit.InitRequest[database.WorkspaceTable](rw, &audit.RequestParams{
162+
Audit: *auditor,
163+
Log: api.Logger,
164+
Request: r,
165+
Action: database.AuditActionCreate,
166+
AdditionalFields: audit.AdditionalFields{
167+
WorkspaceOwner: owner.Username,
168+
},
169+
})
170+
171+
defer commitAudit()
172+
createWorkspace(ctx, aReq, apiKey.UserID, api, owner, createReq, rw, r)
173+
}

coderd/aitasks_test.go

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
package coderd_test
22

33
import (
4+
"net/http"
45
"testing"
56

67
"github.com/google/uuid"
8+
"github.com/stretchr/testify/assert"
79
"github.com/stretchr/testify/require"
810

911
"github.com/coder/coder/v2/coderd/coderdtest"
@@ -139,3 +141,125 @@ func TestAITasksPrompts(t *testing.T) {
139141
require.Empty(t, prompts.Prompts)
140142
})
141143
}
144+
145+
func TestTaskCreate(t *testing.T) {
146+
t.Parallel()
147+
148+
t.Run("OK", func(t *testing.T) {
149+
t.Parallel()
150+
151+
var (
152+
ctx = testutil.Context(t, testutil.WaitShort)
153+
154+
taskName = "task-foo-bar-baz"
155+
taskPrompt = "Some task prompt"
156+
)
157+
158+
client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true})
159+
user := coderdtest.CreateFirstUser(t, client)
160+
161+
// Given: A template with an "AI Prompt" parameter
162+
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, &echo.Responses{
163+
Parse: echo.ParseComplete,
164+
ProvisionApply: echo.ApplyComplete,
165+
ProvisionPlan: []*proto.Response{
166+
{Type: &proto.Response_Plan{Plan: &proto.PlanComplete{
167+
Parameters: []*proto.RichParameter{{Name: "AI Prompt", Type: "string"}},
168+
HasAiTasks: true,
169+
}}},
170+
},
171+
})
172+
coderdtest.AwaitTemplateVersionJobCompleted(t, client, version.ID)
173+
template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID)
174+
175+
expClient := codersdk.NewExperimentalClient(client)
176+
177+
// When: We attempt to create a Task.
178+
workspace, err := expClient.CreateTask(ctx, "me", codersdk.CreateTaskRequest{
179+
Name: taskName,
180+
TemplateVersionID: template.ActiveVersionID,
181+
Prompt: taskPrompt,
182+
})
183+
require.NoError(t, err)
184+
coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, workspace.LatestBuild.ID)
185+
186+
// Then: We expect a workspace to have been created.
187+
assert.Equal(t, taskName, workspace.Name)
188+
assert.Equal(t, template.ID, workspace.TemplateID)
189+
190+
// And: We expect it to have the "AI Prompt" parameter correctly set.
191+
parameters, err := client.WorkspaceBuildParameters(ctx, workspace.LatestBuild.ID)
192+
require.NoError(t, err)
193+
require.Len(t, parameters, 1)
194+
assert.Equal(t, codersdk.AITaskPromptParameterName, parameters[0].Name)
195+
assert.Equal(t, taskPrompt, parameters[0].Value)
196+
})
197+
198+
t.Run("FailsOnNonTaskTemplate", func(t *testing.T) {
199+
t.Parallel()
200+
201+
var (
202+
ctx = testutil.Context(t, testutil.WaitShort)
203+
204+
taskName = "task-foo-bar-baz"
205+
taskPrompt = "Some task prompt"
206+
)
207+
208+
client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true})
209+
user := coderdtest.CreateFirstUser(t, client)
210+
211+
// Given: A template without an "AI Prompt" parameter
212+
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, nil)
213+
coderdtest.AwaitTemplateVersionJobCompleted(t, client, version.ID)
214+
template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID)
215+
216+
expClient := codersdk.NewExperimentalClient(client)
217+
218+
// When: We attempt to create a Task.
219+
_, err := expClient.CreateTask(ctx, "me", codersdk.CreateTaskRequest{
220+
Name: taskName,
221+
TemplateVersionID: template.ActiveVersionID,
222+
Prompt: taskPrompt,
223+
})
224+
225+
// Then: We expect it to fail.
226+
var sdkErr *codersdk.Error
227+
require.Error(t, err)
228+
require.ErrorAsf(t, err, &sdkErr, "error should be of type *codersdk.Error")
229+
assert.Equal(t, http.StatusBadRequest, sdkErr.StatusCode())
230+
})
231+
232+
t.Run("FailsOnInvalidTemplate", func(t *testing.T) {
233+
t.Parallel()
234+
235+
var (
236+
ctx = testutil.Context(t, testutil.WaitShort)
237+
238+
taskName = "task-foo-bar-baz"
239+
taskPrompt = "Some task prompt"
240+
)
241+
242+
client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true})
243+
user := coderdtest.CreateFirstUser(t, client)
244+
245+
// Given: A template
246+
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, nil)
247+
coderdtest.AwaitTemplateVersionJobCompleted(t, client, version.ID)
248+
_ = coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID)
249+
250+
expClient := codersdk.NewExperimentalClient(client)
251+
252+
// When: We attempt to create a Task with an invalid template version ID.
253+
_, err := expClient.CreateTask(ctx, "me", codersdk.CreateTaskRequest{
254+
Name: taskName,
255+
TemplateVersionID: uuid.New(),
256+
Prompt: taskPrompt,
257+
})
258+
259+
// Then: We expect it to fail.
260+
var sdkErr *codersdk.Error
261+
require.Error(t, err)
262+
require.ErrorAsf(t, err, &sdkErr, "error should be of type *codersdk.Error")
263+
assert.Equal(t, http.StatusNotFound, sdkErr.StatusCode())
264+
})
265+
}

coderd/coderd.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -995,6 +995,15 @@ func New(options *Options) *API {
995995
r.Route("/aitasks", func(r chi.Router) {
996996
r.Get("/prompts", api.aiTasksPrompts)
997997
})
998+
r.Route("/tasks", func(r chi.Router) {
999+
r.Use(apiRateLimiter)
1000+
1001+
r.Route("/{user}", func(r chi.Router) {
1002+
r.Use(httpmw.ExtractOrganizationMembersParam(options.Database, api.HTTPAuth.Authorize))
1003+
1004+
r.Post("/", api.tasksCreate)
1005+
})
1006+
})
9981007
r.Route("/mcp", func(r chi.Router) {
9991008
r.Use(
10001009
httpmw.RequireExperimentWithDevBypass(api.Experiments, codersdk.ExperimentOAuth2, codersdk.ExperimentMCPServerHTTP),

coderd/database/dbauthz/dbauthz.go

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2863,6 +2863,17 @@ func (q *querier) GetTemplateVersionByTemplateIDAndName(ctx context.Context, arg
28632863
return tv, nil
28642864
}
28652865

2866+
func (q *querier) GetTemplateVersionHasAITask(ctx context.Context, id uuid.UUID) (bool, error) {
2867+
// If we can successfully call `GetTemplateVersionByID`, then
2868+
// we know the actor has sufficient permissions to know if the
2869+
// template has an AI task.
2870+
if _, err := q.GetTemplateVersionByID(ctx, id); err != nil {
2871+
return false, err
2872+
}
2873+
2874+
return q.db.GetTemplateVersionHasAITask(ctx, id)
2875+
}
2876+
28662877
func (q *querier) GetTemplateVersionParameters(ctx context.Context, templateVersionID uuid.UUID) ([]database.TemplateVersionParameter, error) {
28672878
// An actor can read template version parameters if they can read the related template.
28682879
tv, err := q.db.GetTemplateVersionByID(ctx, templateVersionID)

coderd/database/dbauthz/dbauthz_test.go

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1443,6 +1443,20 @@ func (s *MethodTestSuite) TestTemplate() {
14431443
})
14441444
check.Args(now.Add(-time.Hour)).Asserts(rbac.ResourceTemplate.All(), policy.ActionRead)
14451445
}))
1446+
s.Run("GetTemplateVersionHasAITask", s.Subtest(func(db database.Store, check *expects) {
1447+
o := dbgen.Organization(s.T(), db, database.Organization{})
1448+
u := dbgen.User(s.T(), db, database.User{})
1449+
t := dbgen.Template(s.T(), db, database.Template{
1450+
OrganizationID: o.ID,
1451+
CreatedBy: u.ID,
1452+
})
1453+
tv := dbgen.TemplateVersion(s.T(), db, database.TemplateVersion{
1454+
OrganizationID: o.ID,
1455+
TemplateID: uuid.NullUUID{UUID: t.ID, Valid: true},
1456+
CreatedBy: u.ID,
1457+
})
1458+
check.Args(tv.ID).Asserts(t, policy.ActionRead)
1459+
}))
14461460
s.Run("GetTemplatesWithFilter", s.Subtest(func(db database.Store, check *expects) {
14471461
o := dbgen.Organization(s.T(), db, database.Organization{})
14481462
u := dbgen.User(s.T(), db, database.User{})

coderd/database/dbmetrics/querymetrics.go

Lines changed: 7 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

coderd/database/dbmock/dbmock.go

Lines changed: 15 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

coderd/database/querier.go

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

coderd/database/queries.sql.go

Lines changed: 15 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

coderd/database/queries/templateversions.sql

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,3 +234,10 @@ FROM
234234
WHERE
235235
template_versions.id IN (archived_versions.id)
236236
RETURNING template_versions.id;
237+
238+
-- name: GetTemplateVersionHasAITask :one
239+
SELECT EXISTS (
240+
SELECT 1
241+
FROM template_versions
242+
WHERE id = $1 AND has_ai_task = TRUE
243+
);

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy