We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent a59c139 commit 7e06ab2Copy full SHA for 7e06ab2
verl/trainer/distillation/utils.py
@@ -38,7 +38,7 @@ def prepare_student_distillation_inputs(
38
logits: torch.Tensor, batch: TensorDict, cu_seqlens: torch.Tensor, config: Optional[DistillationConfig]
39
) -> dict[str, torch.Tensor]:
40
"""Prepare student distillation inputs."""
41
- stage = batch.get("stage", None)
+ stage = None if "stage" not in batch else batch["stage"]
42
if not is_distillation_enabled(config) or stage != Stage.ACTOR_UPDATE:
43
return {}
44
loss_config: DistillationLossConfig = config.distillation_loss
0 commit comments