-
Notifications
You must be signed in to change notification settings - Fork 3.1k
[rollout, vllm, sglang] fix: forward max_tokens/max_new_tokens from rollout config to vllm/sglang backends #5028
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request correctly forwards max_tokens and max_new_tokens from the rollout configuration to the vLLM and SGLang backends, allowing for better control over generation length. The changes are applied consistently across agent_loop.py and its fully_async_policy counterpart.
I've identified a potential for configuration ambiguity if a user sets both max_tokens and max_new_tokens. The effective number of tokens would then depend on the backend being used, which could be confusing. I've suggested adding a warning to alert users to this situation, which will improve the robustness of the configuration.
| for param_name in ["max_tokens", "max_new_tokens"]: | ||
| param_value = getattr(config, param_name, None) | ||
| if param_value is not None: | ||
| sampling_params[param_name] = param_value |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If both max_tokens and max_new_tokens are set in the configuration, the behavior becomes dependent on the backend being used (vllm or sglang), as they have different parameter preferences. This can lead to unexpected generation lengths and hard-to-debug issues when switching backends. Adding a warning will alert users to this potential configuration issue, improving robustness.
| for param_name in ["max_tokens", "max_new_tokens"]: | |
| param_value = getattr(config, param_name, None) | |
| if param_value is not None: | |
| sampling_params[param_name] = param_value | |
| params_to_check = ["max_tokens", "max_new_tokens"] | |
| found_params = [] | |
| for param_name in params_to_check: | |
| param_value = getattr(config, param_name, None) | |
| if param_value is not None: | |
| sampling_params[param_name] = param_value | |
| found_params.append(param_name) | |
| if len(found_params) > 1: | |
| logger.warning( | |
| "Both 'max_tokens' and 'max_new_tokens' are set in the rollout config. " | |
| "The effective value will depend on the backend ('vllm' prefers 'max_tokens', " | |
| "'sglang' prefers 'max_new_tokens'). It is recommended to set only one." | |
| ) |
| for param_name in ["max_tokens", "max_new_tokens"]: | ||
| param_value = getattr(config, param_name, None) | ||
| if param_value is not None: | ||
| sampling_params[param_name] = param_value |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If both max_tokens and max_new_tokens are set in the configuration, the behavior becomes dependent on the backend being used (vllm or sglang), as they have different parameter preferences. This can lead to unexpected generation lengths and hard-to-debug issues when switching backends. Adding a warning will alert users to this potential configuration issue, improving robustness.
| for param_name in ["max_tokens", "max_new_tokens"]: | |
| param_value = getattr(config, param_name, None) | |
| if param_value is not None: | |
| sampling_params[param_name] = param_value | |
| params_to_check = ["max_tokens", "max_new_tokens"] | |
| found_params = [] | |
| for param_name in params_to_check: | |
| param_value = getattr(config, param_name, None) | |
| if param_value is not None: | |
| sampling_params[param_name] = param_value | |
| found_params.append(param_name) | |
| if len(found_params) > 1: | |
| logger.warning( | |
| "Both 'max_tokens' and 'max_new_tokens' are set in the rollout config. " | |
| "The effective value will depend on the backend ('vllm' prefers 'max_tokens', " | |
| "'sglang' prefers 'max_new_tokens'). It is recommended to set only one." | |
| ) |
What does this PR do?
In the current codebase, when max_tokens is not specified in sampling_params, its default value is dynamically calculated based on prompt_ids . This approach can lead to the generation of text exceeding the intended length, resulting in unnecessary computation, resource waste, and reduced inference performance. By configuring max_tokens in the configuration file, a fixed upper bound can be set to prevent uncontrolled text generation, ensuring predictable resource utilization and inference efficiency.
However, the max_tokens/max_new_tokens parameter is not passed into sampling_params in agent_loop.py. This leads to max_tokens in vllm_async_server.py only taking the default dynamically calculated value, which does not meet expectations in certain scenarios.
Therefore, we have added the two parameters max_tokens/max_new_tokens to the rollout configuration, setting their default value as null. When these parameters are configured in the runtime script file, the corresponding values in vllm_async_server.py/sglang_async_server.py will be updated accordingly. If the parameters are not configured in the runtime script file, the original default logic will be followed, thereby avoiding any impact on existing functionality.
Checklist Before Starting
API and Usage Example
actor_rollout_ref.rollout.max_tokens=${max_response_length}