diff --git a/src/Microsoft.ML.Core/Data/IHostEnvironment.cs b/src/Microsoft.ML.Core/Data/IHostEnvironment.cs index 06fe0b32b4..f9849c9a6c 100644 --- a/src/Microsoft.ML.Core/Data/IHostEnvironment.cs +++ b/src/Microsoft.ML.Core/Data/IHostEnvironment.cs @@ -3,6 +3,7 @@ // See the LICENSE file in the project root for more information. using System; +using System.Collections.Generic; using Microsoft.ML.Data; namespace Microsoft.ML.Runtime; @@ -104,6 +105,16 @@ internal interface IHostEnvironmentInternal : IHostEnvironment /// GPU device ID to run execution on, to run on CPU. /// int? GpuDeviceId { get; set; } + + bool TryAddOption(string name, T value); + + void SetOption(string name, T value); + + bool TryGetOption(string name, out T value); + + T GetOptionOrDefault(string name); + + bool RemoveOption(string name); } /// diff --git a/src/Microsoft.ML.Core/Environment/HostEnvironmentBase.cs b/src/Microsoft.ML.Core/Environment/HostEnvironmentBase.cs index 0bb18628df..7ab2620169 100644 --- a/src/Microsoft.ML.Core/Environment/HostEnvironmentBase.cs +++ b/src/Microsoft.ML.Core/Environment/HostEnvironmentBase.cs @@ -334,6 +334,10 @@ public void RemoveListener(Action listenerFunc) public bool FallbackToCpu { get; set; } +#pragma warning disable MSML_NoInstanceInitializers // Need this to have a default value. + protected Dictionary Options { get; } = []; +#pragma warning restore MSML_NoInstanceInitializers + protected readonly TEnv Root; // This is non-null iff this environment was a fork of another. Disposing a fork // doesn't free temp files. That is handled when the master is disposed. @@ -567,4 +571,91 @@ public virtual void PrintMessageNormalized(TextWriter writer, string message, bo else if (!removeLastNewLine) writer.WriteLine(); } + + /// + /// Trys to add a new runtime option. + /// + /// + /// Name of the option to add. + /// Value to set. + /// if successful. otherwise. + /// When is null or empty. + public bool TryAddOption(string name, T value) + { + if (string.IsNullOrWhiteSpace(name)) + throw new ArgumentNullException(nameof(name)); + + if (Options.ContainsKey(name)) + return false; + SetOption(name, value); + return true; + } + + /// + /// Adds or Sets the with the given . Is cast to . + /// + /// + /// Name of the option to set. + /// Value to set. + /// When is null or empty. + public void SetOption(string name, T value) + { + if (string.IsNullOrWhiteSpace(name)) + throw new ArgumentNullException(nameof(name)); + Options[name] = value; + } + + /// + /// Gets an option by and returns if that has been added and otherwise. + /// + /// + /// Name of the option to get. + /// Options value of type . + /// if the option was able to be retrieved, else + /// When is null or empty. + public bool TryGetOption(string name, out T value) + { + if (string.IsNullOrWhiteSpace(name)) + throw new ArgumentNullException(nameof(name)); + + if (!Options.TryGetValue(name, out var val) || val is not T) + { + value = default; + return false; + } + value = (T)val; + return true; + } + + /// + /// Gets either the option stored by that , or adds the default value of with that and returns it. + /// + /// + /// Name of the option to get. + /// Options value of type . + /// When is null or empty. + public T GetOptionOrDefault(string name) + { + if (string.IsNullOrWhiteSpace(name)) + throw new ArgumentNullException(nameof(name)); + + if (!Options.TryGetValue(name, out object value)) + SetOption(name, default); + else + return (T)value; + return (T)Options[name]; + } + + /// + /// Removes an option. + /// + /// Name of the option to remove. + /// if successfully removed, else . + /// When is null or empty. + public bool RemoveOption(string name) + { + if (string.IsNullOrWhiteSpace(name)) + throw new ArgumentNullException(nameof(name)); + return Options.Remove(name); + } } diff --git a/src/Microsoft.ML.Data/MLContext.cs b/src/Microsoft.ML.Data/MLContext.cs index f0e1986c5e..c966e5b6be 100644 --- a/src/Microsoft.ML.Data/MLContext.cs +++ b/src/Microsoft.ML.Data/MLContext.cs @@ -191,5 +191,11 @@ private static bool InitializeOneDalDispatchingEnabled() return false; } } + + public bool TryAddOption(string name, T value) => _env.TryAddOption(name, value); + public void SetOption(string name, T value) => _env.SetOption(name, value); + public bool TryGetOption(string name, out T value) => _env.TryGetOption(name, out value); + public T GetOptionOrDefault(string name) => _env.GetOptionOrDefault(name); + public bool RemoveOption(string name) => _env.RemoveOption(name); } } diff --git a/src/Microsoft.ML.OnnxTransformer/OnnxSessionOptions.cs b/src/Microsoft.ML.OnnxTransformer/OnnxSessionOptions.cs new file mode 100644 index 0000000000..4401a4491c --- /dev/null +++ b/src/Microsoft.ML.OnnxTransformer/OnnxSessionOptions.cs @@ -0,0 +1,150 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Generic; +using System.Text; +using Microsoft.ML.Data; +using Microsoft.ML.OnnxRuntime; +using Microsoft.ML.Runtime; +using Microsoft.ML.Transforms.Onnx; + +namespace Microsoft.ML.Transforms.Onnx +{ + public static class OnnxSessionOptionsExtensions + { + private const string OnnxSessionOptionsName = "OnnxSessionOptions"; + + public static OnnxSessionOptions GetOnnxSessionOption(this IHostEnvironment env) + { + if (env is IHostEnvironmentInternal localEnvironment) + { + return localEnvironment.GetOptionOrDefault(OnnxSessionOptionsName); + } + + throw new ArgumentException("No Onnx Session Options"); + } + + public static void SetOnnxSessionOption(this IHostEnvironment env, OnnxSessionOptions onnxSessionOptions) + { + if (env is IHostEnvironmentInternal localEnvironment) + { + localEnvironment.SetOption(OnnxSessionOptionsName, onnxSessionOptions); + } + else + throw new ArgumentException("No Onnx Session Options"); + } + } + + public sealed class OnnxSessionOptions + { + internal void CopyTo(SessionOptions sessionOptions) + { + sessionOptions.EnableMemoryPattern = EnableMemoryPattern; + sessionOptions.ProfileOutputPathPrefix = ProfileOutputPathPrefix; + sessionOptions.EnableProfiling = EnableProfiling; + sessionOptions.OptimizedModelFilePath = OptimizedModelFilePath; + sessionOptions.EnableCpuMemArena = EnableCpuMemArena; + if (!PerSessionThreads) + sessionOptions.DisablePerSessionThreads(); + sessionOptions.LogId = LogId; + sessionOptions.LogSeverityLevel = LogSeverityLevel; + sessionOptions.LogVerbosityLevel = LogVerbosityLevel; + sessionOptions.InterOpNumThreads = InterOpNumThreads; + sessionOptions.IntraOpNumThreads = IntraOpNumThreads; + sessionOptions.GraphOptimizationLevel = GraphOptimizationLevel; + sessionOptions.ExecutionMode = ExecutionMode; + } + + /// + /// Enables the use of the memory allocation patterns in the first Run() call for subsequent runs. Default = true. + /// +#pragma warning disable MSML_NoInstanceInitializers // No initializers on instance fields or properties + public bool EnableMemoryPattern { get; set; } = true; + + /// + /// Path prefix to use for output of profiling data + /// + public string ProfileOutputPathPrefix { get; set; } = "onnxruntime_profile_"; // this is the same default in C++ implementation + + /// + /// Enables profiling of InferenceSession.Run() calls. Default is false + /// + public bool EnableProfiling { get; set; } = false; + + /// + /// Set filepath to save optimized model after graph level transformations. Default is empty, which implies saving is disabled. + /// + public string OptimizedModelFilePath { get; set; } = string.Empty; + + /// + /// Enables Arena allocator for the CPU memory allocations. Default is true. + /// + public bool EnableCpuMemArena { get; set; } = true; + + /// + /// Per session threads. Default is true. + /// If false this makes all sessions in the process use a global TP. + /// + public bool PerSessionThreads { get; set; } = true; + + /// + /// Sets the number of threads used to parallelize the execution within nodes + /// A value of 0 means ORT will pick a default. Only used when is false. + /// + public int GlobalIntraOpNumThreads { get; set; } = 0; + + /// + /// Sets the number of threads used to parallelize the execution of the graph (across nodes) + /// If sequential execution is enabled this value is ignored + /// A value of 0 means ORT will pick a default. Only used when is false. + /// + public int GlobalInterOpNumThreads { get; set; } = 0; + + /// + /// Log Id to be used for the session. Default is empty string. + /// + public string LogId { get; set; } = string.Empty; + + /// + /// Log Severity Level for the session logs. Default = ORT_LOGGING_LEVEL_WARNING + /// + public OrtLoggingLevel LogSeverityLevel = OrtLoggingLevel.ORT_LOGGING_LEVEL_WARNING; + + /// + /// Log Verbosity Level for the session logs. Default = 0. Valid values are >=0. + /// This takes into effect only when the LogSeverityLevel is set to ORT_LOGGING_LEVEL_VERBOSE. + /// + public int LogVerbosityLevel { get; set; } = 0; + + /// + /// Sets the number of threads used to parallelize the execution within nodes + /// A value of 0 means ORT will pick a default + /// + public int IntraOpNumThreads { get; set; } = 0; + + /// + /// Sets the number of threads used to parallelize the execution of the graph (across nodes) + /// If sequential execution is enabled this value is ignored + /// A value of 0 means ORT will pick a default + /// + public int InterOpNumThreads { get; set; } = 0; + + /// + /// Sets the graph optimization level for the session. Default is set to ORT_ENABLE_ALL. + /// + public GraphOptimizationLevel GraphOptimizationLevel { get; set; } = GraphOptimizationLevel.ORT_ENABLE_ALL; + + /// + /// Sets the execution mode for the session. Default is set to ORT_SEQUENTIAL. + /// See [ONNX_Runtime_Perf_Tuning.md] for more details. + /// + public ExecutionMode ExecutionMode { get; set; } = ExecutionMode.ORT_SEQUENTIAL; +#pragma warning restore MSML_NoInstanceInitializers // No initializers on instance fields or properties + + public delegate SessionOptions CreateOnnxSessionOptions(); + + public CreateOnnxSessionOptions CreateSessionOptions { get; set; } + } +} diff --git a/src/Microsoft.ML.OnnxTransformer/OnnxTransform.cs b/src/Microsoft.ML.OnnxTransformer/OnnxTransform.cs index c4413dd50e..0357c7c040 100644 --- a/src/Microsoft.ML.OnnxTransformer/OnnxTransform.cs +++ b/src/Microsoft.ML.OnnxTransformer/OnnxTransform.cs @@ -262,7 +262,7 @@ private OnnxTransformer(IHostEnvironment env, Options options, byte[] modelBytes Host.CheckNonWhiteSpace(options.ModelFile, nameof(options.ModelFile)); Host.CheckIO(File.Exists(options.ModelFile), "Model file {0} does not exists.", options.ModelFile); // Because we cannot delete the user file, ownModelFile should be false. - Model = new OnnxModel(options.ModelFile, options.GpuDeviceId, options.FallbackToCpu, ownModelFile: false, shapeDictionary: shapeDictionary, options.RecursionLimit, + Model = new OnnxModel(env, options.ModelFile, options.GpuDeviceId, options.FallbackToCpu, ownModelFile: false, shapeDictionary: shapeDictionary, options.RecursionLimit, options.InterOpNumThreads, options.IntraOpNumThreads); } else diff --git a/src/Microsoft.ML.OnnxTransformer/OnnxUtils.cs b/src/Microsoft.ML.OnnxTransformer/OnnxUtils.cs index 9dbd92c3c1..8461531b47 100644 --- a/src/Microsoft.ML.OnnxTransformer/OnnxUtils.cs +++ b/src/Microsoft.ML.OnnxTransformer/OnnxUtils.cs @@ -7,6 +7,7 @@ using System.IO; using System.Linq; using Microsoft.ML.Data; +using Microsoft.ML.Internal.Utilities; using Microsoft.ML.Model.OnnxConverter; using Microsoft.ML.OnnxRuntime; using Microsoft.ML.OnnxRuntime.Tensors; @@ -158,6 +159,7 @@ public OnnxVariableInfo(string name, OnnxShape shape, Type typeInOnnxRuntime, Da /// /// Constructs OnnxModel object from file. /// + /// /// Model file path. /// GPU device ID to execute on. Null for CPU. /// If true, resumes CPU execution quietly upon GPU error. @@ -167,29 +169,61 @@ public OnnxVariableInfo(string name, OnnxShape shape, Type typeInOnnxRuntime, Da /// Optional, specifies the Protobuf CodedInputStream recursion limit. Default value is 100. /// Controls the number of threads used to parallelize the execution of the graph (across nodes). /// Controls the number of threads to use to run the model. - public OnnxModel(string modelFile, int? gpuDeviceId = null, bool fallbackToCpu = false, + public OnnxModel(IHostEnvironment env, string modelFile, int? gpuDeviceId = null, bool fallbackToCpu = false, bool ownModelFile = false, IDictionary shapeDictionary = null, int recursionLimit = 100, int? interOpNumThreads = null, int? intraOpNumThreads = null) { // If we don't own the model file, _disposed should be false to prevent deleting user's file. _disposed = false; - if (gpuDeviceId != null) + OnnxSessionOptions onnxSessionOptions = default; + + if (env is IHostEnvironmentInternal localEnvironment) + onnxSessionOptions = localEnvironment.GetOnnxSessionOption(); + + if (onnxSessionOptions == default) + onnxSessionOptions = new OnnxSessionOptions(); + + if (!onnxSessionOptions.PerSessionThreads && !OrtEnv.IsCreated) + { + EnvironmentCreationOptions environmentCreationOptions = new EnvironmentCreationOptions() + { + threadOptions = new OrtThreadingOptions() + { + GlobalInterOpNumThreads = onnxSessionOptions.GlobalInterOpNumThreads, + GlobalIntraOpNumThreads = onnxSessionOptions.GlobalIntraOpNumThreads, + } + }; + // Don't need to catch return value as it sets the singleton as well. + OrtEnv.CreateInstanceWithOptions(ref environmentCreationOptions); + } + + if (onnxSessionOptions.CreateSessionOptions != null) + { + _session = new InferenceSession(modelFile, onnxSessionOptions.CreateSessionOptions()); + } + else if (gpuDeviceId != null) { try { - _session = new InferenceSession(modelFile, - SessionOptions.MakeSessionOptionWithCudaProvider(gpuDeviceId.Value)); + SessionOptions sessionOptions = SessionOptions.MakeSessionOptionWithCudaProvider(gpuDeviceId.Value); + onnxSessionOptions.CopyTo(sessionOptions); + + sessionOptions.InterOpNumThreads = interOpNumThreads.HasValue ? interOpNumThreads.GetValueOrDefault() : onnxSessionOptions.InterOpNumThreads; + sessionOptions.IntraOpNumThreads = intraOpNumThreads.HasValue ? intraOpNumThreads.GetValueOrDefault() : onnxSessionOptions.IntraOpNumThreads; + + _session = new InferenceSession(modelFile, sessionOptions); } catch (OnnxRuntimeException) { if (fallbackToCpu) { - var sessionOptions = new SessionOptions() - { - InterOpNumThreads = interOpNumThreads.GetValueOrDefault(), - IntraOpNumThreads = intraOpNumThreads.GetValueOrDefault() - }; + SessionOptions sessionOptions = new SessionOptions(); + onnxSessionOptions.CopyTo(sessionOptions); + + sessionOptions.InterOpNumThreads = interOpNumThreads.HasValue ? interOpNumThreads.GetValueOrDefault() : onnxSessionOptions.InterOpNumThreads; + sessionOptions.IntraOpNumThreads = intraOpNumThreads.HasValue ? intraOpNumThreads.GetValueOrDefault() : onnxSessionOptions.IntraOpNumThreads; + _session = new InferenceSession(modelFile, sessionOptions); } else @@ -199,11 +233,12 @@ public OnnxModel(string modelFile, int? gpuDeviceId = null, bool fallbackToCpu = } else { - var sessionOptions = new SessionOptions() - { - InterOpNumThreads = interOpNumThreads.GetValueOrDefault(), - IntraOpNumThreads = intraOpNumThreads.GetValueOrDefault() - }; + SessionOptions sessionOptions = new SessionOptions(); + onnxSessionOptions.CopyTo(sessionOptions); + + sessionOptions.InterOpNumThreads = interOpNumThreads.HasValue ? interOpNumThreads.GetValueOrDefault() : onnxSessionOptions.InterOpNumThreads; + sessionOptions.IntraOpNumThreads = intraOpNumThreads.HasValue ? intraOpNumThreads.GetValueOrDefault() : onnxSessionOptions.IntraOpNumThreads; + _session = new InferenceSession(modelFile, sessionOptions); } @@ -372,7 +407,7 @@ public static OnnxModel CreateFromBytes(byte[] modelBytes, IHostEnvironment env, { var tempModelFile = Path.Combine(((IHostEnvironmentInternal)env).TempFilePath, Path.GetRandomFileName()); File.WriteAllBytes(tempModelFile, modelBytes); - return new OnnxModel(tempModelFile, gpuDeviceId, fallbackToCpu, + return new OnnxModel(env, tempModelFile, gpuDeviceId, fallbackToCpu, ownModelFile: true, shapeDictionary: shapeDictionary, recursionLimit); } @@ -401,7 +436,7 @@ public static OnnxModel CreateFromStream(Stream modelBytes, IHostEnvironment env modelBytes.Seek(0, SeekOrigin.Begin); modelBytes.CopyTo(fileStream); } - return new OnnxModel(tempModelFile, gpuDeviceId, fallbackToCpu, + return new OnnxModel(env, tempModelFile, gpuDeviceId, fallbackToCpu, ownModelFile: false, shapeDictionary: shapeDictionary, recursionLimit); } diff --git a/test/Microsoft.ML.OnnxTransformerTest/OnnxTransformTests.cs b/test/Microsoft.ML.OnnxTransformerTest/OnnxTransformTests.cs index 8bae1e9c26..01713d8ed1 100644 --- a/test/Microsoft.ML.OnnxTransformerTest/OnnxTransformTests.cs +++ b/test/Microsoft.ML.OnnxTransformerTest/OnnxTransformTests.cs @@ -8,6 +8,7 @@ using System.Linq; using Microsoft.ML.Data; using Microsoft.ML.Model; +using Microsoft.ML.OnnxRuntime; using Microsoft.ML.RunTests; using Microsoft.ML.Runtime; using Microsoft.ML.TestFramework.Attributes; @@ -383,13 +384,47 @@ public void OnnxModelScenario() (onnxTransformer as IDisposable)?.Dispose(); } + [OnnxFact] + public void OnnxModelCustomOptions() + { + var modelFile = "squeezenet/00000001/model.onnx"; + var env = new ConsoleEnvironment(seed: 1); + var samplevector = GetSampleArrayData(); + + var dataView = ML.Data.LoadFromEnumerable( + new TestData[] { + new TestData() + { + data_0 = samplevector + } + }); + + // Setting per session threads to true should work. + OnnxSessionOptions onnxSessionOptions = new OnnxSessionOptions() + { + PerSessionThreads = true + }; + ML.SetOnnxSessionOption(onnxSessionOptions); + var pipeline = ML.Transforms.ApplyOnnxModel("softmaxout_1", "data_0", modelFile, gpuDeviceId: _gpuDeviceId, fallbackToCpu: _fallbackToCpu); + var onnxTransformer = pipeline.Fit(dataView); + + // Trying to then set per session threads to false after the OrtEnv has been initialized to true should throw. + onnxSessionOptions.PerSessionThreads = false; + onnxSessionOptions.GlobalIntraOpNumThreads = 1; + onnxSessionOptions.GlobalInterOpNumThreads = 1; + + ML.SetOnnxSessionOption(onnxSessionOptions); + Assert.Throws(() => ML.Transforms.ApplyOnnxModel("softmaxout_1", "data_0", modelFile, gpuDeviceId: _gpuDeviceId, fallbackToCpu: _fallbackToCpu)); + + (onnxTransformer as IDisposable)?.Dispose(); + } + [OnnxFact] public void OnnxModelMultiInput() { var modelFile = Path.Combine(Directory.GetCurrentDirectory(), "twoinput", "twoinput.onnx"); var env = new ConsoleEnvironment(seed: 1); var samplevector = GetSampleArrayData(); - var dataView = ML.Data.LoadFromEnumerable( new TestDataMulti[] { new TestDataMulti() @@ -774,7 +809,7 @@ public void TestOnnxModelNotDisposal() var modelFile = Path.Combine(Directory.GetCurrentDirectory(), "zipmap", "TestZipMapInt64.onnx"); // Create ONNX model from the model file. - var onnxModel = new OnnxModel(modelFile); + var onnxModel = new OnnxModel(ML, modelFile); // Check if a temporal file is crated for storing the byte[]. Assert.True(File.Exists(onnxModel.ModelStream.Name));