-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Adds in a way to add settings for the MLContext. (#7273)
* api, no tests * updates from pr * fixed rebase errors and pr comments * updates based on ONNX team
- Loading branch information
1 parent
869dc9f
commit a7a6d88
Showing
7 changed files
with
347 additions
and
19 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,150 @@ | ||
// Licensed to the .NET Foundation under one or more agreements. | ||
// The .NET Foundation licenses this file to you under the MIT license. | ||
// See the LICENSE file in the project root for more information. | ||
|
||
using System; | ||
using System.Collections.Generic; | ||
using System.Text; | ||
using Microsoft.ML.Data; | ||
using Microsoft.ML.OnnxRuntime; | ||
using Microsoft.ML.Runtime; | ||
using Microsoft.ML.Transforms.Onnx; | ||
|
||
namespace Microsoft.ML.Transforms.Onnx | ||
{ | ||
public static class OnnxSessionOptionsExtensions | ||
{ | ||
private const string OnnxSessionOptionsName = "OnnxSessionOptions"; | ||
|
||
public static OnnxSessionOptions GetOnnxSessionOption(this IHostEnvironment env) | ||
{ | ||
if (env is IHostEnvironmentInternal localEnvironment) | ||
{ | ||
return localEnvironment.GetOptionOrDefault<OnnxSessionOptions>(OnnxSessionOptionsName); | ||
} | ||
|
||
throw new ArgumentException("No Onnx Session Options"); | ||
} | ||
|
||
public static void SetOnnxSessionOption(this IHostEnvironment env, OnnxSessionOptions onnxSessionOptions) | ||
{ | ||
if (env is IHostEnvironmentInternal localEnvironment) | ||
{ | ||
localEnvironment.SetOption(OnnxSessionOptionsName, onnxSessionOptions); | ||
} | ||
else | ||
throw new ArgumentException("No Onnx Session Options"); | ||
} | ||
} | ||
|
||
public sealed class OnnxSessionOptions | ||
{ | ||
internal void CopyTo(SessionOptions sessionOptions) | ||
{ | ||
sessionOptions.EnableMemoryPattern = EnableMemoryPattern; | ||
sessionOptions.ProfileOutputPathPrefix = ProfileOutputPathPrefix; | ||
sessionOptions.EnableProfiling = EnableProfiling; | ||
sessionOptions.OptimizedModelFilePath = OptimizedModelFilePath; | ||
sessionOptions.EnableCpuMemArena = EnableCpuMemArena; | ||
if (!PerSessionThreads) | ||
sessionOptions.DisablePerSessionThreads(); | ||
sessionOptions.LogId = LogId; | ||
sessionOptions.LogSeverityLevel = LogSeverityLevel; | ||
sessionOptions.LogVerbosityLevel = LogVerbosityLevel; | ||
sessionOptions.InterOpNumThreads = InterOpNumThreads; | ||
sessionOptions.IntraOpNumThreads = IntraOpNumThreads; | ||
sessionOptions.GraphOptimizationLevel = GraphOptimizationLevel; | ||
sessionOptions.ExecutionMode = ExecutionMode; | ||
} | ||
|
||
/// <summary> | ||
/// Enables the use of the memory allocation patterns in the first Run() call for subsequent runs. Default = true. | ||
/// </summary> | ||
#pragma warning disable MSML_NoInstanceInitializers // No initializers on instance fields or properties | ||
public bool EnableMemoryPattern { get; set; } = true; | ||
|
||
/// <summary> | ||
/// Path prefix to use for output of profiling data | ||
/// </summary> | ||
public string ProfileOutputPathPrefix { get; set; } = "onnxruntime_profile_"; // this is the same default in C++ implementation | ||
|
||
/// <summary> | ||
/// Enables profiling of InferenceSession.Run() calls. Default is false | ||
/// </summary> | ||
public bool EnableProfiling { get; set; } = false; | ||
|
||
/// <summary> | ||
/// Set filepath to save optimized model after graph level transformations. Default is empty, which implies saving is disabled. | ||
/// </summary> | ||
public string OptimizedModelFilePath { get; set; } = string.Empty; | ||
|
||
/// <summary> | ||
/// Enables Arena allocator for the CPU memory allocations. Default is true. | ||
/// </summary> | ||
public bool EnableCpuMemArena { get; set; } = true; | ||
|
||
/// <summary> | ||
/// Per session threads. Default is true. | ||
/// If false this makes all sessions in the process use a global TP. | ||
/// </summary> | ||
public bool PerSessionThreads { get; set; } = true; | ||
|
||
/// <summary> | ||
/// Sets the number of threads used to parallelize the execution within nodes | ||
/// A value of 0 means ORT will pick a default. Only used when <see cref="PerSessionThreads"/> is false. | ||
/// </summary> | ||
public int GlobalIntraOpNumThreads { get; set; } = 0; | ||
|
||
/// <summary> | ||
/// Sets the number of threads used to parallelize the execution of the graph (across nodes) | ||
/// If sequential execution is enabled this value is ignored | ||
/// A value of 0 means ORT will pick a default. Only used when <see cref="PerSessionThreads"/> is false. | ||
/// </summary> | ||
public int GlobalInterOpNumThreads { get; set; } = 0; | ||
|
||
/// <summary> | ||
/// Log Id to be used for the session. Default is empty string. | ||
/// </summary> | ||
public string LogId { get; set; } = string.Empty; | ||
|
||
/// <summary> | ||
/// Log Severity Level for the session logs. Default = ORT_LOGGING_LEVEL_WARNING | ||
/// </summary> | ||
public OrtLoggingLevel LogSeverityLevel = OrtLoggingLevel.ORT_LOGGING_LEVEL_WARNING; | ||
|
||
/// <summary> | ||
/// Log Verbosity Level for the session logs. Default = 0. Valid values are >=0. | ||
/// This takes into effect only when the LogSeverityLevel is set to ORT_LOGGING_LEVEL_VERBOSE. | ||
/// </summary> | ||
public int LogVerbosityLevel { get; set; } = 0; | ||
|
||
/// <summary> | ||
/// Sets the number of threads used to parallelize the execution within nodes | ||
/// A value of 0 means ORT will pick a default | ||
/// </summary> | ||
public int IntraOpNumThreads { get; set; } = 0; | ||
|
||
/// <summary> | ||
/// Sets the number of threads used to parallelize the execution of the graph (across nodes) | ||
/// If sequential execution is enabled this value is ignored | ||
/// A value of 0 means ORT will pick a default | ||
/// </summary> | ||
public int InterOpNumThreads { get; set; } = 0; | ||
|
||
/// <summary> | ||
/// Sets the graph optimization level for the session. Default is set to ORT_ENABLE_ALL. | ||
/// </summary> | ||
public GraphOptimizationLevel GraphOptimizationLevel { get; set; } = GraphOptimizationLevel.ORT_ENABLE_ALL; | ||
|
||
/// <summary> | ||
/// Sets the execution mode for the session. Default is set to ORT_SEQUENTIAL. | ||
/// See [ONNX_Runtime_Perf_Tuning.md] for more details. | ||
/// </summary> | ||
public ExecutionMode ExecutionMode { get; set; } = ExecutionMode.ORT_SEQUENTIAL; | ||
#pragma warning restore MSML_NoInstanceInitializers // No initializers on instance fields or properties | ||
|
||
public delegate SessionOptions CreateOnnxSessionOptions(); | ||
|
||
public CreateOnnxSessionOptions CreateSessionOptions { get; set; } | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.