Skip to content

Commit

Permalink
Java: Add UDS read/write error handling and tests with glide core mock (
Browse files Browse the repository at this point in the history
valkey-io#949)

Add UDS read/write error handling and tests with glide core mock. (#76)

* Add UDS read/write error handling and tests with glide core mock.

Signed-off-by: Yury-Fridlyand <[email protected]>
  • Loading branch information
Yury-Fridlyand authored Feb 27, 2024
1 parent 3fb2ebd commit 5d5a36c
Show file tree
Hide file tree
Showing 5 changed files with 469 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,11 @@
import io.netty.bootstrap.Bootstrap;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.unix.DomainSocketAddress;
import java.util.concurrent.CompletableFuture;
import lombok.NonNull;
import lombok.RequiredArgsConstructor;
import redis_request.RedisRequestOuterClass.RedisRequest;
import response.ResponseOuterClass.Response;

Expand All @@ -17,8 +20,6 @@
*/
public class ChannelHandler {

private static final String THREAD_POOL_NAME = "glide-channel";

protected final Channel channel;
protected final CallbackDispatcher callbackDispatcher;

Expand All @@ -41,6 +42,8 @@ public ChannelHandler(
.channel(threadPoolResource.getDomainSocketChannelClass())
.handler(new ProtobufSocketChannelInitializer(callbackDispatcher))
.connect(new DomainSocketAddress(socketPath))
// TODO .addListener(new NettyFutureErrorHandler())
// we need to use connection promise here for that ^
.sync()
.channel();
this.callbackDispatcher = callbackDispatcher;
Expand All @@ -58,9 +61,11 @@ public CompletableFuture<Response> write(RedisRequest.Builder request, boolean f
request.setCallbackIdx(commandId.getKey());

if (flush) {
channel.writeAndFlush(request.build());
channel
.writeAndFlush(request.build())
.addListener(new NettyFutureErrorHandler(commandId.getValue()));
} else {
channel.write(request.build());
channel.write(request.build()).addListener(new NettyFutureErrorHandler(commandId.getValue()));
}
return commandId.getValue();
}
Expand All @@ -73,7 +78,7 @@ public CompletableFuture<Response> write(RedisRequest.Builder request, boolean f
*/
public CompletableFuture<Response> connect(ConnectionRequest request) {
var future = callbackDispatcher.registerConnection();
channel.writeAndFlush(request);
channel.writeAndFlush(request).addListener(new NettyFutureErrorHandler(future));
return future;
}

Expand All @@ -82,4 +87,25 @@ public ChannelFuture close() {
callbackDispatcher.shutdownGracefully();
return channel.close();
}

/**
* Propagate an error from Netty's {@link ChannelFuture} and complete the {@link
* CompletableFuture} promise.
*/
@RequiredArgsConstructor
private static class NettyFutureErrorHandler implements ChannelFutureListener {

private final CompletableFuture<Response> promise;

@Override
public void operationComplete(@NonNull ChannelFuture channelFuture) throws Exception {
if (channelFuture.isCancelled()) {
promise.cancel(false);
}
var cause = channelFuture.cause();
if (cause != null) {
promise.completeExceptionally(cause);
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,10 @@ public void channelRead(@NonNull ChannelHandlerContext ctx, @NonNull Object msg)
/** Handles uncaught exceptions from {@link #channelRead(ChannelHandlerContext, Object)}. */
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
// TODO: log thru logger
System.out.printf("=== exceptionCaught %s %s %n", ctx, cause);
cause.printStackTrace(System.err);
super.exceptionCaught(ctx, cause);

callbackDispatcher.distributeClosingException(
"An unhandled error while reading from UDS channel: " + cause);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
/** Copyright GLIDE-for-Redis Project Contributors - SPDX Identifier: Apache-2.0 */
package glide.connection;

import static java.util.concurrent.TimeUnit.SECONDS;
import static org.junit.jupiter.api.Assertions.assertAll;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;

import connection_request.ConnectionRequestOuterClass.ConnectionRequest;
import connection_request.ConnectionRequestOuterClass.NodeAddress;
import glide.api.RedisClient;
import glide.api.models.exceptions.ClosingException;
import glide.connectors.handlers.CallbackDispatcher;
import glide.connectors.handlers.ChannelHandler;
import glide.connectors.resources.Platform;
import glide.managers.CommandManager;
import glide.managers.ConnectionManager;
import glide.utils.RustCoreLibMockTestBase;
import glide.utils.RustCoreMock;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Future;
import java.util.concurrent.TimeoutException;
import lombok.SneakyThrows;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import redis_request.RedisRequestOuterClass.RedisRequest;
import response.ResponseOuterClass.Response;

public class ConnectionWithGlideMockTests extends RustCoreLibMockTestBase {

private ChannelHandler channelHandler = null;

@BeforeEach
@SneakyThrows
public void createTestClient() {
channelHandler =
new ChannelHandler(
new CallbackDispatcher(), socketPath, Platform.getThreadPoolResourceSupplier().get());
}

@AfterEach
public void closeTestClient() {
channelHandler.close();
}

private Future<Response> testConnection() {
return channelHandler.connect(createConnectionRequest());
}

private static ConnectionRequest createConnectionRequest() {
return ConnectionRequest.newBuilder()
.addAddresses(NodeAddress.newBuilder().setHost("dummyhost").setPort(42).build())
.build();
}

@BeforeAll
public static void init() {
startRustCoreLibMock(null);
}

@Test
@SneakyThrows
// as of #710 https://github.com/aws/babushka/pull/710 - connection response is empty
public void can_connect_with_empty_response() {
RustCoreMock.updateGlideMock(
new RustCoreMock.GlideMockProtobuf() {
@Override
public Response connection(ConnectionRequest request) {
return Response.newBuilder().build();
}

@Override
public Response.Builder redisRequest(RedisRequest request) {
return null;
}
});

var connectionResponse = testConnection().get();
assertAll(
() -> assertFalse(connectionResponse.hasClosingError()),
() -> assertFalse(connectionResponse.hasRequestError()),
() -> assertFalse(connectionResponse.hasRespPointer()));
}

@Test
@SneakyThrows
public void can_connect_with_ok_response() {
RustCoreMock.updateGlideMock(
new RustCoreMock.GlideMockProtobuf() {
@Override
public Response connection(ConnectionRequest request) {
return OK().build();
}

@Override
public Response.Builder redisRequest(RedisRequest request) {
return null;
}
});

var connectionResponse = testConnection().get();
assertAll(
() -> assertTrue(connectionResponse.hasConstantResponse()),
() -> assertFalse(connectionResponse.hasClosingError()),
() -> assertFalse(connectionResponse.hasRequestError()),
() -> assertFalse(connectionResponse.hasRespPointer()));
}

@Test
public void cant_connect_when_no_response() {
RustCoreMock.updateGlideMock(
new RustCoreMock.GlideMockProtobuf() {
@Override
public Response connection(ConnectionRequest request) {
return null;
}

@Override
public Response.Builder redisRequest(RedisRequest request) {
return null;
}
});

assertThrows(TimeoutException.class, () -> testConnection().get(1, SECONDS));
}

@Test
@SneakyThrows
public void cant_connect_when_negative_response() {
RustCoreMock.updateGlideMock(
new RustCoreMock.GlideMockProtobuf() {
@Override
public Response connection(ConnectionRequest request) {
return Response.newBuilder().setClosingError("You shall not pass!").build();
}

@Override
public Response.Builder redisRequest(RedisRequest request) {
return null;
}
});

var exception = assertThrows(ExecutionException.class, () -> testConnection().get(1, SECONDS));
assertAll(
() -> assertTrue(exception.getCause() instanceof ClosingException),
() -> assertEquals("You shall not pass!", exception.getCause().getMessage()));
}

@Test
@SneakyThrows
public void rethrow_error_on_read_when_malformed_packet_received() {
RustCoreMock.updateGlideMock(request -> new byte[] {-1});

var exception = assertThrows(ExecutionException.class, () -> testConnection().get(1, SECONDS));
assertAll(
() -> assertTrue(exception.getCause() instanceof ClosingException),
() ->
assertTrue(
exception
.getCause()
.getMessage()
.contains("An unhandled error while reading from UDS channel")));
}

@Test
@SneakyThrows
public void rethrow_error_if_UDS_channel_closed() {
var client = new TestClient(channelHandler);
stopRustCoreLibMock();
try {
var exception =
assertThrows(
ExecutionException.class, () -> client.customCommand(new String[0]).get(1, SECONDS));
assertTrue(exception.getCause() instanceof RuntimeException);

// Not a public class, can't import
assertEquals(
"io.netty.channel.StacklessClosedChannelException",
exception.getCause().getCause().getClass().getName());
} finally {
// restart mock to let other tests pass if this one failed
startRustCoreLibMock(null);
}
}

private static class TestClient extends RedisClient {

public TestClient(ChannelHandler channelHandler) {
super(new ConnectionManager(channelHandler), new CommandManager(channelHandler));
}
}
}
49 changes: 49 additions & 0 deletions java/client/src/test/java/glide/utils/RustCoreLibMockTestBase.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
/** Copyright GLIDE-for-Redis Project Contributors - SPDX Identifier: Apache-2.0 */
package glide.utils;

import glide.connectors.handlers.ChannelHandler;
import glide.ffi.resolvers.SocketListenerResolver;
import lombok.SneakyThrows;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;

public class RustCoreLibMockTestBase {

/**
* Pass this socket path to {@link ChannelHandler} or mock {@link
* SocketListenerResolver#getSocket()} to return it.
*/
protected static String socketPath = null;

@SneakyThrows
public static void startRustCoreLibMock(RustCoreMock.GlideMock rustCoreLibMock) {
assert socketPath == null
: "Previous `RustCoreMock` wasn't stopped. Ensure that your test class inherits"
+ " `RustCoreLibMockTestBase`.";

socketPath = RustCoreMock.start(rustCoreLibMock);
}

@BeforeEach
public void preTestCheck() {
assert socketPath != null
: "You missed to call `startRustCoreLibMock` in a `@BeforeAll` method of your test class"
+ " inherited from `RustCoreLibMockTestBase`.";
}

@AfterEach
public void afterTestCheck() {
assert !RustCoreMock.failed() : "Error occurred in `RustCoreMock`";
}

@AfterAll
@SneakyThrows
public static void stopRustCoreLibMock() {
assert socketPath != null
: "You missed to call `startRustCoreLibMock` in a `@AfterAll` method of your test class"
+ " inherited from `RustCoreLibMockTestBase`.";
RustCoreMock.stop();
socketPath = null;
}
}
Loading

0 comments on commit 5d5a36c

Please sign in to comment.