Skip to content

Commit e03e62c

Browse files
committed
Don't include dependencies of pre-computed properties in parallel graph
1 parent 6af98de commit e03e62c

File tree

1 file changed

+6
-0
lines changed

1 file changed

+6
-0
lines changed

phiml/parallel/_parallel.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ def parallel_compute(instance, properties: Sequence, parallel_dims=batch,
6060
keep_intermediate: Whether the outputs of cached properties required to compute `properties` but not contained in `properties` should be kept in memory.
6161
If `False`, these values will not be cached on `instance` after this call.
6262
"""
63+
assert hasattr(instance, '__dict__'), f"parallel_compute requires instance to have __dict__. Slots are not supported."
6364
if memory_limit is not None:
6465
assert cache_dir is not None, "cache_dir must be specified if memory_limit is set"
6566
dims = shape(instance).only(parallel_dims)
@@ -215,6 +216,11 @@ def recursive_add_node(obj, cls, name: str, prop: Optional, dims: Shape, nodes:
215216
# --- Determine shape ---
216217
spec_out = prop.out if isinstance(prop, ParallelProperty) else INFER
217218
needs_trace = spec_out is INFER or (isinstance(prop, ParallelProperty) and prop.requires in {MIXED, INFER})
219+
already_computed = name in obj.__dict__
220+
if already_computed:
221+
precomputed_value = obj.__dict__[name]
222+
node = nodes[name] = PGraphNode(name, precomputed_value, EMPTY_SHAPE, None, True, set(), [], True, stage=-1)
223+
return node
218224
if needs_trace:
219225
out, trace = trace_cached_property(obj, cls, name, prop, dims, {d.name: expand_tracers(d.out, d.distributed) for d in dependencies})
220226
if isinstance(prop, ParallelProperty):

0 commit comments

Comments
 (0)