Skip to content

Commit 53a2b19

Browse files
Optimizing post-processing of requests (#2920)
1 parent 389436b commit 53a2b19

File tree

3 files changed

+101
-54
lines changed

3 files changed

+101
-54
lines changed

megatron/core/inference/data_parallel_inference_coordinator.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ class DataParallelInferenceCoordinator:
6565
next_request_id (int): A counter for generating unique server-side request IDs.
6666
"""
6767

68-
def __init__(self, inference_coordinator_port: int, data_parallel_size: int):
68+
def __init__(self, inference_coordinator_port: int, data_parallel_size: int, tokenizer):
6969
"""
7070
Initializes the inference coordinator.
7171
@@ -116,6 +116,7 @@ def __init__(self, inference_coordinator_port: int, data_parallel_size: int):
116116
self.request_id_to_client_request_id = {}
117117

118118
self.next_request_id = 0
119+
self.tokenizer = tokenizer
119120

120121
def get_next_data_parallel_rank(self):
121122
"""
@@ -261,6 +262,7 @@ def start(self):
261262
finished_request_records = deserialized_payload[1]
262263

263264
for finished_request_record in finished_request_records:
265+
self.detokenize(finished_request_record)
264266
fid = finished_request_record["requests"][0]["request_id"]
265267
client_identity = self.request_id_to_client_id[fid]
266268
client_request_identity = self.request_id_to_client_request_id[fid]
@@ -280,9 +282,25 @@ def start(self):
280282
else:
281283
raise UnknownHeaderError(header)
282284

285+
def detokenize(self, finished_request_record):
286+
"""
287+
Detokenizes the generated tokens in the finished request record.
288+
289+
This method uses the coordinator's tokenizer to convert the list of
290+
generated token IDs back into human-readable text.
291+
292+
Args:
293+
finished_request_record (dict): The record containing the generated
294+
tokens to be detokenized. It is modified in place.
295+
"""
296+
for request in finished_request_record["requests"]:
297+
if request["prompt"] is None:
298+
request["prompt"] = self.tokenizer.detokenize(request["prompt_tokens"][1])
299+
request["generated_text"] = self.tokenizer.detokenize(request["generated_tokens"])
300+
283301
@classmethod
284302
def entrypoint(
285-
cls, ready_event: Event, inference_coordinator_port: int, data_parallel_size: int
303+
cls, ready_event: Event, inference_coordinator_port: int, data_parallel_size: int, tokenizer
286304
):
287305
"""
288306
Class method to instantiate and run the coordinator, for use in a separate process.
@@ -296,7 +314,7 @@ def entrypoint(
296314
inference_coordinator_port (int): The port to bind to.
297315
data_parallel_size (int): The number of expected TP-coordinators.
298316
"""
299-
coordinator = cls(inference_coordinator_port, data_parallel_size)
317+
coordinator = cls(inference_coordinator_port, data_parallel_size, tokenizer)
300318
ready_event.set()
301319
try:
302320
coordinator.start()

megatron/core/inference/engines/dynamic_engine.py

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -413,6 +413,7 @@ async def start_listening_to_data_parallel_coordinator(
413413
coordinator_ready_event,
414414
inference_coordinator_port,
415415
get_pg_size(self.pg_collection.dp),
416+
self.controller.tokenizer,
416417
),
417418
)
418419
self.inference_coordinator_process.start()
@@ -1205,6 +1206,7 @@ async def async_bookkeep(
12051206
cuda_graph_request_count (int): The CUDA graph batch size matching this step.
12061207
"""
12071208
# Increment finished_request_count.
1209+
range_push("bookkeeping")
12081210
cuda_graph_request_count = None
12091211

12101212
if step_result is not None:
@@ -1248,26 +1250,33 @@ async def async_bookkeep(
12481250
finished_request_records.append(failed_entry.record)
12491251
failed_entry.future.set_result(failed_entry.record)
12501252
self.failed_request_ids.clear()
1253+
range_pop()
12511254

1252-
# Detokenize all finished requests (critical for InferenceClient, which
1253-
# doesn't necessarily have the tokenizer).
1254-
for record in finished_request_records:
1255-
for request in record.requests:
1256-
if request.prompt is None:
1257-
request.prompt = self.controller.tokenizer.detokenize(
1258-
request.prompt_tokens.tolist()
1255+
# Detokenize all finished requests if not using
1256+
# the coordinator. Otherwise, the coordinator will
1257+
# overlap detokenization with the engine.
1258+
if not self.use_coordinator:
1259+
range_push("detokenization")
1260+
for record in finished_request_records:
1261+
for request in record.requests:
1262+
if request.prompt is None:
1263+
request.prompt = self.controller.tokenizer.detokenize(
1264+
request.prompt_tokens.tolist()
1265+
)
1266+
request.generated_text = self.controller.tokenizer.detokenize(
1267+
request.generated_tokens
12591268
)
1260-
request.generated_text = self.controller.tokenizer.detokenize(
1261-
request.generated_tokens
1262-
)
1269+
range_pop()
12631270

12641271
# Handle necessary ZMQ DP coordinator communication.
12651272
if self.use_coordinator and self.is_mp_coordinator and finished_request_records:
1273+
range_push("coordinator_communication")
12661274
payload = msgpack.packb(
12671275
[Headers.ENGINE_REPLY.value, [r.serialize() for r in finished_request_records]],
12681276
use_bin_type=True,
12691277
)
12701278
self.socket_for_receiving_requests.send(payload)
1279+
range_pop()
12711280

12721281
# Log KV cache utilization stats to W&B
12731282
if context_state["kv_stats"] is not None:
@@ -1461,7 +1470,7 @@ def schedule_requests(self) -> int:
14611470
int: The number of messages that were received and processed in this batch.
14621471
"""
14631472

1464-
torch.cuda.nvtx.range_push("drain_zmq_socket")
1473+
range_push("drain_zmq_socket")
14651474
all_messages = []
14661475
if self.is_mp_coordinator:
14671476
while True:
@@ -1494,7 +1503,7 @@ def schedule_requests(self) -> int:
14941503
else:
14951504
all_messages = []
14961505

1497-
torch.cuda.nvtx.range_pop()
1506+
range_pop()
14981507
for message in all_messages:
14991508
data = msgpack.unpackb(message, raw=False)
15001509
header = Headers(data[0])
@@ -1507,7 +1516,9 @@ def schedule_requests(self) -> int:
15071516
if header == Headers.SUBMIT_REQUEST:
15081517
request_id, prompt, sampling_params = data[1:]
15091518
sampling_params = SamplingParams.deserialize(sampling_params)
1519+
range_push("add_request")
15101520
self.add_request(request_id, prompt, sampling_params)
1521+
range_pop()
15111522
elif header == Headers.PAUSE:
15121523
# Pause thyself.
15131524
self.received_pause = True

megatron/core/inference/inference_request.py

Lines changed: 57 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22

33
import copy
4-
import io
54
import time
65
import warnings
76
from dataclasses import asdict, dataclass, field
@@ -15,33 +14,34 @@
1514
from megatron.core.utils import experimental_api
1615

1716

18-
def serialize_tensor(tensor: torch.Tensor) -> bytes:
17+
def serialize_tensor(tensor: torch.Tensor) -> List:
1918
"""Serialize tensor to bytes.
2019
2120
Args:
2221
tensor (Tensor): Tensor.
2322
2423
Returns:
25-
(bytes) Byte representation of tensor.
24+
(List) Tensor as a list
2625
"""
27-
buffer = io.BytesIO()
28-
torch.save(tensor, buffer)
29-
buffer.seek(0)
30-
tensor_bytes = buffer.read()
31-
return tensor_bytes
26+
torch.cuda.nvtx.range_push("serialize_tensor")
3227

28+
# simply convert tensor into a list
29+
tensor = tensor.cpu().tolist()
3330

34-
def deserialize_tensor(tensor_bytes: bytes) -> torch.Tensor:
31+
torch.cuda.nvtx.range_pop()
32+
return tensor
33+
34+
35+
def deserialize_tensor(tensor_as_list: List) -> torch.Tensor:
3536
"""Deserialize tensor from bytes.
3637
3738
Args:
38-
tensor_bytes (bytes): Byte representation of tensor.
39+
tensor_as_list (List): List representation of tensor.
3940
4041
Returns:
4142
(Tensor) Tensor.
4243
"""
43-
buffer = io.BytesIO(tensor_bytes)
44-
tensor = torch.load(buffer)
44+
tensor = torch.tensor(tensor_as_list)
4545
return tensor
4646

4747

@@ -99,17 +99,21 @@ def serialize(self) -> dict:
9999
(dict) A dictionary representation of the instance suitable for
100100
serialization.
101101
"""
102-
103102
# Dataclass to dict.
104-
obj = asdict(self)
103+
# do not use asdict(self) - it has very high CPU overheads
104+
# and if there are tensors, it will try to deepcopy them
105+
obj = self.__dict__.copy() # shallow dict copy
105106
obj["status"] = self.status.name if self.status else None
107+
obj["sampling_params"] = self.sampling_params.serialize() if self.sampling_params else None
108+
obj["inference_parameters"] = (
109+
self.inference_parameters.serialize() if self.inference_parameters else None
110+
)
106111

107112
# Serialize tensors.
108113
obj = {
109114
k: (("tensor", serialize_tensor(v)) if isinstance(v, torch.Tensor) else v)
110115
for k, v in obj.items()
111116
}
112-
113117
return obj
114118

115119
@classmethod
@@ -125,14 +129,31 @@ def deserialize(cls, obj: dict) -> "InferenceRequest":
125129

126130
# Initialize request.
127131
request = cls(**obj)
128-
request.status = None if obj["status"] is None else Status[obj["status"]]
132+
request._post_deserialize(obj)
133+
return request
129134

130-
# Deserialize tensors.
135+
def _post_deserialize(self, obj: dict):
136+
"""
137+
This is called after the dataclass is initialized to handle any special
138+
deserialization logic.
139+
"""
140+
# Deserialize status.
141+
self.status = None if obj["status"] is None else Status[obj["status"]]
142+
self.sampling_params = (
143+
None
144+
if obj["sampling_params"] is None
145+
else SamplingParams.deserialize(obj["sampling_params"])
146+
)
147+
self.inference_parameters = (
148+
None
149+
if obj["inference_parameters"] is None
150+
else SamplingParams.deserialize(obj["inference_parameters"])
151+
)
152+
153+
# Deserialize tensors and sampling params.
131154
for k, v in obj.items():
132155
if isinstance(v, list) and len(v) == 2 and v[0] == "tensor":
133-
setattr(request, k, deserialize_tensor(v[1]))
134-
135-
return request
156+
setattr(self, k, deserialize_tensor(v[1]))
136157

137158

138159
class DynamicInferenceEventType(Enum):
@@ -197,15 +218,18 @@ def serialize(self) -> dict:
197218
"""
198219

199220
# Dataclass to dict.
200-
obj = asdict(self)
221+
torch.cuda.nvtx.range_push("DynamicInferenceEvent.serialize")
222+
# do not use asdict(self) - it has very high CPU overheads
223+
# and if there are tensors, it will try to deepcopy them
224+
obj = self.__dict__.copy()
201225
obj["type"] = self.type.name
202226

203227
# Serialize payload.
204228
if self.payload:
205229
from .contexts.dynamic_context import ContextErrorFactory # avoid circular import.
206230

207231
obj["payload"] = ContextErrorFactory.serialize(self.payload)
208-
232+
torch.cuda.nvtx.range_pop()
209233
return obj
210234

211235
@classmethod
@@ -247,7 +271,7 @@ class DynamicInferenceRequest(InferenceRequest):
247271
# remaining prompt tokens are used for chunked prefill
248272
remaining_prompt_tokens: Optional[torch.Tensor] = None
249273
latency: Optional[float] = None
250-
finished_chunk_token_count = 0
274+
finished_chunk_token_count: int = 0
251275
stop_word_ids: Optional[List[List[int]]] = None # Tokenized stop words (populated internally)
252276

253277
def __post_init__(self):
@@ -275,30 +299,22 @@ def __str__(self):
275299
)
276300
)
277301

278-
def serialize(self):
302+
def serialize(self) -> dict:
279303
"""Converts the instance into a serializable dictionary.
280304
281305
Returns:
282306
(dict) A dictionary representation of the instance suitable for
283307
serialization.
284308
"""
309+
torch.cuda.nvtx.range_push("DynamicInferenceRequest.serialize")
285310
obj = super().serialize()
286311
obj["events"] = [e.serialize() for e in self.events]
312+
torch.cuda.nvtx.range_pop()
287313
return obj
288314

289-
@classmethod
290-
def deserialize(cls, obj: dict) -> "DynamicInferenceRequest":
291-
"""Deserialize request.
292-
293-
Args:
294-
obj (dict): Serialized request data.
295-
296-
Returns:
297-
(DynamicInferenceRequest) Deserialized request.
298-
"""
299-
request = super().deserialize(obj)
300-
request.events = [DynamicInferenceEvent.deserialize(e) for e in obj["events"]]
301-
return request
315+
def _post_deserialize(self, obj):
316+
super()._post_deserialize(obj)
317+
self.events = [DynamicInferenceEvent.deserialize(e) for e in obj["events"]]
302318

303319
@property
304320
def tracked_metadata(self) -> List[Any]:
@@ -517,8 +533,10 @@ def serialize(self) -> dict:
517533
(dict) A dictionary representation of the instance suitable for
518534
serialization.
519535
"""
520-
obj = asdict(self)
521-
obj["requests"] = [r.serialize() for r in self.requests]
536+
torch.cuda.nvtx.range_push("DynamicInferenceRequestRecord.serialize")
537+
obj = self.__dict__.copy() # shallow dict copy
538+
obj["requests"] = [r.serialize() for r in obj["requests"]]
539+
torch.cuda.nvtx.range_pop()
522540
return obj
523541

524542
@classmethod

0 commit comments

Comments
 (0)