|
57 | 57 | _DEFAULT_BUFFER_SIZE = 200 |
58 | 58 | _MIN_DP_BUFFER_SIZE = 50 |
59 | 59 | _IS_MOE_MODEL = None |
| 60 | +_IS_VL_MODEL = None |
60 | 61 | _ENABLE_SP = None |
61 | 62 | _HAS_LAYER_IDX = None |
62 | 63 | _ENABLE_NZ = None |
@@ -319,6 +320,53 @@ def _rec_find(d): |
319 | 320 | return max(layer_counts) |
320 | 321 |
|
321 | 322 |
|
| 323 | +def _is_default_capture_sizes(vllm_config: VllmConfig) -> bool: |
| 324 | + """ |
| 325 | + Check whether it is vLLM default capture sizes. |
| 326 | + """ |
| 327 | + |
| 328 | + cuda_graph_sizes = vllm_config.scheduler_config.cuda_graph_sizes |
| 329 | + if len(cuda_graph_sizes) == 1: |
| 330 | + default_size_capture_list = [1, 2, 4] + [ |
| 331 | + i for i in range(8, cuda_graph_sizes[0] + 1, 8) |
| 332 | + ] |
| 333 | + |
| 334 | + if sorted(default_size_capture_list, reverse=True) == \ |
| 335 | + vllm_config.compilation_config.cudagraph_capture_sizes: |
| 336 | + return True |
| 337 | + |
| 338 | + return False |
| 339 | + |
| 340 | + |
| 341 | +def update_default_aclgraph_sizes(vllm_config: VllmConfig) -> None: |
| 342 | + """ |
| 343 | + Update ACL graph default capture sizes, so that new sizes |
| 344 | + are more friendly to ascend ops && hardware. |
| 345 | + """ |
| 346 | + |
| 347 | + if vllm_config.model_config is None or \ |
| 348 | + vllm_config.model_config.enforce_eager or \ |
| 349 | + not _is_default_capture_sizes(vllm_config): |
| 350 | + return |
| 351 | + |
| 352 | + # modify the default capture_sizes for Qwen3-MoE models on dp settings. |
| 353 | + # this is mainly because performance of _npu_paged_attention might degrades |
| 354 | + # on special shapes. |
| 355 | + # TODO(Angazenn): we will remove this once _npu_paged_attention is fully |
| 356 | + # replaced by npu_fused_infer_attention_score which does not contain such bugs. |
| 357 | + if vllm_config.model_config and vllm_config.model_config.hf_config.model_type == "qwen3_moe" \ |
| 358 | + and vllm_config.parallel_config.tensor_parallel_size == 1 \ |
| 359 | + and vllm_config.parallel_config.data_parallel_size > 1 : |
| 360 | + max_capture_size = vllm_config.scheduler_config.cuda_graph_sizes[0] |
| 361 | + new_cudagraph_capture_sizes = [1, 2, 5, 10, 15, 20] + [ |
| 362 | + i for i in range(24, max_capture_size + 1, 8) |
| 363 | + ] |
| 364 | + |
| 365 | + vllm_config.compilation_config.cudagraph_capture_sizes = new_cudagraph_capture_sizes |
| 366 | + vllm_config.compilation_config.init_with_cudagraph_sizes( |
| 367 | + new_cudagraph_capture_sizes) |
| 368 | + |
| 369 | + |
322 | 370 | def update_aclgraph_sizes(vllm_config: VllmConfig) -> None: |
323 | 371 | """Update ACL graph capture sizes based on hardware limitations""" |
324 | 372 | # NOTE: Currently, we can only capture 1800 graphs at most, |
@@ -649,6 +697,15 @@ def _is_contain_expert(config: Any): |
649 | 697 | return False |
650 | 698 |
|
651 | 699 |
|
| 700 | +def is_vl_model(vllm_config: VllmConfig): |
| 701 | + """Checks if the model is a VL model by config""" |
| 702 | + global _IS_VL_MODEL |
| 703 | + if _IS_VL_MODEL is None: |
| 704 | + model_configs = vllm_config.model_config.hf_config.to_dict() |
| 705 | + _IS_VL_MODEL = "VL" in model_configs["architectures"][0] |
| 706 | + return _IS_VL_MODEL |
| 707 | + |
| 708 | + |
652 | 709 | def weak_ref_tensor(tensor: Any) -> Any: |
653 | 710 | """ |
654 | 711 | Create a weak reference to a tensor. |
|
0 commit comments