@@ -300,16 +300,24 @@ def _get_action(
300300 """Replay the next action chunk from the dataset.
301301
302302 Args:
303- observation: Batched observation dictionary (used for validation, not inference)
304- options: Optional parameters (currently unused)
303+ observation: Optional batched observation dictionary (used for validation, not inference)
304+ options: Optional parameters
305+ - batch_size: int - Batch size to use for the action chunk
305306
306307 Returns:
307308 Tuple of (actions_dict, info_dict) where actions_dict contains action chunks
308309 with shape (B, action_horizon, D) for each action key
309310 """
310311 # Infer batch size from observation
311- first_video_key = self .modality_configs ["video" ].modality_keys [0 ]
312- batch_size = observation ["video" ][first_video_key ].shape [0 ]
312+ if observation is not None :
313+ first_video_key = self .modality_configs ["video" ].modality_keys [0 ]
314+ batch_size = observation ["video" ][first_video_key ].shape [0 ]
315+ # If batch size is not provided in observation, check if it's provided in options
316+ elif "batch_size" in options :
317+ batch_size = options ["batch_size" ]
318+ else :
319+ batch_size = 1
320+ print ("No batch size provided, using default batch size of 1" )
313321 # Note that this can differ form the execution horizon, as the policy can predict more steps than what's actually executed.
314322 action_horizon = (
315323 self .modality_configs ["action" ].delta_indices [- 1 ]
0 commit comments