5656from aimet_torch .v2 .nn import compute_param_encodings
5757from aimet_common .utils import AimetLogger
5858from .decoder_processor import get_transformer_processor
59- from .omniquant_config import OmniquantConfig
6059from .let_modules import LETModule
6160from ._utils import (
6261 _convert_sim_to_letsim ,
7069
7170OMNIQUANT_ARTIFACT_DIR = "./aimet_omniquant_artifact/"
7271OMNIQUANT_METADATA_SAFETENSOR_NAME = "aimet_omniquant_metadata.safetensor"
73-
72+ OMNIQUANT_COMPUTE_SQNR = True
73+ OMNIQUANT_LR = 5e-4 # 1st fp/qt to choose input source on fp block to get ground truth. 2nd fp/qt to choose input source on qt block to get prediction.
74+ CACHE_ON_CPU = True # Will be removed after using blockwise sampler.
7475
7576_logger = AimetLogger .get_area_logger (AimetLogger .LogAreas .Quant )
7677
7778class Omniquant :
7879 """
7980 Omniquant for Post Training Quantization (PTQ)
8081 """
82+ # pylint: disable=too-many-arguments
8183 @classmethod
82- def apply_omniquant (cls , quant_sim : QuantizationSimModel , model : torch .nn .Module , omniquant_config : OmniquantConfig , dataloader ,
83- forward_fn : Callable , output_path : str = OMNIQUANT_ARTIFACT_DIR ) -> torch .nn .Module :
84+ def apply_omniquant (cls , quant_sim : QuantizationSimModel , model : torch .nn .Module , dataloader ,
85+ forward_fn : Callable , num_epoch : int , output_path : str = OMNIQUANT_ARTIFACT_DIR ) -> torch .nn .Module :
8486 """
8587 Returns model with with omniquant weight, and save metadata in safetensor format to output path. Metadata safetensor
8688 can be used in update_lora_weights to update lora adaptor weights for peft lora model.
8789
8890 :param quant_sim: QuantizationSimModel object to optimize with Omniquant.
8991 :param model: Original fp32 model from which quant_sim was created.
90- :param omniquant_config: Configuration for Omniquant optimization.
9192 :param dataloader: Dataloader used to train model.
9293 :param forward_fn: Model forward function used to cache intermediate data.
9394 Expect to have model and inputs as function argument. e.g. lambda model, inputs: model(*inputs)
95+ :param num_epoch: Epochs to train each block with omniquant.
9496 :param output_path: Path to save {layer_name: scale} metadata safetensor.
9597 :return: Model with Omniquant weights.
9698 """
99+ num_batch = len (dataloader )
100+
97101 @contextlib .contextmanager
98102 def disable_dynamic_cache ():
99103 # Disable dynamic_cache for LET blockwise training, and restore after optimization.
@@ -105,11 +109,10 @@ def disable_dynamic_cache():
105109 quant_sim .model .config .use_cache , model .config .use_cache = quant_sim_use_cache_bool , model_use_cache_bool
106110 output_path = Path (output_path )
107111 os .makedirs (output_path , exist_ok = True )
108- cls ._validate_omniquant_config (omniquant_config )
109- _logger .info (omniquant_config )
112+
110113 start_omq_optmztn_time = time .perf_counter ()
111114 with disable_dynamic_cache ():
112- cls ._apply_omniquant (quant_sim , model , omniquant_config , dataloader , forward_fn , output_path )
115+ cls ._apply_omniquant (quant_sim , model , dataloader , forward_fn , num_epoch , num_batch , output_path )
113116 total_omq_optmztn_time = time .perf_counter () - start_omq_optmztn_time
114117 _logger .info ("Took %.4f seconds for omq optimization " , total_omq_optmztn_time )
115118
@@ -119,17 +122,18 @@ def disable_dynamic_cache():
119122 # pylint: disable=too-many-arguments
120123 # pylint: disable=too-many-statements
121124 @classmethod
122- def _apply_omniquant (cls , quant_sim : QuantizationSimModel , model : torch .nn .Module , omniquant_config : OmniquantConfig ,
123- dataloader , forward_fn , output_path : str ) -> torch .nn .Module :
125+ def _apply_omniquant (cls , quant_sim : QuantizationSimModel , model : torch .nn .Module , dataloader , forward_fn ,
126+ num_epoch : int , num_batch : int , output_path : str ) -> torch .nn .Module :
124127 """
125128 Implemenatation to run omniquant optimization block by block. Return model with optimized weights.
126129
127130 :param quant_sim: QuantizationSimModel object to optimize with Omniquant.
128131 :param model: Original fp32 model from which quant_sim was created.
129- :param omniquant_config: Configuration for Omniquant optimization.
130132 :param dataloader: Dataloader used to train model.
131133 :param forward_fn: Model forward function used to cache intermediate data.
132134 Expect to have model and inputs as function argument. e.g. lambda model, inputs: model(*inputs)
135+ :param num_epoch: Epochs to train each block with omniquant.
136+ :param num_batch: Number of batches in dataloader.
133137 :param output_path: Path where to store artifacts.
134138 :return: Model with Omniquant weights.
135139 """
@@ -148,11 +152,11 @@ def _apply_omniquant(cls, quant_sim: QuantizationSimModel, model: torch.nn.Modul
148152 num_repeats = model .config .num_attention_heads // model .config .num_key_value_heads
149153 with tempfile .TemporaryDirectory () as tempdir :
150154 cached_dir = os .path .join (tempdir , 'cached_dataset' )
151- cached_dataset = CachedDataset (dataloader , omniquant_config . num_batch , cached_dir )
155+ cached_dataset = CachedDataset (dataloader , num_batch , cached_dir )
152156
153157 cached_fp_dataset , cached_quant_dataset = get_block_inputs (
154- model , quant_sim , "." .join ([transformer_processor .transformer_block_list_path , "0" ]), cached_dataset , omniquant_config . cache_on_cpu ,
155- forward_fn , omniquant_config . num_batch , cached_dir , incl_kwargs = True
158+ model , quant_sim , "." .join ([transformer_processor .transformer_block_list_path , "0" ]), cached_dataset , CACHE_ON_CPU ,
159+ forward_fn , num_batch , cached_dir , incl_kwargs = True
156160 )
157161 for block_num , (fp_block , qt_block ) in enumerate (zip (fp_transformer_block_list , qt_transformer_block_list )):
158162 qt_let_pair_list = transformer_processor .get_let_module_pair (qt_block )
@@ -174,26 +178,26 @@ def set_qt_params_trainable(qt_block):
174178 encoding_params , param_names = cls ._get_trainable_params (qt_block )
175179 let_params = cls ._get_let_params (qt_let_pair_list )
176180 grouped_params = [
177- {"params" : encoding_params , "lr" : omniquant_config . omq_lr , "weight_decay" : 0. },
178- {"params" : let_params , "lr" : omniquant_config . omq_lr , "weight_decay" : 0. },
181+ {"params" : encoding_params , "lr" : OMNIQUANT_LR , "weight_decay" : 0. },
182+ {"params" : let_params , "lr" : OMNIQUANT_LR , "weight_decay" : 0. },
179183 ]
180184
181185 optimizer = torch .optim .AdamW (grouped_params )
182186 loss_fn = torch .nn .MSELoss (reduction = "sum" )
183187
184188 _logger .info ("Starting blockwise training for params" )
185- for epoch in tqdm (range (omniquant_config . num_epoch )):
189+ for epoch in tqdm (range (num_epoch )):
186190 sqnr_list = []
187191 loss_list = []
188- for batch_num in range (omniquant_config . num_batch ):
192+ for batch_num in range (num_batch ):
189193 fp_input , qt_input = cached_fp_dataset [batch_num ], cached_quant_dataset [batch_num ]
190194 # Do block-wise training.
191- loss , sqnr = cls ._block_wise_training_step (omniquant_config , fp_input , qt_input , fp_block , qt_block , qt_let_pair_list , optimizer , loss_fn , omniquant_config . compute_sqnr , model .device )
195+ loss , sqnr = cls ._block_wise_training_step (fp_input , qt_input , fp_block , qt_block , qt_let_pair_list , optimizer , loss_fn , model .device )
192196 sqnr_list += [sqnr ]
193197 loss_list += [loss ]
194198 loss_mean = torch .stack (loss_list ).mean ()
195199 log_msg = f"layer { block_num } epoch { epoch } | loss: { loss_mean :.3f} "
196- if omniquant_config . compute_sqnr :
200+ if OMNIQUANT_COMPUTE_SQNR :
197201 sqnr_mean = torch .stack (sqnr_list ).mean ()
198202 log_msg += f"{ log_msg } | sqnr: { sqnr_mean :.3f} "
199203
@@ -208,7 +212,7 @@ def set_qt_params_trainable(qt_block):
208212 freeze_let_optimized_param_quantizers (qt_block )
209213 # TODO if should call compute_param_encodings after blockwise training
210214 get_block_outputs (
211- fp_block , qt_block , False , cached_fp_dataset , cached_quant_dataset , omniquant_config . cache_on_cpu ,
215+ fp_block , qt_block , False , cached_fp_dataset , cached_quant_dataset , CACHE_ON_CPU ,
212216 lambda decoder_block , * args , ** kwargs : decoder_block (* args , ** kwargs ), model .device , cached_dir
213217 )
214218 # pylint: disable=protected-access
@@ -224,29 +228,25 @@ def set_qt_params_trainable(qt_block):
224228 # pylint: disable=too-many-arguments
225229 @classmethod
226230 def _block_wise_training_step (cls ,
227- omniquant_config ,
228231 fp_input ,
229232 qt_input ,
230233 fp_block ,
231234 qt_block ,
232235 qt_let_pair_list ,
233236 optimizer ,
234237 loss_fn ,
235- compute_sqnr : bool ,
236238 device : str ):
237239 """
238240 Run block-wise traing on LET parameters. Use fp_block output as ground truth and qt_block output as
239- model output. Use omniquant_config.input_symmetry to choose input for fp and qt block.
241+ model output.
240242
241- :param omniquant_config: Configuration for Omniquant optimization.
242243 :param fp_input: block output from previous block in fp model.
243244 :param qt_input: block output from previous block in qt model.
244245 :param fp_block: decoder block in fp model.
245246 :param qt_block: decoder block in qt model.
246247 :param qt_let_pair_list: let_pair_list in qt model. Use to get LET training parameters.
247248 :param optimizer: optimizer used for LET blockwise training
248249 :param loss_fn: loss_fn used for LET bloackwise training
249- :param compute_sqnr: Computes sqnr between fp and qt block during blockwise training
250250 """
251251 optimizer .zero_grad ()
252252 def _process_block_input (_block_input ):
@@ -258,19 +258,20 @@ def _process_block_input(_block_input):
258258 _kwargs = _move_to_device (_kwargs , device )
259259 return _args , _kwargs
260260
261+ input_symmetry = "fpfp"
261262 # Get target output (ground truth)
262- target_input = fp_input if omniquant_config . input_symmetry .startswith ("fp" ) else qt_input
263+ target_input = fp_input if input_symmetry .startswith ("fp" ) else qt_input
263264 _args , _kwargs = _process_block_input (target_input )
264265 fp_outputs = fp_block (* _args , ** _kwargs )[0 ]
265266
266267 # Get model output (prediction)
267- model_input = fp_input if omniquant_config . input_symmetry .endswith ("fp" ) else qt_input
268+ model_input = fp_input if input_symmetry .endswith ("fp" ) else qt_input
268269 _qt_args , _qt_kwargs = _process_block_input (model_input )
269270 target_op = fp_block (* _qt_args , ** _qt_kwargs )[0 ]
270271 qt_output = qt_block (* _qt_args , ** _qt_kwargs )[0 ]
271272
272273 with torch .no_grad ():
273- sqnr = (torch .tensor (get_sqnr (target_op , qt_output )) if omniquant_config . compute_sqnr else None )
274+ sqnr = (torch .tensor (get_sqnr (target_op , qt_output )) if OMNIQUANT_COMPUTE_SQNR else None )
274275
275276 loss = loss_fn (qt_output , target_op )
276277 loss .backward ()
@@ -279,15 +280,6 @@ def _process_block_input(_block_input):
279280
280281 return loss .detach ().cpu (), sqnr
281282
282- @classmethod
283- def _validate_omniquant_config (cls , omniquant_config : OmniquantConfig ):
284- """ Validate omniquant config """
285- # input_symmetry should be one of qtqt, qtfp, fpqt, fpfp
286- input_symmetry_error_msg = f"Expect omniquant_config.input_symmetry be one of qtqt, qtfp, fpqt, fpfp but got { omniquant_config .input_symmetry } ."
287- assert len (omniquant_config .input_symmetry ) == 4 , input_symmetry_error_msg
288- assert omniquant_config .input_symmetry [:2 ] in ("qt" , "fp" ), input_symmetry_error_msg
289- assert omniquant_config .input_symmetry [2 :] in ("qt" , "fp" ), input_symmetry_error_msg
290-
291283 @classmethod
292284 # pylint: disable=protected-access
293285 def _dump_meta_data (cls , model , output_path ):
0 commit comments