Skip to content

Commit 32c982e

Browse files
authored
Add tensor_count property for ControlMessage (#2078)
For ControlMessage, msg.tensors().count is a common pattern, calling msg.tensors() might require a bit more cost than we think. Add a `tensor_count` property to avoid the overhead. Closes #1876 ## By Submitting this PR I confirm: - I am familiar with the [Contributing Guidelines](https://github.com/nv-morpheus/Morpheus/blob/main/docs/source/developer_guide/contributing.md). - When the PR is ready for review, new or existing tests cover these changes. - When the PR is ready for review, the documentation is up to date with these changes. Authors: - Yuchen Zhang (https://github.com/yczhang-nv) Approvers: - David Gardner (https://github.com/dagardner-nv) URL: #2078
1 parent 59aeaca commit 32c982e

File tree

17 files changed

+68
-38
lines changed

17 files changed

+68
-38
lines changed

examples/log_parsing/inference.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -57,16 +57,16 @@ class TritonInferenceLogParsing(TritonInferenceWorker):
5757
"""
5858

5959
def build_output_message(self, msg: ControlMessage) -> ControlMessage:
60-
seq_ids = cp.zeros((msg.tensors().count, 3), dtype=cp.uint32)
61-
seq_ids[:, 0] = cp.arange(0, msg.tensors().count, dtype=cp.uint32)
60+
seq_ids = cp.zeros((msg.tensor_count(), 3), dtype=cp.uint32)
61+
seq_ids[:, 0] = cp.arange(0, msg.tensor_count(), dtype=cp.uint32)
6262
seq_ids[:, 2] = msg.tensors().get_tensor('seq_ids')[:, 2]
6363

6464
memory = TensorMemory(
65-
count=msg.tensors().count,
65+
count=msg.tensor_count(),
6666
tensors={
67-
'confidences': cp.zeros((msg.tensors().count, self._inputs[list(self._inputs.keys())[0]].shape[1])),
68-
'labels': cp.zeros((msg.tensors().count, self._inputs[list(self._inputs.keys())[0]].shape[1])),
69-
'input_ids': cp.zeros((msg.tensors().count, msg.tensors().get_tensor('input_ids').shape[1])),
67+
'confidences': cp.zeros((msg.tensor_count(), self._inputs[list(self._inputs.keys())[0]].shape[1])),
68+
'labels': cp.zeros((msg.tensor_count(), self._inputs[list(self._inputs.keys())[0]].shape[1])),
69+
'input_ids': cp.zeros((msg.tensor_count(), msg.tensors().get_tensor('input_ids').shape[1])),
7070
'seq_ids': seq_ids
7171
})
7272

@@ -154,19 +154,19 @@ def _convert_one_response(output: ControlMessage, inf: ControlMessage, res: Tens
154154
seq_offset = seq_ids[0, 0].item()
155155
seq_count = seq_ids[-1, 0].item() + 1 - seq_offset
156156

157-
input_ids[batch_offset:inf.tensors().count + batch_offset, :] = inf.tensors().get_tensor('input_ids')
158-
out_seq_ids[batch_offset:inf.tensors().count + batch_offset, :] = seq_ids
157+
input_ids[batch_offset:inf.tensor_count() + batch_offset, :] = inf.tensors().get_tensor('input_ids')
158+
out_seq_ids[batch_offset:inf.tensor_count() + batch_offset, :] = seq_ids
159159

160160
resp_confidences = res.get_tensor('confidences')
161161
resp_labels = res.get_tensor('labels')
162162

163163
# Two scenarios:
164-
if (inf.payload().count == inf.tensors().count):
164+
if (inf.payload().count == inf.tensor_count()):
165165
assert seq_count == res.count
166-
confidences[batch_offset:inf.tensors().count + batch_offset, :] = resp_confidences
167-
labels[batch_offset:inf.tensors().count + batch_offset, :] = resp_labels
166+
confidences[batch_offset:inf.tensor_count() + batch_offset, :] = resp_confidences
167+
labels[batch_offset:inf.tensor_count() + batch_offset, :] = resp_labels
168168
else:
169-
assert inf.tensors().count == res.count
169+
assert inf.tensor_count() == res.count
170170

171171
mess_ids = seq_ids[:, 0].get().tolist()
172172

python/morpheus/morpheus/_lib/include/morpheus/messages/control.hpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
#include "morpheus/export.h" // for MORPHEUS_EXPORT
2121
#include "morpheus/messages/meta.hpp" // for MessageMeta
22+
#include "morpheus/types.hpp"
2223
#include "morpheus/utilities/json_types.hpp" // for json_t
2324

2425
#include <pybind11/pytypes.h> // for object, dict, list
@@ -197,6 +198,13 @@ class MORPHEUS_EXPORT ControlMessage
197198
*/
198199
void tensors(const std::shared_ptr<TensorMemory>& tensor_memory);
199200

201+
/**
202+
* @brief Get the length of tensors in the tensor memory.
203+
*
204+
* @return The length of tensors in the tensor memory.
205+
*/
206+
TensorIndex tensor_count();
207+
200208
/**
201209
* @brief Get the type of task associated with the control message.
202210
* @return An enum value indicating the task type.
@@ -262,6 +270,7 @@ class MORPHEUS_EXPORT ControlMessage
262270
ControlMessageType m_cm_type{ControlMessageType::NONE};
263271
std::shared_ptr<MessageMeta> m_payload{nullptr};
264272
std::shared_ptr<TensorMemory> m_tensors{nullptr};
273+
TensorIndex m_tensor_count{0};
265274

266275
morpheus::utilities::json_t m_tasks{};
267276
morpheus::utilities::json_t m_config{};

python/morpheus/morpheus/_lib/messages/__init__.pyi

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ class ControlMessage():
7373
def task_type(self) -> ControlMessageType: ...
7474
@typing.overload
7575
def task_type(self, task_type: ControlMessageType) -> None: ...
76+
def tensor_count(self) -> int: ...
7677
@typing.overload
7778
def tensors(self) -> TensorMemory: ...
7879
@typing.overload

python/morpheus/morpheus/_lib/messages/module.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -290,6 +290,7 @@ PYBIND11_MODULE(messages, _module)
290290
py::arg("meta"))
291291
.def("tensors", pybind11::overload_cast<>(&ControlMessage::tensors))
292292
.def("tensors", pybind11::overload_cast<const std::shared_ptr<TensorMemory>&>(&ControlMessage::tensors))
293+
.def("tensor_count", &ControlMessage::tensor_count)
293294
.def("remove_task", &ControlMessage::remove_task, py::arg("task_type"))
294295
.def("set_metadata", &ControlMessage::set_metadata, py::arg("key"), py::arg("value"))
295296
.def("task_type", pybind11::overload_cast<>(&ControlMessage::task_type))

python/morpheus/morpheus/_lib/src/messages/control.cpp

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -59,9 +59,10 @@ ControlMessage::ControlMessage(const morpheus::utilities::json_t& _config) :
5959

6060
ControlMessage::ControlMessage(const ControlMessage& other)
6161
{
62-
m_cm_type = other.m_cm_type;
63-
m_payload = other.m_payload;
64-
m_tensors = other.m_tensors;
62+
m_cm_type = other.m_cm_type;
63+
m_payload = other.m_payload;
64+
m_tensors = other.m_tensors;
65+
m_tensor_count = other.m_tensor_count;
6566

6667
m_config = other.m_config;
6768
m_tasks = other.m_tasks;
@@ -256,7 +257,13 @@ std::shared_ptr<TensorMemory> ControlMessage::tensors()
256257

257258
void ControlMessage::tensors(const std::shared_ptr<TensorMemory>& tensors)
258259
{
259-
m_tensors = tensors;
260+
m_tensors = tensors;
261+
m_tensor_count = tensors ? tensors->count : 0;
262+
}
263+
264+
TensorIndex ControlMessage::tensor_count()
265+
{
266+
return m_tensor_count;
260267
}
261268

262269
ControlMessageType ControlMessage::task_type()

python/morpheus/morpheus/_lib/src/stages/inference_client_stage.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ static ShapeType get_seq_ids(const std::shared_ptr<ControlMessage>& message)
5858
auto seq_ids = message->tensors()->get_tensor("seq_ids");
5959
const auto item_size = seq_ids.dtype().item_size();
6060

61-
ShapeType host_seq_ids(message->tensors()->count);
61+
ShapeType host_seq_ids(message->tensor_count());
6262
MRC_CHECK_CUDA(cudaMemcpy2D(host_seq_ids.data(),
6363
item_size,
6464
seq_ids.data(),
@@ -82,7 +82,7 @@ static TensorObject get_tensor(std::shared_ptr<ControlMessage> message, std::str
8282

8383
static void reduce_outputs(std::shared_ptr<ControlMessage> const& message, TensorMap& output_tensors)
8484
{
85-
if (message->payload()->count() == message->tensors()->count)
85+
if (message->payload()->count() == message->tensor_count())
8686
{
8787
return;
8888
}

python/morpheus/morpheus/_lib/tests/messages/test_control_message.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include "morpheus/messages/control.hpp" // for ControlMessage
2222
#include "morpheus/messages/memory/tensor_memory.hpp" // for TensorMemory
2323
#include "morpheus/messages/meta.hpp" // for MessageMeta
24+
#include "morpheus/types.hpp"
2425
#include "morpheus/utilities/json_types.hpp" // for PythonByteContainer
2526

2627
#include <gtest/gtest.h> // for Message, TestPartResult, AssertionResult, TestInfo
@@ -298,7 +299,8 @@ TEST_F(TestControlMessage, SetAndGetTensorMemory)
298299
{
299300
auto msg = ControlMessage();
300301

301-
auto tensorMemory = std::make_shared<TensorMemory>(0);
302+
TensorIndex count = 5;
303+
auto tensorMemory = std::make_shared<TensorMemory>(count);
302304
// Optionally, modify tensorMemory here if it has any mutable state to test
303305

304306
// Set the tensor memory
@@ -309,6 +311,7 @@ TEST_F(TestControlMessage, SetAndGetTensorMemory)
309311

310312
// Verify that the retrieved tensor memory matches what was set
311313
EXPECT_EQ(tensorMemory, retrievedTensorMemory);
314+
EXPECT_EQ(count, msg.tensor_count());
312315
}
313316

314317
// Test setting TensorMemory to nullptr

python/morpheus/morpheus/messages/control_message.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ def __init__(self, config_or_message: typing.Union["ControlMessage", dict] = Non
4646

4747
self._payload: MessageMeta = None
4848
self._tensors: TensorMemory = None
49+
self._tensor_count: int = 0
4950

5051
self._tasks: dict[str, deque] = defaultdict(deque)
5152
self._timestamps: dict[str, datetime] = {}
@@ -147,9 +148,13 @@ def payload(self, payload: MessageMeta = None) -> MessageMeta | None:
147148
def tensors(self, tensors: TensorMemory = None) -> TensorMemory | None:
148149
if tensors is not None:
149150
self._tensors = tensors
151+
self._tensor_count = tensors.count
150152

151153
return self._tensors
152154

155+
def tensor_count(self) -> int:
156+
return self._tensor_count
157+
153158
def task_type(self, new_task_type: ControlMessageType = None) -> ControlMessageType:
154159
if new_task_type is not None:
155160
self._type = new_task_type

python/morpheus/morpheus/stages/inference/identity_inference_stage.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,12 +45,12 @@ def __init__(self, inf_queue: ProducerConsumerQueue, c: Config):
4545
self._seq_length = c.feature_length
4646

4747
def calc_output_dims(self, msg: ControlMessage) -> typing.Tuple:
48-
return (msg.tensors().count, self._seq_length)
48+
return (msg.tensor_count(), self._seq_length)
4949

5050
def process(self, batch: ControlMessage, callback: typing.Callable[[TensorMemory], None]):
5151

5252
def tmp(batch: ControlMessage, f):
53-
count = batch.tensors().count
53+
count = batch.tensor_count()
5454
f(TensorMemory(
5555
count=count,
5656
tensors={'probs': cp.zeros((count, self._seq_length), dtype=cp.float32)},

python/morpheus/morpheus/stages/inference/inference_stage.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,7 @@ def set_output_fut(resp: TensorMemory, inner_batch, batch_future: mrc.Future):
244244
nonlocal outstanding_requests
245245
nonlocal batch_offset
246246
mess = self._convert_one_response(output_message, inner_batch, resp, batch_offset)
247-
batch_offset += inner_batch.tensors().count
247+
batch_offset += inner_batch.tensor_count()
248248
outstanding_requests -= 1
249249

250250
batch_future.set_result(mess)
@@ -359,13 +359,13 @@ def _convert_one_response(output: ControlMessage, inf: ControlMessage, res: Tens
359359
seq_count = seq_ids[-1, 0].item() + 1 - seq_offset
360360

361361
# Two scenarios:
362-
if (inf.payload().count == inf.tensors().count):
362+
if (inf.payload().count == inf.tensor_count()):
363363
assert seq_count == res.count
364364

365365
# In message and out message have same count. Just use probs as is
366366
probs[seq_offset:seq_offset + seq_count, :] = resp_probs
367367
else:
368-
assert inf.tensors().count == res.count
368+
assert inf.tensor_count() == res.count
369369

370370
mess_ids = seq_ids[:, 0].get().tolist()
371371

0 commit comments

Comments
 (0)