Skip to content

Commit

Permalink
feature: handle RediSearch Dialect 3 in repositories (add @UseDialect…
Browse files Browse the repository at this point in the history
… annotation for repo methods) - (resolves gh-476)
  • Loading branch information
bsbodden committed Jul 9, 2024
1 parent 77c3599 commit cfc8a00
Show file tree
Hide file tree
Showing 10 changed files with 271 additions and 1 deletion.
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
package com.redis.om.spring.annotations;

public enum Dialect {
ONE(1),
TWO(2),
THREE(3);

private final int value;
Dialect(int value) {
this.value = value;
}

public int getValue() {
return value;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
package com.redis.om.spring.annotations;



import java.lang.annotation.*;

@Documented
@Retention(RetentionPolicy.RUNTIME)
@Target({ ElementType.METHOD, ElementType.ANNOTATION_TYPE })
public @interface UseDialect {
Dialect dialect() default Dialect.ONE;
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import com.github.f4b6a3.ulid.Ulid;
import com.google.gson.Gson;
import com.google.gson.GsonBuilder;
import com.google.gson.JsonArray;
import com.redis.om.spring.RedisOMProperties;
import com.redis.om.spring.annotations.*;
import com.redis.om.spring.indexing.RediSearchIndexer;
Expand Down Expand Up @@ -85,6 +86,7 @@ public class RediSearchQuery implements RepositoryQuery {
private Boolean aggregationVerbatim;
private Gson gson;
private boolean isNullParamQuery;
private Dialect dialect = Dialect.ONE;

@SuppressWarnings("unchecked")
public RediSearchQuery(//
Expand Down Expand Up @@ -123,6 +125,13 @@ public RediSearchQuery(//

try {
java.lang.reflect.Method method = repoClass.getMethod(queryMethod.getName(), params);

// set dialect if @UseDialect is present
if (method.isAnnotationPresent(UseDialect.class)) {
UseDialect dialectAnnotation = method.getAnnotation(UseDialect.class);
this.dialect = dialectAnnotation.dialect();
}

if (method.isAnnotationPresent(com.redis.om.spring.annotations.Query.class)) {
com.redis.om.spring.annotations.Query queryAnnotation = method.getAnnotation(
com.redis.om.spring.annotations.Query.class);
Expand Down Expand Up @@ -475,10 +484,14 @@ private Object executeQuery(Object[] parameters) {
}
}

// Set query dialect
query.dialect(dialect.getValue());

SearchResult searchResult = ops.search(query);

// what to return
Object result = null;

if (queryMethod.getReturnedObjectType() == SearchResult.class) {
result = searchResult;
} else if (queryMethod.isPageQuery()) {
Expand Down Expand Up @@ -509,7 +522,11 @@ private Object parseDocumentResult(redis.clients.jedis.search.Document doc) {

Gson gsonInstance = getGson();

return gsonInstance.fromJson(SafeEncoder.encode((byte[]) doc.get("$")), domainType);
return switch (dialect) {
case ONE, TWO -> gsonInstance.fromJson(SafeEncoder.encode((byte[]) doc.get("$")), domainType);
case THREE -> gsonInstance.fromJson(
gsonInstance.fromJson(SafeEncoder.encode((byte[]) doc.get("$")), JsonArray.class).get(0), domainType);
};
}

private Object executeDeleteQuery(Object[] parameters) {
Expand Down Expand Up @@ -538,6 +555,9 @@ private Object executeDeleteQuery(Object[] parameters) {
aggregation.sortBy(aggregationSortedFields.toArray(new SortedField[] {}));
aggregation.limit(0, redisOMProperties.getRepository().getQuery().getLimit());

// Set query dialect
aggregation.dialect(dialect.getValue());

// Execute the aggregation query
AggregationResult aggregationResult = ops.aggregate(aggregation);

Expand Down Expand Up @@ -655,6 +675,9 @@ private Object executeAggregation(Object[] parameters) {
}
}

// Set query dialect
aggregation.dialect(dialect.getValue());

// execute the aggregation
AggregationResult aggregationResult = ops.aggregate(aggregation);

Expand Down Expand Up @@ -836,6 +859,9 @@ private Object executeNullQuery(Object[] parameters) {
}
}

// Set query dialect
aggregation.dialect(dialect.getValue());

// Execute the aggregation query
AggregationResult aggregationResult = ops.aggregate(aggregation);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ public class RedisEnhancedQuery implements RepositoryQuery {
private Long aggregationTimeout;
private Boolean aggregationVerbatim;
private boolean isNullParamQuery;
private Dialect dialect = Dialect.ONE;

@SuppressWarnings("unchecked")
public RedisEnhancedQuery(QueryMethod queryMethod, //
Expand Down Expand Up @@ -126,6 +127,13 @@ public RedisEnhancedQuery(QueryMethod queryMethod, //

try {
java.lang.reflect.Method method = repoClass.getDeclaredMethod(queryMethod.getName(), params);

// set dialect if @UseDialect is present
if (method.isAnnotationPresent(UseDialect.class)) {
UseDialect dialectAnnotation = method.getAnnotation(UseDialect.class);
this.dialect = dialectAnnotation.dialect();
}

if (method.isAnnotationPresent(com.redis.om.spring.annotations.Query.class)) {
com.redis.om.spring.annotations.Query queryAnnotation = method.getAnnotation(
com.redis.om.spring.annotations.Query.class);
Expand Down Expand Up @@ -475,8 +483,12 @@ private Object executeQuery(Object[] parameters) {
}
}

// Set query dialect
query.dialect(dialect.getValue());

SearchResult searchResult = ops.search(query);

// what to return
Object result;

if (queryMethod.getReturnedObjectType() == SearchResult.class) {
Expand Down Expand Up @@ -536,6 +548,9 @@ private Object executeDeleteQuery(Object[] parameters) {
aggregation.sortBy(aggregationSortedFields.toArray(new SortedField[] {}));
aggregation.limit(0, redisOMProperties.getRepository().getQuery().getLimit());

// Set query dialect
aggregation.dialect(dialect.getValue());

// Execute the aggregation query
AggregationResult aggregationResult = ops.aggregate(aggregation);

Expand Down Expand Up @@ -669,6 +684,9 @@ private Object executeAggregation(Object[] parameters) {
}
}

// Set query dialect
aggregation.dialect(dialect.getValue());

// execute the aggregation
AggregationResult aggregationResult = ops.aggregate(aggregation);

Expand Down Expand Up @@ -844,6 +862,9 @@ private Object executeNullQuery(Object[] parameters) {
}
}

// Set query dialect
aggregation.dialect(dialect.getValue());

// Execute the aggregation query
AggregationResult aggregationResult = ops.aggregate(aggregation);

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
package com.redis.om.spring.annotations.document;

import com.redis.om.spring.AbstractBaseDocumentTest;
import com.redis.om.spring.fixtures.document.model.TestResultRedisModel;
import com.redis.om.spring.fixtures.document.repository.TestResultRedisRepository;
import com.redis.om.spring.search.stream.EntityStream;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.data.redis.connection.jedis.JedisConnectionFactory;
import redis.clients.jedis.JedisPooled;
import redis.clients.jedis.UnifiedJedis;

import java.util.Objects;

import static org.assertj.core.api.Assertions.assertThat;

class DialectThreeTest extends AbstractBaseDocumentTest {

@Autowired
TestResultRedisRepository testResultRedisRepository;

@Autowired
JedisConnectionFactory jedisConnectionFactory;

@Autowired
EntityStream es;

private UnifiedJedis jedis;

@BeforeEach
void cleanUp() {
flushSearchIndexFor(TestResultRedisModel.class);

if (testResultRedisRepository.count() == 0) {
testResultRedisRepository.save(TestResultRedisModel.of(123L, "123-123-123-123-123", "9_TNR290INP\\-WEE2024011124\\.xml", "REJECTED"));
testResultRedisRepository.save(TestResultRedisModel.of(456L, "456-456-456-456-456", "8_TNR290INP\\-WEE2024011124\\.xml", "ACCEPTED"));
}

jedis = new JedisPooled(Objects.requireNonNull(jedisConnectionFactory.getPoolConfig()),
jedisConnectionFactory.getHostName(), jedisConnectionFactory.getPort());
}

@Test
void testDialect3() {
var results = testResultRedisRepository.findAllByFilenameIs("9_TNR290INP\\-WEE2024011124\\.xml");
assertThat(results).hasSize(1);
var result = results.iterator().next();
assertThat(result.getUuid()).isEqualTo("123-123-123-123-123");
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
package com.redis.om.spring.annotations.hash;

import com.redis.om.spring.AbstractBaseDocumentTest;
import com.redis.om.spring.AbstractBaseEnhancedRedisTest;
import com.redis.om.spring.fixtures.hash.model.TestResultRedisModel;
import com.redis.om.spring.fixtures.hash.repository.TestResultRedisRepository;
import com.redis.om.spring.search.stream.EntityStream;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.data.redis.connection.jedis.JedisConnectionFactory;
import redis.clients.jedis.JedisPooled;
import redis.clients.jedis.UnifiedJedis;

import java.util.Objects;

import static org.assertj.core.api.Assertions.assertThat;

class DialectThreeTest extends AbstractBaseEnhancedRedisTest {

@Autowired
TestResultRedisRepository testResultRedisRepository;

@Autowired
JedisConnectionFactory jedisConnectionFactory;

@Autowired
EntityStream es;

private UnifiedJedis jedis;

@BeforeEach
void cleanUp() {
flushSearchIndexFor(TestResultRedisModel.class);

if (testResultRedisRepository.count() == 0) {
testResultRedisRepository.save(TestResultRedisModel.of(123L, "123-123-123-123-123", "9_TNR290INP\\-WEE2024011124\\.xml", "REJECTED"));
testResultRedisRepository.save(TestResultRedisModel.of(456L, "456-456-456-456-456", "8_TNR290INP\\-WEE2024011124\\.xml", "ACCEPTED"));
}

jedis = new JedisPooled(Objects.requireNonNull(jedisConnectionFactory.getPoolConfig()),
jedisConnectionFactory.getHostName(), jedisConnectionFactory.getPort());
}

@Test
void testDialect3() {
var results = testResultRedisRepository.findAllByFilenameIs("9_TNR290INP\\-WEE2024011124\\.xml");
assertThat(results).hasSize(1);
var result = results.iterator().next();
assertThat(result.getUuid()).isEqualTo("123-123-123-123-123");
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
package com.redis.om.spring.fixtures.document.model;

import com.redis.om.spring.annotations.Document;
import com.redis.om.spring.annotations.Indexed;
import com.redis.om.spring.annotations.Searchable;
import lombok.*;
import org.springframework.data.annotation.Id;

@Data
@RequiredArgsConstructor(staticName = "of")
@NoArgsConstructor(force = true)
@Document("TestResultRedis")
public class TestResultRedisModel {
@Id
@Indexed
@NonNull
private Long id;

@Indexed
@NonNull
String uuid;

@Searchable
@NonNull
String filename;

@Indexed
@NonNull
String status;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
package com.redis.om.spring.fixtures.document.repository;

import com.redis.om.spring.annotations.UseDialect;
import com.redis.om.spring.annotations.Dialect;
import com.redis.om.spring.fixtures.document.model.TestResultRedisModel;
import com.redis.om.spring.repository.RedisDocumentRepository;

import java.util.List;

public interface TestResultRedisRepository extends RedisDocumentRepository<TestResultRedisModel, Long> {
@UseDialect(dialect = Dialect.THREE)
List<TestResultRedisModel> findAllByFilenameIs(String filename);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
package com.redis.om.spring.fixtures.hash.model;

import com.redis.om.spring.annotations.Document;
import com.redis.om.spring.annotations.Indexed;
import com.redis.om.spring.annotations.Searchable;
import lombok.Data;
import lombok.NoArgsConstructor;
import lombok.NonNull;
import lombok.RequiredArgsConstructor;
import org.springframework.data.annotation.Id;
import org.springframework.data.redis.core.RedisHash;

@Data
@RequiredArgsConstructor(staticName = "of")
@NoArgsConstructor(force = true)
@RedisHash("TestResultRedis")
public class TestResultRedisModel {
@Id
@Indexed
@NonNull
private Long id;

@Indexed
@NonNull
String uuid;

@Searchable
@NonNull
String filename;

@Indexed
@NonNull
String status;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package com.redis.om.spring.fixtures.hash.repository;

import com.redis.om.spring.annotations.Dialect;
import com.redis.om.spring.annotations.UseDialect;

import com.redis.om.spring.fixtures.hash.model.TestResultRedisModel;
import com.redis.om.spring.repository.RedisDocumentRepository;
import com.redis.om.spring.repository.RedisEnhancedRepository;

import java.util.List;

public interface TestResultRedisRepository extends RedisEnhancedRepository<TestResultRedisModel, Long> {
@UseDialect(dialect = Dialect.THREE)
List<TestResultRedisModel> findAllByFilenameIs(String filename);
}

0 comments on commit cfc8a00

Please sign in to comment.