Skip to content

Commit 496b6e9

Browse files
p-hollPhilipp Holl
authored andcommitted
Persistent parallel properties
1 parent 118ae5b commit 496b6e9

File tree

2 files changed

+17
-13
lines changed

2 files changed

+17
-13
lines changed

phiml/parallel/_parallel.py

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def parallel_compute(instance, properties: Sequence, parallel_dims=batch,
7171
else:
7272
class_example = cls.__new__(cls)
7373
nodes: Dict[str, PGraphNode] = {}
74-
output_user = PGraphNode('<output>', EMPTY_SHAPE, parallel_dims, None, set(), None, False, [], 999999)
74+
output_user = PGraphNode('<output>', EMPTY_SHAPE, parallel_dims, None, False, set(), None, False, [], 999999)
7575
for p in properties:
7676
recursive_add_node(instance, cls, property_name(p), p, dims, nodes).users.append(output_user)
7777
# nodes = merge_duplicate_nodes(nodes.values())
@@ -125,7 +125,7 @@ def parallel_compute(instance, properties: Sequence, parallel_dims=batch,
125125

126126
def delete_intermediate_caches(instance, stages: list, stage_idx: int):
127127
for node in sum(stages[:stage_idx+1], []):
128-
if not node.has_users_after(stage_idx) and node.name in instance.__dict__:
128+
if not node.persistent and not node.has_users_after(stage_idx) and node.name in instance.__dict__:
129129
ML_LOGGER.debug(f"Deleting host cache of {node.name} as it's not needed after stage {stage_idx}")
130130
del instance.__dict__[node.name]
131131

@@ -210,6 +210,7 @@ def recursive_add_node(obj, cls, name: str, prop: Optional, dims: Shape, nodes:
210210
if name in nodes:
211211
return nodes[name]
212212
prop = getattr(cls, name) if prop is None else prop
213+
persistent = isinstance(prop, ParallelProperty) and prop.persistent
213214
dep_names = cache_deps(cls, prop)
214215
dependencies = [recursive_add_node(obj, cls, n, None, dims, nodes) for n in dep_names]
215216
# --- Determine shape ---
@@ -219,7 +220,7 @@ def recursive_add_node(obj, cls, name: str, prop: Optional, dims: Shape, nodes:
219220
out, trace = trace_cached_property(obj, cls, name, prop, dims, {d.name: expand_tracers(d.out, d.distributed) for d in dependencies})
220221
if isinstance(prop, ParallelProperty):
221222
if prop.requires is MIXED:
222-
last_node = split_mixed_prop(name, out, trace, dims, nodes)
223+
last_node = split_mixed_prop(name, out, trace, dims, persistent, nodes)
223224
if last_node is not NotImplemented:
224225
return last_node
225226
prop.requires = INFER
@@ -241,7 +242,7 @@ def recursive_add_node(obj, cls, name: str, prop: Optional, dims: Shape, nodes:
241242
# --- Add node ---
242243
distributed = dims - requires # ToDo does not take input shape into account
243244
field_dep_names = field_deps(cls, prop)
244-
node = nodes[name] = PGraphNode(name, out, distributed, None, field_dep_names, dependencies)
245+
node = nodes[name] = PGraphNode(name, out, distributed, None, persistent, field_dep_names, dependencies)
245246
for dep in node.dependencies:
246247
dep.users.append(node)
247248
return node
@@ -267,7 +268,7 @@ def trace_cached_property(obj, cls, p_name: str, prop: cached_property, distribu
267268
return out_tracer, trace
268269

269270

270-
def split_mixed_prop(name: str, out: Any, trace: Trace, parallel_dims: Shape, nodes: Dict[str, PGraphNode]):
271+
def split_mixed_prop(name: str, out: Any, trace: Trace, parallel_dims: Shape, persistent: bool, nodes: Dict[str, PGraphNode]):
271272
if len(trace.all_ops) == 0:
272273
return NotImplemented
273274
op_names = {op: f"{name}_{i}_{op.name}" for i, op in enumerate(trace.all_ops)}
@@ -308,11 +309,9 @@ def {op_name}(self):
308309
replacement[t] = op_trace.add_input(dep_expr[t][5:], t)
309310
single_op = op.replace_input_tracers(op_trace, replacement)
310311
program = TraceProgram(single_op)
311-
if op.name == 'assemble_tree' or (isinstance(out, Tracer) and op_name == name): # final property output
312-
out_tracers = out
313-
else: # intermediate result, return list of op out tensors
314-
out_tracers = op.outputs
315-
node = PGraphNode(op_name, out_tracers, distributed, program, ext_field_base, ext_deps + int_deps)
312+
is_output_node = op.name == 'assemble_tree' or (isinstance(out, Tracer) and op_name == name)
313+
out_tracers = out if is_output_node else op.outputs
314+
node = PGraphNode(op_name, out_tracers, distributed, program, persistent and is_output_node, ext_field_base, ext_deps + int_deps)
316315
nodes[op_name] = node
317316
for dep in node.dependencies:
318317
dep.users.append(node)
@@ -342,6 +341,7 @@ def __hash__(self):
342341
def parallel_property(func: Callable = None, /,
343342
requires: Union[DimFilter, object] = None,
344343
out: Any = INFER,
344+
persistent: bool = False,
345345
on_direct_eval='raise'):
346346
"""
347347
Similar to `@cached_property` but with additional controls over parallelization.
@@ -355,6 +355,8 @@ def parallel_property(func: Callable = None, /,
355355
out: Declare output shapes and dtypes in the same tree structure as the output of `func`.
356356
Placeholders for `shape` and `dtype` can be created using `shape * dtype`.
357357
`Shape` instances will be assumed to be of floating-point type.
358+
persistent: If `True` the output of this property will be available after `parallel_compute` even if it was not specified as a property to be computed,
359+
as long as its computation is necessary to compute any of the requested properties.
358360
on_direct_eval: What to do when the property is accessed normally (outside `parallel_compute`) before it has been computed.
359361
Option:
360362
@@ -365,18 +367,19 @@ def parallel_property(func: Callable = None, /,
365367
if out is not INFER:
366368
out = to_tracers(out)
367369
if func is None:
368-
return partial(parallel_property, requires=requires, out=out, on_direct_eval=on_direct_eval)
369-
return ParallelProperty(func, requires=requires, out=out, on_direct_eval=on_direct_eval)
370+
return partial(parallel_property, requires=requires, out=out, persistent=persistent, on_direct_eval=on_direct_eval)
371+
return ParallelProperty(func, requires, out, persistent, on_direct_eval)
370372

371373

372374
_NOT_CACHED = object()
373375

374376

375377
class ParallelProperty(cached_property):
376-
def __init__(self, func, requires: DimFilter, out: Any, on_direct_eval: str):
378+
def __init__(self, func, requires: DimFilter, out: Any, persistent: bool, on_direct_eval: str):
377379
super().__init__(func)
378380
self.requires = requires
379381
self.out = out
382+
self.persistent = persistent
380383
self.on_direct_eval = on_direct_eval
381384

382385
def __set_name__(self, owner, name):

phiml/parallel/_pgraph.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ class PGraphNode:
1313
out: Union[Tracer, Any]
1414
distributed: Shape
1515
program: Optional[Any] # code as str or Tracer objects?
16+
persistent: bool
1617
field_dep_names: Set[str]
1718
dependencies: Sequence['PGraphNode'] = None
1819
done: bool = False

0 commit comments

Comments
 (0)