Skip to content

Commit af5de4d

Browse files
Use ORT APIs for registering EPs
1 parent 4fd0ed4 commit af5de4d

File tree

6 files changed

+45
-56
lines changed

6 files changed

+45
-56
lines changed

examples/c/src/common.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -237,11 +237,11 @@ void RegisterEP(const std::string& ep, const std::string& ep_path) {
237237
}
238238

239239
std::cout << "Registering execution provider: " << ep_path << std::endl;
240-
240+
auto env = Ort::Env();
241241
if (ep.compare("cuda") == 0) {
242-
OgaRegisterExecutionProviderLibrary("CUDAExecutionProvider", ep_path.c_str());
242+
env.RegisterExecutionProviderLibrary("CUDAExecutionProvider", ep_path.c_str());
243243
} else if (ep.compare("NvTensorRtRtx") == 0) {
244-
OgaRegisterExecutionProviderLibrary("NvTensorRTRTXExecutionProvider", ep_path.c_str());
244+
env.RegisterExecutionProviderLibrary("NvTensorRTRTXExecutionProvider", ep_path.c_str());
245245
} else {
246246
std::cout << "Warning: EP registration not supported for " << ep << std::endl;
247247
std::cout << "Only 'cuda' and 'NvTensorRtRtx' support plug-in libraries." << std::endl;

examples/c/src/common.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
#include <CLI/CLI.hpp>
1717
#include <nlohmann/json.hpp>
18+
#include "onnxruntime_cxx_api.h"
1819
#include "ort_genai.h"
1920

2021
using Clock = std::chrono::high_resolution_clock;

examples/csharp/Common/Common.cs

Lines changed: 37 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
using Microsoft.ML.OnnxRuntimeGenAI;
1+
using Microsoft.ML.OnnxRuntime;
2+
using Microsoft.ML.OnnxRuntimeGenAI;
23
using System.CommandLine;
34
using System.Reflection;
45
using System.Reflection.Metadata.Ecma335;
@@ -26,45 +27,41 @@ public static void SetLogger(bool inputs = true, bool outputs = true)
2627
Utils.SetLogBool("model_output_values", outputs);
2728
}
2829

29-
/**
30-
* TODO: Uncomment the below snippet to use Utils.RegisterEPLibrary once
31-
* the C# binding to Utils.RegisterEPLibrary is in a stable package release.
32-
*/
33-
34-
// /// <summary>
35-
// /// Register execution provider if path is provided
36-
// /// </summary>
37-
// /// <param name="ep">Name of execution provider to set</param>
38-
// /// <param name="ep_path">Path to execution provider to set</param>
39-
// /// <returns>
40-
// /// None
41-
// /// </returns>
42-
// public static void RegisterEP(string ep, string ep_path)
43-
// {
44-
// if (string.IsNullOrEmpty(ep_path))
45-
// {
46-
// return; // No library path specified, skip registration
47-
// }
48-
49-
// Console.WriteLine($"Registering execution provider: {ep_path}");
50-
51-
// if (string.Equals(ep, "cuda", StringComparison.OrdinalIgnoreCase))
52-
// {
53-
// Utils.RegisterExecutionProviderLibrary("CUDAExecutionProvider", ep_path);
54-
// }
55-
// else if (string.Equals(ep, "NvTensorRtRtx", StringComparison.OrdinalIgnoreCase))
56-
// {
57-
// Utils.RegisterExecutionProviderLibrary("NvTensorRTRTXExecutionProvider", ep_path);
58-
// }
59-
// else
60-
// {
61-
// Console.WriteLine($"Warning: EP registration not supported for {ep}");
62-
// Console.WriteLine("Only 'cuda' and 'NvTensorRtRtx' support plug-in libraries.");
63-
// return;
64-
// }
65-
66-
// Console.WriteLine($"Registered {ep} successfully!");
67-
// }
30+
/// <summary>
31+
/// Register execution provider if path is provided
32+
/// </summary>
33+
/// <param name="ep">Name of execution provider to set</param>
34+
/// <param name="ep_path">Path to execution provider to set</param>
35+
/// <returns>
36+
/// None
37+
/// </returns>
38+
public static void RegisterEP(string ep, string ep_path)
39+
{
40+
if (string.IsNullOrEmpty(ep_path))
41+
{
42+
return; // No library path specified, skip registration
43+
}
44+
45+
Console.WriteLine($"Registering execution provider: {ep_path}");
46+
47+
var ortEnv = OrtEnv.Instance();
48+
if (string.Equals(ep, "cuda", StringComparison.OrdinalIgnoreCase))
49+
{
50+
ortEnv.RegisterExecutionProviderLibrary("CUDAExecutionProvider", ep_path);
51+
}
52+
else if (string.Equals(ep, "NvTensorRtRtx", StringComparison.OrdinalIgnoreCase))
53+
{
54+
ortEnv.RegisterExecutionProviderLibrary("NvTensorRTRTXExecutionProvider", ep_path);
55+
}
56+
else
57+
{
58+
Console.WriteLine($"Warning: EP registration not supported for {ep}");
59+
Console.WriteLine("Only 'cuda' and 'NvTensorRtRtx' support plug-in libraries.");
60+
return;
61+
}
62+
63+
Console.WriteLine($"Registered {ep} successfully!");
64+
}
6865

