Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
179 changes: 84 additions & 95 deletions src/models/graph_builder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,116 +109,105 @@ namespace GraphBuilder {
OrtModel* Build(const ModelConfig& config) {
const auto& model_editor_api = Ort::GetModelEditorApi();

OrtGraph* graph = nullptr;
OrtModel* model = nullptr;
std::vector<OrtOpAttr*> node_attributes;

try {
// Create graph
Ort::ThrowOnError(model_editor_api.CreateGraph(&graph));

// Create input ValueInfos
std::vector<OrtValueInfo*> graph_inputs;
for (const auto& input : config.inputs) {
OrtTensorTypeAndShapeInfo* tensor_info = nullptr;
Ort::ThrowOnError(Ort::api->CreateTensorTypeAndShapeInfo(&tensor_info));
Ort::ThrowOnError(Ort::api->SetTensorElementType(tensor_info, input.elem_type));
Ort::ThrowOnError(Ort::api->SetDimensions(tensor_info, input.shape.data(), input.shape.size()));
// Create graph using RAII wrapper
auto graph = OrtGraph::Create();

OrtTypeInfo* type_info = nullptr;
Ort::ThrowOnError(model_editor_api.CreateTensorTypeInfo(tensor_info, &type_info));
Ort::api->ReleaseTensorTypeAndShapeInfo(tensor_info);
// Create input ValueInfos
std::vector<std::unique_ptr<OrtValueInfo>> graph_inputs;
for (const auto& input : config.inputs) {
OrtTensorTypeAndShapeInfo* tensor_info = nullptr;
Ort::ThrowOnError(Ort::api->CreateTensorTypeAndShapeInfo(&tensor_info));
Ort::ThrowOnError(Ort::api->SetTensorElementType(tensor_info, input.elem_type));
Ort::ThrowOnError(Ort::api->SetDimensions(tensor_info, input.shape.data(), input.shape.size()));

OrtValueInfo* value_info = nullptr;
Ort::ThrowOnError(model_editor_api.CreateValueInfo(input.name.c_str(), type_info, &value_info));
Ort::api->ReleaseTypeInfo(type_info);
auto value_info = OrtValueInfo::Create(input.name.c_str(), tensor_info);
Ort::api->ReleaseTensorTypeAndShapeInfo(tensor_info);

graph_inputs.push_back(value_info);
}

// Create output ValueInfos
std::vector<OrtValueInfo*> graph_outputs;
for (const auto& output : config.outputs) {
OrtTensorTypeAndShapeInfo* tensor_info = nullptr;
Ort::ThrowOnError(Ort::api->CreateTensorTypeAndShapeInfo(&tensor_info));
Ort::ThrowOnError(Ort::api->SetTensorElementType(tensor_info, output.elem_type));
Ort::ThrowOnError(Ort::api->SetDimensions(tensor_info, output.shape.data(), output.shape.size()));
graph_inputs.push_back(std::move(value_info));
}

OrtTypeInfo* type_info = nullptr;
Ort::ThrowOnError(model_editor_api.CreateTensorTypeInfo(tensor_info, &type_info));
Ort::api->ReleaseTensorTypeAndShapeInfo(tensor_info);
// Create output ValueInfos
std::vector<std::unique_ptr<OrtValueInfo>> graph_outputs;
for (const auto& output : config.outputs) {
OrtTensorTypeAndShapeInfo* tensor_info = nullptr;
Ort::ThrowOnError(Ort::api->CreateTensorTypeAndShapeInfo(&tensor_info));
Ort::ThrowOnError(Ort::api->SetTensorElementType(tensor_info, output.elem_type));
Ort::ThrowOnError(Ort::api->SetDimensions(tensor_info, output.shape.data(), output.shape.size()));

OrtValueInfo* value_info = nullptr;
Ort::ThrowOnError(model_editor_api.CreateValueInfo(output.name.c_str(), type_info, &value_info));
Ort::api->ReleaseTypeInfo(type_info);
auto value_info = OrtValueInfo::Create(output.name.c_str(), tensor_info);
Ort::api->ReleaseTensorTypeAndShapeInfo(tensor_info);

graph_outputs.push_back(value_info);
}
graph_outputs.push_back(std::move(value_info));
}

// Set graph inputs and outputs (graph takes ownership of ValueInfos)
Ort::ThrowOnError(model_editor_api.SetGraphInputs(graph, graph_inputs.data(), graph_inputs.size()));
Ort::ThrowOnError(model_editor_api.SetGraphOutputs(graph, graph_outputs.data(), graph_outputs.size()));
// Set graph inputs and outputs (graph takes ownership of ValueInfos)
std::vector<OrtValueInfo*> input_ptrs;
input_ptrs.reserve(graph_inputs.size());
for (auto& vi : graph_inputs) {
input_ptrs.push_back(vi.get());
}

// Create node attributes
for (const auto& attr : config.attributes) {
node_attributes.push_back(CreateOpAttr(attr));
}
std::vector<OrtValueInfo*> output_ptrs;
output_ptrs.reserve(graph_outputs.size());
for (auto& vi : graph_outputs) {
output_ptrs.push_back(vi.get());
}

// Create input/output name vectors
std::vector<const char*> input_names;
for (const auto& input : config.inputs) {
input_names.push_back(input.name.c_str());
}
Ort::ThrowOnError(model_editor_api.SetGraphInputs(graph.get(), input_ptrs.data(), input_ptrs.size()));
Ort::ThrowOnError(model_editor_api.SetGraphOutputs(graph.get(), output_ptrs.data(), output_ptrs.size()));

std::vector<const char*> output_names;
for (const auto& output : config.outputs) {
output_names.push_back(output.name.c_str());
}
// Release ownership since graph took it
for (auto& vi : graph_inputs) vi.release();
for (auto& vi : graph_outputs) vi.release();

// Create node
OrtNode* node = nullptr;
Ort::ThrowOnError(model_editor_api.CreateNode(
config.op_type.c_str(),
"", // empty domain = ONNX domain
(config.op_type + "_node").c_str(),
input_names.data(),
input_names.size(),
output_names.data(),
output_names.size(),
node_attributes.empty() ? nullptr : node_attributes.data(),
node_attributes.size(),
&node));

// Add node to graph (graph takes ownership of node)
Ort::ThrowOnError(model_editor_api.AddNodeToGraph(graph, node));
// Release node attributes - CreateNode made its own copy
for (auto* attr : node_attributes) {
Ort::api->ReleaseOpAttr(attr);
}
node_attributes.clear();
// Create model with opset
const char* domain_name = "";
Ort::ThrowOnError(model_editor_api.CreateModel(&domain_name, &config.opset_version, 1, &model));
// Create node attributes
std::vector<OrtOpAttr*> node_attributes;
for (const auto& attr : config.attributes) {
node_attributes.push_back(CreateOpAttr(attr));
}

// Add graph to model (model takes ownership of graph)
Ort::ThrowOnError(model_editor_api.AddGraphToModel(model, graph));
graph = nullptr; // model now owns graph
// Create input/output name vectors
std::vector<const char*> input_names;
for (const auto& input : config.inputs) {
input_names.push_back(input.name.c_str());
}

return model;
std::vector<const char*> output_names;
for (const auto& output : config.outputs) {
output_names.push_back(output.name.c_str());
}

} catch (...) {
// Clean up on error
for (auto* attr : node_attributes) {
Ort::api->ReleaseOpAttr(attr);
}
if (graph != nullptr) {
Ort::api->ReleaseGraph(graph);
}
if (model != nullptr) {
Ort::api->ReleaseModel(model);
}
throw;
// Create node using RAII wrapper
auto node = OrtNode::Create(
config.op_type.c_str(),
"", // empty domain = ONNX domain
(config.op_type + "_node").c_str(),
input_names.data(),
input_names.size(),
output_names.data(),
output_names.size(),
node_attributes.empty() ? nullptr : node_attributes.data(),
node_attributes.size());

// Add node to graph (graph takes ownership of node)
Ort::ThrowOnError(model_editor_api.AddNodeToGraph(graph.get(), node.get()));
node.release(); // graph now owns node

// Create model with opset using RAII wrapper
const char* domain_name = "";
int opset = config.opset_version;
auto model = OrtModel::Create(&domain_name, &opset, 1);

// Add graph to model (model takes ownership of graph)
Ort::ThrowOnError(model_editor_api.AddGraphToModel(model.get(), graph.get()));
graph.release(); // model now owns graph

// Release node attributes - must be done AFTER model is built since CreateNode stores references
for (auto* attr : node_attributes) {
Ort::api->ReleaseOpAttr(attr);
}

return model.release(); // Return ownership to caller
}

} // namespace GraphBuilder
Expand Down
39 changes: 39 additions & 0 deletions src/models/onnxruntime_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -1316,6 +1316,45 @@ struct OrtOpAttr {
Ort::Abstract make_abstract;
};

/// <summary>
/// This struct provides life time management for OrtGraph used in Model Editor API
/// </summary>
struct OrtGraph {
static std::unique_ptr<OrtGraph> Create();
static void operator delete(void* p) { Ort::api->ReleaseGraph(reinterpret_cast<OrtGraph*>(p)); }
Ort::Abstract make_abstract;
};

/// <summary>
/// This struct provides life time management for OrtModel used in Model Editor API
/// </summary>
struct OrtModel {
static std::unique_ptr<OrtModel> Create(const char** domain_names, const int* opset_versions, size_t num_domains);
static void operator delete(void* p) { Ort::api->ReleaseModel(reinterpret_cast<OrtModel*>(p)); }
Ort::Abstract make_abstract;
};

/// <summary>
/// This struct provides life time management for OrtValueInfo used in Model Editor API
/// </summary>
struct OrtValueInfo {
static std::unique_ptr<OrtValueInfo> Create(const char* name, const OrtTensorTypeAndShapeInfo* tensor_info);
static void operator delete(void* p) { Ort::api->ReleaseValueInfo(reinterpret_cast<OrtValueInfo*>(p)); }
Ort::Abstract make_abstract;
};

/// <summary>
/// This struct provides life time management for OrtNode used in Model Editor API
/// </summary>
struct OrtNode {
static std::unique_ptr<OrtNode> Create(const char* op_type, const char* domain, const char* name,
const char** input_names, size_t num_inputs,
const char** output_names, size_t num_outputs,
OrtOpAttr** attributes, size_t num_attributes);
static void operator delete(void* p) { Ort::api->ReleaseNode(reinterpret_cast<OrtNode*>(p)); }
Ort::Abstract make_abstract;
};

/// <summary>
/// This class wraps a raw pointer OrtKernelContext* that is being passed
/// to the custom kernel Compute() method. Use it to safely access context
Expand Down
40 changes: 40 additions & 0 deletions src/models/onnxruntime_inline.h
Original file line number Diff line number Diff line change
Expand Up @@ -1338,6 +1338,46 @@ inline std::unique_ptr<OrtOpAttr> OrtOpAttr::Create(const char* name, const void
return std::unique_ptr<OrtOpAttr>{p};
}

inline std::unique_ptr<OrtGraph> OrtGraph::Create() {
OrtGraph* p;
Ort::ThrowOnError(Ort::GetModelEditorApi().CreateGraph(&p));
return std::unique_ptr<OrtGraph>{p};
}

inline std::unique_ptr<OrtModel> OrtModel::Create(const char** domain_names, const int* opset_versions, size_t num_domains) {
OrtModel* p;
Ort::ThrowOnError(Ort::GetModelEditorApi().CreateModel(domain_names, opset_versions, num_domains, &p));
return std::unique_ptr<OrtModel>{p};
}

inline std::unique_ptr<OrtValueInfo> OrtValueInfo::Create(const char* name, const OrtTensorTypeAndShapeInfo* tensor_info) {
const auto& model_editor_api = Ort::GetModelEditorApi();

OrtTypeInfo* type_info;
Ort::ThrowOnError(model_editor_api.CreateTensorTypeInfo(tensor_info, &type_info));

OrtValueInfo* p;
auto status = model_editor_api.CreateValueInfo(name, type_info, &p);
Ort::api->ReleaseTypeInfo(type_info);
Ort::ThrowOnError(status);

return std::unique_ptr<OrtValueInfo>{p};
}

inline std::unique_ptr<OrtNode> OrtNode::Create(const char* op_type, const char* domain, const char* name,
const char** input_names, size_t num_inputs,
const char** output_names, size_t num_outputs,
OrtOpAttr** attributes, size_t num_attributes) {
OrtNode* p;
Ort::ThrowOnError(Ort::GetModelEditorApi().CreateNode(
op_type, domain, name,
input_names, num_inputs,
output_names, num_outputs,
attributes, num_attributes,
&p));
return std::unique_ptr<OrtNode>{p};
}

inline std::unique_ptr<OrtKernelInfo> OrtKernelInfo::Clone() const {
OrtKernelInfo* p;
Ort::ThrowOnError(Ort::api->CopyKernelInfo(this, &p));
Expand Down
Loading