Skip to content

Commit 8cbd44a

Browse files
committed
chore: populate connectionlog count using a separate query
1 parent 482a5d3 commit 8cbd44a

File tree

14 files changed

+671
-5
lines changed

14 files changed

+671
-5
lines changed

coderd/database/dbauthz/dbauthz.go

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1323,6 +1323,21 @@ func (q *querier) CleanTailnetTunnels(ctx context.Context) error {
13231323
return q.db.CleanTailnetTunnels(ctx)
13241324
}
13251325

1326+
func (q *querier) CountConnectionLogs(ctx context.Context, arg database.CountConnectionLogsParams) (int64, error) {
1327+
// Just like the actual query, shortcut if the user is an owner.
1328+
err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceConnectionLog)
1329+
if err == nil {
1330+
return q.db.CountConnectionLogs(ctx, arg)
1331+
}
1332+
1333+
prep, err := prepareSQLFilter(ctx, q.auth, policy.ActionRead, rbac.ResourceConnectionLog.Type)
1334+
if err != nil {
1335+
return 0, xerrors.Errorf("(dev error) prepare sql filter: %w", err)
1336+
}
1337+
1338+
return q.db.CountAuthorizedConnectionLogs(ctx, arg, prep)
1339+
}
1340+
13261341
func (q *querier) CountInProgressPrebuilds(ctx context.Context) ([]database.CountInProgressPrebuildsRow, error) {
13271342
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceWorkspace.All()); err != nil {
13281343
return nil, err
@@ -5301,3 +5316,7 @@ func (q *querier) GetAuthorizedAuditLogsOffset(ctx context.Context, arg database
53015316
func (q *querier) GetAuthorizedConnectionLogsOffset(ctx context.Context, arg database.GetConnectionLogsOffsetParams, _ rbac.PreparedAuthorized) ([]database.GetConnectionLogsOffsetRow, error) {
53025317
return q.GetConnectionLogsOffset(ctx, arg)
53035318
}
5319+
5320+
func (q *querier) CountAuthorizedConnectionLogs(ctx context.Context, arg database.CountConnectionLogsParams, _ rbac.PreparedAuthorized) (int64, error) {
5321+
return q.CountConnectionLogs(ctx, arg)
5322+
}

coderd/database/dbauthz/dbauthz_test.go

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -391,6 +391,42 @@ func (s *MethodTestSuite) TestConnectionLogs() {
391391
LimitOpt: 10,
392392
}, emptyPreparedAuthorized{}).Asserts(rbac.ResourceConnectionLog, policy.ActionRead)
393393
}))
394+
s.Run("CountConnectionLogs", s.Subtest(func(db database.Store, check *expects) {
395+
ws := createWorkspace(s.T(), db)
396+
_ = dbgen.ConnectionLog(s.T(), db, database.UpsertConnectionLogParams{
397+
Type: database.ConnectionTypeSsh,
398+
WorkspaceID: ws.ID,
399+
OrganizationID: ws.OrganizationID,
400+
WorkspaceOwnerID: ws.OwnerID,
401+
})
402+
_ = dbgen.ConnectionLog(s.T(), db, database.UpsertConnectionLogParams{
403+
Type: database.ConnectionTypeSsh,
404+
WorkspaceID: ws.ID,
405+
OrganizationID: ws.OrganizationID,
406+
WorkspaceOwnerID: ws.OwnerID,
407+
})
408+
check.Args(database.CountConnectionLogsParams{}).Asserts(
409+
rbac.ResourceConnectionLog, policy.ActionRead,
410+
).WithNotAuthorized("nil")
411+
}))
412+
s.Run("CountAuthorizedConnectionLogs", s.Subtest(func(db database.Store, check *expects) {
413+
ws := createWorkspace(s.T(), db)
414+
_ = dbgen.ConnectionLog(s.T(), db, database.UpsertConnectionLogParams{
415+
Type: database.ConnectionTypeSsh,
416+
WorkspaceID: ws.ID,
417+
OrganizationID: ws.OrganizationID,
418+
WorkspaceOwnerID: ws.OwnerID,
419+
})
420+
_ = dbgen.ConnectionLog(s.T(), db, database.UpsertConnectionLogParams{
421+
Type: database.ConnectionTypeSsh,
422+
WorkspaceID: ws.ID,
423+
OrganizationID: ws.OrganizationID,
424+
WorkspaceOwnerID: ws.OwnerID,
425+
})
426+
check.Args(database.CountConnectionLogsParams{}, emptyPreparedAuthorized{}).Asserts(
427+
rbac.ResourceConnectionLog, policy.ActionRead,
428+
)
429+
}))
394430
}
395431

