Skip to content

Commit e567e54

Browse files
Hsu, Mu-Chienquic-muchhsu
authored andcommitted
Remove omniquant config (#4816)
* save trained LET scale to _cached_prev/foll_scale when folding Signed-off-by: Mu-Chien Hsu <quic_muchhsu@quicinc.com> * fix forward function Signed-off-by: Mu-Chien Hsu <quic_muchhsu@quicinc.com> * add forward function to API input Signed-off-by: Mu-Chien Hsu <quic_muchhsu@quicinc.com> * remove omniquant config Signed-off-by: Mu-Chien Hsu <quic_muchhsu@quicinc.com> * fix bug Signed-off-by: Mu-Chien Hsu <quic_muchhsu@quicinc.com> * remove lr and input sym Signed-off-by: Mu-Chien Hsu <quic_muchhsu@quicinc.com> * hardcode input_symmetry Signed-off-by: Mu-Chien Hsu <quic_muchhsu@quicinc.com> --------- Signed-off-by: Mu-Chien Hsu <quic_muchhsu@quicinc.com> Co-authored-by: Mu-Chien Hsu <quic_muchhsu@quicinc.com>
1 parent 7236fb2 commit e567e54

File tree

3 files changed

+30
-87
lines changed

3 files changed

+30
-87
lines changed

TrainingExtensions/torch/src/python/aimet_torch/experimental/omniquant/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,4 +37,3 @@
3737

3838
# pylint: disable=missing-docstring
3939
from .omniquant_optimizer import Omniquant
40-
from .omniquant_config import OmniquantConfig

TrainingExtensions/torch/src/python/aimet_torch/experimental/omniquant/omniquant_config.py

Lines changed: 0 additions & 48 deletions
This file was deleted.

TrainingExtensions/torch/src/python/aimet_torch/experimental/omniquant/omniquant_optimizer.py

Lines changed: 30 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,6 @@
5656
from aimet_torch.v2.nn import compute_param_encodings
5757
from aimet_common.utils import AimetLogger
5858
from .decoder_processor import get_transformer_processor
59-
from .omniquant_config import OmniquantConfig
6059
from .let_modules import LETModule
6160
from ._utils import (
6261
_convert_sim_to_letsim,
@@ -70,30 +69,35 @@
7069

7170
OMNIQUANT_ARTIFACT_DIR = "./aimet_omniquant_artifact/"
7271
OMNIQUANT_METADATA_SAFETENSOR_NAME = "aimet_omniquant_metadata.safetensor"
73-
72+
OMNIQUANT_COMPUTE_SQNR = True
73+
OMNIQUANT_LR = 5e-4 # 1st fp/qt to choose input source on fp block to get ground truth. 2nd fp/qt to choose input source on qt block to get prediction.
74+
CACHE_ON_CPU = True # Will be removed after using blockwise sampler.
7475

7576
_logger = AimetLogger.get_area_logger(AimetLogger.LogAreas.Quant)
7677

7778
class Omniquant:
7879
"""
7980
Omniquant for Post Training Quantization (PTQ)
8081
"""
82+
# pylint: disable=too-many-arguments
8183
@classmethod
82-
def apply_omniquant(cls, quant_sim: QuantizationSimModel, model: torch.nn.Module, omniquant_config: OmniquantConfig, dataloader,
83-
forward_fn: Callable, output_path: str = OMNIQUANT_ARTIFACT_DIR) -> torch.nn.Module:
84+
def apply_omniquant(cls, quant_sim: QuantizationSimModel, model: torch.nn.Module, dataloader,
85+
forward_fn: Callable, num_epoch: int, output_path: str = OMNIQUANT_ARTIFACT_DIR) -> torch.nn.Module:
8486
"""
8587
Returns model with with omniquant weight, and save metadata in safetensor format to output path. Metadata safetensor
8688
can be used in update_lora_weights to update lora adaptor weights for peft lora model.
8789
8890
:param quant_sim: QuantizationSimModel object to optimize with Omniquant.
8991
:param model: Original fp32 model from which quant_sim was created.
90-
:param omniquant_config: Configuration for Omniquant optimization.
9192
:param dataloader: Dataloader used to train model.
9293
:param forward_fn: Model forward function used to cache intermediate data.
9394
Expect to have model and inputs as function argument. e.g. lambda model, inputs: model(*inputs)
95+
:param num_epoch: Epochs to train each block with omniquant.
9496
:param output_path: Path to save {layer_name: scale} metadata safetensor.
9597
:return: Model with Omniquant weights.
9698
"""
99+
num_batch = len(dataloader)
100+
97101
@contextlib.contextmanager
98102
def disable_dynamic_cache():
99103
# Disable dynamic_cache for LET blockwise training, and restore after optimization.
@@ -105,11 +109,10 @@ def disable_dynamic_cache():
105109
quant_sim.model.config.use_cache, model.config.use_cache = quant_sim_use_cache_bool, model_use_cache_bool
106110
output_path = Path(output_path)
107111
os.makedirs(output_path, exist_ok=True)
108-
cls._validate_omniquant_config(omniquant_config)
109-
_logger.info(omniquant_config)
112+
110113
start_omq_optmztn_time = time.perf_counter()
111114
with disable_dynamic_cache():
112-
cls._apply_omniquant(quant_sim, model, omniquant_config, dataloader, forward_fn, output_path)
115+
cls._apply_omniquant(quant_sim, model, dataloader, forward_fn, num_epoch, num_batch, output_path)
113116
total_omq_optmztn_time= time.perf_counter() - start_omq_optmztn_time
114117
_logger.info("Took %.4f seconds for omq optimization ", total_omq_optmztn_time)
115118

@@ -119,17 +122,18 @@ def disable_dynamic_cache():
119122
# pylint: disable=too-many-arguments
120123
# pylint: disable=too-many-statements
121124
@classmethod
122-
def _apply_omniquant(cls, quant_sim: QuantizationSimModel, model: torch.nn.Module, omniquant_config: OmniquantConfig,
123-
dataloader, forward_fn, output_path: str) -> torch.nn.Module:
125+
def _apply_omniquant(cls, quant_sim: QuantizationSimModel, model: torch.nn.Module, dataloader, forward_fn,
126+
num_epoch: int, num_batch: int, output_path: str) -> torch.nn.Module:
124127
"""
125128
Implemenatation to run omniquant optimization block by block. Return model with optimized weights.
126129
127130
:param quant_sim: QuantizationSimModel object to optimize with Omniquant.
128131
:param model: Original fp32 model from which quant_sim was created.
129-
:param omniquant_config: Configuration for Omniquant optimization.
130132
:param dataloader: Dataloader used to train model.
131133
:param forward_fn: Model forward function used to cache intermediate data.
132134
Expect to have model and inputs as function argument. e.g. lambda model, inputs: model(*inputs)
135+
:param num_epoch: Epochs to train each block with omniquant.
136+
:param num_batch: Number of batches in dataloader.
133137
:param output_path: Path where to store artifacts.
134138
:return: Model with Omniquant weights.
135139
"""
@@ -148,11 +152,11 @@ def _apply_omniquant(cls, quant_sim: QuantizationSimModel, model: torch.nn.Modul
148152
num_repeats = model.config.num_attention_heads//model.config.num_key_value_heads
149153
with tempfile.TemporaryDirectory() as tempdir:
150154
cached_dir = os.path.join(tempdir, 'cached_dataset')
151-
cached_dataset = CachedDataset(dataloader, omniquant_config.num_batch, cached_dir)
155+
cached_dataset = CachedDataset(dataloader, num_batch, cached_dir)
152156

153157
cached_fp_dataset, cached_quant_dataset = get_block_inputs(
154-
model, quant_sim, ".".join([transformer_processor.transformer_block_list_path, "0"]), cached_dataset, omniquant_config.cache_on_cpu,
155-
forward_fn, omniquant_config.num_batch, cached_dir, incl_kwargs=True
158+
model, quant_sim, ".".join([transformer_processor.transformer_block_list_path, "0"]), cached_dataset, CACHE_ON_CPU,
159+
forward_fn, num_batch, cached_dir, incl_kwargs=True
156160
)
157161
for block_num, (fp_block, qt_block) in enumerate(zip(fp_transformer_block_list, qt_transformer_block_list)):
158162
qt_let_pair_list = transformer_processor.get_let_module_pair(qt_block)
@@ -174,26 +178,26 @@ def set_qt_params_trainable(qt_block):
174178
encoding_params, param_names = cls._get_trainable_params(qt_block)
175179
let_params = cls._get_let_params(qt_let_pair_list)
176180
grouped_params = [
177-
{"params": encoding_params, "lr": omniquant_config.omq_lr, "weight_decay": 0.},
178-
{"params": let_params, "lr": omniquant_config.omq_lr, "weight_decay": 0.},
181+
{"params": encoding_params, "lr": OMNIQUANT_LR, "weight_decay": 0.},
182+
{"params": let_params, "lr": OMNIQUANT_LR, "weight_decay": 0.},
179183
]
180184

181185
optimizer = torch.optim.AdamW(grouped_params)
182186
loss_fn = torch.nn.MSELoss(reduction="sum")
183187

184188
_logger.info("Starting blockwise training for params")
185-
for epoch in tqdm(range(omniquant_config.num_epoch)):
189+
for epoch in tqdm(range(num_epoch)):
186190
sqnr_list = []
187191
loss_list = []
188-
for batch_num in range(omniquant_config.num_batch):
192+
for batch_num in range(num_batch):
189193
fp_input, qt_input = cached_fp_dataset[batch_num], cached_quant_dataset[batch_num]
190194
# Do block-wise training.
191-
loss, sqnr = cls._block_wise_training_step(omniquant_config, fp_input, qt_input, fp_block, qt_block, qt_let_pair_list, optimizer, loss_fn, omniquant_config.compute_sqnr, model.device)
195+
loss, sqnr = cls._block_wise_training_step(fp_input, qt_input, fp_block, qt_block, qt_let_pair_list, optimizer, loss_fn, model.device)
192196
sqnr_list += [sqnr]
193197
loss_list += [loss]
194198
loss_mean = torch.stack(loss_list).mean()
195199
log_msg = f"layer {block_num} epoch {epoch} | loss: {loss_mean:.3f}"
196-
if omniquant_config.compute_sqnr:
200+
if OMNIQUANT_COMPUTE_SQNR:
197201
sqnr_mean = torch.stack(sqnr_list).mean()
198202
log_msg += f"{log_msg} | sqnr: {sqnr_mean:.3f}"
199203

@@ -208,7 +212,7 @@ def set_qt_params_trainable(qt_block):
208212
freeze_let_optimized_param_quantizers(qt_block)
209213
# TODO if should call compute_param_encodings after blockwise training
210214
get_block_outputs(
211-
fp_block, qt_block, False, cached_fp_dataset, cached_quant_dataset, omniquant_config.cache_on_cpu,
215+
fp_block, qt_block, False, cached_fp_dataset, cached_quant_dataset, CACHE_ON_CPU,
212216
lambda decoder_block, *args, **kwargs: decoder_block(*args, **kwargs), model.device, cached_dir
213217
)
214218
# pylint: disable=protected-access
@@ -224,29 +228,25 @@ def set_qt_params_trainable(qt_block):
224228
# pylint: disable=too-many-arguments
225229
@classmethod
226230
def _block_wise_training_step(cls,
227-
omniquant_config,
228231
fp_input,
229232
qt_input,
230233
fp_block,
231234
qt_block,
232235
qt_let_pair_list,
233236
optimizer,
234237
loss_fn,
235-
compute_sqnr : bool,
236238
device : str):
237239
"""
238240
Run block-wise traing on LET parameters. Use fp_block output as ground truth and qt_block output as
239-
model output. Use omniquant_config.input_symmetry to choose input for fp and qt block.
241+
model output.
240242
241-
:param omniquant_config: Configuration for Omniquant optimization.
242243
:param fp_input: block output from previous block in fp model.
243244
:param qt_input: block output from previous block in qt model.
244245
:param fp_block: decoder block in fp model.
245246
:param qt_block: decoder block in qt model.
246247
:param qt_let_pair_list: let_pair_list in qt model. Use to get LET training parameters.
247248
:param optimizer: optimizer used for LET blockwise training
248249
:param loss_fn: loss_fn used for LET bloackwise training
249-
:param compute_sqnr: Computes sqnr between fp and qt block during blockwise training
250250
"""
251251
optimizer.zero_grad()
252252
def _process_block_input(_block_input):
@@ -258,19 +258,20 @@ def _process_block_input(_block_input):
258258
_kwargs = _move_to_device(_kwargs, device)
259259
return _args, _kwargs
260260

261+
input_symmetry = "fpfp"
261262
# Get target output (ground truth)
262-
target_input = fp_input if omniquant_config.input_symmetry.startswith("fp") else qt_input
263+
target_input = fp_input if input_symmetry.startswith("fp") else qt_input
263264
_args, _kwargs = _process_block_input(target_input)
264265
fp_outputs = fp_block(*_args, **_kwargs)[0]
265266

266267
# Get model output (prediction)
267-
model_input = fp_input if omniquant_config.input_symmetry.endswith("fp") else qt_input
268+
model_input = fp_input if input_symmetry.endswith("fp") else qt_input
268269
_qt_args, _qt_kwargs = _process_block_input(model_input)
269270
target_op = fp_block(*_qt_args, **_qt_kwargs)[0]
270271
qt_output = qt_block(*_qt_args, **_qt_kwargs)[0]
271272

272273
with torch.no_grad():
273-
sqnr = (torch.tensor(get_sqnr(target_op, qt_output)) if omniquant_config.compute_sqnr else None)
274+
sqnr = (torch.tensor(get_sqnr(target_op, qt_output)) if OMNIQUANT_COMPUTE_SQNR else None)
274275

275276
loss = loss_fn(qt_output, target_op)
276277
loss.backward()
@@ -279,15 +280,6 @@ def _process_block_input(_block_input):
279280

280281
return loss.detach().cpu(), sqnr
281282

282-
@classmethod
283-
def _validate_omniquant_config(cls, omniquant_config: OmniquantConfig):
284-
""" Validate omniquant config """
285-
# input_symmetry should be one of qtqt, qtfp, fpqt, fpfp
286-
input_symmetry_error_msg = f"Expect omniquant_config.input_symmetry be one of qtqt, qtfp, fpqt, fpfp but got {omniquant_config.input_symmetry}."
287-
assert len(omniquant_config.input_symmetry) == 4, input_symmetry_error_msg
288-
assert omniquant_config.input_symmetry[:2] in ("qt", "fp"), input_symmetry_error_msg
289-
assert omniquant_config.input_symmetry[2:] in ("qt", "fp"), input_symmetry_error_msg
290-
291283
@classmethod
292284
# pylint: disable=protected-access
293285
def _dump_meta_data(cls, model, output_path):

0 commit comments

Comments
 (0)