Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Terminate stream with error on null values returned by RedisElementReader for top-level elements #2672

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

<groupId>org.springframework.data</groupId>
<artifactId>spring-data-redis</artifactId>
<version>3.0.9-SNAPSHOT</version>
<version>3.0.x-GH-2655-SNAPSHOT</version>

<name>Spring Data Redis</name>
<description>Spring Data module for Redis</description>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -465,7 +465,7 @@ public static Object parse(Object source, String sourcePath, Map<String, Class<?
* @return
* @since 2.6
*/
public static <K, V> Map.Entry<K, V> entryOf(K key, V value) {
public static <K, V> Map.Entry<K, V> entryOf(@Nullable K key, @Nullable V value) {
return new AbstractMap.SimpleImmutableEntry<>(key, value);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
import java.util.stream.Collectors;

import org.reactivestreams.Publisher;

import org.springframework.data.geo.Circle;
import org.springframework.data.geo.Distance;
import org.springframework.data.geo.GeoResult;
Expand All @@ -40,6 +39,7 @@
import org.springframework.data.redis.domain.geo.GeoReference.GeoMemberReference;
import org.springframework.data.redis.domain.geo.GeoShape;
import org.springframework.data.redis.serializer.RedisSerializationContext;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems to me the geoDist(..) methods in ReactiveGeoOperations for the GEODIST command variations should be protected in the case of returning null.

While the arguments passed to the ReactiveGeoOperations.distance(..) methods used in the computation of the geo distance calculation may not be null, the member argument may also not have been geocoded in Redis with the geoAdd(..) command properly (which I think is a prerequisite based on the docs).

Therefore, in the case of a "missing" member (or element) the GEODIST command returns null as documented.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think, perhaps, ReactiveGeoOperations needs further refinements down the road, particularly with guards around possible null values.

/**
Expand Down Expand Up @@ -321,13 +321,14 @@ private ByteBuffer rawValue(V value) {
return serializationContext.getValueSerializationPair().write(value);
}

@Nullable
private V readValue(ByteBuffer buffer) {
return serializationContext.getValueSerializationPair().read(buffer);
}

private GeoResult<GeoLocation<V>> readGeoResult(GeoResult<GeoLocation<ByteBuffer>> source) {

return new GeoResult<>(new GeoLocation(readValue(source.getContent().getName()), source.getContent().getPoint()),
return new GeoResult<>(new GeoLocation<>(readValue(source.getContent().getName()), source.getContent().getPoint()),
source.getDistance());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,11 @@
import java.util.function.Function;

import org.reactivestreams.Publisher;

import org.springframework.dao.InvalidDataAccessApiUsageException;
import org.springframework.data.redis.connection.ReactiveHashCommands;
import org.springframework.data.redis.connection.convert.Converters;
import org.springframework.data.redis.serializer.RedisSerializationContext;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;

/**
Expand Down Expand Up @@ -127,7 +128,7 @@ public Mono<HK> randomKey(H key) {
Assert.notNull(key, "Key must not be null");

return template.doCreateMono(connection -> connection //
.hashCommands().hRandField(rawKey(key))).map(this::readHashKey);
.hashCommands().hRandField(rawKey(key))).map(this::readRequiredHashKey);
}

@Override
Expand All @@ -145,7 +146,7 @@ public Flux<HK> randomKeys(H key, long count) {
Assert.notNull(key, "Key must not be null");

return template.doCreateFlux(connection -> connection //
.hashCommands().hRandField(rawKey(key), count)).map(this::readHashKey);
.hashCommands().hRandField(rawKey(key), count)).map(this::readRequiredHashKey);
}

@Override
Expand All @@ -163,7 +164,7 @@ public Flux<HK> keys(H key) {
Assert.notNull(key, "Key must not be null");

return createFlux(connection -> connection.hKeys(rawKey(key)) //
.map(this::readHashKey));
.map(this::readRequiredHashKey));
}

@Override
Expand Down Expand Up @@ -211,7 +212,7 @@ public Flux<HV> values(H key) {
Assert.notNull(key, "Key must not be null");

return createFlux(connection -> connection.hVals(rawKey(key)) //
.map(this::readHashValue));
.map(this::readRequiredHashValue));
}

@Override
Expand Down Expand Up @@ -268,13 +269,37 @@ private ByteBuffer rawHashValue(HV key) {
}

@SuppressWarnings("unchecked")
@Nullable
private HK readHashKey(ByteBuffer value) {
return (HK) serializationContext.getHashKeySerializationPair().read(value);
}

private HK readRequiredHashKey(ByteBuffer buffer) {

HK hashKey = readHashKey(buffer);

if (hashKey == null) {
throw new InvalidDataAccessApiUsageException("Deserialized hash key is null");
}

return hashKey;
}

@SuppressWarnings("unchecked")
private HV readHashValue(ByteBuffer value) {
return (HV) (value == null ? value : serializationContext.getHashValueSerializationPair().read(value));
@Nullable
private HV readHashValue(@Nullable ByteBuffer value) {
return (HV) (value == null ? null : serializationContext.getHashValueSerializationPair().read(value));
}

private HV readRequiredHashValue(ByteBuffer buffer) {

HV hashValue = readHashValue(buffer);

if (hashValue == null) {
throw new InvalidDataAccessApiUsageException("Deserialized hash value is null");
}

return hashValue;
}

private Map.Entry<HK, HV> deserializeHashEntry(Map.Entry<ByteBuffer, ByteBuffer> source) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,13 @@
import java.util.function.Function;

import org.reactivestreams.Publisher;
import org.springframework.dao.InvalidDataAccessApiUsageException;
import org.springframework.data.redis.connection.ReactiveListCommands;
import org.springframework.data.redis.connection.ReactiveListCommands.Direction;
import org.springframework.data.redis.connection.ReactiveListCommands.LPosCommand;
import org.springframework.data.redis.connection.RedisListCommands.Position;
import org.springframework.data.redis.serializer.RedisSerializationContext;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;

/**
Expand All @@ -58,7 +60,7 @@ public Flux<V> range(K key, long start, long end) {

Assert.notNull(key, "Key must not be null");

return createFlux(connection -> connection.lRange(rawKey(key), start, end).map(this::readValue));
return createFlux(connection -> connection.lRange(rawKey(key), start, end).map(this::readRequiredValue));
}

@Override
Expand Down Expand Up @@ -170,7 +172,8 @@ public Mono<V> move(K sourceKey, Direction from, K destinationKey, Direction to)
Assert.notNull(to, "To direction must not be null");

return createMono(
connection -> connection.lMove(rawKey(sourceKey), rawKey(destinationKey), from, to).map(this::readValue));
connection -> connection.lMove(rawKey(sourceKey), rawKey(destinationKey), from, to)
.map(this::readRequiredValue));
}

@Override
Expand All @@ -183,7 +186,7 @@ public Mono<V> move(K sourceKey, Direction from, K destinationKey, Direction to,
Assert.notNull(timeout, "Timeout must not be null");

return createMono(connection -> connection.bLMove(rawKey(sourceKey), rawKey(destinationKey), from, to, timeout)
.map(this::readValue));
.map(this::readRequiredValue));
}

@Override
Expand All @@ -208,7 +211,7 @@ public Mono<V> index(K key, long index) {

Assert.notNull(key, "Key must not be null");

return createMono(connection -> connection.lIndex(rawKey(key), index).map(this::readValue));
return createMono(connection -> connection.lIndex(rawKey(key), index).map(this::readRequiredValue));
}

@Override
Expand All @@ -232,7 +235,7 @@ public Mono<V> leftPop(K key) {

Assert.notNull(key, "Key must not be null");

return createMono(connection -> connection.lPop(rawKey(key)).map(this::readValue));
return createMono(connection -> connection.lPop(rawKey(key)).map(this::readRequiredValue));

}

Expand All @@ -244,15 +247,15 @@ public Mono<V> leftPop(K key, Duration timeout) {
Assert.isTrue(isZeroOrGreater1Second(timeout), "Duration must be either zero or greater or equal to 1 second");

return createMono(connection -> connection.blPop(Collections.singletonList(rawKey(key)), timeout)
.map(popResult -> readValue(popResult.getValue())));
.mapNotNull(popResult -> readValue(popResult.getValue())));
}

@Override
public Mono<V> rightPop(K key) {

Assert.notNull(key, "Key must not be null");

return createMono(connection -> connection.rPop(rawKey(key)).map(this::readValue));
return createMono(connection -> connection.rPop(rawKey(key)).map(this::readRequiredValue));
}

@Override
Expand All @@ -263,7 +266,7 @@ public Mono<V> rightPop(K key, Duration timeout) {
Assert.isTrue(isZeroOrGreater1Second(timeout), "Duration must be either zero or greater or equal to 1 second");

return createMono(connection -> connection.brPop(Collections.singletonList(rawKey(key)), timeout)
.map(popResult -> readValue(popResult.getValue())));
.mapNotNull(popResult -> readValue(popResult.getValue())));
}

@Override
Expand All @@ -273,7 +276,7 @@ public Mono<V> rightPopAndLeftPush(K sourceKey, K destinationKey) {
Assert.notNull(destinationKey, "Destination key must not be null");

return createMono(
connection -> connection.rPopLPush(rawKey(sourceKey), rawKey(destinationKey)).map(this::readValue));
connection -> connection.rPopLPush(rawKey(sourceKey), rawKey(destinationKey)).map(this::readRequiredValue));
}

@Override
Expand All @@ -285,7 +288,8 @@ public Mono<V> rightPopAndLeftPush(K sourceKey, K destinationKey, Duration timeo
Assert.isTrue(isZeroOrGreater1Second(timeout), "Duration must be either zero or greater or equal to 1 second");

return createMono(
connection -> connection.bRPopLPush(rawKey(sourceKey), rawKey(destinationKey), timeout).map(this::readValue));
connection -> connection.bRPopLPush(rawKey(sourceKey), rawKey(destinationKey), timeout)
.map(this::readRequiredValue));
Copy link
Contributor

@jxblum jxblum Sep 14, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this be .mapNotNull(..) in this case given the Duration parameter?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The previous variant was mapNotNull that effectively filters null values turning RedisElementReader into something that has an ability of filtering. That's why we chose to use readRequiredValue throwing a proper exception.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This previous variant was .map(this::readValue), line 288.

I was thinking more along the lines of the rightPop(key, :Duration) overloaded method above on line 262.

But, perhaps the choice of simply .map(this:readRequiredValue) vs. .mapNotNull(this:readRequiredValue) is due to the right pop followed by left push compound operation??

}

@Override
Expand Down Expand Up @@ -322,7 +326,19 @@ private ByteBuffer rawValue(V value) {
return serializationContext.getValueSerializationPair().write(value);
}

@Nullable
private V readValue(ByteBuffer buffer) {
return serializationContext.getValueSerializationPair().read(buffer);
}

private V readRequiredValue(ByteBuffer buffer) {

V v = readValue(buffer);

if (v == null) {
throw new InvalidDataAccessApiUsageException("Deserialized list value is null");
}

return v;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,10 @@
import java.util.function.Function;

import org.reactivestreams.Publisher;
import org.springframework.dao.InvalidDataAccessApiUsageException;
import org.springframework.data.redis.connection.ReactiveSetCommands;
import org.springframework.data.redis.serializer.RedisSerializationContext;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;

/**
Expand Down Expand Up @@ -88,15 +90,15 @@ public Mono<V> pop(K key) {

Assert.notNull(key, "Key must not be null");

return createMono(connection -> connection.sPop(rawKey(key)).map(this::readValue));
return createMono(connection -> connection.sPop(rawKey(key)).map(this::readRequiredValue));
}

@Override
public Flux<V> pop(K key, long count) {

Assert.notNull(key, "Key must not be null");

return createFlux(connection -> connection.sPop(rawKey(key), count).map(this::readValue));
return createFlux(connection -> connection.sPop(rawKey(key), count).map(this::readRequiredValue));
}

@Override
Expand Down Expand Up @@ -176,7 +178,7 @@ public Flux<V> intersect(Collection<K> keys) {
.map(this::rawKey) //
.collectList() //
.flatMapMany(connection::sInter) //
.map(this::readValue));
.map(this::readRequiredValue));
}

@Override
Expand Down Expand Up @@ -238,7 +240,7 @@ public Flux<V> union(Collection<K> keys) {
.map(this::rawKey) //
.collectList() //
.flatMapMany(connection::sUnion) //
.map(this::readValue));
.map(this::readRequiredValue));
}

@Override
Expand Down Expand Up @@ -300,7 +302,7 @@ public Flux<V> difference(Collection<K> keys) {
.map(this::rawKey) //
.collectList() //
.flatMapMany(connection::sDiff) //
.map(this::readValue));
.map(this::readRequiredValue));
}

@Override
Expand Down Expand Up @@ -340,7 +342,7 @@ public Flux<V> members(K key) {

Assert.notNull(key, "Key must not be null");

return createFlux(connection -> connection.sMembers(rawKey(key)).map(this::readValue));
return createFlux(connection -> connection.sMembers(rawKey(key)).map(this::readRequiredValue));
}

@Override
Expand All @@ -349,31 +351,31 @@ public Flux<V> scan(K key, ScanOptions options) {
Assert.notNull(key, "Key must not be null");
Assert.notNull(options, "ScanOptions must not be null");

return createFlux(connection -> connection.sScan(rawKey(key), options).map(this::readValue));
return createFlux(connection -> connection.sScan(rawKey(key), options).map(this::readRequiredValue));
}

@Override
public Mono<V> randomMember(K key) {

Assert.notNull(key, "Key must not be null");

return createMono(connection -> connection.sRandMember(rawKey(key)).map(this::readValue));
return createMono(connection -> connection.sRandMember(rawKey(key)).map(this::readRequiredValue));
}

@Override
public Flux<V> distinctRandomMembers(K key, long count) {

Assert.isTrue(count > 0, "Negative count not supported; Use randomMembers to allow duplicate elements");

return createFlux(connection -> connection.sRandMember(rawKey(key), count).map(this::readValue));
return createFlux(connection -> connection.sRandMember(rawKey(key), count).map(this::readRequiredValue));
}

@Override
public Flux<V> randomMembers(K key, long count) {

Assert.isTrue(count > 0, "Use a positive number for count; This method is already allowing duplicate elements");

return createFlux(connection -> connection.sRandMember(rawKey(key), -count).map(this::readValue));
return createFlux(connection -> connection.sRandMember(rawKey(key), -count).map(this::readRequiredValue));
}

@Override
Expand Down Expand Up @@ -416,7 +418,19 @@ private ByteBuffer rawValue(V value) {
return serializationContext.getValueSerializationPair().write(value);
}

@Nullable
private V readValue(ByteBuffer buffer) {
return serializationContext.getValueSerializationPair().read(buffer);
}

private V readRequiredValue(ByteBuffer buffer) {

V v = readValue(buffer);

if (v == null) {
throw new InvalidDataAccessApiUsageException("Deserialized set value is null");
}

return v;
}
}
Loading