Skip to content

Commit 3b6c004

Browse files
committed
wrap if else into the function
1 parent 6ed106a commit 3b6c004

File tree

3 files changed

+25
-40
lines changed

3 files changed

+25
-40
lines changed

verl/utils/transformers_compat.py

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,10 @@
2222

2323
from packaging import version
2424

25+
from verl.utils.logging import get_logger
26+
27+
logger = get_logger(__name__)
28+
2529
# Handle version compatibility for flash_attn_supports_top_left_mask
2630
# This function was added in newer versions of transformers
2731
try:
@@ -57,22 +61,25 @@ def is_transformers_version_in_range(min_version: Optional[str] = None, max_vers
5761
return lower_bound_check and upper_bound_check
5862

5963

60-
def get_max_position_embeddings(hf_config: Any) -> Optional[int]:
61-
"""Best-effort resolution of model context length from HF configs.
62-
63-
Works for:
64-
- text-only configs where max_position_embeddings is top-level
65-
- multimodal wrapper configs (e.g., Qwen3-VL) where it lives in text_config
66-
"""
64+
def resolve_max_model_len_from_hf_config(hf_config: Any) -> int | None:
6765
mpe = getattr(hf_config, "max_position_embeddings", None)
6866
if isinstance(mpe, int):
6967
return mpe
70-
71-
# Common wrappers for VLMs / composite configs
7268
for subname in ("text_config", "language_config", "llm_config"):
73-
subcfg = getattr(hf_config, subname, None)
74-
mpe = getattr(subcfg, "max_position_embeddings", None) if subcfg is not None else None
69+
sub = getattr(hf_config, subname, None)
70+
mpe = getattr(sub, "max_position_embeddings", None) if sub is not None else None
7571
if isinstance(mpe, int):
7672
return mpe
77-
7873
return None
74+
75+
76+
def maybe_set_max_model_len_from_hf_config(config: Any, hf_config: Any) -> None:
77+
mpe = resolve_max_model_len_from_hf_config(hf_config)
78+
if mpe is not None:
79+
config.max_model_len = mpe
80+
else:
81+
logger.warning(
82+
"Cannot infer max_model_len from hf_config=%s; keeping max_model_len=%s",
83+
type(hf_config),
84+
getattr(config, "max_model_len", None),
85+
)

verl/workers/rollout/sglang_rollout/async_sglang_server.py

Lines changed: 3 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141

4242
from verl.single_controller.ray import RayClassWithInitArgs
4343
from verl.utils.config import omega_conf_to_dataclass
44-
from verl.utils.transformers_compat import get_max_position_embeddings
44+
from verl.utils.transformers_compat import maybe_set_max_model_len_from_hf_config
4545
from verl.workers.config import HFModelConfig, RolloutConfig
4646
from verl.workers.rollout.replica import RolloutMode, RolloutReplica, TokenOutput
4747
from verl.workers.rollout.sglang_rollout.sglang_rollout import ServerAdapter, _set_envs_and_config
@@ -84,19 +84,8 @@ def __init__(
8484

8585
self.config: RolloutConfig = omega_conf_to_dataclass(config)
8686
self.model_config: HFModelConfig = omega_conf_to_dataclass(model_config, dataclass_type=HFModelConfig)
87-
mpe = get_max_position_embeddings(self.model_config.hf_config)
88-
if mpe is not None:
89-
# Don't accidentally exceed model limit; clamp if user set something smaller.
90-
if getattr(self.config, "max_model_len", None) is not None:
91-
self.config.max_model_len = min(self.config.max_model_len, mpe)
92-
else:
93-
self.config.max_model_len = mpe
94-
else:
95-
logger.warning(
96-
"Cannot infer max_position_embeddings from hf_config=%s; keeping max_model_len=%s",
97-
type(self.model_config.hf_config),
98-
getattr(self.config, "max_model_len", None),
99-
)
87+
# safely make sure config.max_model_len doesn't exceed hf_config's max_position_embeddings + prompt_length
88+
maybe_set_max_model_len_from_hf_config(self.config, self.model_config.hf_config)
10089
self.rollout_mode = rollout_mode
10190
self.workers = workers
10291

verl/workers/rollout/vllm_rollout/vllm_async_server.py

Lines changed: 3 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@
4545

4646
from verl.single_controller.ray import RayClassWithInitArgs
4747
from verl.utils.config import omega_conf_to_dataclass
48-
from verl.utils.transformers_compat import get_max_position_embeddings
48+
from verl.utils.transformers_compat import maybe_set_max_model_len_from_hf_config
4949
from verl.utils.vllm.vllm_fp8_utils import apply_vllm_fp8_patches
5050
from verl.workers.config import HFModelConfig, RolloutConfig
5151
from verl.workers.rollout.replica import RolloutMode, RolloutReplica, TokenOutput
@@ -196,19 +196,8 @@ def __init__(
196196

197197
self.config: RolloutConfig = omega_conf_to_dataclass(config)
198198
self.model_config: HFModelConfig = omega_conf_to_dataclass(model_config, dataclass_type=HFModelConfig)
199-
mpe = get_max_position_embeddings(self.model_config.hf_config)
200-
if mpe is not None:
201-
# Don't accidentally exceed model limit; clamp if user set something smaller.
202-
if getattr(self.config, "max_model_len", None) is not None:
203-
self.config.max_model_len = min(self.config.max_model_len, mpe)
204-
else:
205-
self.config.max_model_len = mpe
206-
else:
207-
logger.warning(
208-
"Cannot infer max_position_embeddings from hf_config=%s; keeping max_model_len=%s",
209-
type(self.model_config.hf_config),
210-
getattr(self.config, "max_model_len", None),
211-
)
199+
# safely make sure config.max_model_len doesn't exceed hf_config's max_position_embeddings + prompt_length
200+
maybe_set_max_model_len_from_hf_config(self.config, self.model_config.hf_config)
212201
self.rollout_mode = rollout_mode
213202
self.workers = workers
214203

0 commit comments

Comments
 (0)