Skip to content

Support decoder block-level sequential calibration#924

Open
sugunav14 wants to merge 9 commits intomainfrom
svelury/sequential-calibrate
Open

Support decoder block-level sequential calibration#924
sugunav14 wants to merge 9 commits intomainfrom
svelury/sequential-calibrate

Conversation

@sugunav14
Copy link
Contributor

@sugunav14 sugunav14 commented Feb 24, 2026

What does this PR do?

Type of change: New feature

Overview: Add support for sequential calibration of layers (at decoder level granularity) in ModelOpt.

Calibration flow

  1. Get list of decoder blocks
  2. For current block call get input activations (considering weight and activation QDQ from all other previous blocks) and call specified calibration function.

functions added

  1. get_decoder_layers() -> to detect and get list of blocks to iterate over
  2. LayerActivationCollector class -> to get input activations to the layer
  3. sequential_calibrate() -> to perform the described calibration flow
  4. use_sequential field in QuantizeAlgorithmConfig

Usage

# Sample config
NVFP4_DEFAULT_CFG = {
    "quant_cfg": {
        "*weight_quantizer": {
            "num_bits": (2, 1),
            "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)},
            "axis": None,
            "enable": True,
        },
        "*input_quantizer": {
            "num_bits": (2, 1),
            "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)},
            "axis": None,
            "enable": True,
        },
        **_default_disabled_quantizer_cfg,
    },
    "algorithm": {
           "method": "max",
           "use_sequential": True,
}

Set use_sequential=True in QUANT_CFG's "algorithm" section.

Testing

Before your PR is "Ready for review"

  • Make sure you read and follow Contributor guidelines and your commits are signed.
  • Is this change backward compatible?: Yes/No
  • Did you write any new necessary tests?: Yes/No
  • Did you add or update any necessary documentation?: Yes/No
  • Did you update Changelog?: Yes/No

Additional Information

Summary by CodeRabbit

  • New Features
    • Sequential layer-by-layer calibration: Quantization now supports processing decoder layers sequentially to improve memory efficiency on large models.

@copy-pr-bot
Copy link

copy-pr-bot bot commented Feb 24, 2026

Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually.

Contributors can view more details about this message here.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Feb 24, 2026

Important

Review skipped

Auto incremental reviews are disabled on this repository.

Please check the settings in the CodeRabbit UI or the .coderabbit.yaml file in this repository. To trigger a single review, invoke the @coderabbitai review command.

You can disable this status message by setting the reviews.review_status to false in the CodeRabbit configuration file.

Use the checkbox below for a quick retry:

  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Added sequential layer-by-layer calibration functionality to quantization pipeline. Introduces a configuration flag to enable this mode, implements the calibration orchestration logic, defines sequential calibration operations, and provides utilities for layer extraction and activation collection.

Changes

Cohort / File(s) Summary
Configuration & Orchestration
modelopt/torch/quantization/config.py, modelopt/torch/quantization/mode.py
Added use_sequential boolean field to QuantizeAlgorithmConfig. Updated mode.py to conditionally route to sequential_calibrate when flag is enabled, with backward compatibility for existing direct function calls.
Sequential Calibration Implementation
modelopt/torch/quantization/model_calib.py
Implemented sequential_calibrate() function that performs layer-by-layer calibration on transformer decoder layers by collecting per-layer activations and invoking calibration logic per layer.
Activation & Layer Utilities
modelopt/torch/quantization/utils.py, modelopt/torch/utils/network.py
Added LayerActivationCollector class to capture layer inputs via forward patching, introduced _EarlyStopForwardError exception for control flow, and implemented get_decoder_layers() utility to extract decoder layers from various model architectures.

Sequence Diagram

sequenceDiagram
    actor User
    participant Config as QuantizeAlgorithmConfig
    participant Mode as mode.py<br/>(Orchestration)
    participant ModelCalib as sequential_calibrate
    participant Collector as LayerActivationCollector
    participant Network as get_decoder_layers
    participant Model as Model

    User->>Config: Create config with<br/>use_sequential=True
    User->>Mode: Call with config
    Mode->>Mode: Check use_sequential flag
    Mode->>ModelCalib: Call sequential_calibrate()
    ModelCalib->>Network: get_decoder_layers(model)
    Network-->>ModelCalib: Return decoder layers
    loop For each layer
        ModelCalib->>Collector: Initialize collector<br/>for layer
        Collector->>Model: Patch layer forward
        Collector->>Model: Run forward pass
        Collector-->>ModelCalib: Collect layer inputs
        Collector->>Model: Unpatch layer
        ModelCalib->>ModelCalib: Call calib_func<br/>on layer inputs
    end
    ModelCalib-->>Mode: Calibration complete
    Mode-->>User: Return result
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~22 minutes

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 57.14% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title accurately describes the main feature introduced: sequential calibration at the decoder block level. It is concise, specific, and directly reflects the primary changes across multiple files.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch svelury/sequential-calibrate

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@sugunav14 sugunav14 marked this pull request as ready for review February 24, 2026 01:49
@sugunav14 sugunav14 requested review from a team as code owners February 24, 2026 01:49
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@modelopt/torch/quantization/mode.py`:
- Around line 225-243: When use_sequential (sequential) is enabled, validate
that forward_loop is provided and callable before calling sequential_calibrate;
if forward_loop is None or not callable raise a clear ValueError explaining that
sequential calibration requires a callable forward_loop. Update the branch where
sequential is True (around the sequential_calibrate call in mode.py) to perform
this check and raise the explicit error instead of letting sequential_calibrate
fail later.

In `@modelopt/torch/quantization/model_calib.py`:
- Around line 1836-1867: The sequential_calibrate function calls calib_func with
inputs as a second positional argument which collides with calibrator signatures
(causing TypeError); change the call in sequential_calibrate to pass only
forward_loop as the positional arg and supply the activations via a named
keyword (e.g., inputs=inputs) if the calibrator expects them; locate the call to
calib_func in sequential_calibrate (and the local _layer_forward_loop which uses
get_input_activations from LayerActivationCollector) and replace
calib_func(layer, inputs, forward_loop=_layer_forward_loop, **calib_kwargs) with
a keyword-argument style call (for example calib_func(layer,
forward_loop=_layer_forward_loop, inputs=inputs, **calib_kwargs)) so no
positional collision occurs, then keep the existing cleanup (del inputs;
torch.cuda.empty_cache()).

In `@modelopt/torch/quantization/utils.py`:
- Around line 816-872: The patched layer forward (_forward_w_data_collection
inside _patch_and_initialize_layer) currently only appends inputs and never
calls the original forward, so when stop_after_collection is False the layer
returns None and breaks the model; modify _forward_w_data_collection to, after
appending to self.inputs, call and return the original forward (e.g. call
self._original_forward(*args, **kwargs) if present) when stop_after_collection
is False (and retain the early raise when True), ensuring you reference
bind_forward_method/_original_forward so the original method is invoked
correctly.

In `@modelopt/torch/utils/network.py`:
- Around line 639-673: get_decoder_layers currently inspects attributes on the
passed module and misses wrapped models (DataParallel/FSDP/DeepSpeed), so first
call unwrap_model(model, force_unwrap=True) and reassign the result to model at
the start of get_decoder_layers; then proceed to check the usual attributes
(model.model.layers, model.decoder.layers, model.layers, model.transformer.h,
model.backbone.layers) on the unwrapped model to correctly locate and return the
decoder ModuleList or None.

ℹ️ Review info

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 52e662d and a938963.

📒 Files selected for processing (5)
  • modelopt/torch/quantization/config.py
  • modelopt/torch/quantization/mode.py
  • modelopt/torch/quantization/model_calib.py
  • modelopt/torch/quantization/utils.py
  • modelopt/torch/utils/network.py

Comment on lines 225 to 243
sequential = kwargs.pop("use_sequential", False)
if method is not None and "awq" in method:
# For backward compatibility
kwargs["algorithm"] = method

if func is not None:
# Call the function with forward_loop as a separate argument
func(model, forward_loop=forward_loop, **kwargs)
if sequential:
# Wrap with sequential processing
sequential_calibrate(
model,
forward_loop=forward_loop,
calib_func=func,
**kwargs,
)
else:
# Direct calibration (existing behavior)
func(model, forward_loop=forward_loop, **kwargs)
else:
raise ValueError(f"No calibration function provided for method: {method}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Validate forward_loop when use_sequential is enabled.

sequential_calibrate assumes a callable forward_loop; if it's None, the error shows up later and is harder to diagnose. Add an explicit check and clear message before calling.

💡 Suggested fix
     if func is not None:
         if sequential:
+            if forward_loop is None:
+                raise ValueError("forward_loop must be provided when use_sequential=True")
             # Wrap with sequential processing
             sequential_calibrate(
                 model,
                 forward_loop=forward_loop,
                 calib_func=func,
                 **kwargs,
             )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
sequential = kwargs.pop("use_sequential", False)
if method is not None and "awq" in method:
# For backward compatibility
kwargs["algorithm"] = method
if func is not None:
# Call the function with forward_loop as a separate argument
func(model, forward_loop=forward_loop, **kwargs)
if sequential:
# Wrap with sequential processing
sequential_calibrate(
model,
forward_loop=forward_loop,
calib_func=func,
**kwargs,
)
else:
# Direct calibration (existing behavior)
func(model, forward_loop=forward_loop, **kwargs)
else:
raise ValueError(f"No calibration function provided for method: {method}")
sequential = kwargs.pop("use_sequential", False)
if method is not None and "awq" in method:
# For backward compatibility
kwargs["algorithm"] = method
if func is not None:
if sequential:
if forward_loop is None:
raise ValueError("forward_loop must be provided when use_sequential=True")
# Wrap with sequential processing
sequential_calibrate(
model,
forward_loop=forward_loop,
calib_func=func,
**kwargs,
)
else:
# Direct calibration (existing behavior)
func(model, forward_loop=forward_loop, **kwargs)
else:
raise ValueError(f"No calibration function provided for method: {method}")
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/quantization/mode.py` around lines 225 - 243, When
use_sequential (sequential) is enabled, validate that forward_loop is provided
and callable before calling sequential_calibrate; if forward_loop is None or not
callable raise a clear ValueError explaining that sequential calibration
requires a callable forward_loop. Update the branch where sequential is True
(around the sequential_calibrate call in mode.py) to perform this check and
raise the explicit error instead of letting sequential_calibrate fail later.

Comment on lines 816 to 872
class _EarlyStopForwardError(Exception):
"""Error to stop the forward pass after collection."""


class LayerActivationCollector:
"""Helper class for collecting layer activations during forward passes.

This class allows for sequential layer calibration by
patching layers to capture inputs/outputs during forward passes
"""

def __init__(self, model: nn.Module):
self.model = model

@staticmethod
def _patch_and_initialize_layer(layer: torch.nn.Module, stop_after_collection: bool = False):
"""Patch a layer to collect inputs during forward passes."""

def _forward_w_data_collection(self, *args, **kwargs):
# Note: 'self' refers to the patched layer.
assert len(args) >= 1, (
f"Expected at least 1 positional arg, got {len(args)} args and {list(kwargs.keys())} kwargs"
)
# Only collect the inputs to the layer
self.inputs.append((args, kwargs))
if stop_after_collection:
raise _EarlyStopForwardError() # Stop the forward pass after collection

bind_forward_method(layer, _forward_w_data_collection, "_original_forward")
layer.inputs = []

@staticmethod
def _unpatch_and_cleanup_layer(layer: torch.nn.Module):
if hasattr(layer, "_original_forward"):
unpatch_forward_method(layer, "_original_forward")
if hasattr(layer, "inputs"):
del layer.inputs

@torch.no_grad()
def get_input_activations(self, layer: torch.nn.Module, forward_loop: ForwardLoop) -> list:
# Wrap model forward to catch _EarlyStopForward per-batch
def _early_stop_forward(self, *args, **kwargs):
try:
return self._original_forward(*args, **kwargs)
except _EarlyStopForwardError:
return None # Stop propagation but allow next batch

try:
bind_forward_method(self.model, _early_stop_forward, "_original_forward")
self._patch_and_initialize_layer(layer, stop_after_collection=True)
forward_loop(self.model)
inputs = layer.inputs.copy()
finally:
self._unpatch_and_cleanup_layer(layer)
unpatch_forward_method(self.model, "_original_forward")

return inputs
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Preserve the original forward when not early-stopping.

_forward_w_data_collection never calls the original forward, so stop_after_collection=False makes the patched layer return None and breaks downstream execution. Either enforce early-stop or forward to _original_forward.

🐛 Proposed fix
         def _forward_w_data_collection(self, *args, **kwargs):
             # Note: 'self' refers to the patched layer.
             assert len(args) >= 1, (
                 f"Expected at least 1 positional arg, got {len(args)} args and {list(kwargs.keys())} kwargs"
             )
             # Only collect the inputs to the layer
             self.inputs.append((args, kwargs))
             if stop_after_collection:
                 raise _EarlyStopForwardError()  # Stop the forward pass after collection
+            return self._original_forward(*args, **kwargs)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/quantization/utils.py` around lines 816 - 872, The patched
layer forward (_forward_w_data_collection inside _patch_and_initialize_layer)
currently only appends inputs and never calls the original forward, so when
stop_after_collection is False the layer returns None and breaks the model;
modify _forward_w_data_collection to, after appending to self.inputs, call and
return the original forward (e.g. call self._original_forward(*args, **kwargs)
if present) when stop_after_collection is False (and retain the early raise when
True), ensuring you reference bind_forward_method/_original_forward so the
original method is invoked correctly.

@codecov
Copy link

codecov bot commented Feb 24, 2026

Codecov Report

❌ Patch coverage is 26.66667% with 55 lines in your changes missing coverage. Please review.
✅ Project coverage is 72.01%. Comparing base (a6cbcba) to head (1f1baae).
⚠️ Report is 3 commits behind head on main.

Files with missing lines Patch % Lines
modelopt/torch/quantization/utils.py 27.77% 26 Missing ⚠️
modelopt/torch/quantization/model_calib.py 27.77% 13 Missing ⚠️
modelopt/torch/utils/network.py 7.14% 13 Missing ⚠️
modelopt/torch/quantization/mode.py 50.00% 3 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main     #924      +/-   ##
==========================================
- Coverage   72.03%   72.01%   -0.03%     
==========================================
  Files         207      210       +3     
  Lines       22718    23594     +876     
==========================================
+ Hits        16365    16991     +626     
- Misses       6353     6603     +250     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

# Call the function with forward_loop as a separate argument
func(model, forward_loop=forward_loop, **kwargs)
if sequential:
assert method in ["max"], (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this True? How can we use this for GPTQ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR is just targeting the sequential calibration flow. I plan on adding gptq in this assertion in the GPTQ support PR

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this sequential calibration will work OOTB for mse_calibrate and local_hessian_calibrate. But I can double check after this PR lands


for _, layer in enumerate(transformer_layers):
# Get updated input activations to the current layer
inputs = gettr.get_input_activations(layer, forward_loop)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For a new layer - do we run forward though all past layers? We had agreed that we will use this API only for the first layer right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was the feedback given in the design review meeting, to support a more generic module forward definition. For example some model forward definitions might contain a residual connection that is not captured in the decoder list.

This is the simplest implementation, we can add caching and other optimizations the LayerActivationCollection class in future PRs.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a significantly slower implementation.

To make the solution layer agnostic - can we add a plugin based support for HF models? Lets discuss more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In addition,

transformer_layers = get_decoder_layers(model)

Already makes sequential calibrate model dependent.

Copy link
Contributor

@realAsma realAsma left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add unittests please?

@cjluo-nv
Copy link
Collaborator

Overall: This PR introduces a useful feature for memory-constrained calibration of large models. However, there are critical correctness issues and design concerns that should be addressed before merging.


🔴 Critical Issues (Blocking)

1. Bug: Incorrect calib_func Call in sequential_calibrate

Location: modelopt/torch/quantization/model_calib.py:1859

# Current (broken):
calib_func(layer, inputs, forward_loop=_layer_forward_loop, **calib_kwargs)

Problem: calib_func (e.g., max_calibrate) expects forward_loop as the second positional argument, but inputs is being passed. This will raise TypeError when max_calibrate tries to call inputs(model).

Fix:

# The _layer_forward_loop already iterates over inputs internally
calib_func(layer, forward_loop=_layer_forward_loop, **calib_kwargs)

2. Bug: _forward_w_data_collection Never Calls Original Forward

Location: modelopt/torch/quantization/utils.py:847-856

def _forward_w_data_collection(self, *args, **kwargs):
    assert len(args) >= 1, ...
    self.inputs.append((args, kwargs))
    if stop_after_collection:
        raise _EarlyStopForwardError()
    # Missing: return self._original_forward(*args, **kwargs) when not stopping!

Problem: When stop_after_collection=False, the function returns None instead of the actual forward result. This breaks normal forward execution.

Fix:

if stop_after_collection:
    raise _EarlyStopForwardError()
return self._original_forward(*args, **kwargs)

3. Missing Validation for forward_loop

Location: modelopt/torch/quantization/mode.py:225-243

When use_sequential=True, forward_loop is required but not validated upfront.

Suggested addition:

if sequential and forward_loop is None:
    raise ValueError("forward_loop must be provided when use_sequential=True")

🟡 Design Concerns (Should be addressed)

4. O(n²) Complexity in Input Collection

Location: modelopt/torch/quantization/model_calib.py:1851-1855

for _, layer in enumerate(transformer_layers):
    inputs = gettr.get_input_activations(layer, forward_loop)  # Runs full model forward!
    ...

Problem: get_input_activations runs the full model forward pass for each layer. For a 100-layer model, this is 100 full forward passes.

Reviewer's Note: The comment exchange shows realAsma raised this concern, and the author acknowledged it's slower but simpler. While caching can be added later, this is a significant performance regression that should be clearly documented or ideally addressed now.

Minimal improvement: Add a note in the docstring:

"""Sequential calibration ... 
Note: This implementation runs O(n) forward passes where n is the number of layers.
Future optimizations may include activation caching.
"""

5. Hardcoded Model Architecture Detection

Location: modelopt/torch/utils/network.py:637-665

def get_decoder_layers(model, granularity="decoder"):
    # HuggingFace transformers pattern: model.model.layers
    if hasattr(model, "model") and hasattr(model.model, "layers"):
        return model.model.layers
    # Megatron/MCore pattern: model.decoder.layers
    if hasattr(model, "decoder") and hasattr(model.decoder, "layers"):
        return model.decoder.layers
    # ... more hardcoded patterns

Problem: This is brittle and model-specific. What about:

  • Models with different attribute names?
  • Custom transformer architectures?
  • Wrapped models (FSDP/DDP) — see point 6

Suggestion: Consider making this configurable or accepting a callable layer_selector instead of hardcoding patterns.


6. No Unwrapping of Distributed Wrappers

Location: modelopt/torch/utils/network.py:637+

get_decoder_layers doesn't unwrap DataParallel, DistributedDataParallel, FSDP, or DeepSpeed wrappers before inspecting attributes.

Fix:

from .network import unwrap_model  # Already available

def get_decoder_layers(model, granularity="decoder"):
model = unwrap_model(model, force_unwrap=True) # Add this
# ... rest of patterns


7. Restrictive Method Assertion

Location: modelopt/torch/quantization/mode.py:231-235

if sequential:
assert method in ["max"], (
f"Sequential calibration currently only supports max calibration, got {method}"
)

Problem: This runtime restriction isn't reflected in the type system or config validation. Users will only discover this at runtime.

Suggestion: Consider making use_sequential a config field only on MaxCalibConfig initially, or add a validator to QuantizeAlgorithmConfig that checks compatibility.


🟢 Simplicity Improvements

8. Code Organization: Separate Sequential Logic

The wrapped_calib_func now has two distinct code paths with different calling conventions:

if sequential:
# Path 1: wraps calib_func with sequential_calibrate
sequential_calibrate(model, forward_loop=forward_loop, calib_func=func, **kwargs)
else:
# Path 2: direct call
func(model, forward_loop=forward_loop, **kwargs)

Problem: This adds branching complexity to an already complex wrapper.

Suggestion: Consider if sequential_calibrate should be a standalone calibration method (registered in CalibrateModeRegistry) rather than a wrapper. This would follow the existing pattern for other calibration algorithms.


9. Exception-Based Flow Control

Location: modelopt/torch/quantization/utils.py:816

Using _EarlyStopForwardError to stop forward propagation is clever but non-obvious to readers.

Suggestion: Add a prominent comment explaining this mechanism:

class _EarlyStopForwardError(Exception):
"""Custom exception to enable stopping forward pass after collecting activations.

This is raised after capturing layer inputs to avoid unnecessary computation.
"""


10. Missing all Update

Location: modelopt/torch/quantization/model_calib.py:54

sequential_calibrate is imported but not added to all:

all = ["awq", "local_hessian_calibrate", "max_calibrate", "smoothquant", "svdquant"]

Missing: "sequential_calibrate"


Summary

Issue Severity Status
calib_func call bug 🔴 Critical Must fix
_forward_w_data_collection bug 🔴 Critical Must fix
forward_loop validation 🔴 Critical Must fix
O(n²) complexity 🟡 Design Document or address
Hardcoded model detection 🟡 Design Consider refactoring
No wrapper unwrapping 🟡 Design Should fix
Method assertion 🟡 Design Consider config refactor
Code organization 🟢 Polish Optional
Exception documentation 🟢 Polish Optional

Recommendation: The feature is valuable, but I would request changes for the critical bugs before merge. The O(n²) complexity should at minimum be clearly documented as a known limitation.

if func is not None:
# Call the function with forward_loop as a separate argument
func(model, forward_loop=forward_loop, **kwargs)
if sequential:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Design feedback: Consider adding an explicit validation for forward_loop when use_sequential=True. Without it, the error from sequential_calibrate is harder to diagnose:

if sequential and forward_loop is None:
    raise ValueError("forward_loop must be provided when use_sequential=True")

This is a small addition that improves the developer experience.

@@ -1819,3 +1831,37 @@ def hessian_hook(module, input, output):
torch.cuda.empty_cache()

print_rank_0("GPTQ-lite quantization completed successfully")

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Documentation suggestion: Consider adding a note about the computational complexity in the docstring. Users should understand that this implementation runs O(n) forward passes where n is the number of layers:

"""Sequential calibration - a sequential layer-by-layer calibration algorithm.

Note: This implementation runs O(n) full forward passes where n is the number of 
transformer layers. This is the simplest approach that handles arbitrary model 
architectures (including those with residual connections). Future optimizations 
may include activation caching.

Args:
    model: Model to be calibrated (must have identifiable transformer layers).
    ...
"""

This sets clear expectations about the trade-off: memory efficiency vs computation.

)

__all__ = ["awq", "local_hessian_calibrate", "max_calibrate", "smoothquant", "svdquant"]
__all__ = [

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: The __all__ export now includes "sequential_calibrate" which is good since it is now part of the public API. 👍

Copy link

@modelopt-bot modelopt-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Design & Simplicity Review - Second Pass

Thanks for the clarifications on the PR! I acknowledge the following updates from your comments:

  1. calib_func call - Fixed in latest revision
  2. ✅ ** early-stop-forward behavior** - Intentional design to only capture inputs
  3. O(n²) complexity - Simplest implementation per design review; caching in followup PR
  4. FSDP/wrapper support - Acceptable as followup PR

Inline Comments Posted

Ive left 4 inline review comments focusing on:

  1. mode.py:231 - Suggest adding explicit forward_loop validation when use_sequential=True for better error messages

  2. network.py:648 - Suggest adding unwrap_model() support (2-line fix) to enable FSDP/DDP models. Without this, users get confusing errors when models are wrapped

  3. model_calib.py:1833 - Request adding docstring documentation about O(n) forward passes so users understand the memory/compute trade-off

  4. model_calib.py:57 - Acknowledged the __all__ update includes sequential_calibrate 👍

Overall Assessment

The design is clean and follows the existing calibration pattern. The simple O(n) approach is pragmatic given the complexity of supporting arbitrary architectures with residual connections. Approving with minor suggestions that can be addressed here or in followups based on your timeline.

LGTM with the documentation suggestion strongly recommended before merge.

title="Enable sequential layer-by-layer calibration.",
description=(
"If True, the calibration algorithm is applied sequentially to each decoder block. "
"Outputs from one layer become inputs to the next, reducing memory usage for large models."
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess we can comment about the On^2 complexity here.

"""Error to stop the forward pass after collection."""


class LayerActivationCollector:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we make it in a separate file?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since it's a generic helper that can help you get and cache (in the future) input activations wouldn't it make sense to have it here? Otherwise we would be creating a separate file just for this class.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think creating a separate file for this class does not hurt and improves the readability

delattr(module, orig_forward_cache_name)


def get_decoder_layers(model: nn.Module, granularity: str = "decoder") -> nn.ModuleList | None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This breaks our modular plugin abstractions. Can we have a plugin based implementation for this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Related to #930

@realAsma
Copy link
Contributor

@sugunav14 @cjluo-nv I created a PR to make the sequential calib plugin/modular based - #930 could you please take a look?

@cjluo-nv
Copy link
Collaborator

@sugunav14 @cjluo-nv I created a PR to make the sequential calib plugin/modular based - #930 could you please take a look?

Thanks @realAsma . If it is not urgent, how about we land this PR first and then work on #930 ?

@sugunav14 sugunav14 force-pushed the svelury/sequential-calibrate branch from a9043cb to af2d97d Compare February 26, 2026 20:05
Comment on lines +1858 to +1868
for layer in transformer_layers:
# Get updated input activations to the current layer
layer_inputs = gettr.get_input_activations(layer, forward_loop)

# Define a forward loop for the current layer
def _layer_forward_loop(m, _inputs=layer_inputs):
for args, kwargs_input in _inputs:
m(*args, **kwargs_input)

# Call calibration function
calib_func(layer, _layer_forward_loop, **calib_kwargs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is my understanding correct that for layer n, we will rerun n-1 layers? so basically there are duplicated compute?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah. That's fixed in #930 which will be merged after this

Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com>
Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com>
Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com>
Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com>
Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com>
Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com>
Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com>
@sugunav14 sugunav14 force-pushed the svelury/sequential-calibrate branch from af2d97d to 63421b6 Compare February 26, 2026 21:39
Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com>
Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants