Skip to content

Commit ae17468

Browse files
authored
Add optional batch size in option args for replay policy (#504)
* add extra batch size in option for replay * mark obs optional
1 parent d331b68 commit ae17468

File tree

1 file changed

+12
-4
lines changed

1 file changed

+12
-4
lines changed

gr00t/policy/replay_policy.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -300,16 +300,24 @@ def _get_action(
300300
"""Replay the next action chunk from the dataset.
301301
302302
Args:
303-
observation: Batched observation dictionary (used for validation, not inference)
304-
options: Optional parameters (currently unused)
303+
observation: Optional batched observation dictionary (used for validation, not inference)
304+
options: Optional parameters
305+
- batch_size: int - Batch size to use for the action chunk
305306
306307
Returns:
307308
Tuple of (actions_dict, info_dict) where actions_dict contains action chunks
308309
with shape (B, action_horizon, D) for each action key
309310
"""
310311
# Infer batch size from observation
311-
first_video_key = self.modality_configs["video"].modality_keys[0]
312-
batch_size = observation["video"][first_video_key].shape[0]
312+
if observation is not None:
313+
first_video_key = self.modality_configs["video"].modality_keys[0]
314+
batch_size = observation["video"][first_video_key].shape[0]
315+
# If batch size is not provided in observation, check if it's provided in options
316+
elif "batch_size" in options:
317+
batch_size = options["batch_size"]
318+
else:
319+
batch_size = 1
320+
print("No batch size provided, using default batch size of 1")
313321
# Note that this can differ form the execution horizon, as the policy can predict more steps than what's actually executed.
314322
action_horizon = (
315323
self.modality_configs["action"].delta_indices[-1]

0 commit comments

Comments
 (0)