6966
/// <summary>
7067
/// Get Config object and set EP-specific and search-specific options inside it

examples/csharp/README.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,15 +37,15 @@ dotnet build ModelMM -c Release
3737

3838
```powershell
3939
# Prerequisite: navigate to the compiled binaries. This is an example. Your navigation may change depending on your target.
40-
cd ./ModelChat/build/Debug/net8.0/
40+
cd ./ModelChat/bin/Debug/net8.0/
4141
4242
# The `model-chat` script allows for multi-turn conversations.
4343
.\ModelChat.exe -m {path to model folder} -e {execution provider}
4444
```
4545

4646
```powershell
4747
# Prerequisite: navigate to the compiled binaries. This is an example. Your navigation may change depending on your target.
48-
cd ./ModelMM/build/Debug/net8.0/
48+
cd ./ModelMM/bin/Debug/net8.0/
4949
5050
# The `model-mm` script works for multi-modal models and streams the output text token by token.
5151
.\ModelMM.exe -m {path to model folder} -e {execution provider}
@@ -55,15 +55,15 @@ cd ./ModelMM/build/Debug/net8.0/
5555

5656
```bash
5757
# Prerequisite: navigate to the compiled binaries. This is an example. Your navigation may change depending on your target.
58-
cd ./ModelChat/build/Debug/net8.0/
58+
cd ./ModelChat/bin/Debug/net8.0/
5959

6060
# The `model-chat` script allows for multi-turn conversations.
6161
./ModelChat -m {path to model folder} -e {execution provider}
6262
```
6363

6464
```bash
6565
# Prerequisite: navigate to the compiled binaries. This is an example. Your navigation may change depending on your target.
66-
cd ./ModelMM/build/Debug/net8.0/
66+
cd ./ModelMM/bin/Debug/net8.0/
6767

6868
# The `model-mm` script works for multi-modal models and streams the output text token by token.
6969
./ModelMM -m {path to model folder} -e {execution provider}

src/csharp/NativeMethods.cs

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -333,10 +333,6 @@ public static extern UIntPtr OgaSequencesGetSequenceCount(IntPtr /* const OgaSeq
333333
[DllImport(NativeLib.DllName, CallingConvention = CallingConvention.Winapi)]
334334
public static extern IntPtr /* OgaResult* */ OgaGetCurrentGpuDeviceId(out IntPtr /* int32_t */ deviceId);
335335

336-
[DllImport(NativeLib.DllName, CallingConvention = CallingConvention.Winapi)]
337-
public static extern IntPtr /* OgaResult* */ OgaRegisterExecutionProviderLibrary(byte[] /* const char* */ name,
338-
byte[] /* const char* */ path);
339-
340336
[DllImport(NativeLib.DllName, CallingConvention = CallingConvention.Winapi)]
341337
public static extern void OgaShutdown();
342338

src/csharp/Utils.cs

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -59,11 +59,6 @@ public static void SetLogString(string name, string value)
5959
{
6060
Result.VerifySuccess(NativeMethods.OgaSetLogString(StringUtils.ToUtf8(name), StringUtils.ToUtf8(value)));
6161
}
62-
63-
public static void RegisterExecutionProviderLibrary(string name, string path)
64-
{
65-
Result.VerifySuccess(NativeMethods.OgaRegisterExecutionProviderLibrary(StringUtils.ToUtf8(name), StringUtils.ToUtf8(path)));
66-
}
6762
}
6863

6964
internal class StringUtils

0 commit comments

Comments
 (0)