Support decoder block-level sequential calibration#924
Support decoder block-level sequential calibration#924
Conversation
|
Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually. Contributors can view more details about this message here. |
|
Important Review skippedAuto incremental reviews are disabled on this repository. Please check the settings in the CodeRabbit UI or the You can disable this status message by setting the Use the checkbox below for a quick retry:
📝 WalkthroughWalkthroughAdded sequential layer-by-layer calibration functionality to quantization pipeline. Introduces a configuration flag to enable this mode, implements the calibration orchestration logic, defines sequential calibration operations, and provides utilities for layer extraction and activation collection. Changes
Sequence DiagramsequenceDiagram
actor User
participant Config as QuantizeAlgorithmConfig
participant Mode as mode.py<br/>(Orchestration)
participant ModelCalib as sequential_calibrate
participant Collector as LayerActivationCollector
participant Network as get_decoder_layers
participant Model as Model
User->>Config: Create config with<br/>use_sequential=True
User->>Mode: Call with config
Mode->>Mode: Check use_sequential flag
Mode->>ModelCalib: Call sequential_calibrate()
ModelCalib->>Network: get_decoder_layers(model)
Network-->>ModelCalib: Return decoder layers
loop For each layer
ModelCalib->>Collector: Initialize collector<br/>for layer
Collector->>Model: Patch layer forward
Collector->>Model: Run forward pass
Collector-->>ModelCalib: Collect layer inputs
Collector->>Model: Unpatch layer
ModelCalib->>ModelCalib: Call calib_func<br/>on layer inputs
end
ModelCalib-->>Mode: Calibration complete
Mode-->>User: Return result
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~22 minutes 🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 4
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@modelopt/torch/quantization/mode.py`:
- Around line 225-243: When use_sequential (sequential) is enabled, validate
that forward_loop is provided and callable before calling sequential_calibrate;
if forward_loop is None or not callable raise a clear ValueError explaining that
sequential calibration requires a callable forward_loop. Update the branch where
sequential is True (around the sequential_calibrate call in mode.py) to perform
this check and raise the explicit error instead of letting sequential_calibrate
fail later.
In `@modelopt/torch/quantization/model_calib.py`:
- Around line 1836-1867: The sequential_calibrate function calls calib_func with
inputs as a second positional argument which collides with calibrator signatures
(causing TypeError); change the call in sequential_calibrate to pass only
forward_loop as the positional arg and supply the activations via a named
keyword (e.g., inputs=inputs) if the calibrator expects them; locate the call to
calib_func in sequential_calibrate (and the local _layer_forward_loop which uses
get_input_activations from LayerActivationCollector) and replace
calib_func(layer, inputs, forward_loop=_layer_forward_loop, **calib_kwargs) with
a keyword-argument style call (for example calib_func(layer,
forward_loop=_layer_forward_loop, inputs=inputs, **calib_kwargs)) so no
positional collision occurs, then keep the existing cleanup (del inputs;
torch.cuda.empty_cache()).
In `@modelopt/torch/quantization/utils.py`:
- Around line 816-872: The patched layer forward (_forward_w_data_collection
inside _patch_and_initialize_layer) currently only appends inputs and never
calls the original forward, so when stop_after_collection is False the layer
returns None and breaks the model; modify _forward_w_data_collection to, after
appending to self.inputs, call and return the original forward (e.g. call
self._original_forward(*args, **kwargs) if present) when stop_after_collection
is False (and retain the early raise when True), ensuring you reference
bind_forward_method/_original_forward so the original method is invoked
correctly.
In `@modelopt/torch/utils/network.py`:
- Around line 639-673: get_decoder_layers currently inspects attributes on the
passed module and misses wrapped models (DataParallel/FSDP/DeepSpeed), so first
call unwrap_model(model, force_unwrap=True) and reassign the result to model at
the start of get_decoder_layers; then proceed to check the usual attributes
(model.model.layers, model.decoder.layers, model.layers, model.transformer.h,
model.backbone.layers) on the unwrapped model to correctly locate and return the
decoder ModuleList or None.
ℹ️ Review info
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (5)
modelopt/torch/quantization/config.pymodelopt/torch/quantization/mode.pymodelopt/torch/quantization/model_calib.pymodelopt/torch/quantization/utils.pymodelopt/torch/utils/network.py
modelopt/torch/quantization/mode.py
Outdated
| sequential = kwargs.pop("use_sequential", False) | ||
| if method is not None and "awq" in method: | ||
| # For backward compatibility | ||
| kwargs["algorithm"] = method | ||
|
|
||
| if func is not None: | ||
| # Call the function with forward_loop as a separate argument | ||
| func(model, forward_loop=forward_loop, **kwargs) | ||
| if sequential: | ||
| # Wrap with sequential processing | ||
| sequential_calibrate( | ||
| model, | ||
| forward_loop=forward_loop, | ||
| calib_func=func, | ||
| **kwargs, | ||
| ) | ||
| else: | ||
| # Direct calibration (existing behavior) | ||
| func(model, forward_loop=forward_loop, **kwargs) | ||
| else: | ||
| raise ValueError(f"No calibration function provided for method: {method}") |
There was a problem hiding this comment.
Validate forward_loop when use_sequential is enabled.
sequential_calibrate assumes a callable forward_loop; if it's None, the error shows up later and is harder to diagnose. Add an explicit check and clear message before calling.
💡 Suggested fix
if func is not None:
if sequential:
+ if forward_loop is None:
+ raise ValueError("forward_loop must be provided when use_sequential=True")
# Wrap with sequential processing
sequential_calibrate(
model,
forward_loop=forward_loop,
calib_func=func,
**kwargs,
)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| sequential = kwargs.pop("use_sequential", False) | |
| if method is not None and "awq" in method: | |
| # For backward compatibility | |
| kwargs["algorithm"] = method | |
| if func is not None: | |
| # Call the function with forward_loop as a separate argument | |
| func(model, forward_loop=forward_loop, **kwargs) | |
| if sequential: | |
| # Wrap with sequential processing | |
| sequential_calibrate( | |
| model, | |
| forward_loop=forward_loop, | |
| calib_func=func, | |
| **kwargs, | |
| ) | |
| else: | |
| # Direct calibration (existing behavior) | |
| func(model, forward_loop=forward_loop, **kwargs) | |
| else: | |
| raise ValueError(f"No calibration function provided for method: {method}") | |
| sequential = kwargs.pop("use_sequential", False) | |
| if method is not None and "awq" in method: | |
| # For backward compatibility | |
| kwargs["algorithm"] = method | |
| if func is not None: | |
| if sequential: | |
| if forward_loop is None: | |
| raise ValueError("forward_loop must be provided when use_sequential=True") | |
| # Wrap with sequential processing | |
| sequential_calibrate( | |
| model, | |
| forward_loop=forward_loop, | |
| calib_func=func, | |
| **kwargs, | |
| ) | |
| else: | |
| # Direct calibration (existing behavior) | |
| func(model, forward_loop=forward_loop, **kwargs) | |
| else: | |
| raise ValueError(f"No calibration function provided for method: {method}") |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@modelopt/torch/quantization/mode.py` around lines 225 - 243, When
use_sequential (sequential) is enabled, validate that forward_loop is provided
and callable before calling sequential_calibrate; if forward_loop is None or not
callable raise a clear ValueError explaining that sequential calibration
requires a callable forward_loop. Update the branch where sequential is True
(around the sequential_calibrate call in mode.py) to perform this check and
raise the explicit error instead of letting sequential_calibrate fail later.
| class _EarlyStopForwardError(Exception): | ||
| """Error to stop the forward pass after collection.""" | ||
|
|
||
|
|
||
| class LayerActivationCollector: | ||
| """Helper class for collecting layer activations during forward passes. | ||
|
|
||
| This class allows for sequential layer calibration by | ||
| patching layers to capture inputs/outputs during forward passes | ||
| """ | ||
|
|
||
| def __init__(self, model: nn.Module): | ||
| self.model = model | ||
|
|
||
| @staticmethod | ||
| def _patch_and_initialize_layer(layer: torch.nn.Module, stop_after_collection: bool = False): | ||
| """Patch a layer to collect inputs during forward passes.""" | ||
|
|
||
| def _forward_w_data_collection(self, *args, **kwargs): | ||
| # Note: 'self' refers to the patched layer. | ||
| assert len(args) >= 1, ( | ||
| f"Expected at least 1 positional arg, got {len(args)} args and {list(kwargs.keys())} kwargs" | ||
| ) | ||
| # Only collect the inputs to the layer | ||
| self.inputs.append((args, kwargs)) | ||
| if stop_after_collection: | ||
| raise _EarlyStopForwardError() # Stop the forward pass after collection | ||
|
|
||
| bind_forward_method(layer, _forward_w_data_collection, "_original_forward") | ||
| layer.inputs = [] | ||
|
|
||
| @staticmethod | ||
| def _unpatch_and_cleanup_layer(layer: torch.nn.Module): | ||
| if hasattr(layer, "_original_forward"): | ||
| unpatch_forward_method(layer, "_original_forward") | ||
| if hasattr(layer, "inputs"): | ||
| del layer.inputs | ||
|
|
||
| @torch.no_grad() | ||
| def get_input_activations(self, layer: torch.nn.Module, forward_loop: ForwardLoop) -> list: | ||
| # Wrap model forward to catch _EarlyStopForward per-batch | ||
| def _early_stop_forward(self, *args, **kwargs): | ||
| try: | ||
| return self._original_forward(*args, **kwargs) | ||
| except _EarlyStopForwardError: | ||
| return None # Stop propagation but allow next batch | ||
|
|
||
| try: | ||
| bind_forward_method(self.model, _early_stop_forward, "_original_forward") | ||
| self._patch_and_initialize_layer(layer, stop_after_collection=True) | ||
| forward_loop(self.model) | ||
| inputs = layer.inputs.copy() | ||
| finally: | ||
| self._unpatch_and_cleanup_layer(layer) | ||
| unpatch_forward_method(self.model, "_original_forward") | ||
|
|
||
| return inputs |
There was a problem hiding this comment.
Preserve the original forward when not early-stopping.
_forward_w_data_collection never calls the original forward, so stop_after_collection=False makes the patched layer return None and breaks downstream execution. Either enforce early-stop or forward to _original_forward.
🐛 Proposed fix
def _forward_w_data_collection(self, *args, **kwargs):
# Note: 'self' refers to the patched layer.
assert len(args) >= 1, (
f"Expected at least 1 positional arg, got {len(args)} args and {list(kwargs.keys())} kwargs"
)
# Only collect the inputs to the layer
self.inputs.append((args, kwargs))
if stop_after_collection:
raise _EarlyStopForwardError() # Stop the forward pass after collection
+ return self._original_forward(*args, **kwargs)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@modelopt/torch/quantization/utils.py` around lines 816 - 872, The patched
layer forward (_forward_w_data_collection inside _patch_and_initialize_layer)
currently only appends inputs and never calls the original forward, so when
stop_after_collection is False the layer returns None and breaks the model;
modify _forward_w_data_collection to, after appending to self.inputs, call and
return the original forward (e.g. call self._original_forward(*args, **kwargs)
if present) when stop_after_collection is False (and retain the early raise when
True), ensuring you reference bind_forward_method/_original_forward so the
original method is invoked correctly.
Codecov Report❌ Patch coverage is Additional details and impacted files@@ Coverage Diff @@
## main #924 +/- ##
==========================================
- Coverage 72.03% 72.01% -0.03%
==========================================
Files 207 210 +3
Lines 22718 23594 +876
==========================================
+ Hits 16365 16991 +626
- Misses 6353 6603 +250 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
| # Call the function with forward_loop as a separate argument | ||
| func(model, forward_loop=forward_loop, **kwargs) | ||
| if sequential: | ||
| assert method in ["max"], ( |
There was a problem hiding this comment.
Is this True? How can we use this for GPTQ?
There was a problem hiding this comment.
This PR is just targeting the sequential calibration flow. I plan on adding gptq in this assertion in the GPTQ support PR
There was a problem hiding this comment.
I think this sequential calibration will work OOTB for mse_calibrate and local_hessian_calibrate. But I can double check after this PR lands
|
|
||
| for _, layer in enumerate(transformer_layers): | ||
| # Get updated input activations to the current layer | ||
| inputs = gettr.get_input_activations(layer, forward_loop) |
There was a problem hiding this comment.
For a new layer - do we run forward though all past layers? We had agreed that we will use this API only for the first layer right?
There was a problem hiding this comment.
This was the feedback given in the design review meeting, to support a more generic module forward definition. For example some model forward definitions might contain a residual connection that is not captured in the decoder list.
This is the simplest implementation, we can add caching and other optimizations the LayerActivationCollection class in future PRs.
There was a problem hiding this comment.
This is a significantly slower implementation.
To make the solution layer agnostic - can we add a plugin based support for HF models? Lets discuss more.
There was a problem hiding this comment.
In addition,
transformer_layers = get_decoder_layers(model)
Already makes sequential calibrate model dependent.
realAsma
left a comment
There was a problem hiding this comment.
Can we add unittests please?
|
Overall: This PR introduces a useful feature for memory-constrained calibration of large models. However, there are critical correctness issues and design concerns that should be addressed before merging. 🔴 Critical Issues (Blocking)1. Bug: Incorrect |
| Issue | Severity | Status |
|---|---|---|
| calib_func call bug | 🔴 Critical | Must fix |
| _forward_w_data_collection bug | 🔴 Critical | Must fix |
| forward_loop validation | 🔴 Critical | Must fix |
| O(n²) complexity | 🟡 Design | Document or address |
| Hardcoded model detection | 🟡 Design | Consider refactoring |
| No wrapper unwrapping | 🟡 Design | Should fix |
| Method assertion | 🟡 Design | Consider config refactor |
| Code organization | 🟢 Polish | Optional |
| Exception documentation | 🟢 Polish | Optional |
Recommendation: The feature is valuable, but I would request changes for the critical bugs before merge. The O(n²) complexity should at minimum be clearly documented as a known limitation.
| if func is not None: | ||
| # Call the function with forward_loop as a separate argument | ||
| func(model, forward_loop=forward_loop, **kwargs) | ||
| if sequential: |
There was a problem hiding this comment.
Design feedback: Consider adding an explicit validation for forward_loop when use_sequential=True. Without it, the error from sequential_calibrate is harder to diagnose:
if sequential and forward_loop is None:
raise ValueError("forward_loop must be provided when use_sequential=True")This is a small addition that improves the developer experience.
| @@ -1819,3 +1831,37 @@ def hessian_hook(module, input, output): | |||
| torch.cuda.empty_cache() | |||
|
|
|||
| print_rank_0("GPTQ-lite quantization completed successfully") | |||
There was a problem hiding this comment.
Documentation suggestion: Consider adding a note about the computational complexity in the docstring. Users should understand that this implementation runs O(n) forward passes where n is the number of layers:
"""Sequential calibration - a sequential layer-by-layer calibration algorithm.
Note: This implementation runs O(n) full forward passes where n is the number of
transformer layers. This is the simplest approach that handles arbitrary model
architectures (including those with residual connections). Future optimizations
may include activation caching.
Args:
model: Model to be calibrated (must have identifiable transformer layers).
...
"""This sets clear expectations about the trade-off: memory efficiency vs computation.
| ) | ||
|
|
||
| __all__ = ["awq", "local_hessian_calibrate", "max_calibrate", "smoothquant", "svdquant"] | ||
| __all__ = [ |
There was a problem hiding this comment.
Nit: The __all__ export now includes "sequential_calibrate" which is good since it is now part of the public API. 👍
modelopt-bot
left a comment
There was a problem hiding this comment.
Design & Simplicity Review - Second Pass
Thanks for the clarifications on the PR! I acknowledge the following updates from your comments:
- ✅ calib_func call - Fixed in latest revision
- ✅ ** early-stop-forward behavior** - Intentional design to only capture inputs
- ✅ O(n²) complexity - Simplest implementation per design review; caching in followup PR
- ✅ FSDP/wrapper support - Acceptable as followup PR
Inline Comments Posted
Ive left 4 inline review comments focusing on:
-
mode.py:231- Suggest adding explicitforward_loopvalidation whenuse_sequential=Truefor better error messages -
network.py:648- Suggest addingunwrap_model()support (2-line fix) to enable FSDP/DDP models. Without this, users get confusing errors when models are wrapped -
model_calib.py:1833- Request adding docstring documentation about O(n) forward passes so users understand the memory/compute trade-off -
model_calib.py:57- Acknowledged the__all__update includessequential_calibrate👍
Overall Assessment
The design is clean and follows the existing calibration pattern. The simple O(n) approach is pragmatic given the complexity of supporting arbitrary architectures with residual connections. Approving with minor suggestions that can be addressed here or in followups based on your timeline.
LGTM with the documentation suggestion strongly recommended before merge.
| title="Enable sequential layer-by-layer calibration.", | ||
| description=( | ||
| "If True, the calibration algorithm is applied sequentially to each decoder block. " | ||
| "Outputs from one layer become inputs to the next, reducing memory usage for large models." |
There was a problem hiding this comment.
I guess we can comment about the On^2 complexity here.
| """Error to stop the forward pass after collection.""" | ||
|
|
||
|
|
||
| class LayerActivationCollector: |
There was a problem hiding this comment.
can we make it in a separate file?
There was a problem hiding this comment.
Since it's a generic helper that can help you get and cache (in the future) input activations wouldn't it make sense to have it here? Otherwise we would be creating a separate file just for this class.
There was a problem hiding this comment.
I think creating a separate file for this class does not hurt and improves the readability
| delattr(module, orig_forward_cache_name) | ||
|
|
||
|
|
||
| def get_decoder_layers(model: nn.Module, granularity: str = "decoder") -> nn.ModuleList | None: |
There was a problem hiding this comment.
This breaks our modular plugin abstractions. Can we have a plugin based implementation for this?
|
@sugunav14 @cjluo-nv I created a PR to make the sequential calib plugin/modular based - #930 could you please take a look? |
Thanks @realAsma . If it is not urgent, how about we land this PR first and then work on #930 ? |
a9043cb to
af2d97d
Compare
| for layer in transformer_layers: | ||
| # Get updated input activations to the current layer | ||
| layer_inputs = gettr.get_input_activations(layer, forward_loop) | ||
|
|
||
| # Define a forward loop for the current layer | ||
| def _layer_forward_loop(m, _inputs=layer_inputs): | ||
| for args, kwargs_input in _inputs: | ||
| m(*args, **kwargs_input) | ||
|
|
||
| # Call calibration function | ||
| calib_func(layer, _layer_forward_loop, **calib_kwargs) |
There was a problem hiding this comment.
Is my understanding correct that for layer n, we will rerun n-1 layers? so basically there are duplicated compute?
There was a problem hiding this comment.
yeah. That's fixed in #930 which will be merged after this
Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com>
Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com>
Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com>
Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com>
Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com>
af2d97d to
63421b6
Compare
Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com>
What does this PR do?
Type of change: New feature
Overview: Add support for sequential calibration of layers (at decoder level granularity) in ModelOpt.
Calibration flow
functions added
Usage
Set use_sequential=True in QUANT_CFG's "algorithm" section.
Testing
Before your PR is "Ready for review"
Additional Information
Summary by CodeRabbit