skip generate option for large models and mxfp8#942
skip generate option for large models and mxfp8#942arendu wants to merge 4 commits intoNVIDIA:mainfrom
Conversation
|
Important Review skippedAuto incremental reviews are disabled on this repository. Please check the settings in the CodeRabbit UI or the You can disable this status message by setting the Use the checkbox below for a quick retry:
📝 WalkthroughWalkthroughA new CLI flag Changes
Estimated code review effort🎯 2 (Simple) | ⏱️ ~8 minutes 🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
There was a problem hiding this comment.
Actionable comments posted: 2
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@examples/llm_ptq/hf_ptq.py`:
- Around line 1083-1092: The help text for the --skip_generate argparse flag
overstates its effect; update the parser.add_argument call for "--skip_generate"
so the description accurately says it only skips the pre/post-quantization
model.generate() preview calls (not forward passes, calibration, or the
batch-size probe path used when --batch_size 0). Edit the string passed to
parser.add_argument in hf_ptq.py to replace “cannot run forward passes” with
wording like “cannot run model.generate() previews” and explicitly note that
calibration and batch-size probing are unaffected.
- Around line 689-691: The preview input extraction (preview_input_ids) is still
executed even when args.skip_generate is true, causing an unnecessary dataloader
fetch that can fail; update the logic around args.skip_generate and
generated_ids_before_ptq so that when args.skip_generate is set you
short-circuit before any preview_input_ids or dataloader reads (set
generated_ids_before_ptq and preview_input_ids to None or skip their
assignment), i.e., move or guard the preview_input_ids extraction behind the `if
not args.skip_generate` path (the branch that currently tests model_type ==
"deepseek" and subsequent generation logic) so no preview/dataloader work runs
when generation is disabled.
| if args.skip_generate: | ||
| generated_ids_before_ptq = None | ||
| elif model_type == "deepseek": |
There was a problem hiding this comment.
Short-circuit preview input extraction when --skip_generate is set.
Even with --skip_generate, the code still fetches preview_input_ids at Line 684 before this branch. That extra batch fetch is unnecessary and can still fail on edge dataloader schemas while generation is intentionally disabled.
Proposed fix
def pre_quantize(
@@
- # Only run single sample for preview
- preview_input_ids = next(iter(calib_dataloader))[
- "input_features" if model_type == "whisper" else "input_ids"
- ][0:1]
-
- # Generate preview before quantization
- if args.skip_generate:
- generated_ids_before_ptq = None
+ preview_input_ids = None
+ # Generate preview before quantization
+ if args.skip_generate:
+ generated_ids_before_ptq = None
elif model_type == "deepseek":
+ preview_input_ids = next(iter(calib_dataloader))[
+ "input_features" if model_type == "whisper" else "input_ids"
+ ][0:1]
# DeepSeek generation may go OOM, so we skip it
generated_ids_before_ptq = None
elif is_nemotron_vl_model and tokenizer is not None:
+ preview_input_ids = next(iter(calib_dataloader))[
+ "input_features" if model_type == "whisper" else "input_ids"
+ ][0:1]
generated_ids_before_ptq = run_nemotron_vl_preview(
@@
else:
+ preview_input_ids = next(iter(calib_dataloader))[
+ "input_features" if model_type == "whisper" else "input_ids"
+ ][0:1]
generated_ids_before_ptq = full_model.generate(preview_input_ids, max_new_tokens=100)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@examples/llm_ptq/hf_ptq.py` around lines 689 - 691, The preview input
extraction (preview_input_ids) is still executed even when args.skip_generate is
true, causing an unnecessary dataloader fetch that can fail; update the logic
around args.skip_generate and generated_ids_before_ptq so that when
args.skip_generate is set you short-circuit before any preview_input_ids or
dataloader reads (set generated_ids_before_ptq and preview_input_ids to None or
skip their assignment), i.e., move or guard the preview_input_ids extraction
behind the `if not args.skip_generate` path (the branch that currently tests
model_type == "deepseek" and subsequent generation logic) so no
preview/dataloader work runs when generation is disabled.
cjluo-nv
left a comment
There was a problem hiding this comment.
Could you update the PR title?
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## main #942 +/- ##
==========================================
- Coverage 72.15% 72.12% -0.03%
==========================================
Files 210 210
Lines 23515 23515
==========================================
- Hits 16967 16961 -6
- Misses 6548 6554 +6 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
c53b00d to
172f485
Compare
|
/ok to test 172f485 |
Signed-off-by: adithyare <adithyare@nvidia.com>
Signed-off-by: adithyare <adithyare@nvidia.com>
Signed-off-by: adithyare <adithyare@nvidia.com>
Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> Signed-off-by: Adi Renduchintala <adithya.r@gmail.com>
Head branch was pushed to by a user without write access
172f485 to
76c8a62
Compare
|
/ok to test 76c8a62 |
What does this PR do?
Type of change: New feature
Overview: Adds a
--skip_generateflag tohf_ptq.pythat skips the pre/post-quantization generation preview calls. These calls runmodel.generate()which crashes for very large models (500B+) that are split across GPU and CPU viadevice_map="auto"(e.g., models with Mamba/Triton kernels that cannot handle CPU-offloaded tensors).Usage
Testing
Tested with a 500B parameter NemotronH hybrid Mamba/attention model on 4x GB200 GPUs. Without --skip_generate, the script crashes at model.generate() due to Mamba Triton kernels failing on CPU-offloaded tensors. With --skip_generate, the generation preview is skipped and quantization proceeds normally.
Before your PR is "Ready for review"
Additional Information
The --skip_generate flag sets generated_ids_before_ptq = None early, which also causes the post-quantization generate to be skipped via the existing if generated_ids_before_ptq is None: pass guard. Combined with --batch_size 1 (to skip the get_max_batch_size forward-pass probe), this eliminates all forward passes that can crash for device-map-split models.
Summary by CodeRabbit
--skip_generateCLI option to skip pre-quantization text and image generation, reducing processing time for very large models. Useful when generation previews are computationally expensive.