Skip to content

Commit e9b7463

Browse files
committed
Add workspace route proxying endpoint
- Makes the workspace conn cache concurrency-safe - Reduces unnecessary open checks in `peer.Channel` - Fixes the use of a temporary context when dialing a workspace agent
1 parent 4d8b257 commit e9b7463

File tree

13 files changed

+247
-182
lines changed

13 files changed

+247
-182
lines changed

agent/conn.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ func (c *Conn) DialContext(ctx context.Context, network string, addr string) (ne
102102
var res dialResponse
103103
err = dec.Decode(&res)
104104
if err != nil {
105-
return nil, xerrors.Errorf("failed to decode initial packet: %w", err)
105+
return nil, xerrors.Errorf("decode agent dial response: %w", err)
106106
}
107107
if res.Error != "" {
108108
_ = channel.Close()

coderd/coderd.go

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,6 @@ import (
1515
"golang.org/x/xerrors"
1616
"google.golang.org/api/idtoken"
1717

18-
"github.com/go-chi/cors"
19-
2018
sdktrace "go.opentelemetry.io/otel/sdk/trace"
2119

2220
"cdr.dev/slog"
@@ -97,7 +95,7 @@ func New(options *Options) *API {
9795
tracing.HTTPMW(api.TracerProvider, "coderd.http"),
9896
)
9997

100-
r.Route("/{user}/{workspaceagent}/{application}", func(r chi.Router) {
98+
r.Route("/@{user}/{workspaceagent}/apps/{application}", func(r chi.Router) {
10199
r.Use(
102100
httpmw.RateLimitPerMinute(options.APIRateLimit),
103101
apiKeyMiddleware,
@@ -327,9 +325,6 @@ func New(options *Options) *API {
327325
r.Put("/extend", api.putExtendWorkspace)
328326
})
329327
})
330-
r.Route("/wildcardauth", func(r chi.Router) {
331-
r.Use(cors.Handler(cors.Options{}))
332-
})
333328
r.Route("/workspacebuilds/{workspacebuild}", func(r chi.Router) {
334329
r.Use(
335330
apiKeyMiddleware,
@@ -357,10 +352,12 @@ type API struct {
357352
}
358353

359354
// Close waits for all WebSocket connections to drain before returning.
360-
func (api *API) Close() {
355+
func (api *API) Close() error {
361356
api.websocketWaitMutex.Lock()
362357
api.websocketWaitGroup.Wait()
363358
api.websocketWaitMutex.Unlock()
359+
360+
return api.workspaceAgentCache.Close()
364361
}
365362

366363
func debugLogRequest(log slog.Logger) func(http.Handler) http.Handler {

coderd/coderdtest/coderdtest.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ func NewWithAPI(t *testing.T, options *Options) (*codersdk.Client, *coderd.API)
172172
cancelFunc()
173173
_ = turnServer.Close()
174174
srv.Close()
175-
coderAPI.Close()
175+
_ = coderAPI.Close()
176176
})
177177

178178
return codersdk.New(serverURL), coderAPI

coderd/database/databasefake/databasefake.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1061,7 +1061,7 @@ func (q *fakeQuerier) GetWorkspaceAgentsByResourceIDs(_ context.Context, resourc
10611061
return workspaceAgents, nil
10621062
}
10631063

1064-
func (q *fakeQuerier) GetWorkspaceAppByAgentIDAndName(ctx context.Context, arg database.GetWorkspaceAppByAgentIDAndNameParams) (database.WorkspaceApp, error) {
1064+
func (q *fakeQuerier) GetWorkspaceAppByAgentIDAndName(_ context.Context, arg database.GetWorkspaceAppByAgentIDAndNameParams) (database.WorkspaceApp, error) {
10651065
q.mutex.RLock()
10661066
defer q.mutex.RUnlock()
10671067

coderd/workspaceagents.go

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package coderd
22

33
import (
4+
"context"
45
"database/sql"
56
"encoding/json"
67
"fmt"
@@ -382,12 +383,12 @@ func (api *API) workspaceAgentPTY(rw http.ResponseWriter, r *http.Request) {
382383
}()
383384
// Accept text connections, because it's more developer friendly.
384385
wsNetConn := websocket.NetConn(r.Context(), conn, websocket.MessageBinary)
385-
agentConn, err := api.dialWorkspaceAgent(r, workspaceAgent.ID)
386+
agentConn, release, err := api.workspaceAgentCache.Acquire(r, workspaceAgent.ID)
386387
if err != nil {
387388
_ = conn.Close(websocket.StatusInternalError, httpapi.WebsocketCloseSprintf("dial workspace agent: %s", err))
388389
return
389390
}
390-
defer agentConn.Close()
391+
defer release()
391392
ptNetConn, err := agentConn.ReconnectingPTY(reconnect.String(), uint16(height), uint16(width), "")
392393
if err != nil {
393394
_ = conn.Close(websocket.StatusInternalError, httpapi.WebsocketCloseSprintf("dial: %s", err))
@@ -404,8 +405,9 @@ func (api *API) workspaceAgentPTY(rw http.ResponseWriter, r *http.Request) {
404405
// dialWorkspaceAgent connects to a workspace agent by ID.
405406
func (api *API) dialWorkspaceAgent(r *http.Request, agentID uuid.UUID) (*agent.Conn, error) {
406407
client, server := provisionersdk.TransportPipe()
408+
ctx, cancelFunc := context.WithCancel(context.Background())
407409
go func() {
408-
_ = peerbroker.ProxyListen(r.Context(), server, peerbroker.ProxyOptions{
410+
_ = peerbroker.ProxyListen(ctx, server, peerbroker.ProxyOptions{
409411
ChannelID: agentID.String(),
410412
Logger: api.Logger.Named("peerbroker-proxy-dial"),
411413
Pubsub: api.Pubsub,
@@ -415,8 +417,9 @@ func (api *API) dialWorkspaceAgent(r *http.Request, agentID uuid.UUID) (*agent.C
415417
}()
416418

417419
peerClient := proto.NewDRPCPeerBrokerClient(provisionersdk.Conn(client))
418-
stream, err := peerClient.NegotiateConnection(r.Context())
420+
stream, err := peerClient.NegotiateConnection(ctx)
419421
if err != nil {
422+
cancelFunc()
420423
return nil, xerrors.Errorf("negotiate: %w", err)
421424
}
422425
options := &peer.ConnOptions{
@@ -452,8 +455,13 @@ func (api *API) dialWorkspaceAgent(r *http.Request, agentID uuid.UUID) (*agent.C
452455
}))
453456
peerConn, err := peerbroker.Dial(stream, append(api.ICEServers, turnconn.Proxy), options)
454457
if err != nil {
458+
cancelFunc()
455459
return nil, xerrors.Errorf("dial: %w", err)
456460
}
461+
go func() {
462+
<-peerConn.Closed()
463+
cancelFunc()
464+
}()
457465
return &agent.Conn{
458466
Negotiator: peerClient,
459467
Conn: peerConn,

coderd/workspaceapps.go

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,22 @@ func (api *API) workspaceAppsProxyPath(rw http.ResponseWriter, r *http.Request)
123123
defer release()
124124

125125
proxy := httputil.NewSingleHostReverseProxy(appURL)
126+
// Write the error directly using our format!
127+
proxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) {
128+
httpapi.Write(w, http.StatusBadGateway, httpapi.Response{
129+
Message: err.Error(),
130+
})
131+
}
126132
proxy.Transport = conn.HTTPTransport()
127-
r.URL.Path = chi.URLParam(r, "*")
133+
path := chi.URLParam(r, "*")
134+
if !strings.HasSuffix(r.URL.Path, "/") && path == "" {
135+
// Web applications typically request paths relative to the
136+
// root URL. This allows for routing behind a proxy or subpath.
137+
// See https://github.com/coder/code-server/issues/241 for examples.
138+
r.URL.Path += "/"
139+
http.Redirect(rw, r, r.URL.String(), http.StatusTemporaryRedirect)
140+
return
141+
}
142+
r.URL.Path = path
128143
proxy.ServeHTTP(rw, r)
129144
}

coderd/workspaceapps_test.go

Lines changed: 68 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -21,69 +21,81 @@ import (
2121

2222
func TestWorkspaceAppsProxyPath(t *testing.T) {
2323
t.Parallel()
24-
t.Run("Proxies", func(t *testing.T) {
25-
t.Parallel()
26-
// #nosec
27-
ln, err := net.Listen("tcp", ":0")
28-
require.NoError(t, err)
29-
server := http.Server{
30-
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
31-
w.WriteHeader(http.StatusOK)
32-
}),
33-
}
34-
t.Cleanup(func() {
35-
_ = server.Close()
36-
_ = ln.Close()
37-
})
38-
go server.Serve(ln)
39-
tcpAddr, _ := ln.Addr().(*net.TCPAddr)
24+
// #nosec
25+
ln, err := net.Listen("tcp", ":0")
26+
require.NoError(t, err)
27+
server := http.Server{
28+
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
29+
w.WriteHeader(http.StatusOK)
30+
}),
31+
}
32+
t.Cleanup(func() {
33+
_ = server.Close()
34+
_ = ln.Close()
35+
})
36+
go server.Serve(ln)
37+
tcpAddr, _ := ln.Addr().(*net.TCPAddr)
4038

41-
client, coderAPI := coderdtest.NewWithAPI(t, nil)
42-
user := coderdtest.CreateFirstUser(t, client)
43-
daemonCloser := coderdtest.NewProvisionerDaemon(t, coderAPI)
44-
authToken := uuid.NewString()
45-
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, &echo.Responses{
46-
Parse: echo.ParseComplete,
47-
ProvisionDryRun: echo.ProvisionComplete,
48-
Provision: []*proto.Provision_Response{{
49-
Type: &proto.Provision_Response_Complete{
50-
Complete: &proto.Provision_Complete{
51-
Resources: []*proto.Resource{{
52-
Name: "example",
53-
Type: "aws_instance",
54-
Agents: []*proto.Agent{{
55-
Id: uuid.NewString(),
56-
Auth: &proto.Agent_Token{
57-
Token: authToken,
58-
},
59-
Apps: []*proto.App{{
60-
Name: "example",
61-
Url: fmt.Sprintf("http://127.0.0.1:%d", tcpAddr.Port),
62-
}},
39+
client, coderAPI := coderdtest.NewWithAPI(t, nil)
40+
user := coderdtest.CreateFirstUser(t, client)
41+
daemonCloser := coderdtest.NewProvisionerDaemon(t, coderAPI)
42+
authToken := uuid.NewString()
43+
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, &echo.Responses{
44+
Parse: echo.ParseComplete,
45+
ProvisionDryRun: echo.ProvisionComplete,
46+
Provision: []*proto.Provision_Response{{
47+
Type: &proto.Provision_Response_Complete{
48+
Complete: &proto.Provision_Complete{
49+
Resources: []*proto.Resource{{
50+
Name: "example",
51+
Type: "aws_instance",
52+
Agents: []*proto.Agent{{
53+
Id: uuid.NewString(),
54+
Auth: &proto.Agent_Token{
55+
Token: authToken,
56+
},
57+
Apps: []*proto.App{{
58+
Name: "example",
59+
Url: fmt.Sprintf("http://127.0.0.1:%d", tcpAddr.Port),
6360
}},
6461
}},
65-
},
62+
}},
6663
},
67-
}},
68-
})
69-
template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID)
70-
coderdtest.AwaitTemplateVersionJob(t, client, version.ID)
71-
workspace := coderdtest.CreateWorkspace(t, client, user.OrganizationID, template.ID)
72-
coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID)
73-
daemonCloser.Close()
64+
},
65+
}},
66+
})
67+
template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID)
68+
coderdtest.AwaitTemplateVersionJob(t, client, version.ID)
69+
workspace := coderdtest.CreateWorkspace(t, client, user.OrganizationID, template.ID)
70+
coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID)
71+
daemonCloser.Close()
7472

75-
agentClient := codersdk.New(client.URL)
76-
agentClient.SessionToken = authToken
77-
agentCloser := agent.New(agentClient.ListenWorkspaceAgent, &agent.Options{
78-
Logger: slogtest.Make(t, nil),
79-
})
80-
t.Cleanup(func() {
81-
_ = agentCloser.Close()
82-
})
83-
coderdtest.AwaitWorkspaceAgents(t, client, workspace.LatestBuild.ID)
73+
agentClient := codersdk.New(client.URL)
74+
agentClient.SessionToken = authToken
75+
agentCloser := agent.New(agentClient.ListenWorkspaceAgent, &agent.Options{
76+
Logger: slogtest.Make(t, nil),
77+
})
78+
t.Cleanup(func() {
79+
_ = agentCloser.Close()
80+
})
81+
coderdtest.AwaitWorkspaceAgents(t, client, workspace.LatestBuild.ID)
82+
client.HTTPClient.CheckRedirect = func(req *http.Request, via []*http.Request) error {
83+
return http.ErrUseLastResponse
84+
}
8485

85-
resp, err := client.Request(context.Background(), http.MethodGet, "/me/"+workspace.Name+"/example", nil)
86+
t.Run("RedirectsWithSlash", func(t *testing.T) {
87+
t.Parallel()
88+
resp, err := client.Request(context.Background(), http.MethodGet, "/@me/"+workspace.Name+"/apps/example", nil)
89+
require.NoError(t, err)
90+
defer resp.Body.Close()
91+
require.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode)
92+
})
93+
94+
t.Run("Proxies", func(t *testing.T) {
95+
t.Parallel()
96+
resp, err := client.Request(context.Background(), http.MethodGet, "/@me/"+workspace.Name+"/apps/example/", nil)
8697
require.NoError(t, err)
98+
defer resp.Body.Close()
8799
body, err := io.ReadAll(resp.Body)
88100
require.NoError(t, err)
89101
require.Equal(t, "", string(body))

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy