Skip to content

Commit

Permalink
Terminate stream with error on null values returned by `RedisElemen…
Browse files Browse the repository at this point in the history
…tReader` for top-level elements.

We now emit InvalidDataAccessApiUsageException when a RedisElementReader returns null in the context of a top-level stream to indicate invalid API usage although RedisElementReader.read can generally return null values if these are being collected in a container or value wrapper or parent complex object.

Apply consistent wording to operations documentation.
  • Loading branch information
mp911de authored and jxblum committed Oct 11, 2023
1 parent 66b00e2 commit b5f124c
Show file tree
Hide file tree
Showing 22 changed files with 276 additions and 99 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -469,7 +469,7 @@ public static Object parse(Object source, String sourcePath, Map<String, Class<?
* @return
* @since 2.6
*/
public static <K, V> Map.Entry<K, V> entryOf(K key, V value) {
public static <K, V> Map.Entry<K, V> entryOf(@Nullable K key, @Nullable V value) {
return new AbstractMap.SimpleImmutableEntry<>(key, value);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
import org.springframework.data.redis.domain.geo.GeoReference.GeoMemberReference;
import org.springframework.data.redis.domain.geo.GeoShape;
import org.springframework.data.redis.serializer.RedisSerializationContext;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;

/**
Expand Down Expand Up @@ -320,6 +321,7 @@ private ByteBuffer rawValue(V value) {
return serializationContext.getValueSerializationPair().write(value);
}

@Nullable
private V readValue(ByteBuffer buffer) {
return serializationContext.getValueSerializationPair().read(buffer);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,11 @@
import java.util.function.Function;

import org.reactivestreams.Publisher;
import org.springframework.dao.InvalidDataAccessApiUsageException;
import org.springframework.data.redis.connection.ReactiveHashCommands;
import org.springframework.data.redis.connection.convert.Converters;
import org.springframework.data.redis.serializer.RedisSerializationContext;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;

/**
Expand Down Expand Up @@ -126,7 +128,8 @@ public Mono<HK> randomKey(H key) {

Assert.notNull(key, "Key must not be null");

return createMono(hashCommands -> hashCommands.hRandField(rawKey(key))).map(this::readHashKey);
return template.doCreateMono(connection -> connection //
.hashCommands().hRandField(rawKey(key))).map(this::readRequiredHashKey);
}

@Override
Expand All @@ -142,7 +145,8 @@ public Flux<HK> randomKeys(H key, long count) {

Assert.notNull(key, "Key must not be null");

return createFlux(hashCommands -> hashCommands.hRandField(rawKey(key), count)).map(this::readHashKey);
return template.doCreateFlux(connection -> connection //
.hashCommands().hRandField(rawKey(key), count)).map(this::readRequiredHashKey);
}

@Override
Expand All @@ -159,8 +163,8 @@ public Flux<HK> keys(H key) {

Assert.notNull(key, "Key must not be null");

return createFlux(hashCommands -> hashCommands.hKeys(rawKey(key)) //
.map(this::readHashKey));
return createFlux(connection -> connection.hKeys(rawKey(key)) //
.map(this::readRequiredHashKey));
}

@Override
Expand Down Expand Up @@ -207,8 +211,8 @@ public Flux<HV> values(H key) {

Assert.notNull(key, "Key must not be null");

return createFlux(hashCommands -> hashCommands.hVals(rawKey(key)) //
.map(this::readHashValue));
return createFlux(connection -> connection.hVals(rawKey(key)) //
.map(this::readRequiredHashValue));
}

@Override
Expand Down Expand Up @@ -265,13 +269,37 @@ private ByteBuffer rawHashValue(HV key) {
}

@SuppressWarnings("unchecked")
@Nullable
private HK readHashKey(ByteBuffer value) {
return (HK) serializationContext.getHashKeySerializationPair().read(value);
}

private HK readRequiredHashKey(ByteBuffer buffer) {

HK hashKey = readHashKey(buffer);

if (hashKey == null) {
throw new InvalidDataAccessApiUsageException("Deserialized hash key is null");
}

return hashKey;
}

@SuppressWarnings("unchecked")
private HV readHashValue(ByteBuffer value) {
return (HV) (value == null ? value : serializationContext.getHashValueSerializationPair().read(value));
@Nullable
private HV readHashValue(@Nullable ByteBuffer value) {
return (HV) (value == null ? null : serializationContext.getHashValueSerializationPair().read(value));
}

private HV readRequiredHashValue(ByteBuffer buffer) {

HV hashValue = readHashValue(buffer);

if (hashValue == null) {
throw new InvalidDataAccessApiUsageException("Deserialized hash value is null");
}

return hashValue;
}

private Map.Entry<HK, HV> deserializeHashEntry(Map.Entry<ByteBuffer, ByteBuffer> source) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,13 @@
import java.util.function.Function;

import org.reactivestreams.Publisher;
import org.springframework.dao.InvalidDataAccessApiUsageException;
import org.springframework.data.redis.connection.ReactiveListCommands;
import org.springframework.data.redis.connection.ReactiveListCommands.Direction;
import org.springframework.data.redis.connection.ReactiveListCommands.LPosCommand;
import org.springframework.data.redis.connection.RedisListCommands.Position;
import org.springframework.data.redis.serializer.RedisSerializationContext;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;

/**
Expand Down Expand Up @@ -59,7 +61,7 @@ public Flux<V> range(K key, long start, long end) {

Assert.notNull(key, "Key must not be null");

return createFlux(listCommands -> listCommands.lRange(rawKey(key), start, end).map(this::readValue));
return createFlux(connection -> connection.lRange(rawKey(key), start, end).map(this::readRequiredValue));
}

@Override
Expand Down Expand Up @@ -172,8 +174,8 @@ public Mono<V> move(K sourceKey, Direction from, K destinationKey, Direction to)
Assert.notNull(from, "From direction must not be null");
Assert.notNull(to, "To direction must not be null");

return createMono(listCommands ->
listCommands.lMove(rawKey(sourceKey), rawKey(destinationKey), from, to).map(this::readValue));
return createMono(connection -> connection.lMove(rawKey(sourceKey), rawKey(destinationKey), from, to)
.map(this::readRequiredValue));
}

@Override
Expand All @@ -185,8 +187,8 @@ public Mono<V> move(K sourceKey, Direction from, K destinationKey, Direction to,
Assert.notNull(to, "To direction must not be null");
Assert.notNull(timeout, "Timeout must not be null");

return createMono(listCommands ->
listCommands.bLMove(rawKey(sourceKey), rawKey(destinationKey), from, to, timeout).map(this::readValue));
return createMono(connection -> connection.bLMove(rawKey(sourceKey), rawKey(destinationKey), from, to, timeout)
.map(this::readRequiredValue));
}

@Override
Expand All @@ -211,7 +213,7 @@ public Mono<V> index(K key, long index) {

Assert.notNull(key, "Key must not be null");

return createMono(listCommands -> listCommands.lIndex(rawKey(key), index).map(this::readValue));
return createMono(connection -> connection.lIndex(rawKey(key), index).map(this::readRequiredValue));
}

@Override
Expand All @@ -236,7 +238,7 @@ public Mono<V> leftPop(K key) {

Assert.notNull(key, "Key must not be null");

return createMono(listCommands -> listCommands.lPop(rawKey(key)).map(this::readValue));
return createMono(connection -> connection.lPop(rawKey(key)).map(this::readRequiredValue));

}

Expand All @@ -245,7 +247,7 @@ public Flux<V> leftPop(K key, long count) {

Assert.notNull(key, "Key must not be null");

return createFlux(listCommands -> listCommands.lPop(rawKey(key), count).map(this::readValue));
return createFlux(listCommands -> listCommands.lPop(rawKey(key), count).map(this::readRequiredValue));
}

@Override
Expand All @@ -255,25 +257,24 @@ public Mono<V> leftPop(K key, Duration timeout) {
Assert.notNull(timeout, "Duration must not be null");
Assert.isTrue(isZeroOrGreaterOneSecond(timeout), "Duration must be either zero or greater or equal to 1 second");

return createMono(listCommands ->
listCommands.blPop(Collections.singletonList(rawKey(key)), timeout)
.map(popResult -> readValue(popResult.getValue())));
return createMono(connection -> connection.blPop(Collections.singletonList(rawKey(key)), timeout)
.mapNotNull(popResult -> readValue(popResult.getValue())));
}

@Override
public Mono<V> rightPop(K key) {

Assert.notNull(key, "Key must not be null");

return createMono(listCommands -> listCommands.rPop(rawKey(key)).map(this::readValue));
return createMono(listCommands -> listCommands.rPop(rawKey(key)).map(this::readRequiredValue));
}

@Override
public Flux<V> rightPop(K key, long count) {

Assert.notNull(key, "Key must not be null");

return createFlux(listCommands -> listCommands.rPop(rawKey(key), count).map(this::readValue));
return createFlux(listCommands -> listCommands.rPop(rawKey(key), count).map(this::readRequiredValue));
}

@Override
Expand All @@ -283,9 +284,8 @@ public Mono<V> rightPop(K key, Duration timeout) {
Assert.notNull(timeout, "Duration must not be null");
Assert.isTrue(isZeroOrGreaterOneSecond(timeout), "Duration must be either zero or greater or equal to 1 second");

return createMono(listCommands ->
listCommands.brPop(Collections.singletonList(rawKey(key)), timeout)
.map(popResult -> readValue(popResult.getValue())));
return createMono(connection -> connection.brPop(Collections.singletonList(rawKey(key)), timeout)
.mapNotNull(popResult -> readValue(popResult.getValue())));
}

@Override
Expand All @@ -294,8 +294,8 @@ public Mono<V> rightPopAndLeftPush(K sourceKey, K destinationKey) {
Assert.notNull(sourceKey, "Source key must not be null");
Assert.notNull(destinationKey, "Destination key must not be null");

return createMono(listCommands ->
listCommands.rPopLPush(rawKey(sourceKey), rawKey(destinationKey)).map(this::readValue));
return createMono(connection -> connection.rPopLPush(rawKey(sourceKey), rawKey(destinationKey))
.map(this::readRequiredValue));
}

@Override
Expand All @@ -306,8 +306,8 @@ public Mono<V> rightPopAndLeftPush(K sourceKey, K destinationKey, Duration timeo
Assert.notNull(timeout, "Duration must not be null");
Assert.isTrue(isZeroOrGreaterOneSecond(timeout), "Duration must be either zero or greater or equal to 1 second");

return createMono(listCommands ->
listCommands.bRPopLPush(rawKey(sourceKey), rawKey(destinationKey), timeout).map(this::readValue));
return createMono(connection -> connection.bRPopLPush(rawKey(sourceKey), rawKey(destinationKey), timeout)
.map(this::readRequiredValue));
}

@Override
Expand Down Expand Up @@ -344,7 +344,19 @@ private ByteBuffer rawValue(V value) {
return serializationContext.getValueSerializationPair().write(value);
}

@Nullable
private V readValue(ByteBuffer buffer) {
return serializationContext.getValueSerializationPair().read(buffer);
}

private V readRequiredValue(ByteBuffer buffer) {

V v = readValue(buffer);

if (v == null) {
throw new InvalidDataAccessApiUsageException("Deserialized list value is null");
}

return v;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,10 @@
import java.util.function.Function;

import org.reactivestreams.Publisher;
import org.springframework.dao.InvalidDataAccessApiUsageException;
import org.springframework.data.redis.connection.ReactiveSetCommands;
import org.springframework.data.redis.serializer.RedisSerializationContext;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;

/**
Expand Down Expand Up @@ -89,15 +91,15 @@ public Mono<V> pop(K key) {

Assert.notNull(key, "Key must not be null");

return createMono(setCommands -> setCommands.sPop(rawKey(key)).map(this::readValue));
return createMono(setCommands -> setCommands.sPop(rawKey(key)).map(this::readRequiredValue));
}

@Override
public Flux<V> pop(K key, long count) {

Assert.notNull(key, "Key must not be null");

return createFlux(setCommands -> setCommands.sPop(rawKey(key), count).map(this::readValue));
return createFlux(setCommands -> setCommands.sPop(rawKey(key), count).map(this::readRequiredValue));
}

@Override
Expand Down Expand Up @@ -175,7 +177,7 @@ public Flux<V> intersect(Collection<K> keys) {
.map(this::rawKey) //
.collectList() //
.flatMapMany(setCommands::sInter) //
.map(this::readValue));
.map(this::readRequiredValue));
}

@Override
Expand Down Expand Up @@ -237,7 +239,7 @@ public Flux<V> union(Collection<K> keys) {
.map(this::rawKey) //
.collectList() //
.flatMapMany(setCommands::sUnion) //
.map(this::readValue));
.map(this::readRequiredValue));
}

@Override
Expand Down Expand Up @@ -299,7 +301,7 @@ public Flux<V> difference(Collection<K> keys) {
.map(this::rawKey) //
.collectList() //
.flatMapMany(setCommands::sDiff) //
.map(this::readValue));
.map(this::readRequiredValue));
}

@Override
Expand Down Expand Up @@ -339,7 +341,7 @@ public Flux<V> members(K key) {

Assert.notNull(key, "Key must not be null");

return createFlux(setCommands -> setCommands.sMembers(rawKey(key)).map(this::readValue));
return createFlux(setCommands -> setCommands.sMembers(rawKey(key)).map(this::readRequiredValue));
}

@Override
Expand All @@ -348,31 +350,31 @@ public Flux<V> scan(K key, ScanOptions options) {
Assert.notNull(key, "Key must not be null");
Assert.notNull(options, "ScanOptions must not be null");

return createFlux(setCommands -> setCommands.sScan(rawKey(key), options).map(this::readValue));
return createFlux(setCommands -> setCommands.sScan(rawKey(key), options).map(this::readRequiredValue));
}

@Override
public Mono<V> randomMember(K key) {

Assert.notNull(key, "Key must not be null");

return createMono(setCommands -> setCommands.sRandMember(rawKey(key)).map(this::readValue));
return createMono(setCommands -> setCommands.sRandMember(rawKey(key)).map(this::readRequiredValue));
}

@Override
public Flux<V> distinctRandomMembers(K key, long count) {

Assert.isTrue(count > 0, "Negative count not supported; Use randomMembers to allow duplicate elements");

return createFlux(setCommands -> setCommands.sRandMember(rawKey(key), count).map(this::readValue));
return createFlux(setCommands -> setCommands.sRandMember(rawKey(key), count).map(this::readRequiredValue));
}

@Override
public Flux<V> randomMembers(K key, long count) {

Assert.isTrue(count > 0, "Use a positive number for count; This method is already allowing duplicate elements");

return createFlux(setCommands -> setCommands.sRandMember(rawKey(key), -count).map(this::readValue));
return createFlux(setCommands -> setCommands.sRandMember(rawKey(key), -count).map(this::readRequiredValue));
}

@Override
Expand Down Expand Up @@ -415,7 +417,19 @@ private ByteBuffer rawValue(V value) {
return serializationContext.getValueSerializationPair().write(value);
}

@Nullable
private V readValue(ByteBuffer buffer) {
return serializationContext.getValueSerializationPair().read(buffer);
}

private V readRequiredValue(ByteBuffer buffer) {

V v = readValue(buffer);

if (v == null) {
throw new InvalidDataAccessApiUsageException("Deserialized set value is null");
}

return v;
}
}
Loading

0 comments on commit b5f124c

Please sign in to comment.