396432
func (s *MethodTestSuite) TestFile() {

coderd/database/dbauthz/setup_test.go

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -271,7 +271,7 @@ func (s *MethodTestSuite) NotAuthorizedErrorTest(ctx context.Context, az *coderd
271271

272272
// This is unfortunate, but if we are using `Filter` the error returned will be nil. So filter out
273273
// any case where the error is nil and the response is an empty slice.
274-
if err != nil || !hasEmptySliceResponse(resp) {
274+
if err != nil || !hasEmptyResponse(resp) {
275275
// Expect the default error
276276
if testCase.notAuthorizedExpect == "" {
277277
s.ErrorContainsf(err, "unauthorized", "error string should have a good message")
@@ -297,7 +297,7 @@ func (s *MethodTestSuite) NotAuthorizedErrorTest(ctx context.Context, az *coderd
297297

298298
// This is unfortunate, but if we are using `Filter` the error returned will be nil. So filter out
299299
// any case where the error is nil and the response is an empty slice.
300-
if err != nil || !hasEmptySliceResponse(resp) {
300+
if err != nil || !hasEmptyResponse(resp) {
301301
if testCase.cancelledCtxExpect == "" {
302302
s.Errorf(err, "method should an error with cancellation")
303303
s.ErrorIsf(err, context.Canceled, "error should match context.Canceled")
@@ -308,13 +308,20 @@ func (s *MethodTestSuite) NotAuthorizedErrorTest(ctx context.Context, az *coderd
308308
})
309309
}
310310

311-
func hasEmptySliceResponse(values []reflect.Value) bool {
311+
func hasEmptyResponse(values []reflect.Value) bool {
312312
for _, r := range values {
313313
if r.Kind() == reflect.Slice || r.Kind() == reflect.Array {
314314
if r.Len() == 0 {
315315
return true
316316
}
317317
}
318+
319+
// Special case for int64, as it's the return type for count queries.
320+
if r.Kind() == reflect.Int64 {
321+
if r.Int() == 0 {
322+
return true
323+
}
324+
}
318325
}
319326
return false
320327
}

coderd/database/dbmem/dbmem.go

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1780,6 +1780,10 @@ func (*FakeQuerier) CleanTailnetTunnels(context.Context) error {
17801780
return ErrUnimplemented
17811781
}
17821782

1783+
func (q *FakeQuerier) CountConnectionLogs(ctx context.Context, arg database.CountConnectionLogsParams) (int64, error) {
1784+
return q.CountAuthorizedConnectionLogs(ctx, arg, nil)
1785+
}
1786+
17831787
func (q *FakeQuerier) CountInProgressPrebuilds(ctx context.Context) ([]database.CountInProgressPrebuildsRow, error) {
17841788
return nil, ErrUnimplemented
17851789
}
@@ -14160,3 +14164,97 @@ func (q *FakeQuerier) GetAuthorizedConnectionLogsOffset(ctx context.Context, arg
1416014164

1416114165
return logs, nil
1416214166
}
14167+
14168+
func (q *FakeQuerier) CountAuthorizedConnectionLogs(ctx context.Context, arg database.CountConnectionLogsParams, prepared rbac.PreparedAuthorized) (int64, error) {
14169+
if err := validateDatabaseType(arg); err != nil {
14170+
return 0, err
14171+
}
14172+
14173+
// Call this to match the same function calls as the SQL implementation.
14174+
// It functionally does nothing for filtering.
14175+
if prepared != nil {
14176+
_, err := prepared.CompileToSQL(ctx, regosql.ConvertConfig{
14177+
VariableConverter: regosql.ConnectionLogConverter(),
14178+
})
14179+
if err != nil {
14180+
return 0, err
14181+
}
14182+
}
14183+
14184+
q.mutex.RLock()
14185+
defer q.mutex.RUnlock()
14186+
14187+
var count int64
14188+
14189+
for _, clog := range q.connectionLogs {
14190+
if arg.OrganizationID != uuid.Nil && clog.OrganizationID != arg.OrganizationID {
14191+
continue
14192+
}
14193+
if arg.WorkspaceOwner != "" {
14194+
workspaceOwner, err := q.getUserByIDNoLock(clog.WorkspaceOwnerID)
14195+
if err == nil && !strings.EqualFold(arg.WorkspaceOwner, workspaceOwner.Username) {
14196+
continue
14197+
}
14198+
}
14199+
if arg.Type != "" && string(clog.Type) != arg.Type {
14200+
continue
14201+
}
14202+
if arg.UserID != uuid.Nil && (!clog.UserID.Valid || clog.UserID.UUID != arg.UserID) {
14203+
continue
14204+
}
14205+
if arg.Username != "" {
14206+
if !clog.UserID.Valid {
14207+
continue
14208+
}
14209+
user, err := q.getUserByIDNoLock(clog.UserID.UUID)
14210+
if err != nil || user.Username != arg.Username {
14211+
continue
14212+
}
14213+
}
14214+
if arg.Email != "" {
14215+
if !clog.UserID.Valid {
14216+
continue
14217+
}
14218+
user, err := q.getUserByIDNoLock(clog.UserID.UUID)
14219+
if err != nil || user.Email != arg.Email {
14220+
continue
14221+
}
14222+
}
14223+
if !arg.StartedAfter.IsZero() && clog.Time.Before(arg.StartedAfter) {
14224+
continue
14225+
}
14226+
if !arg.StartedBefore.IsZero() && clog.Time.After(arg.StartedBefore) {
14227+
continue
14228+
}
14229+
if !arg.ClosedAfter.IsZero() && (!clog.CloseTime.Valid || clog.CloseTime.Time.Before(arg.ClosedAfter)) {
14230+
continue
14231+
}
14232+
if !arg.ClosedBefore.IsZero() && (!clog.CloseTime.Valid || clog.CloseTime.Time.After(arg.ClosedBefore)) {
14233+
continue
14234+
}
14235+
if arg.WorkspaceID != uuid.Nil && clog.WorkspaceID != arg.WorkspaceID {
14236+
continue
14237+
}
14238+
if arg.ConnectionID != uuid.Nil && (!clog.ConnectionID.Valid || clog.ConnectionID.UUID != arg.ConnectionID) {
14239+
continue
14240+
}
14241+
if arg.Status != "" {
14242+
if clog.Type == database.ConnectionTypeWorkspaceApp ||
14243+
clog.Type == database.ConnectionTypePortForwarding {
14244+
continue
14245+
}
14246+
isConnected := !clog.CloseTime.Valid
14247+
if (arg.Status == "connected" && !isConnected) || (arg.Status == "disconnected" && isConnected) {
14248+
continue
14249+
}
14250+
}
14251+
14252+
if prepared != nil && prepared.Authorize(ctx, clog.RBACObject()) != nil {
14253+
continue
14254+
}
14255+
14256+
count++
14257+
}
14258+
14259+
return count, nil
14260+
}

coderd/database/dbmetrics/querymetrics.go

Lines changed: 14 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

coderd/database/dbmock/dbmock.go

Lines changed: 30 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

coderd/database/modelqueries.go

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -566,6 +566,7 @@ func (q *sqlQuerier) GetAuthorizedAuditLogsOffset(ctx context.Context, arg GetAu
566566

567567
type connectionLogQuerier interface {
568568
GetAuthorizedConnectionLogsOffset(ctx context.Context, arg GetConnectionLogsOffsetParams, prepared rbac.PreparedAuthorized) ([]GetConnectionLogsOffsetRow, error)
569+
CountAuthorizedConnectionLogs(ctx context.Context, arg CountConnectionLogsParams, prepared rbac.PreparedAuthorized) (int64, error)
569570
}
570571

571572
func (q *sqlQuerier) GetAuthorizedConnectionLogsOffset(ctx context.Context, arg GetConnectionLogsOffsetParams, prepared rbac.PreparedAuthorized) ([]GetConnectionLogsOffsetRow, error) {
@@ -653,6 +654,53 @@ func (q *sqlQuerier) GetAuthorizedConnectionLogsOffset(ctx context.Context, arg
653654
return items, nil
654655
}
655656

657+
func (q *sqlQuerier) CountAuthorizedConnectionLogs(ctx context.Context, arg CountConnectionLogsParams, prepared rbac.PreparedAuthorized) (int64, error) {
658+
authorizedFilter, err := prepared.CompileToSQL(ctx, regosql.ConvertConfig{
659+
VariableConverter: regosql.ConnectionLogConverter(),
660+
})
661+
if err != nil {
662+
return 0, xerrors.Errorf("compile authorized filter: %w", err)
663+
}
664+
filtered, err := insertAuthorizedFilter(countConnectionLogs, fmt.Sprintf(" AND %s", authorizedFilter))
665+
if err != nil {
666+
return 0, xerrors.Errorf("insert authorized filter: %w", err)
667+
}
668+
669+
query := fmt.Sprintf("-- name: CountAuthorizedConnectionLogs :one\n%s", filtered)
670+
rows, err := q.db.QueryContext(ctx, query,
671+
arg.OrganizationID,
672+
arg.WorkspaceOwner,
673+
arg.Type,
674+
arg.UserID,
675+
arg.Username,
676+
arg.Email,
677+
arg.StartedAfter,
678+
arg.StartedBefore,
679+
arg.ClosedAfter,
680+
arg.ClosedBefore,
681+
arg.WorkspaceID,
682+
arg.ConnectionID,
683+
arg.Status,
684+
)
685+
if err != nil {
686+
return 0, err
687+
}
688+
defer rows.Close()
689+
var count int64
690+
for rows.Next() {
691+
if err := rows.Scan(&count); err != nil {
692+
return 0, err
693+
}
694+
}
695+
if err := rows.Close(); err != nil {
696+
return 0, err
697+
}
698+
if err := rows.Err(); err != nil {
699+
return 0, err
700+
}
701+
return count, nil
702+
}
703+
656704
func insertAuthorizedFilter(query string, replaceWith string) (string, error) {
657705
if !strings.Contains(query, authorizedQueryPlaceholder) {
658706
return "", xerrors.Errorf("query does not contain authorized replace string, this is not an authorized query")

coderd/database/modelqueries_internal_test.go

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
package database
22

33
import (
4+
"regexp"
5+
"strings"
46
"testing"
57
"time"
68

@@ -54,3 +56,35 @@ func TestWorkspaceTableConvert(t *testing.T) {
5456
"'workspace.WorkspaceTable()' is not missing at least 1 field when converting to 'WorkspaceTable'. "+
5557
"To resolve this, go to the 'func (w Workspace) WorkspaceTable()' and ensure all fields are converted.")
5658
}
59+
60+
func TestConnectionLogsQueryConsistency(t *testing.T) {
61+
t.Parallel()
62+
63+
getWhereClause := extractWhereClause(getConnectionLogsOffset)
64+
require.NotEmpty(t, getWhereClause, "getConnectionLogsOffset query should have a WHERE clause")
65+
66+
countWhereClause := extractWhereClause(countConnectionLogs)
67+
require.NotEmpty(t, countWhereClause, "countConnectionLogs query should have a WHERE clause")
68+
69+
require.Equal(t, getWhereClause, countWhereClause, "getConnectionLogsOffset and countConnectionLogs queries should have the same WHERE clause")
70+
}
71+
72+
// extractWhereClause extracts the WHERE clause from a SQL query string
73+
func extractWhereClause(query string) string {
74+
// Find WHERE and get everything after it
75+
wherePattern := regexp.MustCompile(`(?is)WHERE\s+(.*)`)
76+
whereMatches := wherePattern.FindStringSubmatch(query)
77+
if len(whereMatches) < 2 {
78+
return ""
79+
}
80+
81+
whereClause := whereMatches[1]
82+
83+
// Remove ORDER BY, LIMIT, OFFSET clauses from the end
84+
whereClause = regexp.MustCompile(`(?is)\s+(ORDER BY|LIMIT|OFFSET).*$`).ReplaceAllString(whereClause, "")
85+
86+
// Remove SQL comments
87+
whereClause = regexp.MustCompile(`(?m)--.*$`).ReplaceAllString(whereClause, "")
88+
89+
return strings.TrimSpace(whereClause)
90+
}

coderd/database/querier.go

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy