diff --git a/src/Microsoft.ML.Core/Data/IHostEnvironment.cs b/src/Microsoft.ML.Core/Data/IHostEnvironment.cs
index 06fe0b32b4..f9849c9a6c 100644
--- a/src/Microsoft.ML.Core/Data/IHostEnvironment.cs
+++ b/src/Microsoft.ML.Core/Data/IHostEnvironment.cs
@@ -3,6 +3,7 @@
// See the LICENSE file in the project root for more information.
using System;
+using System.Collections.Generic;
using Microsoft.ML.Data;
namespace Microsoft.ML.Runtime;
@@ -104,6 +105,16 @@ internal interface IHostEnvironmentInternal : IHostEnvironment
/// GPU device ID to run execution on, to run on CPU.
///
int? GpuDeviceId { get; set; }
+
+ bool TryAddOption(string name, T value);
+
+ void SetOption(string name, T value);
+
+ bool TryGetOption(string name, out T value);
+
+ T GetOptionOrDefault(string name);
+
+ bool RemoveOption(string name);
}
///
diff --git a/src/Microsoft.ML.Core/Environment/HostEnvironmentBase.cs b/src/Microsoft.ML.Core/Environment/HostEnvironmentBase.cs
index 0bb18628df..7ab2620169 100644
--- a/src/Microsoft.ML.Core/Environment/HostEnvironmentBase.cs
+++ b/src/Microsoft.ML.Core/Environment/HostEnvironmentBase.cs
@@ -334,6 +334,10 @@ public void RemoveListener(Action listenerFunc)
public bool FallbackToCpu { get; set; }
+#pragma warning disable MSML_NoInstanceInitializers // Need this to have a default value.
+ protected Dictionary Options { get; } = [];
+#pragma warning restore MSML_NoInstanceInitializers
+
protected readonly TEnv Root;
// This is non-null iff this environment was a fork of another. Disposing a fork
// doesn't free temp files. That is handled when the master is disposed.
@@ -567,4 +571,91 @@ public virtual void PrintMessageNormalized(TextWriter writer, string message, bo
else if (!removeLastNewLine)
writer.WriteLine();
}
+
+ ///
+ /// Trys to add a new runtime option.
+ ///
+ ///
+ /// Name of the option to add.
+ /// Value to set.
+ /// if successful. otherwise.
+ /// When is null or empty.
+ public bool TryAddOption(string name, T value)
+ {
+ if (string.IsNullOrWhiteSpace(name))
+ throw new ArgumentNullException(nameof(name));
+
+ if (Options.ContainsKey(name))
+ return false;
+ SetOption(name, value);
+ return true;
+ }
+
+ ///
+ /// Adds or Sets the with the given . Is cast to .
+ ///
+ ///
+ /// Name of the option to set.
+ /// Value to set.
+ /// When is null or empty.
+ public void SetOption(string name, T value)
+ {
+ if (string.IsNullOrWhiteSpace(name))
+ throw new ArgumentNullException(nameof(name));
+ Options[name] = value;
+ }
+
+ ///
+ /// Gets an option by and returns if that has been added and otherwise.
+ ///
+ ///
+ /// Name of the option to get.
+ /// Options value of type .
+ /// if the option was able to be retrieved, else
+ /// When is null or empty.
+ public bool TryGetOption(string name, out T value)
+ {
+ if (string.IsNullOrWhiteSpace(name))
+ throw new ArgumentNullException(nameof(name));
+
+ if (!Options.TryGetValue(name, out var val) || val is not T)
+ {
+ value = default;
+ return false;
+ }
+ value = (T)val;
+ return true;
+ }
+
+ ///
+ /// Gets either the option stored by that , or adds the default value of with that and returns it.
+ ///
+ ///
+ /// Name of the option to get.
+ /// Options value of type .
+ /// When is null or empty.
+ public T GetOptionOrDefault(string name)
+ {
+ if (string.IsNullOrWhiteSpace(name))
+ throw new ArgumentNullException(nameof(name));
+
+ if (!Options.TryGetValue(name, out object value))
+ SetOption(name, default);
+ else
+ return (T)value;
+ return (T)Options[name];
+ }
+
+ ///
+ /// Removes an option.
+ ///
+ /// Name of the option to remove.
+ /// if successfully removed, else .
+ /// When is null or empty.
+ public bool RemoveOption(string name)
+ {
+ if (string.IsNullOrWhiteSpace(name))
+ throw new ArgumentNullException(nameof(name));
+ return Options.Remove(name);
+ }
}
diff --git a/src/Microsoft.ML.Data/MLContext.cs b/src/Microsoft.ML.Data/MLContext.cs
index f0e1986c5e..c966e5b6be 100644
--- a/src/Microsoft.ML.Data/MLContext.cs
+++ b/src/Microsoft.ML.Data/MLContext.cs
@@ -191,5 +191,11 @@ private static bool InitializeOneDalDispatchingEnabled()
return false;
}
}
+
+ public bool TryAddOption(string name, T value) => _env.TryAddOption(name, value);
+ public void SetOption(string name, T value) => _env.SetOption(name, value);
+ public bool TryGetOption(string name, out T value) => _env.TryGetOption(name, out value);
+ public T GetOptionOrDefault(string name) => _env.GetOptionOrDefault(name);
+ public bool RemoveOption(string name) => _env.RemoveOption(name);
}
}
diff --git a/src/Microsoft.ML.OnnxTransformer/OnnxSessionOptions.cs b/src/Microsoft.ML.OnnxTransformer/OnnxSessionOptions.cs
new file mode 100644
index 0000000000..4401a4491c
--- /dev/null
+++ b/src/Microsoft.ML.OnnxTransformer/OnnxSessionOptions.cs
@@ -0,0 +1,150 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+using System.Collections.Generic;
+using System.Text;
+using Microsoft.ML.Data;
+using Microsoft.ML.OnnxRuntime;
+using Microsoft.ML.Runtime;
+using Microsoft.ML.Transforms.Onnx;
+
+namespace Microsoft.ML.Transforms.Onnx
+{
+ public static class OnnxSessionOptionsExtensions
+ {
+ private const string OnnxSessionOptionsName = "OnnxSessionOptions";
+
+ public static OnnxSessionOptions GetOnnxSessionOption(this IHostEnvironment env)
+ {
+ if (env is IHostEnvironmentInternal localEnvironment)
+ {
+ return localEnvironment.GetOptionOrDefault(OnnxSessionOptionsName);
+ }
+
+ throw new ArgumentException("No Onnx Session Options");
+ }
+
+ public static void SetOnnxSessionOption(this IHostEnvironment env, OnnxSessionOptions onnxSessionOptions)
+ {
+ if (env is IHostEnvironmentInternal localEnvironment)
+ {
+ localEnvironment.SetOption(OnnxSessionOptionsName, onnxSessionOptions);
+ }
+ else
+ throw new ArgumentException("No Onnx Session Options");
+ }
+ }
+
+ public sealed class OnnxSessionOptions
+ {
+ internal void CopyTo(SessionOptions sessionOptions)
+ {
+ sessionOptions.EnableMemoryPattern = EnableMemoryPattern;
+ sessionOptions.ProfileOutputPathPrefix = ProfileOutputPathPrefix;
+ sessionOptions.EnableProfiling = EnableProfiling;
+ sessionOptions.OptimizedModelFilePath = OptimizedModelFilePath;
+ sessionOptions.EnableCpuMemArena = EnableCpuMemArena;
+ if (!PerSessionThreads)
+ sessionOptions.DisablePerSessionThreads();
+ sessionOptions.LogId = LogId;
+ sessionOptions.LogSeverityLevel = LogSeverityLevel;
+ sessionOptions.LogVerbosityLevel = LogVerbosityLevel;
+ sessionOptions.InterOpNumThreads = InterOpNumThreads;
+ sessionOptions.IntraOpNumThreads = IntraOpNumThreads;
+ sessionOptions.GraphOptimizationLevel = GraphOptimizationLevel;
+ sessionOptions.ExecutionMode = ExecutionMode;
+ }
+
+ ///
+ /// Enables the use of the memory allocation patterns in the first Run() call for subsequent runs. Default = true.
+ ///
+#pragma warning disable MSML_NoInstanceInitializers // No initializers on instance fields or properties
+ public bool EnableMemoryPattern { get; set; } = true;
+
+ ///
+ /// Path prefix to use for output of profiling data
+ ///
+ public string ProfileOutputPathPrefix { get; set; } = "onnxruntime_profile_"; // this is the same default in C++ implementation
+
+ ///
+ /// Enables profiling of InferenceSession.Run() calls. Default is false
+ ///
+ public bool EnableProfiling { get; set; } = false;
+
+ ///
+ /// Set filepath to save optimized model after graph level transformations. Default is empty, which implies saving is disabled.
+ ///
+ public string OptimizedModelFilePath { get; set; } = string.Empty;
+
+ ///
+ /// Enables Arena allocator for the CPU memory allocations. Default is true.
+ ///
+ public bool EnableCpuMemArena { get; set; } = true;
+
+ ///
+ /// Per session threads. Default is true.
+ /// If false this makes all sessions in the process use a global TP.
+ ///
+ public bool PerSessionThreads { get; set; } = true;
+
+ ///
+ /// Sets the number of threads used to parallelize the execution within nodes
+ /// A value of 0 means ORT will pick a default. Only used when is false.
+ ///
+ public int GlobalIntraOpNumThreads { get; set; } = 0;
+
+ ///
+ /// Sets the number of threads used to parallelize the execution of the graph (across nodes)
+ /// If sequential execution is enabled this value is ignored
+ /// A value of 0 means ORT will pick a default. Only used when is false.
+ ///
+ public int GlobalInterOpNumThreads { get; set; } = 0;
+
+ ///
+ /// Log Id to be used for the session. Default is empty string.
+ ///
+ public string LogId { get; set; } = string.Empty;
+
+ ///
+ /// Log Severity Level for the session logs. Default = ORT_LOGGING_LEVEL_WARNING
+ ///
+ public OrtLoggingLevel LogSeverityLevel = OrtLoggingLevel.ORT_LOGGING_LEVEL_WARNING;
+
+ ///
+ /// Log Verbosity Level for the session logs. Default = 0. Valid values are >=0.
+ /// This takes into effect only when the LogSeverityLevel is set to ORT_LOGGING_LEVEL_VERBOSE.
+ ///
+ public int LogVerbosityLevel { get; set; } = 0;
+
+ ///
+ /// Sets the number of threads used to parallelize the execution within nodes
+ /// A value of 0 means ORT will pick a default
+ ///
+ public int IntraOpNumThreads { get; set; } = 0;
+
+ ///
+ /// Sets the number of threads used to parallelize the execution of the graph (across nodes)
+ /// If sequential execution is enabled this value is ignored
+ /// A value of 0 means ORT will pick a default
+ ///
+ public int InterOpNumThreads { get; set; } = 0;
+
+ ///
+ /// Sets the graph optimization level for the session. Default is set to ORT_ENABLE_ALL.
+ ///
+ public GraphOptimizationLevel GraphOptimizationLevel { get; set; } = GraphOptimizationLevel.ORT_ENABLE_ALL;
+
+ ///
+ /// Sets the execution mode for the session. Default is set to ORT_SEQUENTIAL.
+ /// See [ONNX_Runtime_Perf_Tuning.md] for more details.
+ ///
+ public ExecutionMode ExecutionMode { get; set; } = ExecutionMode.ORT_SEQUENTIAL;
+#pragma warning restore MSML_NoInstanceInitializers // No initializers on instance fields or properties
+
+ public delegate SessionOptions CreateOnnxSessionOptions();
+
+ public CreateOnnxSessionOptions CreateSessionOptions { get; set; }
+ }
+}
diff --git a/src/Microsoft.ML.OnnxTransformer/OnnxTransform.cs b/src/Microsoft.ML.OnnxTransformer/OnnxTransform.cs
index c4413dd50e..0357c7c040 100644
--- a/src/Microsoft.ML.OnnxTransformer/OnnxTransform.cs
+++ b/src/Microsoft.ML.OnnxTransformer/OnnxTransform.cs
@@ -262,7 +262,7 @@ private OnnxTransformer(IHostEnvironment env, Options options, byte[] modelBytes
Host.CheckNonWhiteSpace(options.ModelFile, nameof(options.ModelFile));
Host.CheckIO(File.Exists(options.ModelFile), "Model file {0} does not exists.", options.ModelFile);
// Because we cannot delete the user file, ownModelFile should be false.
- Model = new OnnxModel(options.ModelFile, options.GpuDeviceId, options.FallbackToCpu, ownModelFile: false, shapeDictionary: shapeDictionary, options.RecursionLimit,
+ Model = new OnnxModel(env, options.ModelFile, options.GpuDeviceId, options.FallbackToCpu, ownModelFile: false, shapeDictionary: shapeDictionary, options.RecursionLimit,
options.InterOpNumThreads, options.IntraOpNumThreads);
}
else
diff --git a/src/Microsoft.ML.OnnxTransformer/OnnxUtils.cs b/src/Microsoft.ML.OnnxTransformer/OnnxUtils.cs
index 9dbd92c3c1..8461531b47 100644
--- a/src/Microsoft.ML.OnnxTransformer/OnnxUtils.cs
+++ b/src/Microsoft.ML.OnnxTransformer/OnnxUtils.cs
@@ -7,6 +7,7 @@
using System.IO;
using System.Linq;
using Microsoft.ML.Data;
+using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Model.OnnxConverter;
using Microsoft.ML.OnnxRuntime;
using Microsoft.ML.OnnxRuntime.Tensors;
@@ -158,6 +159,7 @@ public OnnxVariableInfo(string name, OnnxShape shape, Type typeInOnnxRuntime, Da
///
/// Constructs OnnxModel object from file.
///
+ ///
/// Model file path.
/// GPU device ID to execute on. Null for CPU.
/// If true, resumes CPU execution quietly upon GPU error.
@@ -167,29 +169,61 @@ public OnnxVariableInfo(string name, OnnxShape shape, Type typeInOnnxRuntime, Da
/// Optional, specifies the Protobuf CodedInputStream recursion limit. Default value is 100.
/// Controls the number of threads used to parallelize the execution of the graph (across nodes).
/// Controls the number of threads to use to run the model.
- public OnnxModel(string modelFile, int? gpuDeviceId = null, bool fallbackToCpu = false,
+ public OnnxModel(IHostEnvironment env, string modelFile, int? gpuDeviceId = null, bool fallbackToCpu = false,
bool ownModelFile = false, IDictionary shapeDictionary = null, int recursionLimit = 100,
int? interOpNumThreads = null, int? intraOpNumThreads = null)
{
// If we don't own the model file, _disposed should be false to prevent deleting user's file.
_disposed = false;
- if (gpuDeviceId != null)
+ OnnxSessionOptions onnxSessionOptions = default;
+
+ if (env is IHostEnvironmentInternal localEnvironment)
+ onnxSessionOptions = localEnvironment.GetOnnxSessionOption();
+
+ if (onnxSessionOptions == default)
+ onnxSessionOptions = new OnnxSessionOptions();
+
+ if (!onnxSessionOptions.PerSessionThreads && !OrtEnv.IsCreated)
+ {
+ EnvironmentCreationOptions environmentCreationOptions = new EnvironmentCreationOptions()
+ {
+ threadOptions = new OrtThreadingOptions()
+ {
+ GlobalInterOpNumThreads = onnxSessionOptions.GlobalInterOpNumThreads,
+ GlobalIntraOpNumThreads = onnxSessionOptions.GlobalIntraOpNumThreads,
+ }
+ };
+ // Don't need to catch return value as it sets the singleton as well.
+ OrtEnv.CreateInstanceWithOptions(ref environmentCreationOptions);
+ }
+
+ if (onnxSessionOptions.CreateSessionOptions != null)
+ {
+ _session = new InferenceSession(modelFile, onnxSessionOptions.CreateSessionOptions());
+ }
+ else if (gpuDeviceId != null)
{
try
{
- _session = new InferenceSession(modelFile,
- SessionOptions.MakeSessionOptionWithCudaProvider(gpuDeviceId.Value));
+ SessionOptions sessionOptions = SessionOptions.MakeSessionOptionWithCudaProvider(gpuDeviceId.Value);
+ onnxSessionOptions.CopyTo(sessionOptions);
+
+ sessionOptions.InterOpNumThreads = interOpNumThreads.HasValue ? interOpNumThreads.GetValueOrDefault() : onnxSessionOptions.InterOpNumThreads;
+ sessionOptions.IntraOpNumThreads = intraOpNumThreads.HasValue ? intraOpNumThreads.GetValueOrDefault() : onnxSessionOptions.IntraOpNumThreads;
+
+ _session = new InferenceSession(modelFile, sessionOptions);
}
catch (OnnxRuntimeException)
{
if (fallbackToCpu)
{
- var sessionOptions = new SessionOptions()
- {
- InterOpNumThreads = interOpNumThreads.GetValueOrDefault(),
- IntraOpNumThreads = intraOpNumThreads.GetValueOrDefault()
- };
+ SessionOptions sessionOptions = new SessionOptions();
+ onnxSessionOptions.CopyTo(sessionOptions);
+
+ sessionOptions.InterOpNumThreads = interOpNumThreads.HasValue ? interOpNumThreads.GetValueOrDefault() : onnxSessionOptions.InterOpNumThreads;
+ sessionOptions.IntraOpNumThreads = intraOpNumThreads.HasValue ? intraOpNumThreads.GetValueOrDefault() : onnxSessionOptions.IntraOpNumThreads;
+
_session = new InferenceSession(modelFile, sessionOptions);
}
else
@@ -199,11 +233,12 @@ public OnnxModel(string modelFile, int? gpuDeviceId = null, bool fallbackToCpu =
}
else
{
- var sessionOptions = new SessionOptions()
- {
- InterOpNumThreads = interOpNumThreads.GetValueOrDefault(),
- IntraOpNumThreads = intraOpNumThreads.GetValueOrDefault()
- };
+ SessionOptions sessionOptions = new SessionOptions();
+ onnxSessionOptions.CopyTo(sessionOptions);
+
+ sessionOptions.InterOpNumThreads = interOpNumThreads.HasValue ? interOpNumThreads.GetValueOrDefault() : onnxSessionOptions.InterOpNumThreads;
+ sessionOptions.IntraOpNumThreads = intraOpNumThreads.HasValue ? intraOpNumThreads.GetValueOrDefault() : onnxSessionOptions.IntraOpNumThreads;
+
_session = new InferenceSession(modelFile, sessionOptions);
}
@@ -372,7 +407,7 @@ public static OnnxModel CreateFromBytes(byte[] modelBytes, IHostEnvironment env,
{
var tempModelFile = Path.Combine(((IHostEnvironmentInternal)env).TempFilePath, Path.GetRandomFileName());
File.WriteAllBytes(tempModelFile, modelBytes);
- return new OnnxModel(tempModelFile, gpuDeviceId, fallbackToCpu,
+ return new OnnxModel(env, tempModelFile, gpuDeviceId, fallbackToCpu,
ownModelFile: true, shapeDictionary: shapeDictionary, recursionLimit);
}
@@ -401,7 +436,7 @@ public static OnnxModel CreateFromStream(Stream modelBytes, IHostEnvironment env
modelBytes.Seek(0, SeekOrigin.Begin);
modelBytes.CopyTo(fileStream);
}
- return new OnnxModel(tempModelFile, gpuDeviceId, fallbackToCpu,
+ return new OnnxModel(env, tempModelFile, gpuDeviceId, fallbackToCpu,
ownModelFile: false, shapeDictionary: shapeDictionary, recursionLimit);
}
diff --git a/test/Microsoft.ML.OnnxTransformerTest/OnnxTransformTests.cs b/test/Microsoft.ML.OnnxTransformerTest/OnnxTransformTests.cs
index 8bae1e9c26..01713d8ed1 100644
--- a/test/Microsoft.ML.OnnxTransformerTest/OnnxTransformTests.cs
+++ b/test/Microsoft.ML.OnnxTransformerTest/OnnxTransformTests.cs
@@ -8,6 +8,7 @@
using System.Linq;
using Microsoft.ML.Data;
using Microsoft.ML.Model;
+using Microsoft.ML.OnnxRuntime;
using Microsoft.ML.RunTests;
using Microsoft.ML.Runtime;
using Microsoft.ML.TestFramework.Attributes;
@@ -383,13 +384,47 @@ public void OnnxModelScenario()
(onnxTransformer as IDisposable)?.Dispose();
}
+ [OnnxFact]
+ public void OnnxModelCustomOptions()
+ {
+ var modelFile = "squeezenet/00000001/model.onnx";
+ var env = new ConsoleEnvironment(seed: 1);
+ var samplevector = GetSampleArrayData();
+
+ var dataView = ML.Data.LoadFromEnumerable(
+ new TestData[] {
+ new TestData()
+ {
+ data_0 = samplevector
+ }
+ });
+
+ // Setting per session threads to true should work.
+ OnnxSessionOptions onnxSessionOptions = new OnnxSessionOptions()
+ {
+ PerSessionThreads = true
+ };
+ ML.SetOnnxSessionOption(onnxSessionOptions);
+ var pipeline = ML.Transforms.ApplyOnnxModel("softmaxout_1", "data_0", modelFile, gpuDeviceId: _gpuDeviceId, fallbackToCpu: _fallbackToCpu);
+ var onnxTransformer = pipeline.Fit(dataView);
+
+ // Trying to then set per session threads to false after the OrtEnv has been initialized to true should throw.
+ onnxSessionOptions.PerSessionThreads = false;
+ onnxSessionOptions.GlobalIntraOpNumThreads = 1;
+ onnxSessionOptions.GlobalInterOpNumThreads = 1;
+
+ ML.SetOnnxSessionOption(onnxSessionOptions);
+ Assert.Throws(() => ML.Transforms.ApplyOnnxModel("softmaxout_1", "data_0", modelFile, gpuDeviceId: _gpuDeviceId, fallbackToCpu: _fallbackToCpu));
+
+ (onnxTransformer as IDisposable)?.Dispose();
+ }
+
[OnnxFact]
public void OnnxModelMultiInput()
{
var modelFile = Path.Combine(Directory.GetCurrentDirectory(), "twoinput", "twoinput.onnx");
var env = new ConsoleEnvironment(seed: 1);
var samplevector = GetSampleArrayData();
-
var dataView = ML.Data.LoadFromEnumerable(
new TestDataMulti[] {
new TestDataMulti()
@@ -774,7 +809,7 @@ public void TestOnnxModelNotDisposal()
var modelFile = Path.Combine(Directory.GetCurrentDirectory(), "zipmap", "TestZipMapInt64.onnx");
// Create ONNX model from the model file.
- var onnxModel = new OnnxModel(modelFile);
+ var onnxModel = new OnnxModel(ML, modelFile);
// Check if a temporal file is crated for storing the byte[].
Assert.True(File.Exists(onnxModel.ModelStream.Name));