diff --git a/docs/perf/nsight_profiling.md b/docs/perf/nsight_profiling.md index 490de5e7e4f..be3befffc6f 100644 --- a/docs/perf/nsight_profiling.md +++ b/docs/perf/nsight_profiling.md @@ -27,14 +27,22 @@ Nsys options in controller nodes and worker nodes are configured in `global_prof * **`global_profiler.global_tool_config.nsys.controller_nsight_options`**. This config group is for the single controller. All fields in this config group will be just sent to Nsight Systems when Ray starts the controller process. `ppo_trainer.yaml` provides a workable example. Users can reference [Nsight Systems manual](https://docs.nvidia.com/nsight-systems/UserGuide/index.html) and [Ray user guide](https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html) for more details. * **`global_profiler.global_tool_config.nsys.worker_nsight_options`**. This config group is for the worker processes. Similarly all fields in this config group will be just sent to Nsight Systems when Ray starts the controller process. Capture range is used to control the profiler when to start and stop. So `capture-range: "cudaProfilerApi"` is fixed and does not change it. Users can change `capture-range-end` with some accurate calculation or just leave it `null`. -### Worker process profiling +### Actor_rollout_ref (SPMD) Worker process profiling Verl manages mulitiple RL roles, _Actor_, _Ref_, _Rollout_, _Critic_, _Reward_, which are implemented in different Worker classes. And these workers can be combined into one Ray Actor, running in a process group. Each RL role has its own profiling config group, `profiler`, which consists of three fields: -* **`all_ranks` and `ranks`**. When `all_ranks` is set `True` then all ranks will be profiled; when set `False`, `ranks` will be profiled. By default, verl profiles the whole training process in a series ` worker_process_..nsys-rep` files for each process rank. PID is the process ID; RID is the capture range ID. +* **`all_ranks` and `ranks`**. When `all_ranks` is set `True` then all ranks will be profiled; when set `False`, `ranks` will be profiled. By default, verl profiles the whole training process in a series `worker_process_..nsys-rep` files for each process rank. PID is the process ID; RID is the capture range ID. * **`discrete`**. When set `False`, all the roles actions in one training step will be dumped in one database. When set `True`, the actions annotated by `DistProfiler.annotate` will be dumped into a discrete database. In this case, each role's action occupies one ``. * **Verl collocate mode**. Verl can combine two Worker sub classes to one Worker Actor. In this case, the user should take care that the combined Workers have consistent `discrete`. The Nsight Systems profiler uses a `torch.cuda.profiler.start()` and `stop()` pair to dump a `` database anyway. +### Rollout server worker process profiling +Verl now use rollout server mode. AgentLoopManger mangages a list of rollout replicas; one repica manages a list of servers (in most cases, list length is 1); one server manages a list ranks of workers. +In current config interface, `actor_rollout_ref.rollout.profiler` is a standalone config, and not is shared with Actor/Ref. +`all_replicas=True` means all replicas are profiled, otherwise `replicas=[...]` are profiled. +`all_ranks=True` means all ranks are profiled, otherwise `ranks=[...]` are profiled. +Since a replica usually has one server, there is no control knobs for servers in a replica. +An example is here `verl/examples/grpo_trainer/run_qwen2-7b_math_trtllm_nsys.sh` + ### where to find the profiling data By default the `*.nsys-rep` files are saved in the directory `/tmp/ray/session_latest/logs/nsight/` at each node. According to the Ray manual, this default directory is not changeable. ["however, Ray preserves the `--output` option of the default config"](https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html). @@ -64,6 +72,13 @@ To enable profiling for specific components and steps, modify your ppo_trainer.y enable: True all_ranks: True # rollout & ref follow actor settings + rollout: + profiler: + enable: True + all_replicas: True + #replicas: [0,2] + all_ranks:False + ranks: [0,2] critic: profiler: enable: True diff --git a/examples/grpo_trainer/run_qwen2-7b_math_trtllm_nsys.sh b/examples/grpo_trainer/run_qwen2-7b_math_trtllm_nsys.sh new file mode 100644 index 00000000000..d5b598cb012 --- /dev/null +++ b/examples/grpo_trainer/run_qwen2-7b_math_trtllm_nsys.sh @@ -0,0 +1,104 @@ +set -x + +# Clean all slurm / MPI / PMIx env to avoid pmix mismatch error +for v in $(env | awk -F= '/^(PMI|PMIX|MPI|OMPI|SLURM)_/{print $1}'); do + unset "$v" +done + +export RAY_DEDUP_LOGS=0 + +# ----- +# Config +# ----- +TP=${1:-4} +PROJECT_NAME=${PROJECT_NAME:-"verl_grpo_example_gsm8k_math"} +EXP_NAME=trtllm-qwen2-7b-tp${TP}-8gpus${EXP_NAME_SUFFIX:+"-"}${EXP_NAME_SUFFIX} + +if [ $TP -eq 4 ]; then + MAX_BATCH_SIZE=1024 +else + MAX_BATCH_SIZE=384 +fi + +# ----- +# Data +# ----- +DATADIR=${DATADIR:-$PWD/data} +MODEL_PATH=${MODEL_PATH:-"Qwen/Qwen2-7B-Instruct"} + +GSM8K_TRAIN_PATH=${DATADIR}/gsm8k/train.parquet +GSM8K_TEST_PATH=${DATADIR}/gsm8k/test.parquet +MATH_TRAIN_PATH=${DATADIR}/math/train.parquet +MATH_TEST_PATH=${DATADIR}/math/test.parquet + +TRAIN_FILES="['$GSM8K_TRAIN_PATH', '$MATH_TRAIN_PATH']" +TEST_FILES="['$GSM8K_TEST_PATH', '$MATH_TEST_PATH']" + +# ----- +# Launch +# ----- +python3 -m verl.trainer.main_ppo \ + algorithm.adv_estimator=grpo \ + algorithm.rollout_correction.rollout_is_threshold=2.0 \ + data.train_files="$TRAIN_FILES" \ + data.val_files="$TEST_FILES" \ + data.train_batch_size=1024 \ + data.max_prompt_length=2048 \ + data.max_response_length=1024 \ + data.return_raw_chat=True \ + data.filter_overlong_prompts=True \ + data.truncation='error' \ + actor_rollout_ref.hybrid_engine=True \ + actor_rollout_ref.model.path=${MODEL_PATH} \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.actor.ppo_mini_batch_size=256 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=16 \ + actor_rollout_ref.actor.use_kl_loss=True \ + actor_rollout_ref.actor.kl_loss_coef=0.001 \ + actor_rollout_ref.actor.kl_loss_type=low_var_kl \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.fsdp_config.param_offload=False \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=16 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=${TP} \ + actor_rollout_ref.rollout.name=trtllm \ + actor_rollout_ref.rollout.mode="async" \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ + actor_rollout_ref.rollout.n=5 \ + actor_rollout_ref.rollout.max_num_seqs=${MAX_BATCH_SIZE} \ + actor_rollout_ref.rollout.max_num_batched_tokens=32768 \ + +actor_rollout_ref.rollout.engine_kwargs.trtllm.batch_wait_timeout_iters=32 \ + +actor_rollout_ref.rollout.engine_kwargs.trtllm.batch_wait_max_tokens_ratio=0.5 \ + actor_rollout_ref.rollout.calculate_log_probs=True \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=16 \ + actor_rollout_ref.ref.fsdp_config.param_offload=True \ + actor_rollout_ref.rollout.checkpoint_engine.update_weights_bucket_megabytes=4096 \ + algorithm.use_kl_in_reward=False \ + trainer.critic_warmup=0 \ + trainer.logger='["console"]' \ + trainer.project_name="${PROJECT_NAME}" \ + trainer.experiment_name=${EXP_NAME} \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes=2 \ + trainer.save_freq=-1 \ + trainer.test_freq=5 \ + trainer.resume_mode=disable \ + trainer.total_epochs=15 \ + trainer.val_before_train=False \ + trainer.total_training_steps=6 \ + global_profiler.tool=nsys \ + global_profiler.steps='[2,3,5]' \ + global_profiler.profile_continuous_steps=True \ + global_profiler.global_tool_config.nsys.discrete=False \ + global_profiler.global_tool_config.nsys.worker_nsight_options.capture-range-end='repeat-shutdown:2' \ + actor_rollout_ref.actor.profiler.enable=True \ + actor_rollout_ref.actor.profiler.all_ranks=False \ + actor_rollout_ref.actor.profiler.ranks=[0,2] \ + actor_rollout_ref.rollout.profiler.enable=True \ + actor_rollout_ref.rollout.profiler.all_replicas=False \ + actor_rollout_ref.rollout.profiler.replicas=[0,2] \ + actor_rollout_ref.rollout.profiler.all_ranks=False \ + actor_rollout_ref.rollout.profiler.ranks=[0,2] \ + "${@:2}" diff --git a/recipe b/recipe index 3490a22a0a3..21892b92769 160000 --- a/recipe +++ b/recipe @@ -1 +1 @@ -Subproject commit 3490a22a0a3adeb7e4787fe70b1060b642efbae4 +Subproject commit 21892b9276936efab5375c3f6b8415e472ef7118 diff --git a/tests/workers/rollout/rollout_trtllm/test_adapter.py b/tests/workers/rollout/rollout_trtllm/test_adapter.py index 004df83d0eb..a0dbc9a6006 100644 --- a/tests/workers/rollout/rollout_trtllm/test_adapter.py +++ b/tests/workers/rollout/rollout_trtllm/test_adapter.py @@ -13,7 +13,6 @@ # limitations under the License. import asyncio import os -import subprocess from unittest.mock import AsyncMock, Mock, patch import aiohttp @@ -142,7 +141,17 @@ def test_init_without_device_mesh(self): try: os.environ.setdefault("TLLM_RAY_FORCE_LOCAL_CLUSTER", "1") - ray.init(address="local", ignore_reinit_error=True, include_dashboard=False) + ray.init( + runtime_env={ + "env_vars": { + "TOKENIZERS_PARALLELISM": "true", + "NCCL_DEBUG": "WARN", + "VLLM_LOGGING_LEVEL": "INFO", + "VLLM_USE_V1": "1", + } + }, + ignore_reinit_error=True, + ) config_dir = os.path.abspath("verl/verl/trainer/config") if not os.path.exists(config_dir): @@ -187,5 +196,5 @@ def test_init_without_device_mesh(self): os.environ.pop("RANK", None) else: os.environ["RANK"] = prev_rank + print("\nShutting down Ray...") ray.shutdown() - subprocess.run(["ray", "stop"], capture_output=True) diff --git a/tests/workers/rollout/rollout_trtllm/test_async_server.py b/tests/workers/rollout/rollout_trtllm/test_async_server.py index 3224a8ce13f..068f88cc0aa 100644 --- a/tests/workers/rollout/rollout_trtllm/test_async_server.py +++ b/tests/workers/rollout/rollout_trtllm/test_async_server.py @@ -13,7 +13,6 @@ # limitations under the License. import os -import subprocess import time from unittest.mock import MagicMock, patch @@ -170,7 +169,17 @@ def test_async_generate(self): """Test TRT-LLM generate method with real model.""" try: os.environ.setdefault("TLLM_RAY_FORCE_LOCAL_CLUSTER", "1") - ray.init(address="local", ignore_reinit_error=True, include_dashboard=False) + ray.init( + runtime_env={ + "env_vars": { + "TOKENIZERS_PARALLELISM": "true", + "NCCL_DEBUG": "WARN", + "VLLM_LOGGING_LEVEL": "INFO", + "VLLM_USE_V1": "1", + } + }, + ignore_reinit_error=True, + ) rollout_config, model_config = self._build_rollout_config(response_length=50) @@ -209,14 +218,24 @@ def test_async_generate(self): print(f"Log probs: {result.log_probs[:10]}...") # Print first 10 log probs finally: + print("\nShutting down Ray...") ray.shutdown() - subprocess.run(["ray", "stop"], capture_output=True) def test_async_memory_management(self): """Test TRT-LLM async memory management (sleep) reduces memory usage.""" try: os.environ.setdefault("TLLM_RAY_FORCE_LOCAL_CLUSTER", "1") - ray.init(address="local", ignore_reinit_error=True, include_dashboard=False) + ray.init( + runtime_env={ + "env_vars": { + "TOKENIZERS_PARALLELISM": "true", + "NCCL_DEBUG": "WARN", + "VLLM_LOGGING_LEVEL": "INFO", + "VLLM_USE_V1": "1", + } + }, + ignore_reinit_error=True, + ) rollout_config, model_config = self._build_rollout_config(free_cache_engine=True) @@ -271,5 +290,5 @@ def get_gpu_memory_mb_for_device(device_uuid: str) -> float: ) finally: + print("\nShutting down Ray...") ray.shutdown() - subprocess.run(["ray", "stop"], capture_output=True) diff --git a/verl/experimental/agent_loop/agent_loop.py b/verl/experimental/agent_loop/agent_loop.py index b591d093696..6cb9f4d1755 100644 --- a/verl/experimental/agent_loop/agent_loop.py +++ b/verl/experimental/agent_loop/agent_loop.py @@ -866,7 +866,6 @@ def __init__( self.config = config self.worker_group = worker_group self.reward_loop_worker_handles = reward_loop_worker_handles - # for recipe to change if not hasattr(self, "rollout_replica_class"): self.rollout_replica_class = get_rollout_replica_class(self.config.actor_rollout_ref.rollout.name) @@ -900,6 +899,15 @@ def _initialize_llm_servers(self, rollout_resource_pool: RayResourcePool): ) for replica_rank in range(num_replicas) ] + profiling_all_replicas = OmegaConf.select(self.config.actor_rollout_ref.rollout.profiler, "all_replicas") + profiling_replica_ranks = OmegaConf.select(self.config.actor_rollout_ref.rollout.profiler, "replicas") + self.profiling_replicas = ( + self.rollout_replicas + if profiling_all_replicas + else [self.rollout_replicas[replica_rank] for replica_rank in profiling_replica_ranks] + if profiling_replica_ranks + else [] + ) if self.worker_group and rollout_config.name != "trtllm": self._run_all([server.init_hybrid(self.worker_group) for server in self.rollout_replicas]) @@ -1000,14 +1008,18 @@ def clear_kv_cache(self): def start_profile(self, **kwargs): """Start profiling on all rollout replicas.""" - self._run_all([replica.start_profile(**kwargs) for replica in self.rollout_replicas]) + self._run_all([replica.start_profile(**kwargs) for replica in self.profiling_replicas]) def stop_profile(self): """Stop profiling on all rollout replicas.""" - self._run_all([replica.stop_profile() for replica in self.rollout_replicas]) + self._run_all([replica.stop_profile() for replica in self.profiling_replicas]) def _run_all(self, tasks: list[asyncio.Task]): async def run_all(): await asyncio.gather(*tasks) asyncio.run(run_all()) + + def shutdown(self): + """Shutdown all rollout replicas.""" + self._run_all([replica.shutdown() for replica in self.rollout_replicas]) diff --git a/verl/single_controller/ray/base.py b/verl/single_controller/ray/base.py index 2f6ee47064f..39282780048 100644 --- a/verl/single_controller/ray/base.py +++ b/verl/single_controller/ray/base.py @@ -457,8 +457,6 @@ def __init__( self.profile_steps = kwargs.get("profile_steps", None) self.worker_nsight_options = kwargs.get("worker_nsight_options", None) self.customized_worker_env = kwargs.get("worker_env", {}) - if self.worker_nsight_options is not None and self.worker_nsight_options["capture-range-end"] is None: - self.worker_nsight_options["capture-range-end"] = f"repeat-shutdown:{6 * len(self.profile_steps)}" if worker_names is not None and (not self.fused_worker_used): assert self._is_init_with_detached_workers diff --git a/verl/trainer/config/_generated_ppo_megatron_trainer.yaml b/verl/trainer/config/_generated_ppo_megatron_trainer.yaml index ea60c881619..f503dbe63fb 100644 --- a/verl/trainer/config/_generated_ppo_megatron_trainer.yaml +++ b/verl/trainer/config/_generated_ppo_megatron_trainer.yaml @@ -113,7 +113,7 @@ actor_rollout_ref: _target_: verl.utils.profiler.ProfilerConfig tool: ${oc.select:global_profiler.tool,null} enable: false - all_ranks: false + all_ranks: true ranks: [] save_path: ${oc.select:global_profiler.save_path,null} tool_config: @@ -300,7 +300,10 @@ actor_rollout_ref: profiler: _target_: verl.utils.profiler.ProfilerConfig tool: ${oc.select:global_profiler.tool,null} + global_tool_config: ${oc.select:global_profiler.global_tool_config,null} enable: ${oc.select:actor_rollout_ref.actor.profiler.enable,false} + all_replicas: true + replicas: [] all_ranks: ${oc.select:actor_rollout_ref.actor.profiler.all_ranks,false} ranks: ${oc.select:actor_rollout_ref.actor.profiler.ranks,[]} save_path: ${oc.select:global_profiler.save_path,null} @@ -724,6 +727,7 @@ global_profiler: save_path: outputs/profile global_tool_config: nsys: + _target_: verl.utils.profiler.config.NsightToolConfig discrete: false controller_nsight_options: trace: cuda,nvtx,cublas,ucx @@ -734,7 +738,7 @@ global_profiler: cuda-memory-usage: 'true' cuda-graph-trace: graph capture-range: cudaProfilerApi - capture-range-end: null + capture-range-end: repeat-shutdown:6 kill: none torch_memory: trace_alloc_max_entries: 100000 diff --git a/verl/trainer/config/_generated_ppo_torchtitan_trainer.yaml b/verl/trainer/config/_generated_ppo_torchtitan_trainer.yaml index b9a8b3aaf84..64ed79370da 100644 --- a/verl/trainer/config/_generated_ppo_torchtitan_trainer.yaml +++ b/verl/trainer/config/_generated_ppo_torchtitan_trainer.yaml @@ -95,7 +95,7 @@ actor_rollout_ref: _target_: verl.utils.profiler.ProfilerConfig tool: ${oc.select:global_profiler.tool,null} enable: false - all_ranks: false + all_ranks: true ranks: [] save_path: ${oc.select:global_profiler.save_path,null} tool_config: @@ -289,7 +289,10 @@ actor_rollout_ref: profiler: _target_: verl.utils.profiler.ProfilerConfig tool: ${oc.select:global_profiler.tool,null} + global_tool_config: ${oc.select:global_profiler.global_tool_config,null} enable: ${oc.select:actor_rollout_ref.actor.profiler.enable,false} + all_replicas: true + replicas: [] all_ranks: ${oc.select:actor_rollout_ref.actor.profiler.all_ranks,false} ranks: ${oc.select:actor_rollout_ref.actor.profiler.ranks,[]} save_path: ${oc.select:global_profiler.save_path,null} @@ -657,7 +660,7 @@ global_profiler: cuda-memory-usage: 'true' cuda-graph-trace: graph capture-range: cudaProfilerApi - capture-range-end: null + capture-range-end: repeat-shutdown:6 kill: none torch_memory: trace_alloc_max_entries: 100000 diff --git a/verl/trainer/config/_generated_ppo_trainer.yaml b/verl/trainer/config/_generated_ppo_trainer.yaml index 6b97103ae9f..143ee39fd45 100644 --- a/verl/trainer/config/_generated_ppo_trainer.yaml +++ b/verl/trainer/config/_generated_ppo_trainer.yaml @@ -94,7 +94,7 @@ actor_rollout_ref: _target_: verl.utils.profiler.ProfilerConfig tool: ${oc.select:global_profiler.tool,null} enable: false - all_ranks: false + all_ranks: true ranks: [] save_path: ${oc.select:global_profiler.save_path,null} tool_config: @@ -288,7 +288,10 @@ actor_rollout_ref: profiler: _target_: verl.utils.profiler.ProfilerConfig tool: ${oc.select:global_profiler.tool,null} + global_tool_config: ${oc.select:global_profiler.global_tool_config,null} enable: ${oc.select:actor_rollout_ref.actor.profiler.enable,false} + all_replicas: true + replicas: [] all_ranks: ${oc.select:actor_rollout_ref.actor.profiler.all_ranks,false} ranks: ${oc.select:actor_rollout_ref.actor.profiler.ranks,[]} save_path: ${oc.select:global_profiler.save_path,null} @@ -669,7 +672,7 @@ global_profiler: cuda-memory-usage: 'true' cuda-graph-trace: graph capture-range: cudaProfilerApi - capture-range-end: null + capture-range-end: repeat-shutdown:6 kill: none torch_memory: trace_alloc_max_entries: 100000 diff --git a/verl/trainer/config/_generated_ppo_veomni_trainer.yaml b/verl/trainer/config/_generated_ppo_veomni_trainer.yaml index 4528e0d667d..a2d0a046b92 100644 --- a/verl/trainer/config/_generated_ppo_veomni_trainer.yaml +++ b/verl/trainer/config/_generated_ppo_veomni_trainer.yaml @@ -94,7 +94,7 @@ actor_rollout_ref: _target_: verl.utils.profiler.ProfilerConfig tool: ${oc.select:global_profiler.tool,null} enable: false - all_ranks: false + all_ranks: true ranks: [] save_path: ${oc.select:global_profiler.save_path,null} tool_config: @@ -270,7 +270,10 @@ actor_rollout_ref: profiler: _target_: verl.utils.profiler.ProfilerConfig tool: ${oc.select:global_profiler.tool,null} + global_tool_config: ${oc.select:global_profiler.global_tool_config,null} enable: ${oc.select:actor_rollout_ref.actor.profiler.enable,false} + all_replicas: true + replicas: [] all_ranks: ${oc.select:actor_rollout_ref.actor.profiler.all_ranks,false} ranks: ${oc.select:actor_rollout_ref.actor.profiler.ranks,[]} save_path: ${oc.select:global_profiler.save_path,null} @@ -637,7 +640,7 @@ global_profiler: cuda-memory-usage: 'true' cuda-graph-trace: graph capture-range: cudaProfilerApi - capture-range-end: null + capture-range-end: repeat-shutdown:6 kill: none torch_memory: trace_alloc_max_entries: 100000 diff --git a/verl/trainer/config/actor/actor.yaml b/verl/trainer/config/actor/actor.yaml index bffe8aec484..43c8aa6d1de 100644 --- a/verl/trainer/config/actor/actor.yaml +++ b/verl/trainer/config/actor/actor.yaml @@ -132,11 +132,11 @@ checkpoint: # Whether to save checkpoints asynchronously. Only effective for Megatron as of now. async_save: False - + # Mbridge config extension. # when vanilla_mbridge=True, and your filesystem is a distributed filesystem,(which means you write a file in node A # and you can read the file in node B immediately) - # set `mbridge_config.distributed_filesystem=True` and `mbridge_config.memory_efficient=True` to + # set `mbridge_config.distributed_filesystem=True` and `mbridge_config.memory_efficient=True` to # speed up the checkpoint saving by 10x speed. mbridge_config: {} @@ -162,7 +162,7 @@ optim: # Whether to use custom fused kernels (e.g., FlashAttention, fused MLP) use_fused_kernels: ${oc.select:actor_rollout_ref.model.use_fused_kernels,false} -# profile the actor model in `update_policy` +# profile the actor model in `update_policy` profiler: # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs @@ -174,9 +174,9 @@ profiler: # whether enable profile on Actor enable: False - + # Whether to profile all ranks. - all_ranks: False + all_ranks: True # The ranks that will be profiled. [] or [0,1,...] ranks: [] @@ -192,10 +192,10 @@ profiler: # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs _target_: verl.utils.profiler.config.NsightToolConfig - + # True for each task has its own database, False for all tasks in one training step share one database. discrete: ${oc.select:global_profiler.global_tool_config.nsys.discrete} - + # npu config npu: @@ -214,7 +214,7 @@ profiler: # True for each task has its own database, False for all tasks in one training step share one database. discrete: False - + # torch profiler config torch: diff --git a/verl/trainer/config/ppo_megatron_trainer.yaml b/verl/trainer/config/ppo_megatron_trainer.yaml index 17dddd60dc6..e4f9456add6 100644 --- a/verl/trainer/config/ppo_megatron_trainer.yaml +++ b/verl/trainer/config/ppo_megatron_trainer.yaml @@ -51,13 +51,13 @@ actor_rollout_ref: # LoRA rank (Dimension of the low-rank projection space.). Set to 0 to disable LoRA rank: 0 # typical values: 8, 16, 32, 64 - + # Weighting factor for the low-rank projection. Defaults to 32 alpha: 32 - + # Dropout rate for the low-rank projection. Defaults to 0.0 dropout: 0.0 - + # A list of module names to apply LoRA to. # For fused LoRA, Defaults to all linear layers ['linear_qkv', 'linear_proj', 'linear_fc1', 'linear_fc2']. # For canonical LoRA: ["linear_q", "linear_k", "linear_v", "linear_proj", "linear_fc1_up", "linear_fc1_gate", "linear_fc2"] @@ -67,7 +67,7 @@ actor_rollout_ref: # - 'linear_fc2': Apply LoRA to the second fully-connected layer in MLP # Target modules can also contain wildcards. For example, you can specify # target_modules=['*.layers.0.*.linear_qkv', '*.layers.1.*.linear_qkv'] to add LoRA to only linear_qkv on the first two layers - # + # # Note: # For MLA (e.g., DeepSeek), you should use ["linear_kv_down_proj","linear_kv_up_proj","linear_q_down_proj","linear_q_up_proj","linear_q_proj"] # Instead of "linear_qkv" or ["linear_q","linear_k","linear_v"] @@ -77,7 +77,7 @@ actor_rollout_ref: - linear_proj - linear_fc1 - linear_fc2 - + # A list of module names not to apply LoRa to. It will match all nn.Linear & nn.Linear-adjacent modules whose name # does not match any string in exclude_modules. If used, will require target_modules to be empty list or None exclude_modules: [] @@ -179,6 +179,10 @@ global_profiler: global_tool_config: # nsys config nsys: + + # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs + _target_: verl.utils.profiler.config.NsightToolConfig + # True for each task has its own database, False for all tasks in one training step share one database. discrete: False @@ -214,8 +218,8 @@ global_profiler: # valid values are "repeat-shutdown:n" or null. # For normal whole step profiling, n = len(profile_steps); # but for discrete profiling, n = len(profile_steps) * Number(subtasks). - # Or you can just leave it null and the program will use n = len(profile_steps) * 6; - capture-range-end: null + # To be simple, it's set to 6 for default, and you can change it if you want. + capture-range-end: "repeat-shutdown:6" # Send signal to the target application's process group. We let the program to exit by itself. kill: none diff --git a/verl/trainer/config/ppo_trainer.yaml b/verl/trainer/config/ppo_trainer.yaml index fd9b59862ae..68011d98e58 100644 --- a/verl/trainer/config/ppo_trainer.yaml +++ b/verl/trainer/config/ppo_trainer.yaml @@ -269,8 +269,8 @@ global_profiler: # valid values are "repeat-shutdown:n" or null. # For normal whole step profiling, n = len(profile_steps); # but for discrete profiling, n = len(profile_steps) * Number(subtasks). - # Or you can just leave it null and the program will use n = len(profile_steps) * 6; - capture-range-end: null + # To be simple, it's set to 6 for default, and you can change it if you want. + capture-range-end: "repeat-shutdown:6" # Send signal to the target application's process group. We let the program to exit by itself. kill: none diff --git a/verl/trainer/config/rollout/rollout.yaml b/verl/trainer/config/rollout/rollout.yaml index e1a4d2dad6d..3c304d19ce3 100644 --- a/verl/trainer/config/rollout/rollout.yaml +++ b/verl/trainer/config/rollout/rollout.yaml @@ -313,9 +313,18 @@ profiler: # choices: npu, torch tool: ${oc.select:global_profiler.tool,null} + # global tool config + global_tool_config: ${oc.select:global_profiler.global_tool_config,null} + # whether enable profile on rollout enable: ${oc.select:actor_rollout_ref.actor.profiler.enable,false} + # Whether to profile all replicas. + all_replicas: true + + # The replicas that will be profiled. [] or [0,1,...] + replicas: [] + # Whether to profile all ranks. all_ranks: ${oc.select:actor_rollout_ref.actor.profiler.all_ranks,false} diff --git a/verl/trainer/ppo/ray_trainer.py b/verl/trainer/ppo/ray_trainer.py index 3b10492e4d8..9b7ad1ad29d 100644 --- a/verl/trainer/ppo/ray_trainer.py +++ b/verl/trainer/ppo/ray_trainer.py @@ -1215,6 +1215,10 @@ def _update_critic(self, batch: DataProto) -> DataProto: critic_output = self.critic_wg.update_critic(batch) return critic_output + def shutdown(self): + """Shutdown the Ray PPO trainer""" + self.async_rollout_manager.shutdown() + def fit(self): """ The training loop of PPO. @@ -1278,12 +1282,27 @@ def fit(self): metrics = {} timing_raw = {} + start_profiling_flag = ( + not prev_step_profile and curr_step_profile + if self.config.global_profiler.profile_continuous_steps + else curr_step_profile + ) + next_step_profile = ( + self.global_steps + 1 in self.config.global_profiler.steps + if self.config.global_profiler.steps is not None + else False + ) + stop_profiling_flag = ( + curr_step_profile and not next_step_profile + if self.config.global_profiler.profile_continuous_steps + else curr_step_profile + ) + prev_step_profile = curr_step_profile + curr_step_profile = next_step_profile + with marked_timer("start_profile", timing_raw): - self._start_profiling( - not prev_step_profile and curr_step_profile - if self.config.global_profiler.profile_continuous_steps - else curr_step_profile - ) + self._start_profiling(start_profiling_flag) + batch: DataProto = DataProto.from_single_dict(batch_dict) batch.meta_info["temperature"] = self.config.actor_rollout_ref.rollout.temperature @@ -1304,11 +1323,11 @@ def fit(self): with marked_timer("step", timing_raw): # generate a batch with marked_timer("gen", timing_raw, color="red"): - if curr_step_profile: + if start_profiling_flag: self.async_rollout_manager.start_profile() gen_batch_output = self.async_rollout_manager.generate_sequences(gen_batch_output) self.checkpoint_manager.sleep_replicas() - if curr_step_profile: + if stop_profiling_flag: self.async_rollout_manager.stop_profile() timing_raw.update(gen_batch_output.meta_info["timing"]) @@ -1318,11 +1337,11 @@ def fit(self): with marked_timer("gen_max", timing_raw, color="purple"): gen_baseline_batch = deepcopy(gen_batch) gen_baseline_batch.meta_info["do_sample"] = False - if curr_step_profile: + if start_profiling_flag: self.async_rollout_manager.start_profile() gen_baseline_output = self.async_rollout_manager.generate_sequences(gen_baseline_batch) self.checkpoint_manager.sleep_replicas() - if curr_step_profile: + if stop_profiling_flag: self.async_rollout_manager.stop_profile() batch = batch.union(gen_baseline_output) # compute reward model score on batch @@ -1539,18 +1558,7 @@ def fit(self): metrics.update(val_metrics) with marked_timer("stop_profile", timing_raw): - next_step_profile = ( - self.global_steps + 1 in self.config.global_profiler.steps - if self.config.global_profiler.steps is not None - else False - ) - self._stop_profiling( - curr_step_profile and not next_step_profile - if self.config.global_profiler.profile_continuous_steps - else curr_step_profile - ) - prev_step_profile = curr_step_profile - curr_step_profile = next_step_profile + self._stop_profiling(stop_profiling_flag) steps_duration = timing_raw["step"] self.max_steps_duration = max(self.max_steps_duration, steps_duration) @@ -1596,6 +1604,7 @@ def fit(self): self.actor_rollout_wg.async_calls_finalize_fn_exec(blocking=True) pprint(f"Final validation metrics: {last_val_metrics}") progress_bar.close() + self.shutdown() return # this is experimental and may be changed/removed in the future diff --git a/verl/utils/profiler/config.py b/verl/utils/profiler/config.py index e31cf4c8929..ae4845de32c 100644 --- a/verl/utils/profiler/config.py +++ b/verl/utils/profiler/config.py @@ -29,6 +29,8 @@ class NsightToolConfig(BaseConfig): "True for each task has its own database, False for all tasks in one training step share one database." discrete: bool = False + controller_nsight_options: Optional[dict] = None + worker_nsight_options: Optional[dict] = None name: str = "nsight" def __post_init__(self) -> None: @@ -124,6 +126,8 @@ class ProfilerConfig(BaseConfig): tool: Optional[str] = MISSING enable: bool = False + all_replicas: bool = False + replicas: list[int] = field(default_factory=list) all_ranks: bool = False ranks: list[int] = field(default_factory=list) save_path: Optional[str] = MISSING @@ -135,6 +139,8 @@ def union(self, other: "ProfilerConfig") -> "ProfilerConfig": return ProfilerConfig( tool=self.tool, enable=self.enable or other.enable, + all_replicas=self.all_replicas or other.all_replicas, + replicas=list(set(self.replicas or []) | set(other.replicas or [])), all_ranks=self.all_ranks or other.all_ranks, ranks=list(set(self.ranks or []) | set(other.ranks or [])), save_path=self.save_path, @@ -149,6 +155,8 @@ def intersect(self, other: "ProfilerConfig") -> "ProfilerConfig": return ProfilerConfig( tool=self.tool, enable=self.enable and other.enable, + all_replicas=self.all_replicas and other.all_replicas, + replicas=list(set(self.replicas or []) & set(other.replicas or [])), all_ranks=self.all_ranks and other.all_ranks, ranks=list(set(self.ranks or []) & set(other.ranks or [])), save_path=self.save_path, diff --git a/verl/workers/rollout/replica.py b/verl/workers/rollout/replica.py index bf83ac7d05f..a4f85cc7d77 100644 --- a/verl/workers/rollout/replica.py +++ b/verl/workers/rollout/replica.py @@ -25,6 +25,7 @@ from verl.single_controller.ray import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup, ResourcePoolManager from verl.utils.config import omega_conf_to_dataclass from verl.utils.device import is_torch_npu_available +from verl.utils.profiler.config import ProfilerConfig from verl.workers.config import HFModelConfig, RolloutConfig logger = logging.getLogger(__file__) @@ -115,6 +116,7 @@ def __init__( self.servers: list[ActorHandle] = [] self._server_address: str = None self._server_handle: ActorHandle = None + self.profiler_config: ProfilerConfig = omega_conf_to_dataclass(self.config.profiler) async def init_hybrid(self, worker_group: RayWorkerGroup): """Init hybrid rollout server, rollout engine and training engine(fsdp/megatron) fused in same process. @@ -255,6 +257,10 @@ async def stop_profile(self): """Stop profiling on the replica.""" await asyncio.gather(*[server.stop_profile.remote() for server in self.servers]) + async def shutdown(self): + """Shutdown the replica.""" + await asyncio.gather(*[server.shutdown.remote() for server in self.servers]) + class RolloutReplicaRegistry: """Factory for managing rollout replica implementations.""" diff --git a/verl/workers/rollout/sglang_rollout/async_sglang_server.py b/verl/workers/rollout/sglang_rollout/async_sglang_server.py index 4190ad98f82..256937feece 100644 --- a/verl/workers/rollout/sglang_rollout/async_sglang_server.py +++ b/verl/workers/rollout/sglang_rollout/async_sglang_server.py @@ -421,6 +421,9 @@ async def stop_profile(self): ): await self.tokenizer_manager.stop_profile() + async def shutdown(self): + pass + _rollout_worker_actor_cls = ray.remote(ServerAdapter) diff --git a/verl/workers/rollout/trtllm_rollout/trtllm_async_server.py b/verl/workers/rollout/trtllm_rollout/trtllm_async_server.py index 86a90cd0287..f6e7075fae2 100644 --- a/verl/workers/rollout/trtllm_rollout/trtllm_async_server.py +++ b/verl/workers/rollout/trtllm_rollout/trtllm_async_server.py @@ -62,6 +62,8 @@ def __init__( max_colocate_count: int, pgs: list[PlacementGroup] = None, bundle_indices: list[list[int]] = None, + nsight_options: dict = None, + profiling_ranks: list[int] | None = None, ): os.environ["TRT_LLM_DISABLE_LOAD_WEIGHTS_IN_PARALLEL"] = "1" assert torch.cuda.is_available(), "TRTLLM http server should run on GPU node" @@ -84,6 +86,8 @@ def __init__( self.max_colocate_count = max_colocate_count self.pgs = pgs self.bundle_indices = bundle_indices + self.nsight_options = nsight_options + self.profiling_ranks = profiling_ranks if self.rollout_mode != RolloutMode.HYBRID and self.config.load_format == "dummy": logger.warning(f"rollout mode is {self.rollout_mode}, load_format is dummy, set to auto") @@ -145,6 +149,8 @@ async def launch_server(self): "enable_sleep": self.config.enable_sleep_mode, "allreduce_strategy": "NCCL", "sampler_type": "TRTLLMSampler", + # TODO: add nsight options back, mute it for CI + # "ray_worker_nsight_options": self.nsight_options, **engine_kwargs, } @@ -172,6 +178,8 @@ async def launch_server(self): self.llm = await AsyncLLM(**llm_kwargs) trtllm_server = OpenAIServer( + # TODO: update to generator in future + # generator=self.llm, llm=self.llm, model=self.model_config.local_path, tool_parser=None, @@ -237,9 +245,22 @@ async def report_device_ids(self) -> list[str]: """Report GPU device UUIDs from TRT-LLM workers.""" return await self.llm.collective_rpc( "report_device_id", + # TODO: mute target_ranks for CI + # target_ranks=[0], unique_reply_rank=0, ) + async def start_profile(self, **kwargs): + await self.llm.collective_rpc("start_profile", target_ranks=self.profiling_ranks) + + async def stop_profile(self, **kwargs): + await self.llm.collective_rpc("stop_profile", target_ranks=self.profiling_ranks) + + async def shutdown(self): + """Shutdown the server.""" + self.llm.shutdown() + pass + _rollout_worker_actor_cls = ray.remote(ServerAdapter) @@ -346,6 +367,11 @@ async def launch_servers(self): else f"trtllm_server_reward_{self.replica_rank}" ) + if "nsys" in self.profiler_config.global_tool_config: + nsight_options = self.profiler_config.global_tool_config["nsys"].worker_nsight_options + else: + nsight_options = None + server = TRTLLMHttpServer.options( scheduling_strategy=ray.util.scheduling_strategies.NodeAffinitySchedulingStrategy( node_id=node_id, @@ -363,6 +389,8 @@ async def launch_servers(self): max_colocate_count=self.resource_pool.max_colocate_count, pgs=pgs, bundle_indices=bundle_indices, + nsight_options=nsight_options, + profiling_ranks=(None if self.profiler_config.all_ranks else self.profiler_config.ranks), ) self.servers.append(server) diff --git a/verl/workers/rollout/vllm_rollout/vllm_async_server.py b/verl/workers/rollout/vllm_rollout/vllm_async_server.py index 638b00f0877..5c2942beaeb 100644 --- a/verl/workers/rollout/vllm_rollout/vllm_async_server.py +++ b/verl/workers/rollout/vllm_rollout/vllm_async_server.py @@ -766,6 +766,9 @@ async def abort_request(self, request_id: str, reset_prefix_cache: bool = True) logger.error(f"Error aborting request {request_id}: {e}") return {"aborted": False, "request_id": request_id, "error": str(e)} + async def shutdown(self): + pass + _rollout_worker_actor_cls = ray.remote(ServerAdapter)