Skip to content

Commit 7e06ab2

Browse files
committed
Fix stage
1 parent a59c139 commit 7e06ab2

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

verl/trainer/distillation/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def prepare_student_distillation_inputs(
3838
logits: torch.Tensor, batch: TensorDict, cu_seqlens: torch.Tensor, config: Optional[DistillationConfig]
3939
) -> dict[str, torch.Tensor]:
4040
"""Prepare student distillation inputs."""
41-
stage = batch.get("stage", None)
41+
stage = None if "stage" not in batch else batch["stage"]
4242
if not is_distillation_enabled(config) or stage != Stage.ACTOR_UPDATE:
4343
return {}
4444
loss_config: DistillationLossConfig = config.distillation_loss

0 commit comments

Comments
 (0)