Skip to content

Commit f7d43c2

Browse files
topdown: eliminate closure allocations in Set and virtual doc enumeration (#8242)
Replace closure allocations in evalTree.enumerate with method values for Set iteration and virtual document traversal. Set enumeration now uses Slice() instead of Iter(callback), and virtual doc enumeration uses enumerateNext helper instead of inline closures. Add BenchmarkEnumerateComprehensions to measure memory impact of closure optimizations in evalTree.enumerate with set/array comprehensions over large datasets. Signed-off-by: alex60217101990 <[email protected]>
1 parent b6bc37a commit f7d43c2

File tree

2 files changed

+268
-16
lines changed

2 files changed

+268
-16
lines changed

v1/topdown/enumerate_bench_test.go

Lines changed: 236 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,236 @@
1+
// Copyright 2026 The OPA Authors. All rights reserved.
2+
// Use of this source code is governed by an Apache2
3+
// license that can be found in the LICENSE file.
4+
5+
package topdown
6+
7+
import (
8+
"context"
9+
"fmt"
10+
"math/rand"
11+
"testing"
12+
13+
"github.com/open-policy-agent/opa/v1/ast"
14+
"github.com/open-policy-agent/opa/v1/storage"
15+
inmem "github.com/open-policy-agent/opa/v1/storage/inmem/test"
16+
)
17+
18+
// BenchmarkEnumerateComprehensions benchmarks policy evaluation with
19+
// comprehensions over large datasets. This specifically targets the
20+
// enumerate optimization that eliminates closure allocations.
21+
func BenchmarkEnumerateComprehensions(b *testing.B) {
22+
sizes := []int{1000, 5000, 10000}
23+
24+
for _, size := range sizes {
25+
b.Run(fmt.Sprintf("size_%d", size), func(b *testing.B) {
26+
ctx := context.Background()
27+
28+
// Generate mock dataset with nested objects
29+
data := generateNestedDataset(size)
30+
store := inmem.NewFromObject(data)
31+
32+
// Policy with multiple comprehensions that exercise enumerate
33+
module := `package test
34+
35+
import rego.v1
36+
37+
# Set comprehension over users
38+
active_users contains user.id if {
39+
some user in data.users
40+
user.profile.active == true
41+
}
42+
43+
# Array comprehension with nested access
44+
premium_users := [user |
45+
some user in data.users
46+
user.profile.settings.subscription.tier == "premium"
47+
]
48+
49+
# Object comprehension with filtering
50+
users_by_age contains age_group if {
51+
age_group := "20-30"
52+
some u in data.users
53+
u.profile.age >= 20
54+
u.profile.age < 30
55+
}
56+
57+
users_by_age contains age_group if {
58+
age_group := "30-40"
59+
some u in data.users
60+
u.profile.age >= 30
61+
u.profile.age < 40
62+
}
63+
64+
# Nested comprehension
65+
high_value_users contains user.email if {
66+
some user in data.users
67+
user.profile.active == true
68+
count([p | some p in user.permissions; p.level > 5]) > 0
69+
}
70+
71+
# Random access pattern
72+
user_lookup[id] := user if {
73+
some user in data.users
74+
id := user.id
75+
}
76+
`
77+
78+
compiler := ast.MustCompileModules(map[string]string{
79+
"test.rego": module,
80+
})
81+
82+
// Query that exercises all comprehensions
83+
query := ast.MustParseBody(`
84+
data.test.active_users
85+
data.test.premium_users
86+
data.test.users_by_age
87+
data.test.high_value_users
88+
`)
89+
90+
b.ReportAllocs()
91+
b.ResetTimer()
92+
93+
for b.Loop() {
94+
err := storage.Txn(ctx, store, storage.TransactionParams{}, func(txn storage.Transaction) error {
95+
q := NewQuery(query).
96+
WithCompiler(compiler).
97+
WithStore(store).
98+
WithTransaction(txn)
99+
100+
_, err := q.Run(ctx)
101+
return err
102+
})
103+
104+
if err != nil {
105+
b.Fatal(err)
106+
}
107+
}
108+
})
109+
}
110+
}
111+
112+
// generateNestedDataset creates a dataset with nested objects of varying depth
113+
func generateNestedDataset(size int) map[string]any {
114+
users := make([]any, size)
115+
rng := rand.New(rand.NewSource(42)) // Fixed seed for reproducibility
116+
117+
tiers := []string{"free", "basic", "premium", "enterprise"}
118+
departments := []string{"engineering", "sales", "marketing", "support", "hr"}
119+
120+
for i := range size {
121+
// Random nested object with 3-5 levels of nesting
122+
permissions := make([]any, rng.Intn(10)+1)
123+
for j := 0; j < len(permissions); j++ {
124+
permissions[j] = map[string]any{
125+
"name": fmt.Sprintf("perm_%d", j),
126+
"level": rng.Intn(10),
127+
"scope": map[string]any{
128+
"resource": fmt.Sprintf("res_%d", rng.Intn(100)),
129+
"actions": []string{"read", "write", "delete"}[rng.Intn(3)],
130+
},
131+
}
132+
}
133+
134+
users[i] = map[string]any{
135+
"id": fmt.Sprintf("user_%d", i),
136+
"name": fmt.Sprintf("User %d", i),
137+
"email": fmt.Sprintf("user%[email protected]", i),
138+
"profile": map[string]any{
139+
"active": rng.Float64() > 0.3, // 70% active
140+
"age": 20 + rng.Intn(40), // Age 20-59
141+
"settings": map[string]any{
142+
"subscription": map[string]any{
143+
"tier": tiers[rng.Intn(len(tiers))],
144+
"start_date": "2024-01-01",
145+
"features": map[string]any{
146+
"api_access": rng.Float64() > 0.5,
147+
"custom_domain": rng.Float64() > 0.7,
148+
"priority_support": map[string]any{
149+
"enabled": rng.Float64() > 0.8,
150+
"level": rng.Intn(5) + 1,
151+
},
152+
},
153+
},
154+
"notifications": map[string]any{
155+
"email": rng.Float64() > 0.4,
156+
"sms": rng.Float64() > 0.8,
157+
},
158+
},
159+
"department": departments[rng.Intn(len(departments))],
160+
},
161+
"permissions": permissions,
162+
"metadata": map[string]any{
163+
"created_at": "2024-01-01T00:00:00Z",
164+
"updated_at": "2024-01-15T00:00:00Z",
165+
"tags": map[string]any{
166+
"region": []string{"us-west", "us-east", "eu-central"}[rng.Intn(3)],
167+
"environment": []string{"prod", "staging", "dev"}[rng.Intn(3)],
168+
"cost_center": fmt.Sprintf("CC%04d", rng.Intn(1000)),
169+
},
170+
},
171+
}
172+
}
173+
174+
return map[string]any{
175+
"users": users,
176+
}
177+
}
178+
179+
// BenchmarkEnumerateRandomAccess benchmarks random access patterns
180+
// that exercise virtual document enumeration
181+
func BenchmarkEnumerateRandomAccess(b *testing.B) {
182+
ctx := context.Background()
183+
184+
data := generateNestedDataset(10000)
185+
store := inmem.NewFromObject(data)
186+
187+
module := `package test
188+
189+
import rego.v1
190+
191+
# Virtual document with random access
192+
user_by_id[id] := user if {
193+
some user in data.users
194+
id := user.id
195+
}
196+
197+
# Nested virtual document access
198+
premium_by_dept[dept] := users if {
199+
dept := data.users[_].profile.department
200+
users := [u |
201+
some u in data.users
202+
u.profile.department == dept
203+
u.profile.settings.subscription.tier == "premium"
204+
]
205+
}
206+
`
207+
208+
compiler := ast.MustCompileModules(map[string]string{
209+
"test.rego": module,
210+
})
211+
212+
// Access random users
213+
query := ast.MustParseBody(`
214+
data.test.user_by_id["user_1234"]
215+
data.test.user_by_id["user_5678"]
216+
data.test.premium_by_dept.engineering
217+
`)
218+
219+
b.ReportAllocs()
220+
221+
for b.Loop() {
222+
err := storage.Txn(ctx, store, storage.TransactionParams{}, func(txn storage.Transaction) error {
223+
q := NewQuery(query).
224+
WithCompiler(compiler).
225+
WithStore(store).
226+
WithTransaction(txn)
227+
228+
_, err := q.Run(ctx)
229+
return err
230+
})
231+
232+
if err != nil {
233+
b.Fatal(err)
234+
}
235+
}
236+
}

v1/topdown/eval.go

Lines changed: 32 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2486,6 +2486,20 @@ func (e evalTree) next(iter unifyIterator, plugged *ast.Term) error {
24862486
return cpy.eval(iter)
24872487
}
24882488

2489+
// enumerateNext is a helper to avoid closure allocation in enumerate loops.
2490+
// Method values don't allocate, unlike explicit closures.
2491+
// Using a pointer to evalTree avoids copying the 96-byte structure.
2492+
// Fields are ordered by size for optimal memory alignment (16 > 8 > 8 bytes).
2493+
type enumerateNext struct {
2494+
iter unifyIterator // 16 bytes (interface)
2495+
e *evalTree // 8 bytes (pointer)
2496+
key *ast.Term // 8 bytes (pointer)
2497+
}
2498+
2499+
func (en *enumerateNext) call() error {
2500+
return en.e.next(en.iter, en.key)
2501+
}
2502+
24892503
func (e evalTree) enumerate(iter unifyIterator) error {
24902504

24912505
if e.e.inliningControl.Disabled(e.plugged[:e.pos], true) {
@@ -2501,14 +2515,17 @@ func (e evalTree) enumerate(iter unifyIterator) error {
25012515
dc.deferred = nil
25022516
defer deecPool.Put(dc)
25032517

2518+
// Use method value to avoid closure allocation.
2519+
// Create once and reuse for both doc and virtual doc enumeration.
2520+
en := enumerateNext{iter: iter, e: &e, key: nil}
2521+
25042522
if doc != nil {
25052523
switch doc := doc.(type) {
25062524
case *ast.Array:
25072525
for i := range doc.Len() {
25082526
k := ast.InternedTerm(i)
2509-
err := e.e.biunify(k, e.ref[e.pos], e.bindings, e.bindings, func() error {
2510-
return e.next(iter, k)
2511-
})
2527+
en.key = k
2528+
err := e.e.biunify(k, e.ref[e.pos], e.bindings, e.bindings, en.call)
25122529

25132530
if err := dc.handleErr(err); err != nil {
25142531
return err
@@ -2517,21 +2534,20 @@ func (e evalTree) enumerate(iter unifyIterator) error {
25172534
case ast.Object:
25182535
ki := doc.KeysIterator()
25192536
for k, more := ki.Next(); more; k, more = ki.Next() {
2520-
err := e.e.biunify(k, e.ref[e.pos], e.bindings, e.bindings, func() error {
2521-
return e.next(iter, k)
2522-
})
2537+
en.key = k
2538+
err := e.e.biunify(k, e.ref[e.pos], e.bindings, e.bindings, en.call)
25232539
if err := dc.handleErr(err); err != nil {
25242540
return err
25252541
}
25262542
}
25272543
case ast.Set:
2528-
if err := doc.Iter(func(elem *ast.Term) error {
2529-
err := e.e.biunify(elem, e.ref[e.pos], e.bindings, e.bindings, func() error {
2530-
return e.next(iter, elem)
2531-
})
2532-
return dc.handleErr(err)
2533-
}); err != nil {
2534-
return err
2544+
// Use Slice() to avoid closure allocation in Iter()
2545+
for _, elem := range doc.Slice() {
2546+
en.key = elem
2547+
err := e.e.biunify(elem, e.ref[e.pos], e.bindings, e.bindings, en.call)
2548+
if err := dc.handleErr(err); err != nil {
2549+
return err
2550+
}
25352551
}
25362552
}
25372553
}
@@ -2544,11 +2560,11 @@ func (e evalTree) enumerate(iter unifyIterator) error {
25442560
return nil
25452561
}
25462562

2563+
// Reuse the same enumerateNext for virtual documents
25472564
for _, k := range e.node.Sorted {
25482565
key := ast.NewTerm(k)
2549-
if err := e.e.biunify(key, e.ref[e.pos], e.bindings, e.bindings, func() error {
2550-
return e.next(iter, key)
2551-
}); err != nil {
2566+
en.key = key
2567+
if err := e.e.biunify(key, e.ref[e.pos], e.bindings, e.bindings, en.call); err != nil {
25522568
return err
25532569
}
25542570
}

0 commit comments

Comments
 (0)