Skip to content

Commit 563b670

Browse files
committed
fix error2
1 parent 3a9bbe3 commit 563b670

File tree

2 files changed

+10
-0
lines changed

2 files changed

+10
-0
lines changed

.github/workflows/e2e_ppo_trainer_megatron_vllm_2_ascend.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,7 @@ jobs:
213213
run: |
214214
pip install -r requirements-npu.txt
215215
pip install --no-deps -e .
216+
pip install trl
216217
- name: Check final pip list
217218
run: |
218219
pip list

tests/special_e2e/run_transferqueue.sh

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,15 @@ print(get_device_name())
167167
EOF
168168
)
169169

170+
extra_flash_args=()
171+
172+
if [ "$device_name" == "npu" ]; then
173+
echo "Detect NPU device, enabling FlashAttention..."
174+
extra_flash_args+=(
175+
++actor_rollout_ref.actor.megatron.override_transformer_config.use_flash_attn=True
176+
)
177+
fi
178+
170179
# For Ascend NPU, please add:
171180
#++actor_rollout_ref.actor.megatron.override_transformer_config.use_flash_attn=True \
172181
#++actor_rollout_ref.ref.megatron.override_transformer_config.use_flash_attn=True \

0 commit comments

Comments
 (0)