File tree Expand file tree Collapse file tree 3 files changed +26
-4
lines changed
Expand file tree Collapse file tree 3 files changed +26
-4
lines changed Original file line number Diff line number Diff line change 4444from verl .workers .config import HFModelConfig , RolloutConfig
4545from verl .workers .rollout .replica import RolloutMode , RolloutReplica , TokenOutput
4646from verl .workers .rollout .sglang_rollout .sglang_rollout import ServerAdapter , _set_envs_and_config
47- from verl .workers .rollout .utils import get_free_port , is_valid_ipv6_address , run_unvicorn
47+ from verl .workers .rollout .utils import (
48+ get_free_port ,
49+ get_max_position_embeddings ,
50+ is_valid_ipv6_address ,
51+ run_unvicorn ,
52+ )
4853
4954logger = logging .getLogger (__file__ )
5055logger .setLevel (logging .INFO )
@@ -83,7 +88,7 @@ def __init__(
8388
8489 self .config : RolloutConfig = omega_conf_to_dataclass (config )
8590 self .model_config : HFModelConfig = omega_conf_to_dataclass (model_config , dataclass_type = HFModelConfig )
86- self .config .max_model_len = self .model_config .hf_config . max_position_embeddings
91+ self .config .max_model_len = get_max_position_embeddings ( self .model_config .hf_config )
8792 self .rollout_mode = rollout_mode
8893 self .workers = workers
8994
Original file line number Diff line number Diff line change 2323logger = logging .getLogger (__file__ )
2424
2525
26+ def get_max_position_embeddings (hf_config ) -> int :
27+ max_len = getattr (hf_config , "max_position_embeddings" , None )
28+ if max_len is None :
29+ text_config = getattr (hf_config , "text_config" , None )
30+ if text_config is not None :
31+ max_len = getattr (text_config , "max_position_embeddings" , None )
32+
33+ if max_len is None :
34+ raise ValueError ("max_position_embeddings not found in HFModelConfig!" )
35+ return int (max_len )
36+
37+
2638def is_valid_ipv6_address (address : str ) -> bool :
2739 try :
2840 ipaddress .IPv6Address (address )
Original file line number Diff line number Diff line change 4848from verl .utils .vllm .vllm_fp8_utils import apply_vllm_fp8_patches
4949from verl .workers .config import HFModelConfig , RolloutConfig
5050from verl .workers .rollout .replica import RolloutMode , RolloutReplica , TokenOutput
51- from verl .workers .rollout .utils import get_free_port , is_valid_ipv6_address , run_unvicorn
51+ from verl .workers .rollout .utils import (
52+ get_free_port ,
53+ get_max_position_embeddings ,
54+ is_valid_ipv6_address ,
55+ run_unvicorn ,
56+ )
5257from verl .workers .rollout .vllm_rollout import vLLMAsyncRollout
5358from verl .workers .rollout .vllm_rollout .utils import (
5459 VLLM_LORA_INT_ID ,
@@ -195,7 +200,7 @@ def __init__(
195200
196201 self .config : RolloutConfig = omega_conf_to_dataclass (config )
197202 self .model_config : HFModelConfig = omega_conf_to_dataclass (model_config , dataclass_type = HFModelConfig )
198- self .config .max_model_len = self .model_config .hf_config . max_position_embeddings
203+ self .config .max_model_len = get_max_position_embeddings ( self .model_config .hf_config )
199204 self .rollout_mode = rollout_mode
200205 self .workers = workers
201206
You can’t perform that action at this time.
0 commit comments