Skip to content

Commit

Permalink
Adds in a way to add settings for the MLContext. (#7273)
Browse files Browse the repository at this point in the history
* api, no tests

* updates from pr

* fixed rebase errors and pr comments

* updates based on ONNX team
  • Loading branch information
michaelgsharp authored Oct 22, 2024
1 parent 869dc9f commit a7a6d88
Show file tree
Hide file tree
Showing 7 changed files with 347 additions and 19 deletions.
11 changes: 11 additions & 0 deletions src/Microsoft.ML.Core/Data/IHostEnvironment.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
// See the LICENSE file in the project root for more information.

using System;
using System.Collections.Generic;
using Microsoft.ML.Data;

namespace Microsoft.ML.Runtime;
Expand Down Expand Up @@ -104,6 +105,16 @@ internal interface IHostEnvironmentInternal : IHostEnvironment
/// GPU device ID to run execution on, <see langword="null" /> to run on CPU.
/// </summary>
int? GpuDeviceId { get; set; }

bool TryAddOption<T>(string name, T value);

void SetOption<T>(string name, T value);

bool TryGetOption<T>(string name, out T value);

T GetOptionOrDefault<T>(string name);

bool RemoveOption(string name);
}

/// <summary>
Expand Down
91 changes: 91 additions & 0 deletions src/Microsoft.ML.Core/Environment/HostEnvironmentBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,10 @@ public void RemoveListener(Action<IMessageSource, TMessage> listenerFunc)

public bool FallbackToCpu { get; set; }

#pragma warning disable MSML_NoInstanceInitializers // Need this to have a default value.
protected Dictionary<string, object> Options { get; } = [];
#pragma warning restore MSML_NoInstanceInitializers

protected readonly TEnv Root;
// This is non-null iff this environment was a fork of another. Disposing a fork
// doesn't free temp files. That is handled when the master is disposed.
Expand Down Expand Up @@ -567,4 +571,91 @@ public virtual void PrintMessageNormalized(TextWriter writer, string message, bo
else if (!removeLastNewLine)
writer.WriteLine();
}

/// <summary>
/// Trys to add a new runtime option.
/// </summary>
/// <typeparam name="T"></typeparam>
/// <param name="name">Name of the option to add.</param>
/// <param name="value">Value to set.</param>
/// <returns><see langword="true"/> if successful. <see langword="false"/> otherwise.</returns>
/// <exception cref="ArgumentNullException">When <paramref name="name"/> is null or empty.</exception>
public bool TryAddOption<T>(string name, T value)
{
if (string.IsNullOrWhiteSpace(name))
throw new ArgumentNullException(nameof(name));

if (Options.ContainsKey(name))
return false;
SetOption(name, value);
return true;
}

/// <summary>
/// Adds or Sets the <paramref name="value"/> with the given <paramref name="name"/>. Is cast to <typeparamref name="T"/>.
/// </summary>
/// <typeparam name="T"></typeparam>
/// <param name="name">Name of the option to set.</param>
/// <param name="value">Value to set.</param>
/// <exception cref="ArgumentNullException">When <paramref name="name"/> is null or empty.</exception>
public void SetOption<T>(string name, T value)
{
if (string.IsNullOrWhiteSpace(name))
throw new ArgumentNullException(nameof(name));
Options[name] = value;
}

/// <summary>
/// Gets an option by <paramref name="name"/> and returns <see langword="true"/> if that has been added and <see langword="false"/> otherwise.
/// </summary>
/// <typeparam name="T"></typeparam>
/// <param name="name">Name of the option to get.</param>
/// <param name="value">Options value of type <typeparamref name="T"/>.</param>
/// <returns><see langword="true"/> if the option was able to be retrieved, else <see langword="false"/></returns>
/// <exception cref="ArgumentNullException">When <paramref name="name"/> is null or empty.</exception>
public bool TryGetOption<T>(string name, out T value)
{
if (string.IsNullOrWhiteSpace(name))
throw new ArgumentNullException(nameof(name));

if (!Options.TryGetValue(name, out var val) || val is not T)
{
value = default;
return false;
}
value = (T)val;
return true;
}

