Skip to content

Commit 7779db1

Browse files
committed
stash
1 parent a59c139 commit 7779db1

File tree

2 files changed

+10
-8
lines changed

2 files changed

+10
-8
lines changed

examples/on_policy_distillation_trainer/run_qwen_gsmk8k.sh

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ export PATH=$CONDA_PREFIX/bin:$PATH
55
export NCCL_P2P_DISABLE=1
66
export CUDA_DEVICE_ORDER=PCI_BUS_ID
77
export CUDA_VISIBLE_DEVICES=5,6,7,8
8+
# export CUDA_VISIBLE_DEVICES=7,8
89
export DATA_PATH=$PWD/../verlData
910
export HF_HOME=$DATA_PATH
1011
export VLLM_CACHE_DIR=$DATA_PATH/vllm_cache
@@ -17,14 +18,14 @@ ROLLOUT_NAME="vllm" # sglang or vllm
1718

1819
FAMILY="Qwen"
1920
STUDENT_MODEL=Qwen2.5-0.5B
20-
TEACHER_MODEL=Qwen2.5-3B-Instruct
21+
TEACHER_MODEL=Qwen2.5-0.5B-Instruct
2122

2223
USE_POLICY_GRADIENT=False
23-
DISTILLATION_LOSS_MODE="k3"
24+
# DISTILLATION_LOSS_MODE="k3"
2425
DISTILLATION_LOSS_MODE="forward_kl_topk"
2526

26-
DISTILLATION_LOSS_MODE="k1"
27-
USE_POLICY_GRADIENT=True
27+
# USE_POLICY_GRADIENT=True
28+
# DISTILLATION_LOSS_MODE="k1"
2829

2930
DISTILLATION_LOSS_MAX_CLAMP=10.0
3031
DISTILLATION_LOG_PROB_MIN_CLAMP=null
@@ -34,7 +35,7 @@ EXP_NAME="${FAMILY}/student-${STUDENT_MODEL}/teacher-${TEACHER_MODEL}/loss-${DIS
3435

3536
MAX_PROMPT=256
3637
MAX_RESPONSE_LENGTH=512
37-
TRAIN_PROMPT_BSZ=128
38+
TRAIN_PROMPT_BSZ=8
3839
STUDENT_MICRO_BATCH_SIZE_PER_GPU=2
3940
STUDENT_MAX_TOKEN_LEN_PER_GPU=$(( STUDENT_MICRO_BATCH_SIZE_PER_GPU * (MAX_PROMPT + MAX_RESPONSE_LENGTH) ))
4041
USE_DYNAMIC_BSZ=False
@@ -44,7 +45,7 @@ STUDENT_WORLD_SIZE=2
4445
TEACHER_RESOURCE_POOL=True
4546
TEACHER_WORLD_SIZE=2
4647

47-
ENFORCE_EAGER=False # true for faster debugging
48+
ENFORCE_EAGER=True # true for faster debugging
4849

4950
############################ Paths ############################
5051

@@ -77,14 +78,14 @@ MODEL=(
7778

7879
DISTILLATION=(
7980
distillation.enabled=True
80-
distillation.num_workers=8
81+
distillation.num_workers=1
8182
distillation.teacher_model.enable_resource_pool=$TEACHER_RESOURCE_POOL
8283
distillation.teacher_model.n_gpus_per_node=$TEACHER_WORLD_SIZE
8384
distillation.teacher_model.nnodes=1
8485
distillation.teacher_model.model_path="${FAMILY}/${TEACHER_MODEL}"
8586
distillation.teacher_model.inference.tensor_model_parallel_size=1
8687
distillation.teacher_model.inference.name=$ROLLOUT_NAME
87-
distillation.teacher_model.inference.gpu_memory_utilization=0.3
88+
distillation.teacher_model.inference.gpu_memory_utilization=0.6
8889
distillation.teacher_model.inference.enforce_eager=$ENFORCE_EAGER
8990
distillation.distillation_loss.loss_mode=$DISTILLATION_LOSS_MODE
9091
distillation.distillation_loss.topk=64

verl/trainer/distillation/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ def prepare_student_distillation_inputs(
3939
) -> dict[str, torch.Tensor]:
4040
"""Prepare student distillation inputs."""
4141
stage = batch.get("stage", None)
42+
breakpoint()
4243
if not is_distillation_enabled(config) or stage != Stage.ACTOR_UPDATE:
4344
return {}
4445
loss_config: DistillationLossConfig = config.distillation_loss

0 commit comments

Comments
 (0)