diff --git a/sdks/java/io/cassandra/src/main/java/org/apache/beam/sdk/io/cassandra/ReadFn.java b/sdks/java/io/cassandra/src/main/java/org/apache/beam/sdk/io/cassandra/ReadFn.java index b8fff2297121..44e3808f53f8 100644 --- a/sdks/java/io/cassandra/src/main/java/org/apache/beam/sdk/io/cassandra/ReadFn.java +++ b/sdks/java/io/cassandra/src/main/java/org/apache/beam/sdk/io/cassandra/ReadFn.java @@ -107,7 +107,7 @@ private static String getHighestSplitQuery( String finalHighQuery = (spec.query() == null) ? buildInitialQuery(spec, true) + highestClause - : spec.query() + " AND " + highestClause; + : spec.query() + getJoinerClause(spec.query().get()) + highestClause; LOG.debug("CassandraIO generated a wrapAround query : {}", finalHighQuery); return finalHighQuery; } @@ -117,7 +117,7 @@ private static String getLowestSplitQuery(Read spec, String partitionKey, Big String finalLowQuery = (spec.query() == null) ? buildInitialQuery(spec, true) + lowestClause - : spec.query() + " AND " + lowestClause; + : spec.query() + getJoinerClause(spec.query().get()) + lowestClause; LOG.debug("CassandraIO generated a wrapAround query : {}", finalLowQuery); return finalLowQuery; } @@ -141,9 +141,10 @@ private static String buildInitialQuery(Read spec, Boolean hasRingRange) { return (spec.query() == null) ? String.format("SELECT * FROM %s.%s", spec.keyspace().get(), spec.table().get()) + " WHERE " - : spec.query().get() - + (hasRingRange - ? spec.query().get().toUpperCase().contains("WHERE") ? " AND " : " WHERE " - : ""); + : spec.query().get() + (hasRingRange ? getJoinerClause(spec.query().get()) : ""); + } + + private static String getJoinerClause(String queryString) { + return queryString.toUpperCase().contains("WHERE") ? " AND " : " WHERE "; } } diff --git a/sdks/java/io/cassandra/src/test/java/org/apache/beam/sdk/io/cassandra/CassandraIOTest.java b/sdks/java/io/cassandra/src/test/java/org/apache/beam/sdk/io/cassandra/CassandraIOTest.java index a7090d5c7bcc..747f803ea46b 100644 --- a/sdks/java/io/cassandra/src/test/java/org/apache/beam/sdk/io/cassandra/CassandraIOTest.java +++ b/sdks/java/io/cassandra/src/test/java/org/apache/beam/sdk/io/cassandra/CassandraIOTest.java @@ -61,6 +61,7 @@ import javax.management.remote.JMXServiceURL; import org.apache.beam.sdk.coders.SerializableCoder; import org.apache.beam.sdk.io.common.NetworkTestHelper; +import org.apache.beam.sdk.options.ValueProvider; import org.apache.beam.sdk.testing.PAssert; import org.apache.beam.sdk.testing.TestPipeline; import org.apache.beam.sdk.transforms.Count; @@ -489,6 +490,94 @@ public void testReadWithQuery() throws Exception { pipeline.run(); } + /** + * Create a mock value provider class that tests how the query gets expanded in + * CassandraIO.ReadFn. + */ + static class MockQueryProvider implements ValueProvider { + private volatile String query; + + MockQueryProvider(String query) { + this.query = query; + } + + @Override + public String get() { + return query; + } + + @Override + public boolean isAccessible() { + return !query.isEmpty(); + } + } + + @Test + public void testReadWithQueryProvider() throws Exception { + String query = + String.format( + "select person_id, writetime(person_name) from %s.%s", + CASSANDRA_KEYSPACE, CASSANDRA_TABLE); + + PCollection output = + pipeline.apply( + CassandraIO.read() + .withHosts(Collections.singletonList(CASSANDRA_HOST)) + .withPort(cassandraPort) + .withKeyspace(CASSANDRA_KEYSPACE) + .withTable(CASSANDRA_TABLE) + .withMinNumberOfSplits(20) + .withQuery(new MockQueryProvider(query)) + .withCoder(SerializableCoder.of(Scientist.class)) + .withEntity(Scientist.class)); + + PAssert.thatSingleton(output.apply("Count", Count.globally())).isEqualTo(NUM_ROWS); + PAssert.that(output) + .satisfies( + input -> { + for (Scientist sci : input) { + assertNull(sci.name); + assertTrue(sci.nameTs != null && sci.nameTs > 0); + } + return null; + }); + + pipeline.run(); + } + + @Test + public void testReadWithQueryProviderWithWhereQuery() throws Exception { + String query = + String.format( + "select person_id, writetime(person_name) from %s.%s where person_id=10 AND person_department='logic'", + CASSANDRA_KEYSPACE, CASSANDRA_TABLE); + + PCollection output = + pipeline.apply( + CassandraIO.read() + .withHosts(Collections.singletonList(CASSANDRA_HOST)) + .withPort(cassandraPort) + .withKeyspace(CASSANDRA_KEYSPACE) + .withTable(CASSANDRA_TABLE) + .withMinNumberOfSplits(20) + .withQuery(new MockQueryProvider(query)) + .withCoder(SerializableCoder.of(Scientist.class)) + .withEntity(Scientist.class)); + + PAssert.thatSingleton(output.apply("Count", Count.globally())).isEqualTo(1L); + PAssert.that(output) + .satisfies( + input -> { + for (Scientist sci : input) { + assertNull(sci.name); + assertTrue(sci.nameTs != null && sci.nameTs > 0); + } + return null; + }); + + pipeline.run(); + } + @Test public void testReadWithUnfilteredQuery() throws Exception { String query =