/// <summary>
/// Gets either the option stored by that <paramref name="name"/>, or adds the default value of <typeparamref name="T"/> with that <paramref name="name"/> and returns it.
/// </summary>
/// <typeparam name="T"></typeparam>
/// <param name="name">Name of the option to get.</param>
/// <returns>Options value of type <typeparamref name="T"/>.</returns>
/// <exception cref="ArgumentNullException">When <paramref name="name"/> is null or empty.</exception>
public T GetOptionOrDefault<T>(string name)
{
if (string.IsNullOrWhiteSpace(name))
throw new ArgumentNullException(nameof(name));

if (!Options.TryGetValue(name, out object value))
SetOption<T>(name, default);
else
return (T)value;
return (T)Options[name];
}

/// <summary>
/// Removes an option.
/// </summary>
/// <param name="name">Name of the option to remove.</param>
/// <returns><see langword="true"/> if successfully removed, else <see langword="false"/>.</returns>
/// <exception cref="ArgumentNullException">When <paramref name="name"/> is null or empty.</exception>
public bool RemoveOption(string name)
{
if (string.IsNullOrWhiteSpace(name))
throw new ArgumentNullException(nameof(name));
return Options.Remove(name);
}
}
6 changes: 6 additions & 0 deletions src/Microsoft.ML.Data/MLContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -191,5 +191,11 @@ private static bool InitializeOneDalDispatchingEnabled()
return false;
}
}

public bool TryAddOption<T>(string name, T value) => _env.TryAddOption(name, value);
public void SetOption<T>(string name, T value) => _env.SetOption(name, value);
public bool TryGetOption<T>(string name, out T value) => _env.TryGetOption<T>(name, out value);
public T GetOptionOrDefault<T>(string name) => _env.GetOptionOrDefault<T>(name);
public bool RemoveOption(string name) => _env.RemoveOption(name);
}
}
150 changes: 150 additions & 0 deletions src/Microsoft.ML.OnnxTransformer/OnnxSessionOptions.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System;
using System.Collections.Generic;
using System.Text;
using Microsoft.ML.Data;
using Microsoft.ML.OnnxRuntime;
using Microsoft.ML.Runtime;
using Microsoft.ML.Transforms.Onnx;

