@@ -71,7 +71,7 @@ def parallel_compute(instance, properties: Sequence, parallel_dims=batch,
7171 else :
7272 class_example = cls .__new__ (cls )
7373 nodes : Dict [str , PGraphNode ] = {}
74- output_user = PGraphNode ('<output>' , EMPTY_SHAPE , parallel_dims , None , set (), None , False , [], 999999 )
74+ output_user = PGraphNode ('<output>' , EMPTY_SHAPE , parallel_dims , None , False , set (), None , False , [], 999999 )
7575 for p in properties :
7676 recursive_add_node (instance , cls , property_name (p ), p , dims , nodes ).users .append (output_user )
7777 # nodes = merge_duplicate_nodes(nodes.values())
@@ -125,7 +125,7 @@ def parallel_compute(instance, properties: Sequence, parallel_dims=batch,
125125
126126def delete_intermediate_caches (instance , stages : list , stage_idx : int ):
127127 for node in sum (stages [:stage_idx + 1 ], []):
128- if not node .has_users_after (stage_idx ) and node .name in instance .__dict__ :
128+ if not node .persistent and not node . has_users_after (stage_idx ) and node .name in instance .__dict__ :
129129 ML_LOGGER .debug (f"Deleting host cache of { node .name } as it's not needed after stage { stage_idx } " )
130130 del instance .__dict__ [node .name ]
131131
@@ -210,6 +210,7 @@ def recursive_add_node(obj, cls, name: str, prop: Optional, dims: Shape, nodes:
210210 if name in nodes :
211211 return nodes [name ]
212212 prop = getattr (cls , name ) if prop is None else prop
213+ persistent = isinstance (prop , ParallelProperty ) and prop .persistent
213214 dep_names = cache_deps (cls , prop )
214215 dependencies = [recursive_add_node (obj , cls , n , None , dims , nodes ) for n in dep_names ]
215216 # --- Determine shape ---
@@ -219,7 +220,7 @@ def recursive_add_node(obj, cls, name: str, prop: Optional, dims: Shape, nodes:
219220 out , trace = trace_cached_property (obj , cls , name , prop , dims , {d .name : expand_tracers (d .out , d .distributed ) for d in dependencies })
220221 if isinstance (prop , ParallelProperty ):
221222 if prop .requires is MIXED :
222- last_node = split_mixed_prop (name , out , trace , dims , nodes )
223+ last_node = split_mixed_prop (name , out , trace , dims , persistent , nodes )
223224 if last_node is not NotImplemented :
224225 return last_node
225226 prop .requires = INFER
@@ -241,7 +242,7 @@ def recursive_add_node(obj, cls, name: str, prop: Optional, dims: Shape, nodes:
241242 # --- Add node ---
242243 distributed = dims - requires # ToDo does not take input shape into account
243244 field_dep_names = field_deps (cls , prop )
244- node = nodes [name ] = PGraphNode (name , out , distributed , None , field_dep_names , dependencies )
245+ node = nodes [name ] = PGraphNode (name , out , distributed , None , persistent , field_dep_names , dependencies )
245246 for dep in node .dependencies :
246247 dep .users .append (node )
247248 return node
@@ -267,7 +268,7 @@ def trace_cached_property(obj, cls, p_name: str, prop: cached_property, distribu
267268 return out_tracer , trace
268269
269270
270- def split_mixed_prop (name : str , out : Any , trace : Trace , parallel_dims : Shape , nodes : Dict [str , PGraphNode ]):
271+ def split_mixed_prop (name : str , out : Any , trace : Trace , parallel_dims : Shape , persistent : bool , nodes : Dict [str , PGraphNode ]):
271272 if len (trace .all_ops ) == 0 :
272273 return NotImplemented
273274 op_names = {op : f"{ name } _{ i } _{ op .name } " for i , op in enumerate (trace .all_ops )}
@@ -308,11 +309,9 @@ def {op_name}(self):
308309 replacement [t ] = op_trace .add_input (dep_expr [t ][5 :], t )
309310 single_op = op .replace_input_tracers (op_trace , replacement )
310311 program = TraceProgram (single_op )
311- if op .name == 'assemble_tree' or (isinstance (out , Tracer ) and op_name == name ): # final property output
312- out_tracers = out
313- else : # intermediate result, return list of op out tensors
314- out_tracers = op .outputs
315- node = PGraphNode (op_name , out_tracers , distributed , program , ext_field_base , ext_deps + int_deps )
312+ is_output_node = op .name == 'assemble_tree' or (isinstance (out , Tracer ) and op_name == name )
313+ out_tracers = out if is_output_node else op .outputs
314+ node = PGraphNode (op_name , out_tracers , distributed , program , persistent and is_output_node , ext_field_base , ext_deps + int_deps )
316315 nodes [op_name ] = node
317316 for dep in node .dependencies :
318317 dep .users .append (node )
@@ -342,6 +341,7 @@ def __hash__(self):
342341def parallel_property (func : Callable = None , / ,
343342 requires : Union [DimFilter , object ] = None ,
344343 out : Any = INFER ,
344+ persistent : bool = False ,
345345 on_direct_eval = 'raise' ):
346346 """
347347 Similar to `@cached_property` but with additional controls over parallelization.
@@ -355,6 +355,8 @@ def parallel_property(func: Callable = None, /,
355355 out: Declare output shapes and dtypes in the same tree structure as the output of `func`.
356356 Placeholders for `shape` and `dtype` can be created using `shape * dtype`.
357357 `Shape` instances will be assumed to be of floating-point type.
358+ persistent: If `True` the output of this property will be available after `parallel_compute` even if it was not specified as a property to be computed,
359+ as long as its computation is necessary to compute any of the requested properties.
358360 on_direct_eval: What to do when the property is accessed normally (outside `parallel_compute`) before it has been computed.
359361 Option:
360362
@@ -365,18 +367,19 @@ def parallel_property(func: Callable = None, /,
365367 if out is not INFER :
366368 out = to_tracers (out )
367369 if func is None :
368- return partial (parallel_property , requires = requires , out = out , on_direct_eval = on_direct_eval )
369- return ParallelProperty (func , requires = requires , out = out , on_direct_eval = on_direct_eval )
370+ return partial (parallel_property , requires = requires , out = out , persistent = persistent , on_direct_eval = on_direct_eval )
371+ return ParallelProperty (func , requires , out , persistent , on_direct_eval )
370372
371373
372374_NOT_CACHED = object ()
373375
374376
375377class ParallelProperty (cached_property ):
376- def __init__ (self , func , requires : DimFilter , out : Any , on_direct_eval : str ):
378+ def __init__ (self , func , requires : DimFilter , out : Any , persistent : bool , on_direct_eval : str ):
377379 super ().__init__ (func )
378380 self .requires = requires
379381 self .out = out
382+ self .persistent = persistent
380383 self .on_direct_eval = on_direct_eval
381384
382385 def __set_name__ (self , owner , name ):
0 commit comments