Skip to content
This repository has been archived by the owner on Sep 17, 2023. It is now read-only.

Commit

Permalink
Fixes reconnect issues and exception logging when socket read times o…
Browse files Browse the repository at this point in the history
…ut (#90)
  • Loading branch information
danielwertheim authored May 15, 2021
1 parent f801270 commit 1b1eca2
Show file tree
Hide file tree
Showing 13 changed files with 243 additions and 147 deletions.
5 changes: 2 additions & 3 deletions src/main/MyNatsClient/INatsConnection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,11 @@ public interface INatsConnection : IDisposable
{
INatsServerInfo ServerInfo { get; }
bool IsConnected { get; }
bool CanRead { get; }

IEnumerable<IOp> ReadOp();
IEnumerable<IOp> ReadOps();
void WithWriteLock(Action<INatsStreamWriter> a);
void WithWriteLock<TArg>(Action<INatsStreamWriter, TArg> a, TArg arg);
Task WithWriteLockAsync(Func<INatsStreamWriter, Task> a);
Task WithWriteLockAsync<TArg>(Func<INatsStreamWriter, TArg, Task> a, TArg arg);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -62,4 +62,4 @@ internal static NetworkStream CreateReadWriteStream(this Socket socket)
return ns;
}
}
}
}
57 changes: 35 additions & 22 deletions src/main/MyNatsClient/Internals/NatsConnection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@
using System.Net.Sockets;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Extensions.Logging;

namespace MyNatsClient.Internals
{
internal sealed class NatsConnection : INatsConnection
{
private readonly Func<bool> _socketIsConnected;
private readonly Func<bool> _canRead;
private readonly ILogger<NatsConnection> _logger = LoggerManager.CreateLogger<NatsConnection>();
private readonly CancellationToken _cancellationToken;

private Socket _socket;
Expand All @@ -24,8 +24,12 @@ internal sealed class NatsConnection : INatsConnection
private bool _isDisposed;

public INatsServerInfo ServerInfo { get; }
public bool IsConnected => _socketIsConnected();
public bool CanRead => _canRead();
public bool IsConnected => _socket.Connected;

private bool CanRead =>
_socket.Connected &&
_stream.CanRead &&
!_cancellationToken.IsCancellationRequested;

internal NatsConnection(
NatsServerInfo serverInfo,
Expand All @@ -45,10 +49,7 @@ internal NatsConnection(
_cancellationToken = cancellationToken;
_writeStreamSync = new SemaphoreSlim(1, 1);
_writer = new NatsStreamWriter(_writeStream, _cancellationToken);
_reader = new NatsOpStreamReader(_readStream);

_socketIsConnected = () => _socket?.Connected == true;
_canRead = () => _socket?.Connected == true && _stream != null && _stream.CanRead && !_cancellationToken.IsCancellationRequested;
_reader = NatsOpStreamReader.Use(_readStream);
}

public void Dispose()
Expand All @@ -71,9 +72,11 @@ void TryDispose(IDisposable disposable)
}
}

TryDispose(_reader);
TryDispose(_writeStream);
TryDispose(_readStream);
TryDispose(_stream);

try
{
_socket.Shutdown(SocketShutdown.Both);
Expand All @@ -82,9 +85,11 @@ void TryDispose(IDisposable disposable)
{
exs.Add(ex);
}

TryDispose(_socket);
TryDispose(_writeStreamSync);

_reader = null;
_writeStream = null;
_readStream = null;
_stream = null;
Expand All @@ -97,13 +102,33 @@ void TryDispose(IDisposable disposable)
throw new AggregateException("Failed while disposing connection. See inner exception(s) for more details.", exs);
}

public IEnumerable<IOp> ReadOp()
private void ThrowIfDisposed()
{
if (_isDisposed)
throw new ObjectDisposedException(GetType().Name);
}

private void ThrowIfNotConnected()
{
if (!IsConnected)
throw NatsException.NotConnected();
}

public IEnumerable<IOp> ReadOps()
{
ThrowIfDisposed();

ThrowIfNotConnected();

return _reader.ReadOps();
_logger.LogDebug("Starting OPs read loop");

while (CanRead)
{
#if Debug
_logger.LogDebug("Reading OP");
#endif
yield return _reader.ReadOp();
}
}

public void WithWriteLock(Action<INatsStreamWriter> a)
Expand Down Expand Up @@ -177,17 +202,5 @@ public async Task WithWriteLockAsync<TArg>(Func<INatsStreamWriter, TArg, Task> a
_writeStreamSync.Release();
}
}

private void ThrowIfDisposed()
{
if (_isDisposed)
throw new ObjectDisposedException(GetType().Name);
}

private void ThrowIfNotConnected()
{
if (!IsConnected)
throw NatsException.NotConnected();
}
}
}
46 changes: 33 additions & 13 deletions src/main/MyNatsClient/Internals/NatsConnectionManager.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ namespace MyNatsClient.Internals
{
internal class NatsConnectionManager : INatsConnectionManager
{
private static readonly ILogger<NatsConnectionManager> Logger = LoggerManager.CreateLogger<NatsConnectionManager>();
private readonly ILogger<NatsConnectionManager> _logger = LoggerManager.CreateLogger<NatsConnectionManager>();

private readonly ISocketFactory _socketFactory;

Expand Down Expand Up @@ -50,7 +50,7 @@ internal NatsConnectionManager(ISocketFactory socketFactory)
}
catch (Exception ex)
{
Logger.LogError(ex, "Error while connecting to {Host}. Trying with next host (if any).", host);
_logger.LogError(ex, "Error while connecting to {Host}. Trying with next host (if any).", host);

if (!ShouldTryAndConnect())
throw;
Expand All @@ -68,6 +68,7 @@ private static bool DefaultServerCertificateValidation(X509Certificate certifica
ConnectionInfo connectionInfo,
CancellationToken cancellationToken)
{
_logger.LogInformation("Establishing connection to {Host}", host);
var serverCertificateValidation = connectionInfo.ServerCertificateValidation ?? DefaultServerCertificateValidation;

bool RemoteCertificateValidationCallback(object _, X509Certificate certificate, X509Chain chain, SslPolicyErrors errors)
Expand All @@ -76,19 +77,23 @@ bool RemoteCertificateValidationCallback(object _, X509Certificate certificate,
var consumedOps = new List<IOp>();
Socket socket = null;
Stream stream = null;
NatsOpStreamReader reader = null;

try
{
_logger.LogDebug("Creating socket.");
socket = _socketFactory.Create(connectionInfo.SocketOptions);
await socket.ConnectAsync(
host,
connectionInfo.SocketOptions.ConnectTimeoutMs,
cancellationToken).ConfigureAwait(false);

_logger.LogDebug("Creating read write stream.");
stream = socket.CreateReadWriteStream();
var reader = new NatsOpStreamReader(stream);
reader = NatsOpStreamReader.Use(stream);

var op = reader.ReadOneOp();
_logger.LogDebug("Trying to read InfoOp.");
var op = reader.ReadOp();
if (op == null)
throw NatsException.FailedToConnectToHost(host,
"Expected to get INFO after establishing connection. Got nothing.");
Expand All @@ -97,6 +102,7 @@ await socket.ConnectAsync(
throw NatsException.FailedToConnectToHost(host,
$"Expected to get INFO after establishing connection. Got {op.GetType().Name}.");

_logger.LogDebug("Parsing server info.");
var serverInfo = NatsServerInfo.Parse(infoOp.Message);
var credentials = host.HasNonEmptyCredentials() ? host.Credentials : connectionInfo.Credentials;
if (serverInfo.AuthRequired && (credentials == null || credentials == Credentials.Empty))
Expand All @@ -109,6 +115,7 @@ await socket.ConnectAsync(

if (serverInfo.TlsRequired)
{
_logger.LogDebug("Creating SSL Stream.");
stream = new SslStream(stream, false, RemoteCertificateValidationCallback, null, EncryptionPolicy.RequireEncryption);
var ssl = (SslStream) stream;

Expand All @@ -123,29 +130,37 @@ await socket.ConnectAsync(
TargetHost = host.Address
};

_logger.LogDebug("Performing SSL client authentication.");
await ssl.AuthenticateAsClientAsync(clientAuthOptions, cancellationToken).ConfigureAwait(false);

reader = new NatsOpStreamReader(ssl);
reader.SetNewSource(ssl);
}

_logger.LogDebug("Sending Connect.");
stream.Write(ConnectCmd.Generate(connectionInfo.Verbose, credentials, connectionInfo.Name));
_logger.LogDebug("Sending Ping.");
stream.Write(PingCmd.Bytes.Span);
await stream.FlushAsync(cancellationToken).ConfigureAwait(false);

op = reader.ReadOneOp();
if (op == null)
throw NatsException.FailedToConnectToHost(host,
"Expected to read something after CONNECT and PING. Got nothing.");

if (op is ErrOp)
throw NatsException.FailedToConnectToHost(host,
$"Expected to get PONG after sending CONNECT and PING. Got {op.Marker}.");
_logger.LogDebug("Trying to read OP to see if connection was established.");
op = reader.ReadOp();
switch (op)
{
case NullOp:
throw NatsException.FailedToConnectToHost(host,
"Expected to read something after CONNECT and PING. Got nothing.");
case ErrOp:
throw NatsException.FailedToConnectToHost(host,
$"Expected to get PONG after sending CONNECT and PING. Got {op.Marker}.");
}

if (!socket.Connected)
throw NatsException.FailedToConnectToHost(host, "No connection could be established.");

consumedOps.Add(op);

_logger.LogInformation("Connection successfully established to {Host}", host);

return (
new NatsConnection(
serverInfo,
Expand All @@ -157,6 +172,11 @@ await socket.ConnectAsync(
catch
{
Swallow.Everything(
() =>
{
reader?.Dispose();
reader = null;
},
() =>
{
stream?.Dispose();
Expand Down
2 changes: 1 addition & 1 deletion src/main/MyNatsClient/Internals/NatsEncoder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -41,4 +41,4 @@ internal static int WriteSingleByteChars(Span<byte> trg, int trgOffset, ReadOnly
return trgOffset;
}
}
}
}
39 changes: 26 additions & 13 deletions src/main/MyNatsClient/NatsClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,15 @@ public sealed class NatsClient : INatsClient, IDisposable

private readonly string _inboxAddress;
private ISubscription _inboxSubscription;
private readonly ConcurrentDictionary<string, TaskCompletionSource<MsgOp>> _outstandingRequests = new ConcurrentDictionary<string, TaskCompletionSource<MsgOp>>();

private readonly ConcurrentDictionary<string, TaskCompletionSource<MsgOp>> _outstandingRequests =
new ConcurrentDictionary<string, TaskCompletionSource<MsgOp>>();

public string Id { get; }
public INatsObservable<IClientEvent> Events => _eventMediator;
public INatsObservable<IOp> OpStream => _opMediator.AllOpsStream;
public INatsObservable<MsgOp> MsgOpStream => _opMediator.MsgOpsStream;
public bool IsConnected => _connection != null && _connection.IsConnected && _connection.CanRead;
public bool IsConnected => _connection?.IsConnected == true;

public NatsClient(
ConnectionInfo connectionInfo,
Expand Down Expand Up @@ -192,6 +194,8 @@ public async Task ConnectAsync()
DoSafeRelease();
_logger.LogDebug("Emitting ClientDisconnected due to failure");
_eventMediator.Emit(new ClientDisconnected(this, DisconnectReason.DueToFailure));
var ex = t.Exception?.GetBaseException() ?? t.Exception;
Expand Down Expand Up @@ -229,11 +233,15 @@ private void ConsumerWork()
{
bool ShouldDoWork() => !_isDisposed && IsConnected && _cancellation?.IsCancellationRequested == false;

_logger.LogDebug("Starting consumer worker {IsConnected}", IsConnected);

var lastOpReceivedAt = DateTime.UtcNow;
var ping = false;

while (ShouldDoWork())
{
_logger.LogDebug("Consumer tick.");

try
{
if (ping)
Expand All @@ -243,8 +251,11 @@ private void ConsumerWork()
Ping();
}

foreach (var op in _connection.ReadOp())
foreach (var op in _connection.ReadOps())
{
if (op == NullOp.Instance)
throw NatsException.ClientCouldNotConsumeStream();

lastOpReceivedAt = DateTime.UtcNow;

_opMediator.Emit(op);
Expand All @@ -259,7 +270,7 @@ private void ConsumerWork()
}
}
}
catch (NatsException nex) when (nex.ExceptionCode == NatsExceptionCodes.OpParserError)
catch (NatsException nex) when (nex.ExceptionCode is NatsExceptionCodes.OpParserError or NatsExceptionCodes.ClientCouldNotConsumeStream)
{
throw;
}
Expand All @@ -268,26 +279,28 @@ private void ConsumerWork()
if (!ShouldDoWork())
break;

_logger.LogError(ex, "Worker got Exception.");

if (ex.InnerException is SocketException socketEx)
{
_logger.LogError("Worker task got SocketException with SocketErrorCode={SocketErrorCode}", socketEx.SocketErrorCode);
_logger.LogWarning(
"Consumer task got SocketException with error code {SocketErrorCode} Frequency of Timeouts is controlled via ReceiveTimeout.",
socketEx.SocketErrorCode);

if (socketEx.SocketErrorCode == SocketError.Interrupted)
break;

if (socketEx.SocketErrorCode != SocketError.TimedOut)
throw;
}
else
_logger.LogError(ex, "Consumer task failed");
}

var silenceDeltaMs = DateTime.UtcNow.Subtract(lastOpReceivedAt).TotalMilliseconds;
if (silenceDeltaMs >= ConsumerMaxMsSilenceFromServer)
throw NatsException.ConnectionFoundIdling(_connection.ServerInfo.Host, _connection.ServerInfo.Port);
var silenceDeltaMs = DateTime.UtcNow.Subtract(lastOpReceivedAt).TotalMilliseconds;
if (silenceDeltaMs >= ConsumerMaxMsSilenceFromServer)
throw NatsException.ConnectionFoundIdling(_connection.ServerInfo.Host, _connection.ServerInfo.Port);

if (silenceDeltaMs >= ConsumerPingAfterMsSilenceFromServer)
ping = true;
}
if (silenceDeltaMs >= ConsumerPingAfterMsSilenceFromServer)
ping = true;
}
}

Expand Down
5 changes: 4 additions & 1 deletion src/main/MyNatsClient/NatsException.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ public class NatsException : Exception
{
public string ExceptionCode { get; private set; }

protected NatsException(string exceptionCode, string message)
private NatsException(string exceptionCode, string message)
: base(message)
{
ExceptionCode = exceptionCode ?? NatsExceptionCodes.Unknown;
Expand Down Expand Up @@ -40,6 +40,9 @@ internal static NatsException ConnectionFoundIdling(string host, int port)
internal static NatsException ClientReceivedErrOp(ErrOp errOp)
=> new NatsException(NatsExceptionCodes.ClientReceivedErrOp, $"Client received ErrOp with message='{errOp.Message}'.");

internal static NatsException ClientCouldNotConsumeStream()
=> new(NatsExceptionCodes.ClientCouldNotConsumeStream, "Client could not consume stream.");

internal static NatsException OpParserError(string message)
=> new NatsException(NatsExceptionCodes.OpParserError, message);

Expand Down
Loading

0 comments on commit 1b1eca2

Please sign in to comment.