Skip to content

Commit 9a1218c

Browse files
Lee, Kyunggeunquic-kyunggeu
authored andcommitted
Update QuantizerBase.load_state_dict function signature
Signed-off-by: Kyunggeun Lee <[email protected]> Co-authored-by: Kyunggeun Lee <[email protected]>
1 parent 2b0a1fc commit 9a1218c

File tree

3 files changed

+14
-13
lines changed

3 files changed

+14
-13
lines changed

TrainingExtensions/torch/src/python/aimet_torch/v2/quantization/base/quantizer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ def state_dict(self, *args, **kwargs): # pylint: disable=arguments-differ
169169

170170
return state_dict
171171

172-
def load_state_dict(self, state_dict, strict: bool = True): # pylint:disable=arguments-differ
172+
def load_state_dict(self, state_dict, *args, **kwargs): # pylint:disable=arguments-differ
173173
if "_extra_state" not in state_dict:
174174
is_initialized = OrderedDict(
175175
{
@@ -180,7 +180,7 @@ def load_state_dict(self, state_dict, strict: bool = True): # pylint:disable=ar
180180
)
181181
state_dict["_extra_state"] = is_initialized
182182

183-
ret = super().load_state_dict(state_dict, strict)
183+
ret = super().load_state_dict(state_dict, *args, **kwargs)
184184

185185
if version.parse(torch.__version__) < version.parse("1.10"):
186186
# This is for backward compatibility with torch < 1.10

TrainingExtensions/torch/src/python/aimet_torch/v2/quantization/float/quantizer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,7 @@ def set_extra_state(self, state):
194194
self.mantissa_bits = state["mantissa_bits"].item()
195195
super().set_extra_state(state)
196196

197-
def load_state_dict(self, state_dict, strict: bool = True):
197+
def load_state_dict(self, state_dict, *args, **kwargs):
198198
if "maxval" in state_dict:
199199
if self.maxval is None:
200200
del self.maxval
@@ -203,7 +203,7 @@ def load_state_dict(self, state_dict, strict: bool = True):
203203
del self.maxval
204204
self.register_buffer("maxval", None)
205205

206-
ret = super().load_state_dict(state_dict, strict)
206+
ret = super().load_state_dict(state_dict, *args, **kwargs)
207207
return ret
208208

209209
@property

TrainingExtensions/torch/test/python/v2/quantization/affine/test_affine_quantizer.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -842,7 +842,8 @@ def test_invalid_encoding_analyzer():
842842

843843
@torch.no_grad()
844844
@pytest.mark.cuda
845-
def test_is_initialized(x):
845+
@pytest.mark.parametrize("assign", [True, False])
846+
def test_is_initialized(x, assign: bool):
846847
"""
847848
When: Instantiate a quantizer object
848849
Then:
@@ -920,7 +921,7 @@ def test_is_initialized(x):
920921
symmetric=True,
921922
encoding_analyzer=MinMaxEncodingAnalyzer((10,)),
922923
)
923-
qdq.load_state_dict({"min": -torch.ones(10), "max": torch.ones(10)})
924+
qdq.load_state_dict({"min": -torch.ones(10), "max": torch.ones(10)}, assign=assign)
924925
assert qdq.is_initialized()
925926

926927
"""
@@ -933,9 +934,9 @@ def test_is_initialized(x):
933934
symmetric=True,
934935
encoding_analyzer=MinMaxEncodingAnalyzer((10,)),
935936
)
936-
qdq.load_state_dict({"min": -torch.ones(10)}, strict=False)
937+
qdq.load_state_dict({"min": -torch.ones(10)}, strict=False, assign=assign)
937938
assert not qdq.is_initialized() # False; max is not initialized yet
938-
qdq.load_state_dict({"max": torch.ones(10)}, strict=False)
939+
qdq.load_state_dict({"max": torch.ones(10)}, strict=False, assign=assign)
939940
assert qdq.is_initialized()
940941

941942
"""
@@ -949,17 +950,17 @@ def test_is_initialized(x):
949950
encoding_analyzer=MinMaxEncodingAnalyzer((10,)),
950951
)
951952
uninitialized_state_dict = qdq.state_dict()
952-
qdq.load_state_dict(uninitialized_state_dict)
953+
qdq.load_state_dict(uninitialized_state_dict, assign=assign)
953954
assert not qdq.is_initialized()
954955

955956
qdq.min.mul_(1.0)
956957
partially_initialized_state_dict = qdq.state_dict()
957-
qdq.load_state_dict(partially_initialized_state_dict)
958+
qdq.load_state_dict(partially_initialized_state_dict, assign=assign)
958959
assert not qdq.is_initialized()
959960

960961
qdq.max.mul_(1.0)
961962
fully_initialized_state_dict = qdq.state_dict()
962-
qdq.load_state_dict(fully_initialized_state_dict)
963+
qdq.load_state_dict(fully_initialized_state_dict, assign=assign)
963964
assert qdq.is_initialized()
964965

965966
"""
@@ -981,7 +982,7 @@ def test_is_initialized(x):
981982
symmetric=True,
982983
encoding_analyzer=MinMaxEncodingAnalyzer((10,)),
983984
)
984-
qdq.load_state_dict({"min": -torch.ones(10), "max": torch.ones(10)})
985+
qdq.load_state_dict({"min": -torch.ones(10), "max": torch.ones(10)}, assign=assign)
985986
qdq = copy.deepcopy(qdq)
986987
assert qdq.is_initialized()
987988

@@ -1005,7 +1006,7 @@ def test_is_initialized(x):
10051006
symmetric=True,
10061007
encoding_analyzer=MinMaxEncodingAnalyzer((10,)),
10071008
)
1008-
qdq.load_state_dict({"min": -torch.ones(10), "max": torch.ones(10)})
1009+
qdq.load_state_dict({"min": -torch.ones(10), "max": torch.ones(10)}, assign=assign)
10091010
out_before = qdq(x.view(-1, 10))
10101011
res = pickle.dumps(qdq)
10111012
qdq = pickle.loads(res)

0 commit comments

Comments
 (0)