11# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
33import copy
4- import io
54import time
65import warnings
76from dataclasses import asdict , dataclass , field
1514from megatron .core .utils import experimental_api
1615
1716
18- def serialize_tensor (tensor : torch .Tensor ) -> bytes :
17+ def serialize_tensor (tensor : torch .Tensor ) -> List :
1918 """Serialize tensor to bytes.
2019
2120 Args:
2221 tensor (Tensor): Tensor.
2322
2423 Returns:
25- (bytes) Byte representation of tensor.
24+ (List) Tensor as a list
2625 """
27- buffer = io .BytesIO ()
28- torch .save (tensor , buffer )
29- buffer .seek (0 )
30- tensor_bytes = buffer .read ()
31- return tensor_bytes
26+ torch .cuda .nvtx .range_push ("serialize_tensor" )
3227
28+ # simply convert tensor into a list
29+ tensor = tensor .cpu ().tolist ()
3330
34- def deserialize_tensor (tensor_bytes : bytes ) -> torch .Tensor :
31+ torch .cuda .nvtx .range_pop ()
32+ return tensor
33+
34+
35+ def deserialize_tensor (tensor_as_list : List ) -> torch .Tensor :
3536 """Deserialize tensor from bytes.
3637
3738 Args:
38- tensor_bytes (bytes ): Byte representation of tensor.
39+ tensor_as_list (List ): List representation of tensor.
3940
4041 Returns:
4142 (Tensor) Tensor.
4243 """
43- buffer = io .BytesIO (tensor_bytes )
44- tensor = torch .load (buffer )
44+ tensor = torch .tensor (tensor_as_list )
4545 return tensor
4646
4747
@@ -99,17 +99,21 @@ def serialize(self) -> dict:
9999 (dict) A dictionary representation of the instance suitable for
100100 serialization.
101101 """
102-
103102 # Dataclass to dict.
104- obj = asdict (self )
103+ # do not use asdict(self) - it has very high CPU overheads
104+ # and if there are tensors, it will try to deepcopy them
105+ obj = self .__dict__ .copy () # shallow dict copy
105106 obj ["status" ] = self .status .name if self .status else None
107+ obj ["sampling_params" ] = self .sampling_params .serialize () if self .sampling_params else None
108+ obj ["inference_parameters" ] = (
109+ self .inference_parameters .serialize () if self .inference_parameters else None
110+ )
106111
107112 # Serialize tensors.
108113 obj = {
109114 k : (("tensor" , serialize_tensor (v )) if isinstance (v , torch .Tensor ) else v )
110115 for k , v in obj .items ()
111116 }
112-
113117 return obj
114118
115119 @classmethod
@@ -125,14 +129,31 @@ def deserialize(cls, obj: dict) -> "InferenceRequest":
125129
126130 # Initialize request.
127131 request = cls (** obj )
128- request .status = None if obj ["status" ] is None else Status [obj ["status" ]]
132+ request ._post_deserialize (obj )
133+ return request
129134
130- # Deserialize tensors.
135+ def _post_deserialize (self , obj : dict ):
136+ """
137+ This is called after the dataclass is initialized to handle any special
138+ deserialization logic.
139+ """
140+ # Deserialize status.
141+ self .status = None if obj ["status" ] is None else Status [obj ["status" ]]
142+ self .sampling_params = (
143+ None
144+ if obj ["sampling_params" ] is None
145+ else SamplingParams .deserialize (obj ["sampling_params" ])
146+ )
147+ self .inference_parameters = (
148+ None
149+ if obj ["inference_parameters" ] is None
150+ else SamplingParams .deserialize (obj ["inference_parameters" ])
151+ )
152+
153+ # Deserialize tensors and sampling params.
131154 for k , v in obj .items ():
132155 if isinstance (v , list ) and len (v ) == 2 and v [0 ] == "tensor" :
133- setattr (request , k , deserialize_tensor (v [1 ]))
134-
135- return request
156+ setattr (self , k , deserialize_tensor (v [1 ]))
136157
137158
138159class DynamicInferenceEventType (Enum ):
@@ -197,15 +218,18 @@ def serialize(self) -> dict:
197218 """
198219
199220 # Dataclass to dict.
200- obj = asdict (self )
221+ torch .cuda .nvtx .range_push ("DynamicInferenceEvent.serialize" )
222+ # do not use asdict(self) - it has very high CPU overheads
223+ # and if there are tensors, it will try to deepcopy them
224+ obj = self .__dict__ .copy ()
201225 obj ["type" ] = self .type .name
202226
203227 # Serialize payload.
204228 if self .payload :
205229 from .contexts .dynamic_context import ContextErrorFactory # avoid circular import.
206230
207231 obj ["payload" ] = ContextErrorFactory .serialize (self .payload )
208-
232+ torch . cuda . nvtx . range_pop ()
209233 return obj
210234
211235 @classmethod
@@ -247,7 +271,7 @@ class DynamicInferenceRequest(InferenceRequest):
247271 # remaining prompt tokens are used for chunked prefill
248272 remaining_prompt_tokens : Optional [torch .Tensor ] = None
249273 latency : Optional [float ] = None
250- finished_chunk_token_count = 0
274+ finished_chunk_token_count : int = 0
251275 stop_word_ids : Optional [List [List [int ]]] = None # Tokenized stop words (populated internally)
252276
253277 def __post_init__ (self ):
@@ -275,30 +299,22 @@ def __str__(self):
275299 )
276300 )
277301
278- def serialize (self ):
302+ def serialize (self ) -> dict :
279303 """Converts the instance into a serializable dictionary.
280304
281305 Returns:
282306 (dict) A dictionary representation of the instance suitable for
283307 serialization.
284308 """
309+ torch .cuda .nvtx .range_push ("DynamicInferenceRequest.serialize" )
285310 obj = super ().serialize ()
286311 obj ["events" ] = [e .serialize () for e in self .events ]
312+ torch .cuda .nvtx .range_pop ()
287313 return obj
288314
289- @classmethod
290- def deserialize (cls , obj : dict ) -> "DynamicInferenceRequest" :
291- """Deserialize request.
292-
293- Args:
294- obj (dict): Serialized request data.
295-
296- Returns:
297- (DynamicInferenceRequest) Deserialized request.
298- """
299- request = super ().deserialize (obj )
300- request .events = [DynamicInferenceEvent .deserialize (e ) for e in obj ["events" ]]
301- return request
315+ def _post_deserialize (self , obj ):
316+ super ()._post_deserialize (obj )
317+ self .events = [DynamicInferenceEvent .deserialize (e ) for e in obj ["events" ]]
302318
303319 @property
304320 def tracked_metadata (self ) -> List [Any ]:
@@ -517,8 +533,10 @@ def serialize(self) -> dict:
517533 (dict) A dictionary representation of the instance suitable for
518534 serialization.
519535 """
520- obj = asdict (self )
521- obj ["requests" ] = [r .serialize () for r in self .requests ]
536+ torch .cuda .nvtx .range_push ("DynamicInferenceRequestRecord.serialize" )
537+ obj = self .__dict__ .copy () # shallow dict copy
538+ obj ["requests" ] = [r .serialize () for r in obj ["requests" ]]
539+ torch .cuda .nvtx .range_pop ()
522540 return obj
523541
524542 @classmethod
0 commit comments