namespace Microsoft.ML.Transforms.Onnx
{
public static class OnnxSessionOptionsExtensions
{
private const string OnnxSessionOptionsName = "OnnxSessionOptions";

public static OnnxSessionOptions GetOnnxSessionOption(this IHostEnvironment env)
{
if (env is IHostEnvironmentInternal localEnvironment)
{
return localEnvironment.GetOptionOrDefault<OnnxSessionOptions>(OnnxSessionOptionsName);
}

throw new ArgumentException("No Onnx Session Options");
}

public static void SetOnnxSessionOption(this IHostEnvironment env, OnnxSessionOptions onnxSessionOptions)
{
if (env is IHostEnvironmentInternal localEnvironment)
{
localEnvironment.SetOption(OnnxSessionOptionsName, onnxSessionOptions);
}
else
throw new ArgumentException("No Onnx Session Options");
}
}

public sealed class OnnxSessionOptions
{
internal void CopyTo(SessionOptions sessionOptions)
{
sessionOptions.EnableMemoryPattern = EnableMemoryPattern;
sessionOptions.ProfileOutputPathPrefix = ProfileOutputPathPrefix;
sessionOptions.EnableProfiling = EnableProfiling;
sessionOptions.OptimizedModelFilePath = OptimizedModelFilePath;
sessionOptions.EnableCpuMemArena = EnableCpuMemArena;
if (!PerSessionThreads)
sessionOptions.DisablePerSessionThreads();
sessionOptions.LogId = LogId;
sessionOptions.LogSeverityLevel = LogSeverityLevel;
sessionOptions.LogVerbosityLevel = LogVerbosityLevel;
sessionOptions.InterOpNumThreads = InterOpNumThreads;
sessionOptions.IntraOpNumThreads = IntraOpNumThreads;
sessionOptions.GraphOptimizationLevel = GraphOptimizationLevel;
sessionOptions.ExecutionMode = ExecutionMode;
}

/// <summary>
/// Enables the use of the memory allocation patterns in the first Run() call for subsequent runs. Default = true.
/// </summary>
#pragma warning disable MSML_NoInstanceInitializers // No initializers on instance fields or properties
public bool EnableMemoryPattern { get; set; } = true;

/// <summary>
/// Path prefix to use for output of profiling data
/// </summary>
public string ProfileOutputPathPrefix { get; set; } = "onnxruntime_profile_"; // this is the same default in C++ implementation

/// <summary>
/// Enables profiling of InferenceSession.Run() calls. Default is false
/// </summary>
public bool EnableProfiling { get; set; } = false;

/// <summary>
/// Set filepath to save optimized model after graph level transformations. Default is empty, which implies saving is disabled.
/// </summary>
public string OptimizedModelFilePath { get; set; } = string.Empty;

/// <summary>
/// Enables Arena allocator for the CPU memory allocations. Default is true.
/// </summary>
public bool EnableCpuMemArena { get; set; } = true;

/// <summary>
/// Per session threads. Default is true.
/// If false this makes all sessions in the process use a global TP.
/// </summary>
public bool PerSessionThreads { get; set; } = true;

/// <summary>
/// Sets the number of threads used to parallelize the execution within nodes
/// A value of 0 means ORT will pick a default. Only used when <see cref="PerSessionThreads"/> is false.
/// </summary>
public int GlobalIntraOpNumThreads { get; set; } = 0;

/// <summary>
/// Sets the number of threads used to parallelize the execution of the graph (across nodes)
/// If sequential execution is enabled this value is ignored
/// A value of 0 means ORT will pick a default. Only used when <see cref="PerSessionThreads"/> is false.
/// </summary>
public int GlobalInterOpNumThreads { get; set; } = 0;

/// <summary>
/// Log Id to be used for the session. Default is empty string.
/// </summary>
public string LogId { get; set; } = string.Empty;

/// <summary>
/// Log Severity Level for the session logs. Default = ORT_LOGGING_LEVEL_WARNING
/// </summary>
public OrtLoggingLevel LogSeverityLevel = OrtLoggingLevel.ORT_LOGGING_LEVEL_WARNING;

/// <summary>
/// Log Verbosity Level for the session logs. Default = 0. Valid values are >=0.
/// This takes into effect only when the LogSeverityLevel is set to ORT_LOGGING_LEVEL_VERBOSE.
/// </summary>
public int LogVerbosityLevel { get; set; } = 0;

/// <summary>
/// Sets the number of threads used to parallelize the execution within nodes
/// A value of 0 means ORT will pick a default
/// </summary>
public int IntraOpNumThreads { get; set; } = 0;

/// <summary>
/// Sets the number of threads used to parallelize the execution of the graph (across nodes)
/// If sequential execution is enabled this value is ignored
/// A value of 0 means ORT will pick a default
/// </summary>
public int InterOpNumThreads { get; set; } = 0;

/// <summary>
/// Sets the graph optimization level for the session. Default is set to ORT_ENABLE_ALL.
/// </summary>
public GraphOptimizationLevel GraphOptimizationLevel { get; set; } = GraphOptimizationLevel.ORT_ENABLE_ALL;

/// <summary>
/// Sets the execution mode for the session. Default is set to ORT_SEQUENTIAL.
/// See [ONNX_Runtime_Perf_Tuning.md] for more details.
/// </summary>
public ExecutionMode ExecutionMode { get; set; } = ExecutionMode.ORT_SEQUENTIAL;
#pragma warning restore MSML_NoInstanceInitializers // No initializers on instance fields or properties

public delegate SessionOptions CreateOnnxSessionOptions();

public CreateOnnxSessionOptions CreateSessionOptions { get; set; }
}
}
2 changes: 1 addition & 1 deletion src/Microsoft.ML.OnnxTransformer/OnnxTransform.cs
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ private OnnxTransformer(IHostEnvironment env, Options options, byte[] modelBytes
Host.CheckNonWhiteSpace(options.ModelFile, nameof(options.ModelFile));
Host.CheckIO(File.Exists(options.ModelFile), "Model file {0} does not exists.", options.ModelFile);
// Because we cannot delete the user file, ownModelFile should be false.
Model = new OnnxModel(options.ModelFile, options.GpuDeviceId, options.FallbackToCpu, ownModelFile: false, shapeDictionary: shapeDictionary, options.RecursionLimit,
Model = new OnnxModel(env, options.ModelFile, options.GpuDeviceId, options.FallbackToCpu, ownModelFile: false, shapeDictionary: shapeDictionary, options.RecursionLimit,
options.InterOpNumThreads, options.IntraOpNumThreads);
}
else
Expand Down
Loading

0 comments on commit a7a6d88

Please sign in to comment.