diff --git a/enterprise/audit/diff.go b/enterprise/audit/diff.go index 05d46499b525a..afa6d88f494f8 100644 --- a/enterprise/audit/diff.go +++ b/enterprise/audit/diff.go @@ -6,6 +6,7 @@ import ( "reflect" "github.com/google/uuid" + "golang.org/x/xerrors" "github.com/coder/coder/coderd/audit" "github.com/coder/coder/coderd/database" @@ -18,11 +19,9 @@ func structName(t reflect.Type) string { func diffValues(left, right any, table Table) audit.Map { var ( baseDiff = audit.Map{} - - leftV = reflect.ValueOf(left) - - rightV = reflect.ValueOf(right) - rightT = reflect.TypeOf(right) + rightT = reflect.TypeOf(right) + leftV = reflect.ValueOf(left) + rightV = reflect.ValueOf(right) diffKey = table[structName(rightT)] ) @@ -31,19 +30,25 @@ func diffValues(left, right any, table Table) audit.Map { panic(fmt.Sprintf("dev error: type %q (type %T) attempted audit but not auditable", rightT.Name(), right)) } - for i := 0; i < rightT.NumField(); i++ { - if !rightT.Field(i).IsExported() { - continue - } + // allFields contains all top level fields of the struct. + allFields, err := flattenStructFields(leftV, rightV) + if err != nil { + // This should never happen. Only structs should be flattened. If an + // error occurs, an unsupported or non-struct type was passed in. + panic(fmt.Sprintf("dev error: failed to flatten struct fields: %v", err)) + } + for _, field := range allFields { var ( - leftF = leftV.Field(i) - rightF = rightV.Field(i) + leftF = field.LeftF + rightF = field.RightF leftI = leftF.Interface() rightI = rightF.Interface() + ) - diffName = rightT.Field(i).Tag.Get("json") + var ( + diffName = field.FieldType.Tag.Get("json") ) atype, ok := diffKey[diffName] @@ -145,6 +150,64 @@ func convertDiffType(left, right any) (newLeft, newRight any, changed bool) { } } +// fieldDiff has all the required information to return an audit diff for a +// given field. +type fieldDiff struct { + FieldType reflect.StructField + LeftF reflect.Value + RightF reflect.Value +} + +// flattenStructFields will return all top level fields for a given structure. +// Only anonymously embedded structs will be recursively flattened such that their +// fields are returned as top level fields. Named nested structs will be returned +// as a single field. +// Conflicting field names need to be handled by the caller. +func flattenStructFields(leftV, rightV reflect.Value) ([]fieldDiff, error) { + // Dereference pointers if the field is a pointer field. + if leftV.Kind() == reflect.Ptr { + leftV = derefPointer(leftV) + rightV = derefPointer(rightV) + } + + if leftV.Kind() != reflect.Struct { + return nil, xerrors.Errorf("%q is not a struct, kind=%s", leftV.String(), leftV.Kind()) + } + + var allFields []fieldDiff + rightT := rightV.Type() + + // Loop through all top level fields of the struct. + for i := 0; i < rightT.NumField(); i++ { + if !rightT.Field(i).IsExported() { + continue + } + + var ( + leftF = leftV.Field(i) + rightF = rightV.Field(i) + ) + + if rightT.Field(i).Anonymous { + // Anonymous fields are recursively flattened. + anonFields, err := flattenStructFields(leftF, rightF) + if err != nil { + return nil, xerrors.Errorf("flatten anonymous field %q: %w", rightT.Field(i).Name, err) + } + allFields = append(allFields, anonFields...) + continue + } + + // Single fields append as is. + allFields = append(allFields, fieldDiff{ + LeftF: leftF, + RightF: rightF, + FieldType: rightT.Field(i), + }) + } + return allFields, nil +} + // derefPointer deferences a reflect.Value that is a pointer to its underlying // value. It dereferences recursively until it finds a non-pointer value. If the // pointer is nil, it will be coerced to the zero value of the underlying type. diff --git a/enterprise/audit/diff_internal_test.go b/enterprise/audit/diff_internal_test.go index bf918a6f97c1d..5df4b3b893d73 100644 --- a/enterprise/audit/diff_internal_test.go +++ b/enterprise/audit/diff_internal_test.go @@ -66,7 +66,6 @@ func Test_diffValues(t *testing.T) { }) }) - //nolint:revive t.Run("PointerField", func(t *testing.T) { t.Parallel() @@ -98,6 +97,83 @@ func Test_diffValues(t *testing.T) { }) }) + //nolint:revive + t.Run("EmbeddedStruct", func(t *testing.T) { + t.Parallel() + + type Bar struct { + Baz int `json:"baz"` + Buzz string `json:"buzz"` + } + + type PtrBar struct { + Qux string `json:"qux"` + } + + type foo struct { + Bar + *PtrBar + TopLevel string `json:"top_level"` + } + + table := auditMap(map[any]map[string]Action{ + &foo{}: { + "baz": ActionTrack, + "buzz": ActionTrack, + "qux": ActionTrack, + "top_level": ActionTrack, + }, + }) + + runDiffValuesTests(t, table, []diffTest{ + { + name: "SingleFieldChange", + left: foo{TopLevel: "top-before", Bar: Bar{Baz: 1, Buzz: "before"}, PtrBar: &PtrBar{Qux: "qux-before"}}, + right: foo{TopLevel: "top-after", Bar: Bar{Baz: 0, Buzz: "after"}, PtrBar: &PtrBar{Qux: "qux-after"}}, + exp: audit.Map{ + "baz": audit.OldNew{Old: 1, New: 0}, + "buzz": audit.OldNew{Old: "before", New: "after"}, + "qux": audit.OldNew{Old: "qux-before", New: "qux-after"}, + "top_level": audit.OldNew{Old: "top-before", New: "top-after"}, + }, + }, + { + name: "Empty", + left: foo{}, + right: foo{}, + exp: audit.Map{}, + }, + { + name: "NoChange", + left: foo{TopLevel: "top-before", Bar: Bar{Baz: 1, Buzz: "before"}, PtrBar: &PtrBar{Qux: "qux-before"}}, + right: foo{TopLevel: "top-before", Bar: Bar{Baz: 1, Buzz: "before"}, PtrBar: &PtrBar{Qux: "qux-before"}}, + exp: audit.Map{}, + }, + { + name: "LeftEmpty", + left: foo{}, + right: foo{TopLevel: "top-after", Bar: Bar{Baz: 1, Buzz: "after"}, PtrBar: &PtrBar{Qux: "qux-after"}}, + exp: audit.Map{ + "baz": audit.OldNew{Old: 0, New: 1}, + "buzz": audit.OldNew{Old: "", New: "after"}, + "qux": audit.OldNew{Old: "", New: "qux-after"}, + "top_level": audit.OldNew{Old: "", New: "top-after"}, + }, + }, + { + name: "RightNil", + left: foo{TopLevel: "top-before", Bar: Bar{Baz: 1, Buzz: "before"}, PtrBar: &PtrBar{Qux: "qux-before"}}, + right: foo{}, + exp: audit.Map{ + "baz": audit.OldNew{Old: 1, New: 0}, + "buzz": audit.OldNew{Old: "before", New: ""}, + "qux": audit.OldNew{Old: "qux-before", New: ""}, + "top_level": audit.OldNew{Old: "top-before", New: ""}, + }, + }, + }) + }) + // We currently don't support nested structs. // t.Run("NestedStruct", func(t *testing.T) { // t.Parallel()
Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.
Alternative Proxies: