Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 1 addition & 6 deletions src/generators.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
#include <iostream>
#include "span.h"
#include <memory>
#include <mutex>
#include <numeric>
#include <optional>
#include <queue>
Expand Down Expand Up @@ -145,11 +144,7 @@ struct OrtGlobals {

// Cache for dynamically built graph sessions (e.g., Cast, TopK operations)
// Destroyed before env_ to ensure proper cleanup order
struct SessionCache {
std::unordered_map<uint64_t, std::unique_ptr<OrtSession>> sessions_;
std::mutex mutex_;
};
SessionCache graph_session_cache_;
std::unordered_map<uint64_t, std::unique_ptr<OrtSession>> graph_session_cache_;

private:
OrtGlobals(const OrtGlobals&) = delete;
Expand Down
205 changes: 103 additions & 102 deletions src/models/graph_builder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,116 +109,117 @@ namespace GraphBuilder {
OrtModel* Build(const ModelConfig& config) {
const auto& model_editor_api = Ort::GetModelEditorApi();

OrtGraph* graph = nullptr;
OrtModel* model = nullptr;
std::vector<OrtOpAttr*> node_attributes;

try {
// Create graph
Ort::ThrowOnError(model_editor_api.CreateGraph(&graph));

// Create input ValueInfos
std::vector<OrtValueInfo*> graph_inputs;
for (const auto& input : config.inputs) {
OrtTensorTypeAndShapeInfo* tensor_info = nullptr;
Ort::ThrowOnError(Ort::api->CreateTensorTypeAndShapeInfo(&tensor_info));
Ort::ThrowOnError(Ort::api->SetTensorElementType(tensor_info, input.elem_type));
Ort::ThrowOnError(Ort::api->SetDimensions(tensor_info, input.shape.data(), input.shape.size()));

OrtTypeInfo* type_info = nullptr;
Ort::ThrowOnError(model_editor_api.CreateTensorTypeInfo(tensor_info, &type_info));
Ort::api->ReleaseTensorTypeAndShapeInfo(tensor_info);

OrtValueInfo* value_info = nullptr;
Ort::ThrowOnError(model_editor_api.CreateValueInfo(input.name.c_str(), type_info, &value_info));
Ort::api->ReleaseTypeInfo(type_info);

graph_inputs.push_back(value_info);
}

// Create output ValueInfos
std::vector<OrtValueInfo*> graph_outputs;
for (const auto& output : config.outputs) {
OrtTensorTypeAndShapeInfo* tensor_info = nullptr;
Ort::ThrowOnError(Ort::api->CreateTensorTypeAndShapeInfo(&tensor_info));
Ort::ThrowOnError(Ort::api->SetTensorElementType(tensor_info, output.elem_type));
Ort::ThrowOnError(Ort::api->SetDimensions(tensor_info, output.shape.data(), output.shape.size()));

OrtTypeInfo* type_info = nullptr;
Ort::ThrowOnError(model_editor_api.CreateTensorTypeInfo(tensor_info, &type_info));
Ort::api->ReleaseTensorTypeAndShapeInfo(tensor_info);

OrtValueInfo* value_info = nullptr;
Ort::ThrowOnError(model_editor_api.CreateValueInfo(output.name.c_str(), type_info, &value_info));
Ort::api->ReleaseTypeInfo(type_info);

graph_outputs.push_back(value_info);
}

// Set graph inputs and outputs (graph takes ownership of ValueInfos)
Ort::ThrowOnError(model_editor_api.SetGraphInputs(graph, graph_inputs.data(), graph_inputs.size()));
Ort::ThrowOnError(model_editor_api.SetGraphOutputs(graph, graph_outputs.data(), graph_outputs.size()));
// Use RAII wrappers for automatic cleanup
std::unique_ptr<OrtGraph> graph;
std::unique_ptr<OrtModel> model;
std::vector<OrtOpAttr*> node_attributes; // Manual management - CreateNode stores references, doesn't copy

// Create graph
OrtGraph* graph_ptr = nullptr;
Ort::ThrowOnError(model_editor_api.CreateGraph(&graph_ptr));
graph.reset(graph_ptr);

// Create input ValueInfos
std::vector<std::unique_ptr<OrtValueInfo>> graph_inputs_owned;
std::vector<OrtValueInfo*> graph_inputs;
for (const auto& input : config.inputs) {
OrtTensorTypeAndShapeInfo* tensor_info_raw = nullptr;
Ort::ThrowOnError(Ort::api->CreateTensorTypeAndShapeInfo(&tensor_info_raw));
std::unique_ptr<OrtTensorTypeAndShapeInfo> tensor_info(tensor_info_raw);
Ort::ThrowOnError(Ort::api->SetTensorElementType(tensor_info.get(), input.elem_type));
Ort::ThrowOnError(Ort::api->SetDimensions(tensor_info.get(), input.shape.data(), input.shape.size()));

OrtTypeInfo* type_info_raw = nullptr;
Ort::ThrowOnError(model_editor_api.CreateTensorTypeInfo(tensor_info.get(), &type_info_raw));
std::unique_ptr<OrtTypeInfo> type_info(type_info_raw);

OrtValueInfo* value_info = nullptr;
Ort::ThrowOnError(model_editor_api.CreateValueInfo(input.name.c_str(), type_info.get(), &value_info));

graph_inputs.push_back(value_info);
graph_inputs_owned.emplace_back(value_info);
}

// Create node attributes
for (const auto& attr : config.attributes) {
node_attributes.push_back(CreateOpAttr(attr));
}
// Create output ValueInfos
std::vector<std::unique_ptr<OrtValueInfo>> graph_outputs_owned;
std::vector<OrtValueInfo*> graph_outputs;
for (const auto& output : config.outputs) {
OrtTensorTypeAndShapeInfo* tensor_info_raw = nullptr;
Ort::ThrowOnError(Ort::api->CreateTensorTypeAndShapeInfo(&tensor_info_raw));
std::unique_ptr<OrtTensorTypeAndShapeInfo> tensor_info(tensor_info_raw);
Ort::ThrowOnError(Ort::api->SetTensorElementType(tensor_info.get(), output.elem_type));
Ort::ThrowOnError(Ort::api->SetDimensions(tensor_info.get(), output.shape.data(), output.shape.size()));

OrtTypeInfo* type_info_raw = nullptr;
Ort::ThrowOnError(model_editor_api.CreateTensorTypeInfo(tensor_info.get(), &type_info_raw));
std::unique_ptr<OrtTypeInfo> type_info(type_info_raw);

OrtValueInfo* value_info = nullptr;
Ort::ThrowOnError(model_editor_api.CreateValueInfo(output.name.c_str(), type_info.get(), &value_info));

graph_outputs.push_back(value_info);
graph_outputs_owned.emplace_back(value_info);
}

// Create input/output name vectors
std::vector<const char*> input_names;
for (const auto& input : config.inputs) {
input_names.push_back(input.name.c_str());
}
// Set graph inputs and outputs (graph takes ownership of ValueInfos)
Ort::ThrowOnError(model_editor_api.SetGraphInputs(graph.get(), graph_inputs.data(), graph_inputs.size()));
Ort::ThrowOnError(model_editor_api.SetGraphOutputs(graph.get(), graph_outputs.data(), graph_outputs.size()));

std::vector<const char*> output_names;
for (const auto& output : config.outputs) {
output_names.push_back(output.name.c_str());
}
// Release ownership since graph took it
for (auto& vi : graph_inputs_owned) vi.release();
for (auto& vi : graph_outputs_owned) vi.release();

// Create node
OrtNode* node = nullptr;
Ort::ThrowOnError(model_editor_api.CreateNode(
config.op_type.c_str(),
"", // empty domain = ONNX domain
(config.op_type + "_node").c_str(),
input_names.data(),
input_names.size(),
output_names.data(),
output_names.size(),
node_attributes.empty() ? nullptr : node_attributes.data(),
node_attributes.size(),
&node));

// Add node to graph (graph takes ownership of node)
Ort::ThrowOnError(model_editor_api.AddNodeToGraph(graph, node));
// Release node attributes - CreateNode made its own copy
for (auto* attr : node_attributes) {
Ort::api->ReleaseOpAttr(attr);
}
node_attributes.clear();
// Create model with opset
const char* domain_name = "";
Ort::ThrowOnError(model_editor_api.CreateModel(&domain_name, &config.opset_version, 1, &model));
// Create node attributes
for (const auto& attr : config.attributes) {
node_attributes.push_back(CreateOpAttr(attr));
}

// Add graph to model (model takes ownership of graph)
Ort::ThrowOnError(model_editor_api.AddGraphToModel(model, graph));
graph = nullptr; // model now owns graph
// Create input/output name vectors
std::vector<const char*> input_names;
for (const auto& input : config.inputs) {
input_names.push_back(input.name.c_str());
}

return model;
std::vector<const char*> output_names;
for (const auto& output : config.outputs) {
output_names.push_back(output.name.c_str());
}

} catch (...) {
// Clean up on error
for (auto* attr : node_attributes) {
Ort::api->ReleaseOpAttr(attr);
}
if (graph != nullptr) {
Ort::api->ReleaseGraph(graph);
}
if (model != nullptr) {
Ort::api->ReleaseModel(model);
}
throw;
// Create node
OrtNode* node_ptr = nullptr;
Ort::ThrowOnError(model_editor_api.CreateNode(
config.op_type.c_str(),
"", // empty domain = ONNX domain
(config.op_type + "_node").c_str(),
input_names.data(),
input_names.size(),
output_names.data(),
output_names.size(),
node_attributes.empty() ? nullptr : node_attributes.data(),
node_attributes.size(),
&node_ptr));
std::unique_ptr<OrtNode> node(node_ptr);

// Add node to graph (graph takes ownership of node)
Ort::ThrowOnError(model_editor_api.AddNodeToGraph(graph.get(), node.get()));
node.release(); // graph now owns node

// Create model with opset
const char* domain_name = "";
OrtModel* model_ptr = nullptr;
Ort::ThrowOnError(model_editor_api.CreateModel(&domain_name, &config.opset_version, 1, &model_ptr));
model.reset(model_ptr);

// Add graph to model (model takes ownership of graph)
Ort::ThrowOnError(model_editor_api.AddGraphToModel(model.get(), graph.get()));
graph.release(); // model now owns graph

// Release node attributes - must be done AFTER model is built since CreateNode stores references
for (auto* attr : node_attributes) {
Ort::api->ReleaseOpAttr(attr);
}

return model.release(); // Return ownership to caller
}

} // namespace GraphBuilder
Expand Down
9 changes: 3 additions & 6 deletions src/models/graph_executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
#include "../generators.h"
#include <functional>
#include <memory>
#include <mutex>
#include <unordered_map>
#include <cstring>

Expand Down Expand Up @@ -75,10 +74,8 @@ OrtSession* GetOrCreateSession(
auto& cache = GetOrtGlobals()->graph_session_cache_;
uint64_t key = GenerateCacheKey(config, ep_name);

std::lock_guard<std::mutex> lock(cache.mutex_);

auto it = cache.sessions_.find(key);
if (it != cache.sessions_.end()) {
auto it = cache.find(key);
if (it != cache.end()) {
return it->second.get();
}

Expand All @@ -92,7 +89,7 @@ OrtSession* GetOrCreateSession(
Ort::api->ReleaseModel(model);

OrtSession* session_ptr = session.get();
cache.sessions_[key] = std::move(session);
cache[key] = std::move(session);

return session_ptr;
}
Expand Down
32 changes: 32 additions & 0 deletions src/models/onnxruntime_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -1316,6 +1316,38 @@ struct OrtOpAttr {
Ort::Abstract make_abstract;
};

/// <summary>
/// This struct provides life time management for OrtGraph used in Model Editor API
/// </summary>
struct OrtGraph {
static void operator delete(void* p) { Ort::api->ReleaseGraph(reinterpret_cast<OrtGraph*>(p)); }
Ort::Abstract make_abstract;
};

/// <summary>
/// This struct provides life time management for OrtModel used in Model Editor API
/// </summary>
struct OrtModel {
static void operator delete(void* p) { Ort::api->ReleaseModel(reinterpret_cast<OrtModel*>(p)); }
Ort::Abstract make_abstract;
};

/// <summary>
/// This struct provides life time management for OrtValueInfo used in Model Editor API
/// </summary>
struct OrtValueInfo {
static void operator delete(void* p) { Ort::api->ReleaseValueInfo(reinterpret_cast<OrtValueInfo*>(p)); }
Ort::Abstract make_abstract;
};

/// <summary>
/// This struct provides life time management for OrtNode used in Model Editor API
/// </summary>
struct OrtNode {
static void operator delete(void* p) { Ort::api->ReleaseNode(reinterpret_cast<OrtNode*>(p)); }
Ort::Abstract make_abstract;
};

/// <summary>
/// This class wraps a raw pointer OrtKernelContext* that is being passed
/// to the custom kernel Compute() method. Use it to safely access context
Expand Down
Loading