From ff931dcc2155bab9066594554125fc688064a94a Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Tue, 27 Sep 2022 13:25:06 -0400 Subject: [PATCH 01/18] feat: Convert rego queries into SQL clauses --- coderd/rbac/query.go | 269 +++++++++++++++++++++++++++++ coderd/rbac/query_internal_test.go | 38 ++++ 2 files changed, 307 insertions(+) create mode 100644 coderd/rbac/query.go create mode 100644 coderd/rbac/query_internal_test.go diff --git a/coderd/rbac/query.go b/coderd/rbac/query.go new file mode 100644 index 0000000000000..1dab38ecea429 --- /dev/null +++ b/coderd/rbac/query.go @@ -0,0 +1,269 @@ +package rbac + +import ( + "context" + "fmt" + "strings" + + "github.com/open-policy-agent/opa/ast" + + "golang.org/x/xerrors" + + "github.com/open-policy-agent/opa/rego" +) + +// Example python: https://github.com/open-policy-agent/contrib/tree/main/data_filter_example +// + +func Compile(ctx context.Context, partialQueries *rego.PartialQueries) (Expression, error) { + if len(partialQueries.Support) > 0 { + return nil, xerrors.Errorf("cannot convert support rules, expect 0 found %d", len(partialQueries.Support)) + } + + result := make([]Expression, 0, len(partialQueries.Queries)) + var builder strings.Builder + for i := range partialQueries.Queries { + query, err := processQuery(partialQueries.Queries[i]) + if err != nil { + return nil, err + } + result = append(result, query) + if i != 0 { + builder.WriteString("\n") + } + builder.WriteString(partialQueries.Queries[i].String()) + } + return ExpOr{ + Base: Base{ + Rego: builder.String(), + }, + Expressions: result, + }, nil +} + +func processQuery(query ast.Body) (Expression, error) { + expressions := make([]Expression, 0, len(query)) + for _, astExpr := range query { + expr, err := processExpression(astExpr) + if err != nil { + return nil, err + } + expressions = append(expressions, expr) + } + + return ExpAnd{ + Base: Base{ + Rego: query.String(), + }, + Expressions: expressions, + }, nil +} + +func processExpression(expr *ast.Expr) (Expression, error) { + if !expr.IsCall() { + return nil, xerrors.Errorf("invalid expression: function calls not supported") + } + + op := expr.Operator().String() + base := Base{Rego: op} + switch op { + case "neq", "eq", "equal": + terms, err := processTerms(2, expr.Operands()) + if err != nil { + return nil, xerrors.Errorf("invalid '%s' expression: %w", op, err) + } + return &OpEqual{ + Base: base, + Terms: [2]Term{terms[0], terms[1]}, + Not: op == "neq", + }, nil + case "internal.member_2": + terms, err := processTerms(2, expr.Operands()) + if err != nil { + return nil, xerrors.Errorf("invalid '%s' expression: %w", op, err) + } + return &OpInternalMember2{ + Base: base, + Terms: [2]Term{terms[0], terms[1]}, + }, nil + + //case "eq", "equal": + default: + return nil, xerrors.Errorf("invalid expression: operator %s not supported", op) + } +} + +func processTerms(expected int, terms []*ast.Term) ([]Term, error) { + if len(terms) != expected { + return nil, xerrors.Errorf("too many arguments, expect %d found %d", expected, len(terms)) + } + + result := make([]Term, 0, len(terms)) + for _, term := range terms { + processed, err := processTerm(term) + if err != nil { + return nil, xerrors.Errorf("invalid term: %w", err) + } + result = append(result, processed) + } + + return result, nil +} + +func processTerm(term *ast.Term) (Term, error) { + base := Base{Rego: term.String()} + switch v := term.Value.(type) { + case ast.Ref: + // A ref is a set of terms. If the first term is a var, then the + // following terms are the path to the value. + if v0, ok := v[0].Value.(ast.Var); ok { + name := v0.String() + for _, p := range v[1:] { + name += "." + p.String() + } + return &TermVariable{ + Base: base, + Name: name, + }, nil + } else { + return nil, xerrors.Errorf("invalid term: ref must start with a var, started with %T", v[0]) + } + case ast.Var: + return &TermVariable{ + Name: v.String(), + Base: base, + }, nil + case ast.String: + return &TermString{ + Value: v.String(), + Base: base, + }, nil + case ast.Set: + return &TermSet{ + Value: v, + Base: base, + }, nil + default: + return nil, xerrors.Errorf("invalid term: %T not supported, %q", v, term.String()) + } +} + +type Base struct { + // Rego is the original rego string + Rego string +} + +func (b Base) RegoString() string { + return b.Rego +} + +// Expression comprises a set of terms, operators, and functions. All +// expressions return a boolean value. +// +// Eg: neq(input.object.org_owner, "") +type Expression interface { + RegoString() string + SQLString() string +} + +type ExpAnd struct { + Base + Expressions []Expression +} + +func (t ExpAnd) SQLString() string { + exprs := make([]string, 0, len(t.Expressions)) + for _, expr := range t.Expressions { + exprs = append(exprs, expr.SQLString()) + } + return strings.Join(exprs, " AND ") +} + +type ExpOr struct { + Base + Expressions []Expression +} + +func (t ExpOr) SQLString() string { + exprs := make([]string, 0, len(t.Expressions)) + for _, expr := range t.Expressions { + exprs = append(exprs, expr.SQLString()) + } + return strings.Join(exprs, " OR ") +} + +// Operator joins terms together to form an expression. +// Operators are also expressions. +// +// Eg: "=", "neq", "internal.member_2", etc. +type Operator interface { + RegoString() string + SQLString() string +} + +type OpEqual struct { + Base + Terms [2]Term + // For NotEqual + Not bool +} + +func (t OpEqual) SQLString() string { + op := "=" + if t.Not { + op = "!=" + } + return fmt.Sprintf("%s %s %s", t.Terms[0].SQLString(), op, t.Terms[1].SQLString()) +} + +type OpInternalMember2 struct { + Base + Terms [2]Term +} + +func (t OpInternalMember2) SQLString() string { + return fmt.Sprintf("%s = ANY(%s)", t.Terms[0].SQLString(), t.Terms[1].SQLString()) +} + +// Term is a single value in an expression. Terms can be variables or constants. +// +// Eg: "f9d6fb75-b59b-4363-ab6b-ae9d26b679d7", "input.object.org_owner", +// "{"f9d6fb75-b59b-4363-ab6b-ae9d26b679d7"}" +type Term interface { + SQLString() string + RegoString() string +} + +type TermString struct { + Base + Value string +} + +func (t TermString) SQLString() string { + return t.Value +} + +type TermVariable struct { + Base + Name string +} + +func (t TermVariable) SQLString() string { + return t.Name +} + +type TermSet struct { + Base + Value ast.Set +} + +func (t TermSet) SQLString() string { + values := t.Value.Slice() + elems := make([]string, 0, len(values)) + // TODO: Handle different typed terms? + for _, v := range t.Value.Slice() { + elems = append(elems, v.String()) + } + + return fmt.Sprintf("ARRAY [%s]", strings.Join(elems, ",")) +} diff --git a/coderd/rbac/query_internal_test.go b/coderd/rbac/query_internal_test.go new file mode 100644 index 0000000000000..c24a48d91ed17 --- /dev/null +++ b/coderd/rbac/query_internal_test.go @@ -0,0 +1,38 @@ +package rbac + +import ( + "context" + "fmt" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/google/uuid" +) + +func TestCompileQuery(t *testing.T) { + ctx := context.Background() + defOrg := uuid.New() + unuseID := uuid.New() + + user := subject{ + UserID: "me", + Scope: must(ScopeRole(ScopeAll)), + Roles: []Role{ + must(RoleByName(RoleMember())), + must(RoleByName(RoleOrgMember(defOrg))), + }, + } + var action Action = ActionRead + object := ResourceWorkspace.InOrg(defOrg).WithOwner(unuseID.String()) + + auth := NewAuthorizer() + part, err := auth.Prepare(ctx, user.UserID, user.Roles, user.Scope, action, object.Type) + require.NoError(t, err) + + result, err := Compile(ctx, part.partialQueries) + require.NoError(t, err) + + fmt.Println("Rego: ", result.RegoString()) + fmt.Println("SQL: ", result.SQLString()) +} From e535e5a2a6254b8515029453773b95eef9fa5156 Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Tue, 27 Sep 2022 13:37:52 -0400 Subject: [PATCH 02/18] Fix postgres quotes to single quotes --- coderd/rbac/query.go | 25 +++++++++++++++++-------- 1 file changed, 17 insertions(+), 8 deletions(-) diff --git a/coderd/rbac/query.go b/coderd/rbac/query.go index 1dab38ecea429..04f0290e0159e 100644 --- a/coderd/rbac/query.go +++ b/coderd/rbac/query.go @@ -117,9 +117,9 @@ func processTerm(term *ast.Term) (Term, error) { // A ref is a set of terms. If the first term is a var, then the // following terms are the path to the value. if v0, ok := v[0].Value.(ast.Var); ok { - name := v0.String() + name := trimQuotes(v0.String()) for _, p := range v[1:] { - name += "." + p.String() + name += "." + trimQuotes(p.String()) } return &TermVariable{ Base: base, @@ -130,12 +130,12 @@ func processTerm(term *ast.Term) (Term, error) { } case ast.Var: return &TermVariable{ - Name: v.String(), + Name: trimQuotes(v.String()), Base: base, }, nil case ast.String: return &TermString{ - Value: v.String(), + Value: trimQuotes(v.String()), Base: base, }, nil case ast.Set: @@ -176,7 +176,7 @@ func (t ExpAnd) SQLString() string { for _, expr := range t.Expressions { exprs = append(exprs, expr.SQLString()) } - return strings.Join(exprs, " AND ") + return "(" + strings.Join(exprs, " AND ") + ")" } type ExpOr struct { @@ -189,7 +189,8 @@ func (t ExpOr) SQLString() string { for _, expr := range t.Expressions { exprs = append(exprs, expr.SQLString()) } - return strings.Join(exprs, " OR ") + + return "(" + strings.Join(exprs, " OR ") + ")" } // Operator joins terms together to form an expression. @@ -240,7 +241,7 @@ type TermString struct { } func (t TermString) SQLString() string { - return t.Value + return "'" + t.Value + "'" } type TermVariable struct { @@ -262,8 +263,16 @@ func (t TermSet) SQLString() string { elems := make([]string, 0, len(values)) // TODO: Handle different typed terms? for _, v := range t.Value.Slice() { - elems = append(elems, v.String()) + t, err := processTerm(v) + if err != nil { + panic(err) + } + elems = append(elems, t.SQLString()) } return fmt.Sprintf("ARRAY [%s]", strings.Join(elems, ",")) } + +func trimQuotes(s string) string { + return strings.Trim(s, "\"") +} From 8f9295316c131fd1680b90226599aa9be06010aa Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Tue, 27 Sep 2022 13:52:17 -0400 Subject: [PATCH 03/18] Ensure all test cases can compile into SQL clauses --- coderd/rbac/authz_internal_test.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/coderd/rbac/authz_internal_test.go b/coderd/rbac/authz_internal_test.go index bf5342c3b27e4..67049e350f8db 100644 --- a/coderd/rbac/authz_internal_test.go +++ b/coderd/rbac/authz_internal_test.go @@ -781,6 +781,9 @@ func testAuthorize(t *testing.T, name string, subject subject, sets ...[]authTes partialAuthz, err := authorizer.Prepare(ctx, subject.UserID, subject.Roles, subject.Scope, a, c.resource.Type) require.NoError(t, err, "make prepared authorizer") + _, err = Compile(ctx, partialAuthz.partialQueries) + require.NoError(t, err, "compile prepared authorizer") + // Also check the rego policy can form a valid partial query result. // This ensures we can convert the queries into SQL WHERE clauses in the future. // If this function returns 'Support' sections, then we cannot convert the query into SQL. From cb5d5198bf02d569a2a69616adda99438dcd219f Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Tue, 27 Sep 2022 14:53:53 -0400 Subject: [PATCH 04/18] Do not export extra types --- coderd/rbac/query.go | 132 ++++++++++++++++++++++++++----------------- 1 file changed, 80 insertions(+), 52 deletions(-) diff --git a/coderd/rbac/query.go b/coderd/rbac/query.go index 04f0290e0159e..4c3388aa7e80a 100644 --- a/coderd/rbac/query.go +++ b/coderd/rbac/query.go @@ -1,25 +1,41 @@ package rbac import ( - "context" "fmt" + "strconv" "strings" "github.com/open-policy-agent/opa/ast" - - "golang.org/x/xerrors" - "github.com/open-policy-agent/opa/rego" + "golang.org/x/xerrors" ) -// Example python: https://github.com/open-policy-agent/contrib/tree/main/data_filter_example -// - -func Compile(ctx context.Context, partialQueries *rego.PartialQueries) (Expression, error) { +// Compile will convert a rego query AST into our custom types. The output is +// an AST that can be used to generate SQL. +func Compile(partialQueries *rego.PartialQueries) (Expression, error) { if len(partialQueries.Support) > 0 { return nil, xerrors.Errorf("cannot convert support rules, expect 0 found %d", len(partialQueries.Support)) } + // 0 queries means the result is "undefined". This is the same as "false". + if len(partialQueries.Queries) == 0 { + return &termBoolean{ + base: base{Rego: "false"}, + Value: false, + }, nil + } + + // Abort early if any of the "OR"'d expressions are the empty string. + // This is the same as "true". + for _, query := range partialQueries.Queries { + if query.String() == "" { + return &termBoolean{ + base: base{Rego: "true"}, + Value: true, + }, nil + } + } + result := make([]Expression, 0, len(partialQueries.Queries)) var builder strings.Builder for i := range partialQueries.Queries { @@ -33,14 +49,16 @@ func Compile(ctx context.Context, partialQueries *rego.PartialQueries) (Expressi } builder.WriteString(partialQueries.Queries[i].String()) } - return ExpOr{ - Base: Base{ + return expOr{ + base: base{ Rego: builder.String(), }, Expressions: result, }, nil } +// processQuery processes an entire set of expressions and joins them with +// "AND". func processQuery(query ast.Body) (Expression, error) { expressions := make([]Expression, 0, len(query)) for _, astExpr := range query { @@ -51,8 +69,8 @@ func processQuery(query ast.Body) (Expression, error) { expressions = append(expressions, expr) } - return ExpAnd{ - Base: Base{ + return expAnd{ + base: base{ Rego: query.String(), }, Expressions: expressions, @@ -65,15 +83,15 @@ func processExpression(expr *ast.Expr) (Expression, error) { } op := expr.Operator().String() - base := Base{Rego: op} + base := base{Rego: op} switch op { case "neq", "eq", "equal": terms, err := processTerms(2, expr.Operands()) if err != nil { return nil, xerrors.Errorf("invalid '%s' expression: %w", op, err) } - return &OpEqual{ - Base: base, + return &opEqual{ + base: base, Terms: [2]Term{terms[0], terms[1]}, Not: op == "neq", }, nil @@ -82,12 +100,10 @@ func processExpression(expr *ast.Expr) (Expression, error) { if err != nil { return nil, xerrors.Errorf("invalid '%s' expression: %w", op, err) } - return &OpInternalMember2{ - Base: base, + return &opInternalMember2{ + base: base, Terms: [2]Term{terms[0], terms[1]}, }, nil - - //case "eq", "equal": default: return nil, xerrors.Errorf("invalid expression: operator %s not supported", op) } @@ -111,7 +127,7 @@ func processTerms(expected int, terms []*ast.Term) ([]Term, error) { } func processTerm(term *ast.Term) (Term, error) { - base := Base{Rego: term.String()} + base := base{Rego: term.String()} switch v := term.Value.(type) { case ast.Ref: // A ref is a set of terms. If the first term is a var, then the @@ -121,57 +137,57 @@ func processTerm(term *ast.Term) (Term, error) { for _, p := range v[1:] { name += "." + trimQuotes(p.String()) } - return &TermVariable{ - Base: base, + return &termVariable{ + base: base, Name: name, }, nil } else { return nil, xerrors.Errorf("invalid term: ref must start with a var, started with %T", v[0]) } case ast.Var: - return &TermVariable{ + return &termVariable{ Name: trimQuotes(v.String()), - Base: base, + base: base, }, nil case ast.String: - return &TermString{ + return &termString{ Value: trimQuotes(v.String()), - Base: base, + base: base, }, nil case ast.Set: - return &TermSet{ + return &termSet{ Value: v, - Base: base, + base: base, }, nil default: return nil, xerrors.Errorf("invalid term: %T not supported, %q", v, term.String()) } } -type Base struct { +type base struct { // Rego is the original rego string Rego string } -func (b Base) RegoString() string { +func (b base) RegoString() string { return b.Rego } // Expression comprises a set of terms, operators, and functions. All // expressions return a boolean value. // -// Eg: neq(input.object.org_owner, "") +// Eg: neq(input.object.org_owner, "") AND input.object.org_owner == "foo" type Expression interface { RegoString() string SQLString() string } -type ExpAnd struct { - Base +type expAnd struct { + base Expressions []Expression } -func (t ExpAnd) SQLString() string { +func (t expAnd) SQLString() string { exprs := make([]string, 0, len(t.Expressions)) for _, expr := range t.Expressions { exprs = append(exprs, expr.SQLString()) @@ -179,12 +195,12 @@ func (t ExpAnd) SQLString() string { return "(" + strings.Join(exprs, " AND ") + ")" } -type ExpOr struct { - Base +type expOr struct { + base Expressions []Expression } -func (t ExpOr) SQLString() string { +func (t expOr) SQLString() string { exprs := make([]string, 0, len(t.Expressions)) for _, expr := range t.Expressions { exprs = append(exprs, expr.SQLString()) @@ -202,14 +218,14 @@ type Operator interface { SQLString() string } -type OpEqual struct { - Base +type opEqual struct { + base Terms [2]Term // For NotEqual Not bool } -func (t OpEqual) SQLString() string { +func (t opEqual) SQLString() string { op := "=" if t.Not { op = "!=" @@ -217,12 +233,14 @@ func (t OpEqual) SQLString() string { return fmt.Sprintf("%s %s %s", t.Terms[0].SQLString(), op, t.Terms[1].SQLString()) } -type OpInternalMember2 struct { - Base +// opInternalMember2 is checking if the first term is a member of the second term. +// The second term is a set or list. +type opInternalMember2 struct { + base Terms [2]Term } -func (t OpInternalMember2) SQLString() string { +func (t opInternalMember2) SQLString() string { return fmt.Sprintf("%s = ANY(%s)", t.Terms[0].SQLString(), t.Terms[1].SQLString()) } @@ -235,30 +253,31 @@ type Term interface { RegoString() string } -type TermString struct { - Base +type termString struct { + base Value string } -func (t TermString) SQLString() string { +func (t termString) SQLString() string { return "'" + t.Value + "'" } -type TermVariable struct { - Base +type termVariable struct { + base Name string } -func (t TermVariable) SQLString() string { +func (t termVariable) SQLString() string { return t.Name } -type TermSet struct { - Base +// termSet is a set of unique terms. +type termSet struct { + base Value ast.Set } -func (t TermSet) SQLString() string { +func (t termSet) SQLString() string { values := t.Value.Slice() elems := make([]string, 0, len(values)) // TODO: Handle different typed terms? @@ -273,6 +292,15 @@ func (t TermSet) SQLString() string { return fmt.Sprintf("ARRAY [%s]", strings.Join(elems, ",")) } +type termBoolean struct { + base + Value bool +} + +func (t termBoolean) SQLString() string { + return strconv.FormatBool(t.Value) +} + func trimQuotes(s string) string { return strings.Trim(s, "\"") } From 2bd01658a4c9706556bb5bb1a398cadb1deffdbe Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Tue, 27 Sep 2022 16:31:25 -0400 Subject: [PATCH 05/18] Add custom query with rbac filter --- coderd/database/custom_queries.go | 59 ++++++++++++++++++++ coderd/database/databasefake/databasefake.go | 10 ++++ coderd/database/db.go | 7 +++ coderd/rbac/authz_internal_test.go | 2 +- coderd/rbac/query.go | 49 +++++++++------- coderd/rbac/query_internal_test.go | 8 ++- 6 files changed, 113 insertions(+), 22 deletions(-) create mode 100644 coderd/database/custom_queries.go diff --git a/coderd/database/custom_queries.go b/coderd/database/custom_queries.go new file mode 100644 index 0000000000000..6dda9021e3317 --- /dev/null +++ b/coderd/database/custom_queries.go @@ -0,0 +1,59 @@ +package database + +import ( + "context" + "fmt" + + "github.com/coder/coder/coderd/rbac" + + "github.com/lib/pq" +) + +// AuthorizedGetWorkspaces returns all workspaces that the user is authorized to access. +func (q *sqlQuerier) AuthorizedGetWorkspaces(ctx context.Context, arg GetWorkspacesParams, authorizedFilter rbac.AuthorizeFilter) ([]Workspace, error) { + query := fmt.Sprintf("%s AND %s", getWorkspaces, authorizedFilter.SQLString(rbac.SQLConfig{ + VariableRenames: map[string]string{ + "input.object.org_owner": "organization_id", + "input.object.owner": "owner_id", + }, + })) + rows, err := q.db.QueryContext(ctx, query, + arg.Deleted, + arg.OwnerID, + arg.OwnerUsername, + arg.TemplateName, + pq.Array(arg.TemplateIds), + arg.Name, + ) + if err != nil { + return nil, err + } + defer rows.Close() + var items []Workspace + for rows.Next() { + var i Workspace + if err := rows.Scan( + &i.ID, + &i.CreatedAt, + &i.UpdatedAt, + &i.OwnerID, + &i.OrganizationID, + &i.TemplateID, + &i.Deleted, + &i.Name, + &i.AutostartSchedule, + &i.Ttl, + &i.LastUsedAt, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} diff --git a/coderd/database/databasefake/databasefake.go b/coderd/database/databasefake/databasefake.go index f0cdee99f254b..d7fa4cd248f4b 100644 --- a/coderd/database/databasefake/databasefake.go +++ b/coderd/database/databasefake/databasefake.go @@ -566,6 +566,16 @@ func (q *fakeQuerier) GetWorkspaces(_ context.Context, arg database.GetWorkspace return workspaces, nil } +func (q *fakeQuerier) AuthorizedGetWorkspaces(ctx context.Context, arg database.GetWorkspacesParams, authorizedFilter rbac.AuthorizeFilter) ([]database.Workspace, error) { + workspaces, err := q.GetWorkspaces(ctx, arg) + if err != nil { + return nil, err + } + + // TODO: Filter workspaces + return workspaces, nil +} + func (q *fakeQuerier) GetWorkspaceByID(_ context.Context, id uuid.UUID) (database.Workspace, error) { q.mutex.RLock() defer q.mutex.RUnlock() diff --git a/coderd/database/db.go b/coderd/database/db.go index 0a9e8928df253..80a5748de7263 100644 --- a/coderd/database/db.go +++ b/coderd/database/db.go @@ -13,6 +13,8 @@ import ( "database/sql" "errors" + "github.com/coder/coder/coderd/rbac" + "golang.org/x/xerrors" ) @@ -20,10 +22,15 @@ import ( // It extends the generated interface to add transaction support. type Store interface { querier + customQuerier InTx(func(Store) error) error } +type customQuerier interface { + AuthorizedGetWorkspaces(ctx context.Context, arg GetWorkspacesParams, authorizedFilter rbac.AuthorizeFilter) ([]Workspace, error) +} + // DBTX represents a database connection or transaction. type DBTX interface { ExecContext(context.Context, string, ...interface{}) (sql.Result, error) diff --git a/coderd/rbac/authz_internal_test.go b/coderd/rbac/authz_internal_test.go index 67049e350f8db..7a646de754ab5 100644 --- a/coderd/rbac/authz_internal_test.go +++ b/coderd/rbac/authz_internal_test.go @@ -781,7 +781,7 @@ func testAuthorize(t *testing.T, name string, subject subject, sets ...[]authTes partialAuthz, err := authorizer.Prepare(ctx, subject.UserID, subject.Roles, subject.Scope, a, c.resource.Type) require.NoError(t, err, "make prepared authorizer") - _, err = Compile(ctx, partialAuthz.partialQueries) + _, err = Compile(partialAuthz.partialQueries) require.NoError(t, err, "compile prepared authorizer") // Also check the rego policy can form a valid partial query result. diff --git a/coderd/rbac/query.go b/coderd/rbac/query.go index 4c3388aa7e80a..e715160e5587c 100644 --- a/coderd/rbac/query.go +++ b/coderd/rbac/query.go @@ -10,6 +10,16 @@ import ( "golang.org/x/xerrors" ) +type SQLConfig struct { + // VariableRenames renames rego variables to sql columns + VariableRenames map[string]string +} + +type AuthorizeFilter interface { + RegoString() string + SQLString(cfg SQLConfig) string +} + // Compile will convert a rego query AST into our custom types. The output is // an AST that can be used to generate SQL. func Compile(partialQueries *rego.PartialQueries) (Expression, error) { @@ -178,8 +188,7 @@ func (b base) RegoString() string { // // Eg: neq(input.object.org_owner, "") AND input.object.org_owner == "foo" type Expression interface { - RegoString() string - SQLString() string + AuthorizeFilter } type expAnd struct { @@ -187,10 +196,10 @@ type expAnd struct { Expressions []Expression } -func (t expAnd) SQLString() string { +func (t expAnd) SQLString(cfg SQLConfig) string { exprs := make([]string, 0, len(t.Expressions)) for _, expr := range t.Expressions { - exprs = append(exprs, expr.SQLString()) + exprs = append(exprs, expr.SQLString(cfg)) } return "(" + strings.Join(exprs, " AND ") + ")" } @@ -200,10 +209,10 @@ type expOr struct { Expressions []Expression } -func (t expOr) SQLString() string { +func (t expOr) SQLString(cfg SQLConfig) string { exprs := make([]string, 0, len(t.Expressions)) for _, expr := range t.Expressions { - exprs = append(exprs, expr.SQLString()) + exprs = append(exprs, expr.SQLString(cfg)) } return "(" + strings.Join(exprs, " OR ") + ")" @@ -214,8 +223,7 @@ func (t expOr) SQLString() string { // // Eg: "=", "neq", "internal.member_2", etc. type Operator interface { - RegoString() string - SQLString() string + AuthorizeFilter } type opEqual struct { @@ -225,12 +233,12 @@ type opEqual struct { Not bool } -func (t opEqual) SQLString() string { +func (t opEqual) SQLString(cfg SQLConfig) string { op := "=" if t.Not { op = "!=" } - return fmt.Sprintf("%s %s %s", t.Terms[0].SQLString(), op, t.Terms[1].SQLString()) + return fmt.Sprintf("%s %s %s", t.Terms[0].SQLString(cfg), op, t.Terms[1].SQLString(cfg)) } // opInternalMember2 is checking if the first term is a member of the second term. @@ -240,8 +248,8 @@ type opInternalMember2 struct { Terms [2]Term } -func (t opInternalMember2) SQLString() string { - return fmt.Sprintf("%s = ANY(%s)", t.Terms[0].SQLString(), t.Terms[1].SQLString()) +func (t opInternalMember2) SQLString(cfg SQLConfig) string { + return fmt.Sprintf("%s = ANY(%s)", t.Terms[0].SQLString(cfg), t.Terms[1].SQLString(cfg)) } // Term is a single value in an expression. Terms can be variables or constants. @@ -249,8 +257,7 @@ func (t opInternalMember2) SQLString() string { // Eg: "f9d6fb75-b59b-4363-ab6b-ae9d26b679d7", "input.object.org_owner", // "{"f9d6fb75-b59b-4363-ab6b-ae9d26b679d7"}" type Term interface { - SQLString() string - RegoString() string + AuthorizeFilter } type termString struct { @@ -258,7 +265,7 @@ type termString struct { Value string } -func (t termString) SQLString() string { +func (t termString) SQLString(_ SQLConfig) string { return "'" + t.Value + "'" } @@ -267,7 +274,11 @@ type termVariable struct { Name string } -func (t termVariable) SQLString() string { +func (t termVariable) SQLString(cfg SQLConfig) string { + rename, ok := cfg.VariableRenames[t.Name] + if ok { + return rename + } return t.Name } @@ -277,7 +288,7 @@ type termSet struct { Value ast.Set } -func (t termSet) SQLString() string { +func (t termSet) SQLString(cfg SQLConfig) string { values := t.Value.Slice() elems := make([]string, 0, len(values)) // TODO: Handle different typed terms? @@ -286,7 +297,7 @@ func (t termSet) SQLString() string { if err != nil { panic(err) } - elems = append(elems, t.SQLString()) + elems = append(elems, t.SQLString(cfg)) } return fmt.Sprintf("ARRAY [%s]", strings.Join(elems, ",")) @@ -297,7 +308,7 @@ type termBoolean struct { Value bool } -func (t termBoolean) SQLString() string { +func (t termBoolean) SQLString(_ SQLConfig) string { return strconv.FormatBool(t.Value) } diff --git a/coderd/rbac/query_internal_test.go b/coderd/rbac/query_internal_test.go index c24a48d91ed17..67964e754f708 100644 --- a/coderd/rbac/query_internal_test.go +++ b/coderd/rbac/query_internal_test.go @@ -30,9 +30,13 @@ func TestCompileQuery(t *testing.T) { part, err := auth.Prepare(ctx, user.UserID, user.Roles, user.Scope, action, object.Type) require.NoError(t, err) - result, err := Compile(ctx, part.partialQueries) + result, err := Compile(part.partialQueries) require.NoError(t, err) fmt.Println("Rego: ", result.RegoString()) - fmt.Println("SQL: ", result.SQLString()) + fmt.Println("SQL: ", result.SQLString(SQLConfig{ + map[string]string{ + "input.object.org_owner": "organization_id", + }, + })) } From 364498c4cc8d5ffa4ec327e479cd6ff2b0cc07ed Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Tue, 27 Sep 2022 21:07:59 -0400 Subject: [PATCH 06/18] First draft of a custom authorized db call --- coderd/authorize.go | 15 +++ coderd/coderdtest/authorize.go | 18 ++++ coderd/database/custom_queries.go | 8 +- coderd/database/databasefake/databasefake.go | 21 ++-- coderd/rbac/authz.go | 15 +++ coderd/rbac/partial.go | 8 ++ coderd/rbac/query.go | 107 ++++++++++++++++--- coderd/workspaces.go | 7 +- 8 files changed, 168 insertions(+), 31 deletions(-) diff --git a/coderd/authorize.go b/coderd/authorize.go index 0a6953cb1231e..7ed0e404612d1 100644 --- a/coderd/authorize.go +++ b/coderd/authorize.go @@ -85,6 +85,21 @@ func (h *HTTPAuthorizer) Authorize(r *http.Request, action rbac.Action, object r return true } +func (a *HTTPAuthorizer) AuthorizeSQLFilter(r *http.Request, action rbac.Action, objectType string) (rbac.AuthorizeFilter, error) { + roles := httpmw.UserAuthorization(r) + prepared, err := a.Authorizer.PrepareByRoleName(r.Context(), roles.ID.String(), roles.Roles, roles.Scope.ToRBAC(), action, objectType) + if err != nil { + return nil, xerrors.Errorf("prepare filter: %w", err) + } + + filter, err := prepared.Compile() + if err != nil { + return nil, xerrors.Errorf("compile filter: %w", err) + } + + return filter, nil +} + // checkAuthorization returns if the current API key can use the given // permissions, factoring in the current user's roles and the API key scopes. func (api *API) checkAuthorization(rw http.ResponseWriter, r *http.Request) { diff --git a/coderd/coderdtest/authorize.go b/coderd/coderdtest/authorize.go index 60fca71cd0062..9ee6313e66efc 100644 --- a/coderd/coderdtest/authorize.go +++ b/coderd/coderdtest/authorize.go @@ -553,3 +553,21 @@ type fakePreparedAuthorizer struct { func (f *fakePreparedAuthorizer) Authorize(ctx context.Context, object rbac.Object) error { return f.Original.ByRoleName(ctx, f.SubjectID, f.Roles, f.Scope, f.Action, object) } + +// Compile returns a compiled version of the authorizer that will work for +// in memory databases. This fake version will not work against a SQL database. +func (f *fakePreparedAuthorizer) Compile() (rbac.AuthorizeFilter, error) { + return f, nil +} + +func (f *fakePreparedAuthorizer) Eval(object rbac.Object) bool { + return f.Original.ByRoleName(context.Background(), f.SubjectID, f.Roles, f.Scope, f.Action, object) == nil +} + +func (f *fakePreparedAuthorizer) RegoString() string { + panic("not implemented") +} + +func (f *fakePreparedAuthorizer) SQLString(_ rbac.SQLConfig) string { + panic("not implemented") +} diff --git a/coderd/database/custom_queries.go b/coderd/database/custom_queries.go index 6dda9021e3317..756f98bcae062 100644 --- a/coderd/database/custom_queries.go +++ b/coderd/database/custom_queries.go @@ -4,6 +4,8 @@ import ( "context" "fmt" + "golang.org/x/xerrors" + "github.com/coder/coder/coderd/rbac" "github.com/lib/pq" @@ -13,8 +15,8 @@ import ( func (q *sqlQuerier) AuthorizedGetWorkspaces(ctx context.Context, arg GetWorkspacesParams, authorizedFilter rbac.AuthorizeFilter) ([]Workspace, error) { query := fmt.Sprintf("%s AND %s", getWorkspaces, authorizedFilter.SQLString(rbac.SQLConfig{ VariableRenames: map[string]string{ - "input.object.org_owner": "organization_id", - "input.object.owner": "owner_id", + "input.object.org_owner": "organization_id::text", + "input.object.owner": "owner_id::text", }, })) rows, err := q.db.QueryContext(ctx, query, @@ -26,7 +28,7 @@ func (q *sqlQuerier) AuthorizedGetWorkspaces(ctx context.Context, arg GetWorkspa arg.Name, ) if err != nil { - return nil, err + return nil, xerrors.Errorf("get authorized workspaces: %w", err) } defer rows.Close() var items []Workspace diff --git a/coderd/database/databasefake/databasefake.go b/coderd/database/databasefake/databasefake.go index d7fa4cd248f4b..1dcecf080dca5 100644 --- a/coderd/database/databasefake/databasefake.go +++ b/coderd/database/databasefake/databasefake.go @@ -520,7 +520,12 @@ func (q *fakeQuerier) GetAuthorizationUserRoles(_ context.Context, userID uuid.U }, nil } -func (q *fakeQuerier) GetWorkspaces(_ context.Context, arg database.GetWorkspacesParams) ([]database.Workspace, error) { +func (q *fakeQuerier) GetWorkspaces(ctx context.Context, arg database.GetWorkspacesParams) ([]database.Workspace, error) { + workspaces, err := q.AuthorizedGetWorkspaces(ctx, arg, nil) + return workspaces, err +} + +func (q *fakeQuerier) AuthorizedGetWorkspaces(ctx context.Context, arg database.GetWorkspacesParams, authorizedFilter rbac.AuthorizeFilter) ([]database.Workspace, error) { q.mutex.RLock() defer q.mutex.RUnlock() @@ -560,19 +565,13 @@ func (q *fakeQuerier) GetWorkspaces(_ context.Context, arg database.GetWorkspace continue } } - workspaces = append(workspaces, workspace) - } - - return workspaces, nil -} -func (q *fakeQuerier) AuthorizedGetWorkspaces(ctx context.Context, arg database.GetWorkspacesParams, authorizedFilter rbac.AuthorizeFilter) ([]database.Workspace, error) { - workspaces, err := q.GetWorkspaces(ctx, arg) - if err != nil { - return nil, err + if authorizedFilter != nil && !authorizedFilter.Eval(workspace.RBACObject()) { + continue + } + workspaces = append(workspaces, workspace) } - // TODO: Filter workspaces return workspaces, nil } diff --git a/coderd/rbac/authz.go b/coderd/rbac/authz.go index 3bc37d056400a..526e028c7359d 100644 --- a/coderd/rbac/authz.go +++ b/coderd/rbac/authz.go @@ -21,6 +21,21 @@ type Authorizer interface { type PreparedAuthorized interface { Authorize(ctx context.Context, object Object) error + Compile() (AuthorizeFilter, error) +} + +func (a *RegoAuthorizer) SQLFilter(ctx context.Context, subjID string, subjRoles []string, scope Scope, action Action, objectType string) (AuthorizeFilter, error) { + prepared, err := a.PrepareByRoleName(ctx, subjID, subjRoles, scope, action, objectType) + if err != nil { + return nil, xerrors.Errorf("filter: %w", err) + } + + filter, err := prepared.Compile() + if err != nil { + return nil, xerrors.Errorf("filter: %w", err) + } + + return filter, nil } // Filter takes in a list of objects, and will filter the list removing all diff --git a/coderd/rbac/partial.go b/coderd/rbac/partial.go index 7fea09d374e17..6dfd0827f39f4 100644 --- a/coderd/rbac/partial.go +++ b/coderd/rbac/partial.go @@ -28,6 +28,14 @@ type PartialAuthorizer struct { var _ PreparedAuthorized = (*PartialAuthorizer)(nil) +func (pa *PartialAuthorizer) Compile() (AuthorizeFilter, error) { + filter, err := Compile(pa.partialQueries) + if err != nil { + return nil, xerrors.Errorf("compile: %w", err) + } + return filter, nil +} + func (pa *PartialAuthorizer) Authorize(ctx context.Context, object Object) error { if pa.alwaysTrue { return nil diff --git a/coderd/rbac/query.go b/coderd/rbac/query.go index e715160e5587c..9d6dccf109a8a 100644 --- a/coderd/rbac/query.go +++ b/coderd/rbac/query.go @@ -18,6 +18,9 @@ type SQLConfig struct { type AuthorizeFilter interface { RegoString() string SQLString(cfg SQLConfig) string + // Eval is required for the fake in memory database to work. The in memory + // database can use this function to filter the results. + Eval(object Object) bool } // Compile will convert a rego query AST into our custom types. The output is @@ -165,8 +168,18 @@ func processTerm(term *ast.Term) (Term, error) { base: base, }, nil case ast.Set: + slice := v.Slice() + set := make([]Term, 0, len(slice)) + for _, elem := range slice { + processed, err := processTerm(elem) + if err != nil { + return nil, xerrors.Errorf("invalid set term %s: %w", elem.String(), err) + } + set = append(set, processed) + } + return &termSet{ - Value: v, + Value: set, base: base, }, nil default: @@ -204,6 +217,15 @@ func (t expAnd) SQLString(cfg SQLConfig) string { return "(" + strings.Join(exprs, " AND ") + ")" } +func (t expAnd) Eval(object Object) bool { + for _, expr := range t.Expressions { + if !expr.Eval(object) { + return false + } + } + return true +} + type expOr struct { base Expressions []Expression @@ -218,12 +240,21 @@ func (t expOr) SQLString(cfg SQLConfig) string { return "(" + strings.Join(exprs, " OR ") + ")" } +func (t expOr) Eval(object Object) bool { + for _, expr := range t.Expressions { + if expr.Eval(object) { + return true + } + } + return false +} + // Operator joins terms together to form an expression. // Operators are also expressions. // // Eg: "=", "neq", "internal.member_2", etc. type Operator interface { - AuthorizeFilter + Expression } type opEqual struct { @@ -241,6 +272,14 @@ func (t opEqual) SQLString(cfg SQLConfig) string { return fmt.Sprintf("%s %s %s", t.Terms[0].SQLString(cfg), op, t.Terms[1].SQLString(cfg)) } +func (t opEqual) Eval(object Object) bool { + a, b := t.Terms[0].Eval(object), t.Terms[1].Eval(object) + if t.Not { + return a != b + } + return a == b +} + // opInternalMember2 is checking if the first term is a member of the second term. // The second term is a set or list. type opInternalMember2 struct { @@ -248,6 +287,20 @@ type opInternalMember2 struct { Terms [2]Term } +func (t opInternalMember2) Eval(object Object) bool { + a, b := t.Terms[0].Eval(object), t.Terms[1].Eval(object) + bset, ok := b.([]interface{}) + if !ok { + return false + } + for _, elem := range bset { + if a == elem { + return true + } + } + return false +} + func (t opInternalMember2) SQLString(cfg SQLConfig) string { return fmt.Sprintf("%s = ANY(%s)", t.Terms[0].SQLString(cfg), t.Terms[1].SQLString(cfg)) } @@ -257,7 +310,11 @@ func (t opInternalMember2) SQLString(cfg SQLConfig) string { // Eg: "f9d6fb75-b59b-4363-ab6b-ae9d26b679d7", "input.object.org_owner", // "{"f9d6fb75-b59b-4363-ab6b-ae9d26b679d7"}" type Term interface { - AuthorizeFilter + RegoString() string + SQLString(cfg SQLConfig) string + // Eval will evaluate the term + // Terms can eval to any type. The operator/expression will type check. + Eval(object Object) interface{} } type termString struct { @@ -265,6 +322,10 @@ type termString struct { Value string } +func (t termString) Eval(_ Object) interface{} { + return t.Value +} + func (t termString) SQLString(_ SQLConfig) string { return "'" + t.Value + "'" } @@ -274,6 +335,19 @@ type termVariable struct { Name string } +func (t termVariable) Eval(obj Object) interface{} { + switch t.Name { + case "input.object.org_owner": + return obj.OrgID + case "input.object.owner": + return obj.Owner + case "input.object.type": + return obj.Type + default: + return fmt.Sprintf("'Unknown variable %s'", t.Name) + } +} + func (t termVariable) SQLString(cfg SQLConfig) string { rename, ok := cfg.VariableRenames[t.Name] if ok { @@ -285,19 +359,22 @@ func (t termVariable) SQLString(cfg SQLConfig) string { // termSet is a set of unique terms. type termSet struct { base - Value ast.Set + Value []Term +} + +func (t termSet) Eval(obj Object) interface{} { + set := make([]interface{}, 0, len(t.Value)) + for _, term := range t.Value { + set = append(set, term.Eval(obj)) + } + + return set } func (t termSet) SQLString(cfg SQLConfig) string { - values := t.Value.Slice() - elems := make([]string, 0, len(values)) - // TODO: Handle different typed terms? - for _, v := range t.Value.Slice() { - t, err := processTerm(v) - if err != nil { - panic(err) - } - elems = append(elems, t.SQLString(cfg)) + elems := make([]string, 0, len(t.Value)) + for _, v := range t.Value { + elems = append(elems, v.SQLString(cfg)) } return fmt.Sprintf("ARRAY [%s]", strings.Join(elems, ",")) @@ -308,6 +385,10 @@ type termBoolean struct { Value bool } +func (t termBoolean) Eval(_ Object) bool { + return t.Value +} + func (t termBoolean) SQLString(_ SQLConfig) string { return strconv.FormatBool(t.Value) } diff --git a/coderd/workspaces.go b/coderd/workspaces.go index c809242d8fed5..f2b92fea54fa3 100644 --- a/coderd/workspaces.go +++ b/coderd/workspaces.go @@ -111,17 +111,16 @@ func (api *API) workspaces(rw http.ResponseWriter, r *http.Request) { filter.OwnerUsername = "" } - workspaces, err := api.Database.GetWorkspaces(ctx, filter) + sqlFilter, err := api.HTTPAuth.AuthorizeSQLFilter(r, rbac.ActionRead, rbac.ResourceWorkspace.Type) if err != nil { httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ - Message: "Internal error fetching workspaces.", + Message: "Internal error preparing sql filter.", Detail: err.Error(), }) return } - // Only return workspaces the user can read - workspaces, err = AuthorizeFilter(api.HTTPAuth, r, rbac.ActionRead, workspaces) + workspaces, err := api.Database.AuthorizedGetWorkspaces(ctx, filter, sqlFilter) if err != nil { httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ Message: "Internal error fetching workspaces.", From d516be7840b5c29b41786ce8b580000dbc10f089 Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Wed, 28 Sep 2022 15:00:10 -0400 Subject: [PATCH 07/18] Add comments + tests --- coderd/authorize.go | 8 +++ coderd/database/custom_queries.go | 9 ++- coderd/database/databasefake/databasefake.go | 2 + coderd/database/db.go | 7 +-- coderd/rbac/authz.go | 14 ----- coderd/rbac/authz_internal_test.go | 2 + coderd/rbac/query.go | 50 +++++++++++++--- coderd/rbac/query_internal_test.go | 60 +++++++++----------- 8 files changed, 89 insertions(+), 63 deletions(-) diff --git a/coderd/authorize.go b/coderd/authorize.go index 7ed0e404612d1..166cae76e841c 100644 --- a/coderd/authorize.go +++ b/coderd/authorize.go @@ -13,6 +13,9 @@ import ( "github.com/coder/coder/codersdk" ) +// AuthorizeFilter takes a list of objects and returns the filtered list of +// objects that the user is authorized to perform the given action on. +// This is faster than calling Authorize() on each object. func AuthorizeFilter[O rbac.Objecter](h *HTTPAuthorizer, r *http.Request, action rbac.Action, objects []O) ([]O, error) { roles := httpmw.UserAuthorization(r) objects, err := rbac.Filter(r.Context(), h.Authorizer, roles.ID.String(), roles.Roles, roles.Scope.ToRBAC(), action, objects) @@ -85,6 +88,11 @@ func (h *HTTPAuthorizer) Authorize(r *http.Request, action rbac.Action, object r return true } +// AuthorizeSQLFilter returns an authorization filter that can used in a +// SQL 'WHERE' clause. If the filter is used, the resulting rows returned +// from postgres are already authorized, and the caller does not need to +// call 'Authorize()' on the returned objects. +// Note the authorization is only for the given action and object type. func (a *HTTPAuthorizer) AuthorizeSQLFilter(r *http.Request, action rbac.Action, objectType string) (rbac.AuthorizeFilter, error) { roles := httpmw.UserAuthorization(r) prepared, err := a.Authorizer.PrepareByRoleName(r.Context(), roles.ID.String(), roles.Roles, roles.Scope.ToRBAC(), action, objectType) diff --git a/coderd/database/custom_queries.go b/coderd/database/custom_queries.go index 756f98bcae062..956fbeb031a7f 100644 --- a/coderd/database/custom_queries.go +++ b/coderd/database/custom_queries.go @@ -4,14 +4,19 @@ import ( "context" "fmt" + "github.com/lib/pq" "golang.org/x/xerrors" "github.com/coder/coder/coderd/rbac" - - "github.com/lib/pq" ) +type customQuerier interface { + AuthorizedGetWorkspaces(ctx context.Context, arg GetWorkspacesParams, authorizedFilter rbac.AuthorizeFilter) ([]Workspace, error) +} + // AuthorizedGetWorkspaces returns all workspaces that the user is authorized to access. +// This code is copied from `GetWorkspaces` and adds the authorized filter WHERE +// clause. func (q *sqlQuerier) AuthorizedGetWorkspaces(ctx context.Context, arg GetWorkspacesParams, authorizedFilter rbac.AuthorizeFilter) ([]Workspace, error) { query := fmt.Sprintf("%s AND %s", getWorkspaces, authorizedFilter.SQLString(rbac.SQLConfig{ VariableRenames: map[string]string{ diff --git a/coderd/database/databasefake/databasefake.go b/coderd/database/databasefake/databasefake.go index d40194812f94d..973da51bd0a01 100644 --- a/coderd/database/databasefake/databasefake.go +++ b/coderd/database/databasefake/databasefake.go @@ -521,6 +521,7 @@ func (q *fakeQuerier) GetAuthorizationUserRoles(_ context.Context, userID uuid.U } func (q *fakeQuerier) GetWorkspaces(ctx context.Context, arg database.GetWorkspacesParams) ([]database.Workspace, error) { + // A nil auth filter means no auth filter. workspaces, err := q.AuthorizedGetWorkspaces(ctx, arg, nil) return workspaces, err } @@ -566,6 +567,7 @@ func (q *fakeQuerier) AuthorizedGetWorkspaces(ctx context.Context, arg database. } } + // If the filter exists, ensure the object is authorized. if authorizedFilter != nil && !authorizedFilter.Eval(workspace.RBACObject()) { continue } diff --git a/coderd/database/db.go b/coderd/database/db.go index 80a5748de7263..9997a88f1e148 100644 --- a/coderd/database/db.go +++ b/coderd/database/db.go @@ -13,8 +13,6 @@ import ( "database/sql" "errors" - "github.com/coder/coder/coderd/rbac" - "golang.org/x/xerrors" ) @@ -22,15 +20,12 @@ import ( // It extends the generated interface to add transaction support. type Store interface { querier + // customQuerier contains custom queries that are not generated. customQuerier InTx(func(Store) error) error } -type customQuerier interface { - AuthorizedGetWorkspaces(ctx context.Context, arg GetWorkspacesParams, authorizedFilter rbac.AuthorizeFilter) ([]Workspace, error) -} - // DBTX represents a database connection or transaction. type DBTX interface { ExecContext(context.Context, string, ...interface{}) (sql.Result, error) diff --git a/coderd/rbac/authz.go b/coderd/rbac/authz.go index 526e028c7359d..237fb7f365188 100644 --- a/coderd/rbac/authz.go +++ b/coderd/rbac/authz.go @@ -24,20 +24,6 @@ type PreparedAuthorized interface { Compile() (AuthorizeFilter, error) } -func (a *RegoAuthorizer) SQLFilter(ctx context.Context, subjID string, subjRoles []string, scope Scope, action Action, objectType string) (AuthorizeFilter, error) { - prepared, err := a.PrepareByRoleName(ctx, subjID, subjRoles, scope, action, objectType) - if err != nil { - return nil, xerrors.Errorf("filter: %w", err) - } - - filter, err := prepared.Compile() - if err != nil { - return nil, xerrors.Errorf("filter: %w", err) - } - - return filter, nil -} - // Filter takes in a list of objects, and will filter the list removing all // the elements the subject does not have permission for. All objects must be // of the same type. diff --git a/coderd/rbac/authz_internal_test.go b/coderd/rbac/authz_internal_test.go index 7a646de754ab5..e7b3b8522bcd1 100644 --- a/coderd/rbac/authz_internal_test.go +++ b/coderd/rbac/authz_internal_test.go @@ -781,6 +781,8 @@ func testAuthorize(t *testing.T, name string, subject subject, sets ...[]authTes partialAuthz, err := authorizer.Prepare(ctx, subject.UserID, subject.Roles, subject.Scope, a, c.resource.Type) require.NoError(t, err, "make prepared authorizer") + // Ensure the partial can compile to a SQL clause. + // This does not guarantee that the clause is valid SQL. _, err = Compile(partialAuthz.partialQueries) require.NoError(t, err, "compile prepared authorizer") diff --git a/coderd/rbac/query.go b/coderd/rbac/query.go index 9d6dccf109a8a..8cc3e10d05aab 100644 --- a/coderd/rbac/query.go +++ b/coderd/rbac/query.go @@ -12,11 +12,16 @@ import ( type SQLConfig struct { // VariableRenames renames rego variables to sql columns + // Example: + // "input.object.org_owner": "organization_id::text" + // "input.object.owner": "owner_id::text" VariableRenames map[string]string } type AuthorizeFilter interface { + // RegoString is used in debugging to see the original rego expression. RegoString() string + // SQLString returns the SQL expression that can be used in a WHERE clause. SQLString(cfg SQLConfig) string // Eval is required for the fake in memory database to work. The in memory // database can use this function to filter the results. @@ -92,7 +97,18 @@ func processQuery(query ast.Body) (Expression, error) { func processExpression(expr *ast.Expr) (Expression, error) { if !expr.IsCall() { - return nil, xerrors.Errorf("invalid expression: function calls not supported") + // This could be a single term that is a valid expression. + if term, ok := expr.Terms.(*ast.Term); ok { + value, err := processTerm(term) + if err != nil { + return nil, xerrors.Errorf("single term expression: %w", err) + } + if boolExp, ok := value.(Expression); ok { + return boolExp, nil + } + // Default to error. + } + return nil, xerrors.Errorf("invalid expression: single non-boolean terms not supported") } op := expr.Operator().String() @@ -142,6 +158,11 @@ func processTerms(expected int, terms []*ast.Term) ([]Term, error) { func processTerm(term *ast.Term) (Term, error) { base := base{Rego: term.String()} switch v := term.Value.(type) { + case ast.Boolean: + return &termBoolean{ + base: base, + Value: bool(v), + }, nil case ast.Ref: // A ref is a set of terms. If the first term is a var, then the // following terms are the path to the value. @@ -210,6 +231,10 @@ type expAnd struct { } func (t expAnd) SQLString(cfg SQLConfig) string { + if len(t.Expressions) == 1 { + return t.Expressions[0].SQLString(cfg) + } + exprs := make([]string, 0, len(t.Expressions)) for _, expr := range t.Expressions { exprs = append(exprs, expr.SQLString(cfg)) @@ -232,11 +257,14 @@ type expOr struct { } func (t expOr) SQLString(cfg SQLConfig) string { + if len(t.Expressions) == 1 { + return t.Expressions[0].SQLString(cfg) + } + exprs := make([]string, 0, len(t.Expressions)) for _, expr := range t.Expressions { exprs = append(exprs, expr.SQLString(cfg)) } - return "(" + strings.Join(exprs, " OR ") + ")" } @@ -273,7 +301,7 @@ func (t opEqual) SQLString(cfg SQLConfig) string { } func (t opEqual) Eval(object Object) bool { - a, b := t.Terms[0].Eval(object), t.Terms[1].Eval(object) + a, b := t.Terms[0].EvalTerm(object), t.Terms[1].EvalTerm(object) if t.Not { return a != b } @@ -288,7 +316,7 @@ type opInternalMember2 struct { } func (t opInternalMember2) Eval(object Object) bool { - a, b := t.Terms[0].Eval(object), t.Terms[1].Eval(object) + a, b := t.Terms[0].EvalTerm(object), t.Terms[1].EvalTerm(object) bset, ok := b.([]interface{}) if !ok { return false @@ -314,7 +342,7 @@ type Term interface { SQLString(cfg SQLConfig) string // Eval will evaluate the term // Terms can eval to any type. The operator/expression will type check. - Eval(object Object) interface{} + EvalTerm(object Object) interface{} } type termString struct { @@ -322,7 +350,7 @@ type termString struct { Value string } -func (t termString) Eval(_ Object) interface{} { +func (t termString) EvalTerm(_ Object) interface{} { return t.Value } @@ -335,7 +363,7 @@ type termVariable struct { Name string } -func (t termVariable) Eval(obj Object) interface{} { +func (t termVariable) EvalTerm(obj Object) interface{} { switch t.Name { case "input.object.org_owner": return obj.OrgID @@ -362,10 +390,10 @@ type termSet struct { Value []Term } -func (t termSet) Eval(obj Object) interface{} { +func (t termSet) EvalTerm(obj Object) interface{} { set := make([]interface{}, 0, len(t.Value)) for _, term := range t.Value { - set = append(set, term.Eval(obj)) + set = append(set, term.EvalTerm(obj)) } return set @@ -389,6 +417,10 @@ func (t termBoolean) Eval(_ Object) bool { return t.Value } +func (t termBoolean) EvalTerm(_ Object) interface{} { + return t.Value +} + func (t termBoolean) SQLString(_ SQLConfig) string { return strconv.FormatBool(t.Value) } diff --git a/coderd/rbac/query_internal_test.go b/coderd/rbac/query_internal_test.go index 67964e754f708..63e5f0c578856 100644 --- a/coderd/rbac/query_internal_test.go +++ b/coderd/rbac/query_internal_test.go @@ -1,42 +1,38 @@ package rbac import ( - "context" - "fmt" "testing" - "github.com/stretchr/testify/require" + "github.com/open-policy-agent/opa/ast" + "github.com/open-policy-agent/opa/rego" - "github.com/google/uuid" + "github.com/stretchr/testify/require" ) func TestCompileQuery(t *testing.T) { - ctx := context.Background() - defOrg := uuid.New() - unuseID := uuid.New() - - user := subject{ - UserID: "me", - Scope: must(ScopeRole(ScopeAll)), - Roles: []Role{ - must(RoleByName(RoleMember())), - must(RoleByName(RoleOrgMember(defOrg))), - }, - } - var action Action = ActionRead - object := ResourceWorkspace.InOrg(defOrg).WithOwner(unuseID.String()) - - auth := NewAuthorizer() - part, err := auth.Prepare(ctx, user.UserID, user.Roles, user.Scope, action, object.Type) - require.NoError(t, err) - - result, err := Compile(part.partialQueries) - require.NoError(t, err) - - fmt.Println("Rego: ", result.RegoString()) - fmt.Println("SQL: ", result.SQLString(SQLConfig{ - map[string]string{ - "input.object.org_owner": "organization_id", - }, - })) + t.Run("EmptyQuery", func(t *testing.T) { + expression, err := Compile(®o.PartialQueries{ + Queries: []ast.Body{ + must(ast.ParseBody("")), + }, + Support: []*ast.Module{}, + }) + require.NoError(t, err, "compile empty") + + require.Equal(t, "true", expression.RegoString(), "empty query is rego 'true'") + require.Equal(t, "true", expression.SQLString(SQLConfig{}), "empty query is sql 'true'") + }) + + t.Run("TrueQuery", func(t *testing.T) { + expression, err := Compile(®o.PartialQueries{ + Queries: []ast.Body{ + must(ast.ParseBody("true")), + }, + Support: []*ast.Module{}, + }) + require.NoError(t, err, "compile empty") + + require.Equal(t, "true", expression.RegoString(), "true query is rego 'true'") + require.Equal(t, "true", expression.SQLString(SQLConfig{}), "true query is sql 'true'") + }) } From 125931a4fbde70c227314d9626df25c36cfc0f5f Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Wed, 28 Sep 2022 17:15:13 -0400 Subject: [PATCH 08/18] Support better regex style matching for variables --- coderd/database/custom_queries.go | 7 +-- coderd/rbac/query.go | 79 ++++++++++++++++++++++++++++--- 2 files changed, 73 insertions(+), 13 deletions(-) diff --git a/coderd/database/custom_queries.go b/coderd/database/custom_queries.go index 956fbeb031a7f..bbcfc3e6ba813 100644 --- a/coderd/database/custom_queries.go +++ b/coderd/database/custom_queries.go @@ -18,12 +18,7 @@ type customQuerier interface { // This code is copied from `GetWorkspaces` and adds the authorized filter WHERE // clause. func (q *sqlQuerier) AuthorizedGetWorkspaces(ctx context.Context, arg GetWorkspacesParams, authorizedFilter rbac.AuthorizeFilter) ([]Workspace, error) { - query := fmt.Sprintf("%s AND %s", getWorkspaces, authorizedFilter.SQLString(rbac.SQLConfig{ - VariableRenames: map[string]string{ - "input.object.org_owner": "organization_id::text", - "input.object.owner": "owner_id::text", - }, - })) + query := fmt.Sprintf("%s AND %s", getWorkspaces, authorizedFilter.SQLString(rbac.DefaultConfig())) rows, err := q.db.QueryContext(ctx, query, arg.Deleted, arg.OwnerID, diff --git a/coderd/rbac/query.go b/coderd/rbac/query.go index 8cc3e10d05aab..3851fac76eec0 100644 --- a/coderd/rbac/query.go +++ b/coderd/rbac/query.go @@ -2,6 +2,7 @@ package rbac import ( "fmt" + "regexp" "strconv" "strings" @@ -10,12 +11,66 @@ import ( "golang.org/x/xerrors" ) +const ( + VarTypeJsonbArray = "jsonb-array" + VarTypeUUID = "uuid" + VarTypeText = "text" +) + +type SQLColumn struct { + // RegoMatch matches the original variable string. + // If it is a match, then this variable config will apply. + RegoMatch *regexp.Regexp + // ColumnSelect is the name of the postgres column to select. + // Can use capture groups from RegoMatch with $1, $2, etc. + ColumnSelect string + + // Type indicates the postgres type of the column. Some expressions will + // need to know this in order to determine what SQL to produce. + // An example is if the variable is a jsonb array, the "contains" SQL + // query is `"value"' @> variable.` instead of `'value' = ANY(variable)`. + // This type is only needed to be provided + Type string +} + type SQLConfig struct { - // VariableRenames renames rego variables to sql columns + // Variables is a map of rego variable names to SQL columns. // Example: - // "input.object.org_owner": "organization_id::text" - // "input.object.owner": "owner_id::text" - VariableRenames map[string]string + // "input\.object\.org_owner": SQLColumn{ + // ColumnSelect: "organization_id", + // Type: VarTypeUUID + // } + // "input\.object\.owner": SQLColumn{ + // ColumnSelect: "owner_id", + // Type: VarTypeUUID + // } + // "input\.object\.group_acl\.(.*)": SQLColumn{ + // ColumnSelect: "group_acl->$1", + // Type: VarTypeJsonb + // } + Variables []SQLColumn +} + +func DefaultConfig() SQLConfig { + return SQLConfig{ + Variables: []SQLColumn{ + { + RegoMatch: regexp.MustCompile(`^input\.object\.acl_group_list\.([^.]*)$`), + ColumnSelect: "group_acl->$1", + Type: VarTypeJsonbArray, + }, + { + RegoMatch: regexp.MustCompile(`^input\.object\.org_owner$`), + ColumnSelect: "organization_id :: text", + Type: VarTypeUUID, + }, + { + RegoMatch: regexp.MustCompile(`^input\.object\.owner$`), + ColumnSelect: "owner_id :: text", + Type: VarTypeUUID, + }, + }, + } } type AuthorizeFilter interface { @@ -377,10 +432,20 @@ func (t termVariable) EvalTerm(obj Object) interface{} { } func (t termVariable) SQLString(cfg SQLConfig) string { - rename, ok := cfg.VariableRenames[t.Name] - if ok { - return rename + for _, col := range cfg.Variables { + matches := col.RegoMatch.FindStringSubmatch(t.Name) + if len(matches) > 0 { + // This config matches this variable. + replace := make([]string, 0, len(matches)*2) + for i, m := range matches { + replace = append(replace, fmt.Sprintf("$%d", i)) + replace = append(replace, m) + } + replacer := strings.NewReplacer(replace...) + return replacer.Replace(col.ColumnSelect) + } } + return t.Name } From fc58da5cfd7cfc5b5d678d48c72cd8ac72c70f70 Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Wed, 28 Sep 2022 18:27:41 -0400 Subject: [PATCH 09/18] Handle jsonb arrays --- coderd/authorize.go | 4 +- coderd/coderdtest/authorize.go | 4 +- coderd/database/databasefake/databasefake.go | 2 +- coderd/rbac/query.go | 73 +++++++++++++---- coderd/rbac/query_internal_test.go | 84 +++++++++++++++++++- 5 files changed, 146 insertions(+), 21 deletions(-) diff --git a/coderd/authorize.go b/coderd/authorize.go index 166cae76e841c..c0b8eaba757ed 100644 --- a/coderd/authorize.go +++ b/coderd/authorize.go @@ -93,9 +93,9 @@ func (h *HTTPAuthorizer) Authorize(r *http.Request, action rbac.Action, object r // from postgres are already authorized, and the caller does not need to // call 'Authorize()' on the returned objects. // Note the authorization is only for the given action and object type. -func (a *HTTPAuthorizer) AuthorizeSQLFilter(r *http.Request, action rbac.Action, objectType string) (rbac.AuthorizeFilter, error) { +func (h *HTTPAuthorizer) AuthorizeSQLFilter(r *http.Request, action rbac.Action, objectType string) (rbac.AuthorizeFilter, error) { roles := httpmw.UserAuthorization(r) - prepared, err := a.Authorizer.PrepareByRoleName(r.Context(), roles.ID.String(), roles.Roles, roles.Scope.ToRBAC(), action, objectType) + prepared, err := h.Authorizer.PrepareByRoleName(r.Context(), roles.ID.String(), roles.Roles, roles.Scope.ToRBAC(), action, objectType) if err != nil { return nil, xerrors.Errorf("prepare filter: %w", err) } diff --git a/coderd/coderdtest/authorize.go b/coderd/coderdtest/authorize.go index 9ee6313e66efc..f3b759845a217 100644 --- a/coderd/coderdtest/authorize.go +++ b/coderd/coderdtest/authorize.go @@ -564,10 +564,10 @@ func (f *fakePreparedAuthorizer) Eval(object rbac.Object) bool { return f.Original.ByRoleName(context.Background(), f.SubjectID, f.Roles, f.Scope, f.Action, object) == nil } -func (f *fakePreparedAuthorizer) RegoString() string { +func (fakePreparedAuthorizer) RegoString() string { panic("not implemented") } -func (f *fakePreparedAuthorizer) SQLString(_ rbac.SQLConfig) string { +func (fakePreparedAuthorizer) SQLString(_ rbac.SQLConfig) string { panic("not implemented") } diff --git a/coderd/database/databasefake/databasefake.go b/coderd/database/databasefake/databasefake.go index 973da51bd0a01..d66cb662a6636 100644 --- a/coderd/database/databasefake/databasefake.go +++ b/coderd/database/databasefake/databasefake.go @@ -526,7 +526,7 @@ func (q *fakeQuerier) GetWorkspaces(ctx context.Context, arg database.GetWorkspa return workspaces, err } -func (q *fakeQuerier) AuthorizedGetWorkspaces(ctx context.Context, arg database.GetWorkspacesParams, authorizedFilter rbac.AuthorizeFilter) ([]database.Workspace, error) { +func (q *fakeQuerier) AuthorizedGetWorkspaces(_ context.Context, arg database.GetWorkspacesParams, authorizedFilter rbac.AuthorizeFilter) ([]database.Workspace, error) { q.mutex.RLock() defer q.mutex.RUnlock() diff --git a/coderd/rbac/query.go b/coderd/rbac/query.go index 3851fac76eec0..397b05bcc3552 100644 --- a/coderd/rbac/query.go +++ b/coderd/rbac/query.go @@ -11,10 +11,11 @@ import ( "golang.org/x/xerrors" ) +type TermType string + const ( - VarTypeJsonbArray = "jsonb-array" - VarTypeUUID = "uuid" - VarTypeText = "text" + VarTypeJsonbTextArray TermType = "jsonb-text-array" + VarTypeText TermType = "text" ) type SQLColumn struct { @@ -30,7 +31,7 @@ type SQLColumn struct { // An example is if the variable is a jsonb array, the "contains" SQL // query is `"value"' @> variable.` instead of `'value' = ANY(variable)`. // This type is only needed to be provided - Type string + Type TermType } type SQLConfig struct { @@ -57,17 +58,22 @@ func DefaultConfig() SQLConfig { { RegoMatch: regexp.MustCompile(`^input\.object\.acl_group_list\.([^.]*)$`), ColumnSelect: "group_acl->$1", - Type: VarTypeJsonbArray, + Type: VarTypeJsonbTextArray, + }, + { + RegoMatch: regexp.MustCompile(`^input\.object\.acl_user_list\.([^.]*)$`), + ColumnSelect: "user_acl->$1", + Type: VarTypeJsonbTextArray, }, { RegoMatch: regexp.MustCompile(`^input\.object\.org_owner$`), ColumnSelect: "organization_id :: text", - Type: VarTypeUUID, + Type: VarTypeText, }, { RegoMatch: regexp.MustCompile(`^input\.object\.owner$`), ColumnSelect: "owner_id :: text", - Type: VarTypeUUID, + Type: VarTypeText, }, }, } @@ -185,8 +191,9 @@ func processExpression(expr *ast.Expr) (Expression, error) { return nil, xerrors.Errorf("invalid '%s' expression: %w", op, err) } return &opInternalMember2{ - base: base, - Terms: [2]Term{terms[0], terms[1]}, + base: base, + Needle: terms[0], + Haystack: terms[1], }, nil default: return nil, xerrors.Errorf("invalid expression: operator %s not supported", op) @@ -230,9 +237,8 @@ func processTerm(term *ast.Term) (Term, error) { base: base, Name: name, }, nil - } else { - return nil, xerrors.Errorf("invalid term: ref must start with a var, started with %T", v[0]) } + return nil, xerrors.Errorf("invalid term: ref must start with a var, started with %T", v[0]) case ast.Var: return &termVariable{ Name: trimQuotes(v.String()), @@ -367,11 +373,12 @@ func (t opEqual) Eval(object Object) bool { // The second term is a set or list. type opInternalMember2 struct { base - Terms [2]Term + Needle Term + Haystack Term } func (t opInternalMember2) Eval(object Object) bool { - a, b := t.Terms[0].EvalTerm(object), t.Terms[1].EvalTerm(object) + a, b := t.Needle.EvalTerm(object), t.Haystack.EvalTerm(object) bset, ok := b.([]interface{}) if !ok { return false @@ -385,7 +392,20 @@ func (t opInternalMember2) Eval(object Object) bool { } func (t opInternalMember2) SQLString(cfg SQLConfig) string { - return fmt.Sprintf("%s = ANY(%s)", t.Terms[0].SQLString(cfg), t.Terms[1].SQLString(cfg)) + if haystack, ok := t.Haystack.(*termVariable); ok { + // This is a special case where the haystack is a jsonb array. + // The more general way to solve this would be to implement a fuller type + // system and handle type conversions for each supported type. + // Then we could determine that the haystack is always an "array" and + // implement the "contains" function on the array type. + // But that requires a lot more code to handle a lot of cases we don't + // actually care about. + if haystack.SQLType(cfg) == VarTypeJsonbTextArray { + return fmt.Sprintf("%s ? %s", haystack.SQLString(cfg), t.Needle.SQLString(cfg)) + } + } + + return fmt.Sprintf("%s = ANY(%s)", t.Needle.SQLString(cfg), t.Haystack.SQLString(cfg)) } // Term is a single value in an expression. Terms can be variables or constants. @@ -413,6 +433,10 @@ func (t termString) SQLString(_ SQLConfig) string { return "'" + t.Value + "'" } +func (t termString) SQLType(_ SQLConfig) TermType { + return VarTypeText +} + type termVariable struct { base Name string @@ -431,8 +455,15 @@ func (t termVariable) EvalTerm(obj Object) interface{} { } } +func (t termVariable) SQLType(cfg SQLConfig) TermType { + if col := t.ColumnConfig(cfg); col != nil { + return col.Type + } + return VarTypeText +} + func (t termVariable) SQLString(cfg SQLConfig) string { - for _, col := range cfg.Variables { + if col := t.ColumnConfig(cfg); col != nil { matches := col.RegoMatch.FindStringSubmatch(t.Name) if len(matches) > 0 { // This config matches this variable. @@ -449,6 +480,18 @@ func (t termVariable) SQLString(cfg SQLConfig) string { return t.Name } +// ColumnConfig returns the correct SQLColumn settings for the +// term. If there is no configured column, it will return nil. +func (t termVariable) ColumnConfig(cfg SQLConfig) *SQLColumn { + for _, col := range cfg.Variables { + matches := col.RegoMatch.MatchString(t.Name) + if matches { + return &col + } + } + return nil +} + // termSet is a set of unique terms. type termSet struct { base diff --git a/coderd/rbac/query_internal_test.go b/coderd/rbac/query_internal_test.go index 63e5f0c578856..50c1efc35b7aa 100644 --- a/coderd/rbac/query_internal_test.go +++ b/coderd/rbac/query_internal_test.go @@ -1,8 +1,11 @@ package rbac import ( + "context" + "fmt" "testing" + "github.com/google/uuid" "github.com/open-policy-agent/opa/ast" "github.com/open-policy-agent/opa/rego" @@ -10,7 +13,12 @@ import ( ) func TestCompileQuery(t *testing.T) { + t.Parallel() + opts := ast.ParserOptions{ + AllFutureKeywords: true, + } t.Run("EmptyQuery", func(t *testing.T) { + t.Parallel() expression, err := Compile(®o.PartialQueries{ Queries: []ast.Body{ must(ast.ParseBody("")), @@ -24,15 +32,89 @@ func TestCompileQuery(t *testing.T) { }) t.Run("TrueQuery", func(t *testing.T) { + t.Parallel() expression, err := Compile(®o.PartialQueries{ Queries: []ast.Body{ must(ast.ParseBody("true")), }, Support: []*ast.Module{}, }) - require.NoError(t, err, "compile empty") + require.NoError(t, err, "compile") require.Equal(t, "true", expression.RegoString(), "true query is rego 'true'") require.Equal(t, "true", expression.SQLString(SQLConfig{}), "true query is sql 'true'") }) + + t.Run("ACLIn", func(t *testing.T) { + t.Parallel() + expression, err := Compile(®o.PartialQueries{ + Queries: []ast.Body{ + ast.MustParseBodyWithOpts(`"*" in input.object.acl_group_list.allUsers`, opts), + }, + Support: []*ast.Module{}, + }) + require.NoError(t, err, "compile") + + require.Equal(t, `internal.member_2("*", input.object.acl_group_list.allUsers)`, expression.RegoString(), "convert to internal_member") + require.Equal(t, `group_acl->allUsers ? '*'`, expression.SQLString(DefaultConfig()), "jsonb in") + }) + + t.Run("Complex", func(t *testing.T) { + t.Parallel() + expression, err := Compile(®o.PartialQueries{ + Queries: []ast.Body{ + ast.MustParseBodyWithOpts(`input.object.org_owner != ""`, opts), + ast.MustParseBodyWithOpts(`input.object.org_owner in {"a", "b", "c"}`, opts), + ast.MustParseBodyWithOpts(`input.object.org_owner != ""`, opts), + ast.MustParseBodyWithOpts(`"read" in input.object.acl_group_list.allUsers`, opts), + ast.MustParseBodyWithOpts(`"read" in input.object.acl_user_list.me`, opts), + }, + Support: []*ast.Module{}, + }) + require.NoError(t, err, "compile") + require.Equal(t, `(organization_id :: text != '' OR `+ + `organization_id :: text = ANY(ARRAY ['a','b','c']) OR `+ + `organization_id :: text != '' OR `+ + `group_acl->allUsers ? 'read' OR `+ + `user_acl->me ? 'read')`, + expression.SQLString(DefaultConfig()), "complex") + }) +} + +//func TestRE(t *testing.T) { +// // ^input\.object\.group_acl\.([^.]*)$ +// re := regexp.MustCompile(`^input\.object\.group_acl\.([^.]*)$`) +// +// x := []string{"test"} +// fmt.Sprintf("test", x) +// +// //re.FindStringSubmatch("input.object.group_acl.allUsers") +// fmt.Println(re.FindStringSubmatch("input.object.group_acl.allUsers")) +//} + +func TestPartialCompileQuery(t *testing.T) { + ctx := context.Background() + defOrg := uuid.New() + unuseID := uuid.New() + + user := subject{ + UserID: "me", + Scope: must(ScopeRole(ScopeAll)), + Roles: []Role{ + must(RoleByName(RoleMember())), + must(RoleByName(RoleOrgMember(defOrg))), + }, + } + var action Action = ActionRead + object := ResourceWorkspace.InOrg(defOrg).WithOwner(unuseID.String()) + + auth := NewAuthorizer() + part, err := auth.Prepare(ctx, user.UserID, user.Roles, user.Scope, action, object.Type) + require.NoError(t, err) + + result, err := Compile(part.partialQueries) + require.NoError(t, err) + + fmt.Println("Rego: ", result.RegoString()) + fmt.Println("SQL: ", result.SQLString(DefaultConfig())) } From 04c1d6a132eb3cfc4d2eb92dde3446f9eff6aa9e Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Wed, 28 Sep 2022 18:31:59 -0400 Subject: [PATCH 10/18] Remove auth call on workspaces --- coderd/coderdtest/authorize.go | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/coderd/coderdtest/authorize.go b/coderd/coderdtest/authorize.go index f3b759845a217..ddba865e2c8ac 100644 --- a/coderd/coderdtest/authorize.go +++ b/coderd/coderdtest/authorize.go @@ -128,11 +128,6 @@ func AGPLRoutes(a *AuthTester) (map[string]string, map[string]RouteCheck) { AssertAction: rbac.ActionCreate, AssertObject: workspaceExecObj, }, - "GET:/api/v2/workspaces/": { - StatusCode: http.StatusOK, - AssertAction: rbac.ActionRead, - AssertObject: workspaceRBACObj, - }, "GET:/api/v2/organizations/{organization}/templates": { StatusCode: http.StatusOK, AssertAction: rbac.ActionRead, @@ -250,6 +245,9 @@ func AGPLRoutes(a *AuthTester) (map[string]string, map[string]RouteCheck) { "PUT:/api/v2/organizations/{organization}/members/{user}/roles": {NoAuthorize: true}, "POST:/api/v2/workspaces/{workspace}/builds": {StatusCode: http.StatusBadRequest, NoAuthorize: true}, "POST:/api/v2/organizations/{organization}/templateversions": {StatusCode: http.StatusBadRequest, NoAuthorize: true}, + + // Endpoints that use the SQLQuery filter. + "GET:/api/v2/workspaces/": {StatusCode: http.StatusOK}, } // Routes like proxy routes support all HTTP methods. A helper func to expand From 98b405e8339c7a633886b97217adcbdaf82b95b1 Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Wed, 28 Sep 2022 19:02:22 -0400 Subject: [PATCH 11/18] Fix PG endpoints test --- coderd/coderdtest/authorize.go | 35 +++++++++++++++---------- coderd/rbac/query_internal_test.go | 41 ------------------------------ 2 files changed, 22 insertions(+), 54 deletions(-) diff --git a/coderd/coderdtest/authorize.go b/coderd/coderdtest/authorize.go index ddba865e2c8ac..63be497e0b622 100644 --- a/coderd/coderdtest/authorize.go +++ b/coderd/coderdtest/authorize.go @@ -247,7 +247,7 @@ func AGPLRoutes(a *AuthTester) (map[string]string, map[string]RouteCheck) { "POST:/api/v2/organizations/{organization}/templateversions": {StatusCode: http.StatusBadRequest, NoAuthorize: true}, // Endpoints that use the SQLQuery filter. - "GET:/api/v2/workspaces/": {StatusCode: http.StatusOK}, + "GET:/api/v2/workspaces/": {StatusCode: http.StatusOK, NoAuthorize: true}, } // Routes like proxy routes support all HTTP methods. A helper func to expand @@ -528,11 +528,12 @@ func (r *RecordingAuthorizer) ByRoleName(_ context.Context, subjectID string, ro func (r *RecordingAuthorizer) PrepareByRoleName(_ context.Context, subjectID string, roles []string, scope rbac.Scope, action rbac.Action, _ string) (rbac.PreparedAuthorized, error) { return &fakePreparedAuthorizer{ - Original: r, - SubjectID: subjectID, - Roles: roles, - Scope: scope, - Action: action, + Original: r, + SubjectID: subjectID, + Roles: roles, + Scope: scope, + Action: action, + HardCodedSQLString: "true", }, nil } @@ -541,11 +542,13 @@ func (r *RecordingAuthorizer) reset() { } type fakePreparedAuthorizer struct { - Original *RecordingAuthorizer - SubjectID string - Roles []string - Scope rbac.Scope - Action rbac.Action + Original *RecordingAuthorizer + SubjectID string + Roles []string + Scope rbac.Scope + Action rbac.Action + HardCodedSQLString string + HardCodedRegoString string } func (f *fakePreparedAuthorizer) Authorize(ctx context.Context, object rbac.Object) error { @@ -562,10 +565,16 @@ func (f *fakePreparedAuthorizer) Eval(object rbac.Object) bool { return f.Original.ByRoleName(context.Background(), f.SubjectID, f.Roles, f.Scope, f.Action, object) == nil } -func (fakePreparedAuthorizer) RegoString() string { +func (f fakePreparedAuthorizer) RegoString() string { + if f.HardCodedRegoString != "" { + return f.HardCodedRegoString + } panic("not implemented") } -func (fakePreparedAuthorizer) SQLString(_ rbac.SQLConfig) string { +func (f fakePreparedAuthorizer) SQLString(_ rbac.SQLConfig) string { + if f.HardCodedSQLString != "" { + return f.HardCodedSQLString + } panic("not implemented") } diff --git a/coderd/rbac/query_internal_test.go b/coderd/rbac/query_internal_test.go index 50c1efc35b7aa..f7923062b18eb 100644 --- a/coderd/rbac/query_internal_test.go +++ b/coderd/rbac/query_internal_test.go @@ -1,11 +1,8 @@ package rbac import ( - "context" - "fmt" "testing" - "github.com/google/uuid" "github.com/open-policy-agent/opa/ast" "github.com/open-policy-agent/opa/rego" @@ -80,41 +77,3 @@ func TestCompileQuery(t *testing.T) { expression.SQLString(DefaultConfig()), "complex") }) } - -//func TestRE(t *testing.T) { -// // ^input\.object\.group_acl\.([^.]*)$ -// re := regexp.MustCompile(`^input\.object\.group_acl\.([^.]*)$`) -// -// x := []string{"test"} -// fmt.Sprintf("test", x) -// -// //re.FindStringSubmatch("input.object.group_acl.allUsers") -// fmt.Println(re.FindStringSubmatch("input.object.group_acl.allUsers")) -//} - -func TestPartialCompileQuery(t *testing.T) { - ctx := context.Background() - defOrg := uuid.New() - unuseID := uuid.New() - - user := subject{ - UserID: "me", - Scope: must(ScopeRole(ScopeAll)), - Roles: []Role{ - must(RoleByName(RoleMember())), - must(RoleByName(RoleOrgMember(defOrg))), - }, - } - var action Action = ActionRead - object := ResourceWorkspace.InOrg(defOrg).WithOwner(unuseID.String()) - - auth := NewAuthorizer() - part, err := auth.Prepare(ctx, user.UserID, user.Roles, user.Scope, action, object.Type) - require.NoError(t, err) - - result, err := Compile(part.partialQueries) - require.NoError(t, err) - - fmt.Println("Rego: ", result.RegoString()) - fmt.Println("SQL: ", result.SQLString(DefaultConfig())) -} From 6ad0b51c01c175aa57c0a7601113449a2022705d Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Wed, 28 Sep 2022 20:32:23 -0400 Subject: [PATCH 12/18] Match psql implementation --- coderd/coderdtest/authorize.go | 24 +++++++++++++++--------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/coderd/coderdtest/authorize.go b/coderd/coderdtest/authorize.go index 63be497e0b622..3c291dfefcb58 100644 --- a/coderd/coderdtest/authorize.go +++ b/coderd/coderdtest/authorize.go @@ -515,17 +515,23 @@ type RecordingAuthorizer struct { var _ rbac.Authorizer = (*RecordingAuthorizer)(nil) -func (r *RecordingAuthorizer) ByRoleName(_ context.Context, subjectID string, roleNames []string, scope rbac.Scope, action rbac.Action, object rbac.Object) error { - r.Called = &authCall{ - SubjectID: subjectID, - Roles: roleNames, - Scope: scope, - Action: action, - Object: object, +func (r *RecordingAuthorizer) FakeByRoleName(_ context.Context, subjectID string, roleNames []string, scope rbac.Scope, action rbac.Action, object rbac.Object, record bool) error { + if record { + r.Called = &authCall{ + SubjectID: subjectID, + Roles: roleNames, + Scope: scope, + Action: action, + Object: object, + } } return r.AlwaysReturn } +func (r *RecordingAuthorizer) ByRoleName(ctx context.Context, subjectID string, roleNames []string, scope rbac.Scope, action rbac.Action, object rbac.Object) error { + return r.FakeByRoleName(ctx, subjectID, roleNames, scope, action, object, true) +} + func (r *RecordingAuthorizer) PrepareByRoleName(_ context.Context, subjectID string, roles []string, scope rbac.Scope, action rbac.Action, _ string) (rbac.PreparedAuthorized, error) { return &fakePreparedAuthorizer{ Original: r, @@ -552,7 +558,7 @@ type fakePreparedAuthorizer struct { } func (f *fakePreparedAuthorizer) Authorize(ctx context.Context, object rbac.Object) error { - return f.Original.ByRoleName(ctx, f.SubjectID, f.Roles, f.Scope, f.Action, object) + return f.Original.FakeByRoleName(ctx, f.SubjectID, f.Roles, f.Scope, f.Action, object, true) } // Compile returns a compiled version of the authorizer that will work for @@ -562,7 +568,7 @@ func (f *fakePreparedAuthorizer) Compile() (rbac.AuthorizeFilter, error) { } func (f *fakePreparedAuthorizer) Eval(object rbac.Object) bool { - return f.Original.ByRoleName(context.Background(), f.SubjectID, f.Roles, f.Scope, f.Action, object) == nil + return f.Original.FakeByRoleName(context.Background(), f.SubjectID, f.Roles, f.Scope, f.Action, object, false) == nil } func (f fakePreparedAuthorizer) RegoString() string { From 7cfad877ff795523c4a02e65ba8d51053e1b979c Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Wed, 28 Sep 2022 20:52:30 -0400 Subject: [PATCH 13/18] Add some comments --- coderd/coderdtest/authorize.go | 26 +++++++++++++------------- coderd/rbac/query.go | 20 +++++++++++--------- 2 files changed, 24 insertions(+), 22 deletions(-) diff --git a/coderd/coderdtest/authorize.go b/coderd/coderdtest/authorize.go index 3c291dfefcb58..2549f9179e5d3 100644 --- a/coderd/coderdtest/authorize.go +++ b/coderd/coderdtest/authorize.go @@ -515,21 +515,21 @@ type RecordingAuthorizer struct { var _ rbac.Authorizer = (*RecordingAuthorizer)(nil) -func (r *RecordingAuthorizer) FakeByRoleName(_ context.Context, subjectID string, roleNames []string, scope rbac.Scope, action rbac.Action, object rbac.Object, record bool) error { - if record { - r.Called = &authCall{ - SubjectID: subjectID, - Roles: roleNames, - Scope: scope, - Action: action, - Object: object, - } - } +// ByRoleNameSQL does not record the call. This matches the postgres behavior +// of not calling Authorize() +func (r *RecordingAuthorizer) ByRoleNameSQL(_ context.Context, _ string, _ []string, _ rbac.Scope, _ rbac.Action, _ rbac.Object) error { return r.AlwaysReturn } func (r *RecordingAuthorizer) ByRoleName(ctx context.Context, subjectID string, roleNames []string, scope rbac.Scope, action rbac.Action, object rbac.Object) error { - return r.FakeByRoleName(ctx, subjectID, roleNames, scope, action, object, true) + r.Called = &authCall{ + SubjectID: subjectID, + Roles: roleNames, + Scope: scope, + Action: action, + Object: object, + } + return r.AlwaysReturn } func (r *RecordingAuthorizer) PrepareByRoleName(_ context.Context, subjectID string, roles []string, scope rbac.Scope, action rbac.Action, _ string) (rbac.PreparedAuthorized, error) { @@ -558,7 +558,7 @@ type fakePreparedAuthorizer struct { } func (f *fakePreparedAuthorizer) Authorize(ctx context.Context, object rbac.Object) error { - return f.Original.FakeByRoleName(ctx, f.SubjectID, f.Roles, f.Scope, f.Action, object, true) + return f.Original.ByRoleName(ctx, f.SubjectID, f.Roles, f.Scope, f.Action, object) } // Compile returns a compiled version of the authorizer that will work for @@ -568,7 +568,7 @@ func (f *fakePreparedAuthorizer) Compile() (rbac.AuthorizeFilter, error) { } func (f *fakePreparedAuthorizer) Eval(object rbac.Object) bool { - return f.Original.FakeByRoleName(context.Background(), f.SubjectID, f.Roles, f.Scope, f.Action, object, false) == nil + return f.Original.ByRoleNameSQL(context.Background(), f.SubjectID, f.Roles, f.Scope, f.Action, object) == nil } func (f fakePreparedAuthorizer) RegoString() string { diff --git a/coderd/rbac/query.go b/coderd/rbac/query.go index 397b05bcc3552..40866d9c58fb6 100644 --- a/coderd/rbac/query.go +++ b/coderd/rbac/query.go @@ -29,7 +29,7 @@ type SQLColumn struct { // Type indicates the postgres type of the column. Some expressions will // need to know this in order to determine what SQL to produce. // An example is if the variable is a jsonb array, the "contains" SQL - // query is `"value"' @> variable.` instead of `'value' = ANY(variable)`. + // query is `variable ? 'value'` instead of `'value' = ANY(variable)`. // This type is only needed to be provided Type TermType } @@ -47,7 +47,7 @@ type SQLConfig struct { // } // "input\.object\.group_acl\.(.*)": SQLColumn{ // ColumnSelect: "group_acl->$1", - // Type: VarTypeJsonb + // Type: VarTypeJsonbTextArray // } Variables []SQLColumn } @@ -394,12 +394,14 @@ func (t opInternalMember2) Eval(object Object) bool { func (t opInternalMember2) SQLString(cfg SQLConfig) string { if haystack, ok := t.Haystack.(*termVariable); ok { // This is a special case where the haystack is a jsonb array. - // The more general way to solve this would be to implement a fuller type - // system and handle type conversions for each supported type. - // Then we could determine that the haystack is always an "array" and - // implement the "contains" function on the array type. - // But that requires a lot more code to handle a lot of cases we don't - // actually care about. + // There is a more general way to solve this, but that requires a lot + // more code to cover a lot more cases that we do not care about. + // To handle this more generally we should implement "Array" as a type. + // Then have the `contains` function on the Array type. This would defer + // knowing the element type to the Array and cover more cases without + // having to add more "if" branches here. + // But until we need more cases, our basic type system is ok, and + // this is the only case we need to handle. if haystack.SQLType(cfg) == VarTypeJsonbTextArray { return fmt.Sprintf("%s ? %s", haystack.SQLString(cfg), t.Needle.SQLString(cfg)) } @@ -433,7 +435,7 @@ func (t termString) SQLString(_ SQLConfig) string { return "'" + t.Value + "'" } -func (t termString) SQLType(_ SQLConfig) TermType { +func (termString) SQLType(_ SQLConfig) TermType { return VarTypeText } From f10e9b70a823932eeb9dc1081769138db82af3f0 Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Wed, 28 Sep 2022 21:20:35 -0400 Subject: [PATCH 14/18] Remove unused argument --- coderd/coderdtest/authorize.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/coderd/coderdtest/authorize.go b/coderd/coderdtest/authorize.go index 2549f9179e5d3..8677b305c1e20 100644 --- a/coderd/coderdtest/authorize.go +++ b/coderd/coderdtest/authorize.go @@ -521,7 +521,7 @@ func (r *RecordingAuthorizer) ByRoleNameSQL(_ context.Context, _ string, _ []str return r.AlwaysReturn } -func (r *RecordingAuthorizer) ByRoleName(ctx context.Context, subjectID string, roleNames []string, scope rbac.Scope, action rbac.Action, object rbac.Object) error { +func (r *RecordingAuthorizer) ByRoleName(_ context.Context, subjectID string, roleNames []string, scope rbac.Scope, action rbac.Action, object rbac.Object) error { r.Called = &authCall{ SubjectID: subjectID, Roles: roleNames, From d89c4d22976ce3d968c009070d03b9f022edb202 Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Thu, 29 Sep 2022 15:07:41 -0400 Subject: [PATCH 15/18] Add query name for tracking --- coderd/database/custom_queries.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/coderd/database/custom_queries.go b/coderd/database/custom_queries.go index bbcfc3e6ba813..b8c004ebb0c41 100644 --- a/coderd/database/custom_queries.go +++ b/coderd/database/custom_queries.go @@ -18,7 +18,8 @@ type customQuerier interface { // This code is copied from `GetWorkspaces` and adds the authorized filter WHERE // clause. func (q *sqlQuerier) AuthorizedGetWorkspaces(ctx context.Context, arg GetWorkspacesParams, authorizedFilter rbac.AuthorizeFilter) ([]Workspace, error) { - query := fmt.Sprintf("%s AND %s", getWorkspaces, authorizedFilter.SQLString(rbac.DefaultConfig())) + // The name comment is for metric tracking + query := fmt.Sprintf("-- name: AuthorizedGetWorkspaces :many\n%s AND %s", getWorkspaces, authorizedFilter.SQLString(rbac.DefaultConfig())) rows, err := q.db.QueryContext(ctx, query, arg.Deleted, arg.OwnerID, From 3828c29e580ef4095f7ab60bfafb8117a3f17055 Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Fri, 30 Sep 2022 10:14:26 -0400 Subject: [PATCH 16/18] Handle nested types This solves it without proper types in our AST. Might bite the bullet and implement some better types --- coderd/rbac/query.go | 98 ++++++++++++++++++++++++++---- coderd/rbac/query_internal_test.go | 13 ++++ 2 files changed, 99 insertions(+), 12 deletions(-) diff --git a/coderd/rbac/query.go b/coderd/rbac/query.go index 40866d9c58fb6..0211a1ff538b0 100644 --- a/coderd/rbac/query.go +++ b/coderd/rbac/query.go @@ -56,12 +56,12 @@ func DefaultConfig() SQLConfig { return SQLConfig{ Variables: []SQLColumn{ { - RegoMatch: regexp.MustCompile(`^input\.object\.acl_group_list\.([^.]*)$`), + RegoMatch: regexp.MustCompile(`^input\.object\.acl_group_list\.?(.*)$`), ColumnSelect: "group_acl->$1", Type: VarTypeJsonbTextArray, }, { - RegoMatch: regexp.MustCompile(`^input\.object\.acl_user_list\.([^.]*)$`), + RegoMatch: regexp.MustCompile(`^input\.object\.acl_user_list\.?(.*)$`), ColumnSelect: "user_acl->$1", Type: VarTypeJsonbTextArray, }, @@ -226,19 +226,42 @@ func processTerm(term *ast.Term) (Term, error) { Value: bool(v), }, nil case ast.Ref: + obj := &termObject{ + base: base, + Variables: []termVariable{}, + } + var idx int // A ref is a set of terms. If the first term is a var, then the // following terms are the path to the value. - if v0, ok := v[0].Value.(ast.Var); ok { - name := trimQuotes(v0.String()) - for _, p := range v[1:] { - name += "." + trimQuotes(p.String()) + var builder strings.Builder + for _, term := range v { + if idx == 0 { + if _, ok := v[0].Value.(ast.Var); !ok { + return nil, xerrors.Errorf("invalid term (%s): ref must start with a var, started with %T", v[0].String(), v[0]) + } } - return &termVariable{ - base: base, - Name: name, - }, nil + + if _, ok := term.Value.(ast.Ref); ok { + // New obj + obj.Variables = append(obj.Variables, termVariable{ + base: base, + Name: builder.String(), + }) + builder.Reset() + idx = 0 + } + if builder.Len() != 0 { + builder.WriteString(".") + } + builder.WriteString(trimQuotes(term.String())) + idx++ } - return nil, xerrors.Errorf("invalid term: ref must start with a var, started with %T", v[0]) + + obj.Variables = append(obj.Variables, termVariable{ + base: base, + Name: builder.String(), + }) + return obj, nil case ast.Var: return &termVariable{ Name: trimQuotes(v.String()), @@ -392,7 +415,7 @@ func (t opInternalMember2) Eval(object Object) bool { } func (t opInternalMember2) SQLString(cfg SQLConfig) string { - if haystack, ok := t.Haystack.(*termVariable); ok { + if haystack, ok := t.Haystack.(*termObject); ok { // This is a special case where the haystack is a jsonb array. // There is a more general way to solve this, but that requires a lot // more code to cover a lot more cases that we do not care about. @@ -439,6 +462,57 @@ func (termString) SQLType(_ SQLConfig) TermType { return VarTypeText } +// termObject is a variable that can be dereferenced. We count some rego objects +// as single variables, eg: input.object.org_owner. In reality, it is a nested +// object. +// In rego, we can dereference the object with the "." operator, which we can +// handle with regex. +// Or we can dereference the object with the "[]", which we can handle with this +// term type. +type termObject struct { + base + Variables []termVariable +} + +func (t termObject) EvalTerm(obj Object) interface{} { + if len(t.Variables) == 0 { + return t.Variables[0].EvalTerm(obj) + } + panic("no nested structures are supported yet") +} + +func (t termObject) SQLType(cfg SQLConfig) TermType { + // Without a full type system, let's just assume the type of the first var + // is the resulting type. This is correct for our use case. + // Solving this more generally requires a full type system, which is + // excessive for our mostly static policy. + return t.Variables[0].SQLType(cfg) +} + +func (t termObject) SQLString(cfg SQLConfig) string { + if len(t.Variables) == 1 { + return t.Variables[0].SQLString(cfg) + } + // Combine the last 2 variables into 1 variable. + end := t.Variables[len(t.Variables)-1] + before := t.Variables[len(t.Variables)-2] + + return termObject{ + base: t.base, + Variables: append( + t.Variables[:len(t.Variables)-2], + termVariable{ + base: base{ + Rego: before.base.Rego + "[" + end.base.Rego + "]", + }, + // Convert the end to SQL string. We evaluate each term + // one at a time. + Name: before.Name + "." + end.SQLString(cfg), + }, + ), + }.SQLString(cfg) +} + type termVariable struct { base Name string diff --git a/coderd/rbac/query_internal_test.go b/coderd/rbac/query_internal_test.go index f7923062b18eb..92d8b91543953 100644 --- a/coderd/rbac/query_internal_test.go +++ b/coderd/rbac/query_internal_test.go @@ -76,4 +76,17 @@ func TestCompileQuery(t *testing.T) { `user_acl->me ? 'read')`, expression.SQLString(DefaultConfig()), "complex") }) + + t.Run("SetDereference", func(t *testing.T) { + t.Parallel() + expression, err := Compile(®o.PartialQueries{ + Queries: []ast.Body{ + ast.MustParseBodyWithOpts(`"*" in input.object.acl_group_list[input.object.org_owner]`, opts), + }, + Support: []*ast.Module{}, + }) + require.NoError(t, err, "compile") + require.Equal(t, `group_acl->organization_id :: text ? '*'`, + expression.SQLString(DefaultConfig()), "set dereference") + }) } From 913fb27c4b070e87bf3cf5f6901d00fe78731b50 Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Fri, 30 Sep 2022 10:18:50 -0400 Subject: [PATCH 17/18] Add comment --- coderd/rbac/query.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/coderd/rbac/query.go b/coderd/rbac/query.go index 0211a1ff538b0..d8b1a140e9eb0 100644 --- a/coderd/rbac/query.go +++ b/coderd/rbac/query.go @@ -497,6 +497,8 @@ func (t termObject) SQLString(cfg SQLConfig) string { end := t.Variables[len(t.Variables)-1] before := t.Variables[len(t.Variables)-2] + // Recursively solve the SQLString by removing the last nested reference. + // This continues until we have a single variable. return termObject{ base: t.base, Variables: append( From 3e2fbb801831a84c2be6b449987b6e4f1f9a8bea Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Fri, 30 Sep 2022 12:19:36 -0400 Subject: [PATCH 18/18] Renaming function call to GetAuthorizedWorkspaces --- coderd/database/custom_queries.go | 8 ++++---- coderd/database/databasefake/databasefake.go | 4 ++-- coderd/workspaces.go | 2 +- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/coderd/database/custom_queries.go b/coderd/database/custom_queries.go index b8c004ebb0c41..219a6cdb13c7b 100644 --- a/coderd/database/custom_queries.go +++ b/coderd/database/custom_queries.go @@ -11,15 +11,15 @@ import ( ) type customQuerier interface { - AuthorizedGetWorkspaces(ctx context.Context, arg GetWorkspacesParams, authorizedFilter rbac.AuthorizeFilter) ([]Workspace, error) + GetAuthorizedWorkspaces(ctx context.Context, arg GetWorkspacesParams, authorizedFilter rbac.AuthorizeFilter) ([]Workspace, error) } -// AuthorizedGetWorkspaces returns all workspaces that the user is authorized to access. +// GetAuthorizedWorkspaces returns all workspaces that the user is authorized to access. // This code is copied from `GetWorkspaces` and adds the authorized filter WHERE // clause. -func (q *sqlQuerier) AuthorizedGetWorkspaces(ctx context.Context, arg GetWorkspacesParams, authorizedFilter rbac.AuthorizeFilter) ([]Workspace, error) { +func (q *sqlQuerier) GetAuthorizedWorkspaces(ctx context.Context, arg GetWorkspacesParams, authorizedFilter rbac.AuthorizeFilter) ([]Workspace, error) { // The name comment is for metric tracking - query := fmt.Sprintf("-- name: AuthorizedGetWorkspaces :many\n%s AND %s", getWorkspaces, authorizedFilter.SQLString(rbac.DefaultConfig())) + query := fmt.Sprintf("-- name: GetAuthorizedWorkspaces :many\n%s AND %s", getWorkspaces, authorizedFilter.SQLString(rbac.DefaultConfig())) rows, err := q.db.QueryContext(ctx, query, arg.Deleted, arg.OwnerID, diff --git a/coderd/database/databasefake/databasefake.go b/coderd/database/databasefake/databasefake.go index d66cb662a6636..33a33cc8c7496 100644 --- a/coderd/database/databasefake/databasefake.go +++ b/coderd/database/databasefake/databasefake.go @@ -522,11 +522,11 @@ func (q *fakeQuerier) GetAuthorizationUserRoles(_ context.Context, userID uuid.U func (q *fakeQuerier) GetWorkspaces(ctx context.Context, arg database.GetWorkspacesParams) ([]database.Workspace, error) { // A nil auth filter means no auth filter. - workspaces, err := q.AuthorizedGetWorkspaces(ctx, arg, nil) + workspaces, err := q.GetAuthorizedWorkspaces(ctx, arg, nil) return workspaces, err } -func (q *fakeQuerier) AuthorizedGetWorkspaces(_ context.Context, arg database.GetWorkspacesParams, authorizedFilter rbac.AuthorizeFilter) ([]database.Workspace, error) { +func (q *fakeQuerier) GetAuthorizedWorkspaces(_ context.Context, arg database.GetWorkspacesParams, authorizedFilter rbac.AuthorizeFilter) ([]database.Workspace, error) { q.mutex.RLock() defer q.mutex.RUnlock() diff --git a/coderd/workspaces.go b/coderd/workspaces.go index a20093130aef4..ff834aa6b246b 100644 --- a/coderd/workspaces.go +++ b/coderd/workspaces.go @@ -122,7 +122,7 @@ func (api *API) workspaces(rw http.ResponseWriter, r *http.Request) { return } - workspaces, err := api.Database.AuthorizedGetWorkspaces(ctx, filter, sqlFilter) + workspaces, err := api.Database.GetAuthorizedWorkspaces(ctx, filter, sqlFilter) if err != nil { httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ Message: "Internal error fetching workspaces.", pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy