Skip to content

Commit

Permalink
execution and sync context preservation
Browse files Browse the repository at this point in the history
  • Loading branch information
maksimkim committed Mar 14, 2018
1 parent 9e5d322 commit d0d827a
Showing 1 changed file with 84 additions and 17 deletions.
101 changes: 84 additions & 17 deletions src/DotNetty.Common/Concurrency/AbstractPromise.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,41 @@ namespace DotNetty.Common.Concurrency
using System.Collections.Generic;
using System.Runtime.CompilerServices;
using System.Runtime.ExceptionServices;
using System.Threading;
using System.Threading.Tasks;
using System.Threading.Tasks.Sources;

public abstract class AbstractPromise : IPromise, IValueTaskSource
{
struct CompletionData
{
public Action<object> Continuation { get; }
public object State { get; }
public ExecutionContext ExecutionContext { get; }
public SynchronizationContext SynchronizationContext { get; }

public CompletionData(Action<object> continuation, object state, ExecutionContext executionContext, SynchronizationContext synchronizationContext)
{
this.Continuation = continuation;
this.State = state;
this.ExecutionContext = executionContext;
this.SynchronizationContext = synchronizationContext;
}
}

const short SourceToken = 0;

static readonly ContextCallback ExecutionContextCallback = Execute;
static readonly SendOrPostCallback SyncContextCallbackWithExecutionContext = ExecuteWithExecutionContext;
static readonly SendOrPostCallback SyncContextCallback = Execute;

static readonly Exception CanceledException = new OperationCanceledException();
static readonly Exception CompletedNoException = new Exception();

protected Exception exception;

int callbackCount;
(Action<object>, object)[] callbacks;
CompletionData[] completions;

public bool TryComplete() => this.TryComplete0(CompletedNoException);

Expand All @@ -34,7 +55,7 @@ protected virtual bool TryComplete0(Exception exception)
{
// Set the exception object to the exception passed in or a sentinel value
this.exception = exception;
this.TryExecuteCallbacks();
this.TryExecuteCompletions();
return true;
}

Expand Down Expand Up @@ -75,27 +96,31 @@ public virtual void GetResult(short token)

public virtual void OnCompleted(Action<object> continuation, object state, short token, ValueTaskSourceOnCompletedFlags flags)
{
//todo: context preservation
if (this.callbacks == null)
if (this.completions == null)
{
this.callbacks = new (Action<object>, object)[1];
this.completions = new CompletionData[1];
}

int newIndex = this.callbackCount;
this.callbackCount++;

if (newIndex == this.callbacks.Length)
if (newIndex == this.completions.Length)
{
var newArray = new (Action<object>, object)[this.callbacks.Length * 2];
Array.Copy(this.callbacks, newArray, this.callbacks.Length);
this.callbacks = newArray;
var newArray = new CompletionData[this.completions.Length * 2];
Array.Copy(this.completions, newArray, this.completions.Length);
this.completions = newArray;
}

this.callbacks[newIndex] = (continuation, state);
this.completions[newIndex] = new CompletionData(
continuation,
state,
(flags & ValueTaskSourceOnCompletedFlags.FlowExecutionContext) != 0 ? ExecutionContext.Capture() : null,
(flags & ValueTaskSourceOnCompletedFlags.UseSchedulingContext) != 0 ? SynchronizationContext.Current : null
);

if (this.exception != null)
{
this.TryExecuteCallbacks();
this.TryExecuteCompletions();
}
}

Expand All @@ -120,9 +145,9 @@ bool IsCompletedOrThrow()
[MethodImpl(MethodImplOptions.NoInlining)]
void ThrowLatchedException() => ExceptionDispatchInfo.Capture(this.exception).Throw();

bool TryExecuteCallbacks()
bool TryExecuteCompletions()
{
if (this.callbackCount == 0 || this.callbacks == null)
if (this.callbackCount == 0 || this.completions == null)
{
return false;
}
Expand All @@ -133,8 +158,8 @@ bool TryExecuteCallbacks()
{
try
{
(Action<object> callback, object state) = this.callbacks[i];
callback(state);
CompletionData completion = this.completions[i];
ExecuteCompletion(completion);
}
catch (Exception ex)
{
Expand All @@ -154,15 +179,57 @@ bool TryExecuteCallbacks()

throw new AggregateException(exceptions);
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
protected void ClearCallbacks()
{
if (this.callbackCount > 0)
{
this.callbackCount = 0;
Array.Clear(this.callbacks, 0, this.callbacks.Length);
Array.Clear(this.completions, 0, this.completions.Length);
}
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
static void ExecuteCompletion(CompletionData completion)
{
if (completion.SynchronizationContext == null)
{
if (completion.ExecutionContext == null)
{
completion.Continuation(completion.State);
}
else
{
//boxing
ExecutionContext.Run(completion.ExecutionContext, ExecutionContextCallback, completion);
}
}
else
{
if (completion.ExecutionContext == null)
{
//boxing
completion.SynchronizationContext.Post(SyncContextCallback, completion);
}
else
{
//boxing
completion.SynchronizationContext.Post(SyncContextCallbackWithExecutionContext, completion);
}
}
}

static void Execute(object state)
{
CompletionData completion = (CompletionData)state;
completion.Continuation(completion.State);
}

static void ExecuteWithExecutionContext(object state)
{
CompletionData completion = (CompletionData)state;
ExecutionContext.Run(completion.ExecutionContext, ExecutionContextCallback, state);
}
}
}

0 comments on commit d0d827a

Please sign in to comment.