Skip to content

Commit f0eddba

Browse files
Kira-PilotEmyrk
andauthored
chore: Support anonymously embedded fields for audit diffs (#5746)
- Anonymously embedded structs are expanded as top level fields. - Unit tests for anonymously embedded structs Co-authored-by: Steven Masley <stevenmasley@coder.com>
1 parent e37bff6 commit f0eddba

File tree

2 files changed

+152
-13
lines changed

2 files changed

+152
-13
lines changed

enterprise/audit/diff.go

Lines changed: 75 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"reflect"
77

88
"github.com/google/uuid"
9+
"golang.org/x/xerrors"
910

1011
"github.com/coder/coder/coderd/audit"
1112
"github.com/coder/coder/coderd/database"
@@ -18,11 +19,9 @@ func structName(t reflect.Type) string {
1819
func diffValues(left, right any, table Table) audit.Map {
1920
var (
2021
baseDiff = audit.Map{}
21-
22-
leftV = reflect.ValueOf(left)
23-
24-
rightV = reflect.ValueOf(right)
25-
rightT = reflect.TypeOf(right)
22+
rightT = reflect.TypeOf(right)
23+
leftV = reflect.ValueOf(left)
24+
rightV = reflect.ValueOf(right)
2625

2726
diffKey = table[structName(rightT)]
2827
)
@@ -31,19 +30,25 @@ func diffValues(left, right any, table Table) audit.Map {
3130
panic(fmt.Sprintf("dev error: type %q (type %T) attempted audit but not auditable", rightT.Name(), right))
3231
}
3332

34-
for i := 0; i < rightT.NumField(); i++ {
35-
if !rightT.Field(i).IsExported() {
36-
continue
37-
}
33+
// allFields contains all top level fields of the struct.
34+
allFields, err := flattenStructFields(leftV, rightV)
35+
if err != nil {
36+
// This should never happen. Only structs should be flattened. If an
37+
// error occurs, an unsupported or non-struct type was passed in.
38+
panic(fmt.Sprintf("dev error: failed to flatten struct fields: %v", err))
39+
}
3840

41+
for _, field := range allFields {
3942
var (
40-
leftF = leftV.Field(i)
41-
rightF = rightV.Field(i)
43+
leftF = field.LeftF
44+
rightF = field.RightF
4245

4346
leftI = leftF.Interface()
4447
rightI = rightF.Interface()
48+
)
4549

46-
diffName = rightT.Field(i).Tag.Get("json")
50+
var (
51+
diffName = field.FieldType.Tag.Get("json")
4752
)
4853

4954
atype, ok := diffKey[diffName]
@@ -145,6 +150,64 @@ func convertDiffType(left, right any) (newLeft, newRight any, changed bool) {
145150
}
146151
}
147152

153+
// fieldDiff has all the required information to return an audit diff for a
154+
// given field.
155+
type fieldDiff struct {
156+
FieldType reflect.StructField
157+
LeftF reflect.Value
158+
RightF reflect.Value
159+
}
160+
161+
// flattenStructFields will return all top level fields for a given structure.
162+
// Only anonymously embedded structs will be recursively flattened such that their
163+
// fields are returned as top level fields. Named nested structs will be returned
164+
// as a single field.
165+
// Conflicting field names need to be handled by the caller.
166+
func flattenStructFields(leftV, rightV reflect.Value) ([]fieldDiff, error) {
167+
// Dereference pointers if the field is a pointer field.
168+
if leftV.Kind() == reflect.Ptr {
169+
leftV = derefPointer(leftV)
170+
rightV = derefPointer(rightV)
171+
}
172+
173+
if leftV.Kind() != reflect.Struct {
174+
return nil, xerrors.Errorf("%q is not a struct, kind=%s", leftV.String(), leftV.Kind())
175+
}
176+
177+
var allFields []fieldDiff
178+
rightT := rightV.Type()
179+
180+
// Loop through all top level fields of the struct.
181+
for i := 0; i < rightT.NumField(); i++ {
182+
if !rightT.Field(i).IsExported() {
183+
continue
184+
}
185+
186+
var (
187+
leftF = leftV.Field(i)
188+
rightF = rightV.Field(i)
189+
)
190+
191+
if rightT.Field(i).Anonymous {
192+
// Anonymous fields are recursively flattened.
193+
anonFields, err := flattenStructFields(leftF, rightF)
194+
if err != nil {
195+
return nil, xerrors.Errorf("flatten anonymous field %q: %w", rightT.Field(i).Name, err)
196+
}
197+
allFields = append(allFields, anonFields...)
198+
continue
199+
}
200+
201+
// Single fields append as is.
202+
allFields = append(allFields, fieldDiff{
203+
LeftF: leftF,
204+
RightF: rightF,
205+
FieldType: rightT.Field(i),
206+
})
207+
}
208+
return allFields, nil
209+
}
210+
148211
// derefPointer deferences a reflect.Value that is a pointer to its underlying
149212
// value. It dereferences recursively until it finds a non-pointer value. If the
150213
// pointer is nil, it will be coerced to the zero value of the underlying type.

enterprise/audit/diff_internal_test.go

Lines changed: 77 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,6 @@ func Test_diffValues(t *testing.T) {
6666
})
6767
})
6868

69-
//nolint:revive
7069
t.Run("PointerField", func(t *testing.T) {
7170
t.Parallel()
7271

@@ -98,6 +97,83 @@ func Test_diffValues(t *testing.T) {
9897
})
9998
})
10099

