Skip to content

Commit aa163bc

Browse files
committed
Missing stage
1 parent 086de48 commit aa163bc

File tree

1 file changed

+2
-3
lines changed

1 file changed

+2
-3
lines changed

verl/trainer/distillation/utils.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,9 @@ def prepare_student_distillation_inputs(
3838
logits: torch.Tensor, batch: TensorDict, cu_seqlens: torch.Tensor, config: Optional[DistillationConfig]
3939
) -> dict[str, torch.Tensor]:
4040
"""Prepare student distillation inputs."""
41-
stage = batch["stage"]
42-
if not is_distillation_enabled(config) or stage in {Stage.OLD_LOG_PROB, Stage.REF_LOG_PROB}:
41+
stage = batch.get("stage", None)
42+
if not is_distillation_enabled(config) or stage != Stage.ACTOR_UPDATE:
4343
return {}
44-
assert stage == Stage.ACTOR_UPDATE, f"Unexpected stage: {stage}"
4544
loss_config: DistillationLossConfig = config.distillation_loss
4645
distillation_settings: DistillationLossSettings = loss_config.loss_settings
4746
if distillation_settings.use_estimator:

0 commit comments

Comments
 (0)