100+
//nolint:revive
101+
t.Run("EmbeddedStruct", func(t *testing.T) {
102+
t.Parallel()
103+
104+
type Bar struct {
105+
Baz int `json:"baz"`
106+
Buzz string `json:"buzz"`
107+
}
108+
109+
type PtrBar struct {
110+
Qux string `json:"qux"`
111+
}
112+
113+
type foo struct {
114+
Bar
115+
*PtrBar
116+
TopLevel string `json:"top_level"`
117+
}
118+
119+
table := auditMap(map[any]map[string]Action{
120+
&foo{}: {
121+
"baz": ActionTrack,
122+
"buzz": ActionTrack,
123+
"qux": ActionTrack,
124+
"top_level": ActionTrack,
125+
},
126+
})
127+
128+
runDiffValuesTests(t, table, []diffTest{
129+
{
130+
name: "SingleFieldChange",
131+
left: foo{TopLevel: "top-before", Bar: Bar{Baz: 1, Buzz: "before"}, PtrBar: &PtrBar{Qux: "qux-before"}},
132+
right: foo{TopLevel: "top-after", Bar: Bar{Baz: 0, Buzz: "after"}, PtrBar: &PtrBar{Qux: "qux-after"}},
133+
exp: audit.Map{
134+
"baz": audit.OldNew{Old: 1, New: 0},
135+
"buzz": audit.OldNew{Old: "before", New: "after"},
136+
"qux": audit.OldNew{Old: "qux-before", New: "qux-after"},
137+
"top_level": audit.OldNew{Old: "top-before", New: "top-after"},
138+
},
139+
},
140+
{
141+
name: "Empty",
142+
left: foo{},
143+
right: foo{},
144+
exp: audit.Map{},
145+
},
146+
{
147+
name: "NoChange",
148+
left: foo{TopLevel: "top-before", Bar: Bar{Baz: 1, Buzz: "before"}, PtrBar: &PtrBar{Qux: "qux-before"}},
149+
right: foo{TopLevel: "top-before", Bar: Bar{Baz: 1, Buzz: "before"}, PtrBar: &PtrBar{Qux: "qux-before"}},
150+
exp: audit.Map{},
151+
},
152+
{
153+
name: "LeftEmpty",
154+
left: foo{},
155+
right: foo{TopLevel: "top-after", Bar: Bar{Baz: 1, Buzz: "after"}, PtrBar: &PtrBar{Qux: "qux-after"}},
156+
exp: audit.Map{
157+
"baz": audit.OldNew{Old: 0, New: 1},
158+
"buzz": audit.OldNew{Old: "", New: "after"},
159+
"qux": audit.OldNew{Old: "", New: "qux-after"},
160+
"top_level": audit.OldNew{Old: "", New: "top-after"},
161+
},
162+
},
163+
{
164+
name: "RightNil",
165+
left: foo{TopLevel: "top-before", Bar: Bar{Baz: 1, Buzz: "before"}, PtrBar: &PtrBar{Qux: "qux-before"}},
166+
right: foo{},
167+
exp: audit.Map{
168+
"baz": audit.OldNew{Old: 1, New: 0},
169+
"buzz": audit.OldNew{Old: "before", New: ""},
170+
"qux": audit.OldNew{Old: "qux-before", New: ""},
171+
"top_level": audit.OldNew{Old: "top-before", New: ""},
172+
},
173+
},
174+
})
175+
})
176+
101177
// We currently don't support nested structs.
102178
// t.Run("NestedStruct", func(t *testing.T) {
103179
// t.Parallel()

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy