diff --git a/datashare-api/src/main/java/org/icij/datashare/PipelineHelper.java b/datashare-api/src/main/java/org/icij/datashare/PipelineHelper.java index 750cdf7f6..8168dd67c 100644 --- a/datashare-api/src/main/java/org/icij/datashare/PipelineHelper.java +++ b/datashare-api/src/main/java/org/icij/datashare/PipelineHelper.java @@ -2,7 +2,6 @@ import java.util.List; import java.util.Objects; -import java.util.stream.Collectors; import static java.util.Arrays.stream; import static java.util.stream.Collectors.joining; diff --git a/datashare-api/src/main/java/org/icij/datashare/Stage.java b/datashare-api/src/main/java/org/icij/datashare/Stage.java index ab85049d9..4518ff2c5 100644 --- a/datashare-api/src/main/java/org/icij/datashare/Stage.java +++ b/datashare-api/src/main/java/org/icij/datashare/Stage.java @@ -11,6 +11,8 @@ public enum Stage { INDEX(true), ENQUEUEIDX(false), NLP(true), + CREATENLPBATCHESFROMIDX(false), + BATCHNLP(false), ARTIFACT(false); public static final Comparator comparator = Comparator.comparing(Stage::ordinal); diff --git a/datashare-api/src/main/java/org/icij/datashare/text/indexing/Indexer.java b/datashare-api/src/main/java/org/icij/datashare/text/indexing/Indexer.java index e40d6697b..4846abc09 100644 --- a/datashare-api/src/main/java/org/icij/datashare/text/indexing/Indexer.java +++ b/datashare-api/src/main/java/org/icij/datashare/text/indexing/Indexer.java @@ -33,7 +33,7 @@ public interface Indexer extends Closeable { boolean bulkUpdate(String indexName, List entities) throws IOException; void add(String indexName, T obj) throws IOException; void update(String indexName, T obj) throws IOException; - boolean exists(String indexName, String id) throws IOException; + boolean exists(String indexName, String id) throws IOException; T get(String indexName, String id); T get(String indexName, String id, List sourceExcludes); @@ -61,9 +61,11 @@ interface Searcher { Searcher withoutSource(String... fields); Searcher withSource(boolean source); Searcher limit(int maxCount); + Searcher sort(String field, SortOrder order); void clearScroll() throws IOException; long totalHits(); Searcher with(int fuzziness, boolean phraseMatches); + enum SortOrder { ASC, DESC } } interface QueryBuilderSearcher extends Searcher { diff --git a/datashare-api/src/test/java/org/icij/datashare/PipelineHelperTest.java b/datashare-api/src/test/java/org/icij/datashare/PipelineHelperTest.java index b1666bad4..1445bb9a7 100644 --- a/datashare-api/src/test/java/org/icij/datashare/PipelineHelperTest.java +++ b/datashare-api/src/test/java/org/icij/datashare/PipelineHelperTest.java @@ -41,6 +41,38 @@ public void test_get_output_queue_name_for_last_pipeline_step() { assertThat(name).isEqualTo("extract:queue:nlp"); } + @Test + public void test_get_queue_names_for_batch_nlp_pipeline() { + PipelineHelper pipelineHelper = new PipelineHelper(new PropertiesProvider(new HashMap<>() {{ + put("stages", "SCAN,INDEX,BATCHNLP"); + }})); + assertThat(pipelineHelper.getQueueNameFor(Stage.BATCHNLP)).isEqualTo("extract:queue:batchnlp"); + } + + @Test + public void test_get_queue_names_for_batch_nlp_pipeline_from_index() { + PipelineHelper pipelineHelper = new PipelineHelper(new PropertiesProvider(new HashMap<>() {{ + put("stages", "CREATENLPBATCHESFROMIDX,BATCHNLP"); + }})); + assertThat(pipelineHelper.getQueueNameFor(Stage.BATCHNLP)).isEqualTo("extract:queue:batchnlp"); + } + + @Test + public void test_get_queue_names_for_batch_nlp() { + PipelineHelper pipelineHelper = new PipelineHelper(new PropertiesProvider(new HashMap<>() {{ + put("stages", "SCAN,INDEX,NLP"); + }})); + assertThat(pipelineHelper.getQueueNameFor(Stage.NLP)).isEqualTo("extract:queue:nlp"); + } + + @Test + public void test_get_queue_names_for_nlp_pipeline_from_index() { + PipelineHelper pipelineHelper = new PipelineHelper(new PropertiesProvider(new HashMap<>() {{ + put("stages", "ENQUEUEIDX,NLP"); + }})); + assertThat(pipelineHelper.getQueueNameFor(Stage.NLP)).isEqualTo("extract:queue:nlp"); + } + @Test public void test_get_queue_name_scan_index() { PipelineHelper pipelineHelper = new PipelineHelper(new PropertiesProvider(new HashMap<>() {{ diff --git a/datashare-app/src/main/java/org/icij/datashare/CliApp.java b/datashare-app/src/main/java/org/icij/datashare/CliApp.java index f1997241e..e311e206a 100644 --- a/datashare-app/src/main/java/org/icij/datashare/CliApp.java +++ b/datashare-app/src/main/java/org/icij/datashare/CliApp.java @@ -6,6 +6,8 @@ import org.icij.datashare.cli.spi.CliExtension; import org.icij.datashare.mode.CommonMode; import org.icij.datashare.tasks.ArtifactTask; +import org.icij.datashare.tasks.CreateNlpBatchesFromIndex; +import org.icij.datashare.tasks.BatchNlpTask; import org.icij.datashare.tasks.DatashareTaskFactory; import org.icij.datashare.tasks.DeduplicateTask; import org.icij.datashare.tasks.EnqueueFromIndexTask; @@ -13,6 +15,7 @@ import org.icij.datashare.tasks.IndexTask; import org.icij.datashare.tasks.ScanIndexTask; import org.icij.datashare.tasks.ScanTask; +import org.icij.datashare.tasks.DatashareTaskFactory; import org.icij.datashare.text.indexing.Indexer; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -125,6 +128,15 @@ private static void runTaskWorker(CommonMode mode, Properties properties) throws new Task<>(EnqueueFromIndexTask.class.getName(), nullUser(), propertiesToMap(properties))); } + if (pipeline.has(Stage.CREATENLPBATCHESFROMIDX)) { + taskManager.startTask(new Task<>(CreateNlpBatchesFromIndex.class.getName(), nullUser(), propertiesToMap(properties))); + } + + if (pipeline.has(Stage.BATCHNLP)) { + taskManager.startTask( + new Task<>(BatchNlpTask.class.getName(), nullUser(), propertiesToMap(properties))); + } + if (pipeline.has(Stage.NLP)) { taskManager.startTask( new Task<>(ExtractNlpTask.class.getName(), nullUser(), propertiesToMap(properties))); diff --git a/datashare-app/src/main/java/org/icij/datashare/nlp/NlpHelper.java b/datashare-app/src/main/java/org/icij/datashare/nlp/NlpHelper.java new file mode 100644 index 000000000..bacf64976 --- /dev/null +++ b/datashare-app/src/main/java/org/icij/datashare/nlp/NlpHelper.java @@ -0,0 +1,13 @@ +package org.icij.datashare.nlp; + +import java.util.Map; +import org.icij.datashare.text.nlp.Pipeline; + +public class NlpHelper { + public static Map pipelineExtras(Pipeline.Type pipeline) { + if (pipeline == Pipeline.Type.SPACY) { + return Map.of("modelSize", "md"); + } + return Map.of(); + } +} diff --git a/datashare-app/src/main/java/org/icij/datashare/tasks/ArtifactTask.java b/datashare-app/src/main/java/org/icij/datashare/tasks/ArtifactTask.java index aaa8fe8aa..685902517 100644 --- a/datashare-app/src/main/java/org/icij/datashare/tasks/ArtifactTask.java +++ b/datashare-app/src/main/java/org/icij/datashare/tasks/ArtifactTask.java @@ -7,7 +7,6 @@ import org.icij.datashare.asynctasks.Task; import org.icij.datashare.asynctasks.TaskGroup; import org.icij.datashare.extract.DocumentCollectionFactory; -import org.icij.datashare.function.Pair; import org.icij.datashare.text.Document; import org.icij.datashare.text.Project; import org.icij.datashare.text.indexing.Indexer; @@ -20,12 +19,12 @@ import java.util.concurrent.TimeUnit; import java.util.function.Function; -import static java.util.Optional.ofNullable; import static org.icij.datashare.cli.DatashareCliOptions.ARTIFACT_DIR_OPT; import static org.icij.datashare.cli.DatashareCliOptions.DEFAULT_DEFAULT_PROJECT; import static org.icij.datashare.cli.DatashareCliOptions.DEFAULT_PROJECT_OPT; +import static org.icij.datashare.tasks.GroupHelper.JAVA_GROUP; -@TaskGroup("Java") +@TaskGroup(JAVA_GROUP) public class ArtifactTask extends PipelineTask { private final Logger logger = LoggerFactory.getLogger(getClass()); private final Indexer indexer; diff --git a/datashare-app/src/main/java/org/icij/datashare/tasks/BatchDownloadCleaner.java b/datashare-app/src/main/java/org/icij/datashare/tasks/BatchDownloadCleaner.java index edd336f73..474cb21b8 100644 --- a/datashare-app/src/main/java/org/icij/datashare/tasks/BatchDownloadCleaner.java +++ b/datashare-app/src/main/java/org/icij/datashare/tasks/BatchDownloadCleaner.java @@ -17,8 +17,9 @@ import static java.util.Arrays.stream; import static java.util.Optional.ofNullable; import static java.util.regex.Pattern.compile; +import static org.icij.datashare.tasks.GroupHelper.JAVA_GROUP; -@TaskGroup("Java") +@TaskGroup(JAVA_GROUP) public class BatchDownloadCleaner implements Runnable { private final Logger logger = LoggerFactory.getLogger(getClass()); private final Pattern filePattern = compile(BatchDownload.ZIP_FORMAT.replace("%s", "[a-z0-9\\.:|_Z\\-\\[GMT\\]]+")); diff --git a/datashare-app/src/main/java/org/icij/datashare/tasks/BatchDownloadRunner.java b/datashare-app/src/main/java/org/icij/datashare/tasks/BatchDownloadRunner.java index 5a55b0d1a..eef3b347b 100644 --- a/datashare-app/src/main/java/org/icij/datashare/tasks/BatchDownloadRunner.java +++ b/datashare-app/src/main/java/org/icij/datashare/tasks/BatchDownloadRunner.java @@ -51,8 +51,9 @@ import static java.lang.String.valueOf; import static java.util.stream.Collectors.toList; import static org.icij.datashare.cli.DatashareCliOptions.*; +import static org.icij.datashare.tasks.GroupHelper.JAVA_GROUP; -@TaskGroup("Java") +@TaskGroup(JAVA_GROUP) public class BatchDownloadRunner implements Callable, Monitorable, UserTask, CancellableTask { private final static Logger logger = LoggerFactory.getLogger(BatchDownloadRunner.class); static final int MAX_SCROLL_SIZE = 3500; diff --git a/datashare-app/src/main/java/org/icij/datashare/tasks/BatchNlpTask.java b/datashare-app/src/main/java/org/icij/datashare/tasks/BatchNlpTask.java new file mode 100644 index 000000000..77d44916c --- /dev/null +++ b/datashare-app/src/main/java/org/icij/datashare/tasks/BatchNlpTask.java @@ -0,0 +1,109 @@ +package org.icij.datashare.tasks; + +import static java.util.Optional.ofNullable; +import static org.icij.datashare.tasks.GroupHelper.JAVA_GROUP; + +import com.google.inject.Inject; +import com.google.inject.assistedinject.Assisted; +import java.util.List; +import java.util.Optional; +import java.util.function.Function; +import org.icij.datashare.asynctasks.CancellableTask; +import org.icij.datashare.asynctasks.Task; +import org.icij.datashare.asynctasks.TaskGroup; +import org.icij.datashare.extension.PipelineRegistry; +import org.icij.datashare.text.Document; +import org.icij.datashare.text.Language; +import org.icij.datashare.text.NamedEntity; +import org.icij.datashare.text.indexing.Indexer; +import org.icij.datashare.text.nlp.Pipeline; +import org.icij.datashare.user.User; +import org.icij.datashare.user.UserTask; +import org.icij.task.DefaultTask; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +@TaskGroup(JAVA_GROUP) +public class BatchNlpTask extends DefaultTask implements UserTask, CancellableTask { + // TODO: fix the raw used of parametrized type... + private static final List EXCLUDED_SOURCES = List.of("contentTranslated"); + private final Logger logger = LoggerFactory.getLogger(getClass()); + private final User user; + private final Function progress; + private volatile Thread taskThread; + private final Indexer indexer; + private final List docs; + private final Pipeline pipeline; + private final int maxLength; + + @Inject + public BatchNlpTask(Indexer indexer, PipelineRegistry registry, @Assisted Task taskView, + @Assisted final Function progress) { + this(indexer, registry.get(Pipeline.Type.parse((String) taskView.args.get("pipeline"))), taskView, progress); + } + + + BatchNlpTask(Indexer indexer, Pipeline pipeline, @Assisted Task taskView, + @Assisted final Function progress) { + this.user = taskView.getUser(); + this.indexer = indexer; + this.pipeline = pipeline; + this.docs = (List) taskView.args.get("docs"); + this.maxLength = (int) taskView.args.get("maxLength"); + this.progress = progress; + } + + @Override + public Long call() throws Exception { + taskThread = Thread.currentThread(); + if (this.docs.isEmpty()) { + return 0L; + } + int batchSize = this.docs.size(); + int updateRate = Integer.max(batchSize / 10, 1); + Language language = this.docs.get(0).language(); + pipeline.initialize(language); + logger.info("performing NER on {} docs in {}...", batchSize, language); + // TODO: for now None of the Java NER seems to support batch processing, we just iterate docs one by one + // TODO: we could improve perfs by fetching docs and processing them concurrently... + int nProcessed = 0; + Optional.ofNullable(this.progress).ifPresent(p -> p.apply(0.0)); + for (CreateNlpBatchesFromIndex.BatchDocument doc : this.docs) { + String project = doc.project(); + Document indexDoc = indexer.get(doc.id(), doc.rootDocument(), EXCLUDED_SOURCES); + if (indexDoc.getContentTextLength() < this.maxLength) { + List namedEntities = pipeline.process(indexDoc); + indexer.bulkAdd(project, pipeline.getType(), namedEntities, indexDoc); + } else { + int nbChunks = indexDoc.getContentTextLength() / this.maxLength + 1; + for (int chunkIndex = 0; chunkIndex < nbChunks; chunkIndex++) { + List namedEntities = + pipeline.process(indexDoc, maxLength, chunkIndex * maxLength); + if (chunkIndex < nbChunks - 1) { + indexer.bulkAdd(project, namedEntities); + } else { + indexer.bulkAdd(project, pipeline.getType(), namedEntities, indexDoc); + } + } + } + nProcessed += 1; + if (nProcessed % updateRate == 0) { + Double prog = (double) nProcessed / (double) batchSize; + Optional.ofNullable(this.progress).ifPresent(p -> p.apply(prog)); + } + } + pipeline.terminate(language); + Optional.ofNullable(this.progress).ifPresent(p -> p.apply(1.0)); + return (long) batchSize; + } + + @Override + public void cancel(boolean requeue) { + ofNullable(taskThread).ifPresent(Thread::interrupt); + } + + @Override + public User getUser() { + return user; + } +} diff --git a/datashare-app/src/main/java/org/icij/datashare/tasks/BatchSearchRunner.java b/datashare-app/src/main/java/org/icij/datashare/tasks/BatchSearchRunner.java index 709506632..9a328f9e0 100644 --- a/datashare-app/src/main/java/org/icij/datashare/tasks/BatchSearchRunner.java +++ b/datashare-app/src/main/java/org/icij/datashare/tasks/BatchSearchRunner.java @@ -43,9 +43,10 @@ import static org.icij.datashare.cli.DatashareCliOptions.DEFAULT_SCROLL_DURATION; import static org.icij.datashare.cli.DatashareCliOptions.DEFAULT_SCROLL_SIZE; import static org.icij.datashare.cli.DatashareCliOptions.SCROLL_SIZE_OPT; +import static org.icij.datashare.tasks.GroupHelper.JAVA_GROUP; import static org.icij.datashare.text.ProjectProxy.asCommaConcatNames; -@TaskGroup("Java") +@TaskGroup(JAVA_GROUP) public class BatchSearchRunner implements CancellableTask, UserTask, Callable { private final Logger logger = LoggerFactory.getLogger(getClass()); diff --git a/datashare-app/src/main/java/org/icij/datashare/tasks/CreateNlpBatchesFromIndex.java b/datashare-app/src/main/java/org/icij/datashare/tasks/CreateNlpBatchesFromIndex.java new file mode 100644 index 000000000..e1088bf04 --- /dev/null +++ b/datashare-app/src/main/java/org/icij/datashare/tasks/CreateNlpBatchesFromIndex.java @@ -0,0 +1,203 @@ +package org.icij.datashare.tasks; + +import static java.util.Collections.singletonList; +import static java.util.Optional.ofNullable; +import static java.util.stream.Collectors.groupingBy; +import static org.icij.datashare.asynctasks.Task.GROUP_KEY; +import static org.icij.datashare.cli.DatashareCliOptions.DEFAULT_DEFAULT_PROJECT; +import static org.icij.datashare.cli.DatashareCliOptions.DEFAULT_NLP_BATCH_SIZE; +import static org.icij.datashare.cli.DatashareCliOptions.DEFAULT_NLP_MAX_TEXT_LENGTH; +import static org.icij.datashare.cli.DatashareCliOptions.DEFAULT_PROJECT_OPT; +import static org.icij.datashare.cli.DatashareCliOptions.DEFAULT_SCROLL_DURATION; +import static org.icij.datashare.cli.DatashareCliOptions.DEFAULT_SCROLL_SIZE; +import static org.icij.datashare.cli.DatashareCliOptions.NLP_BATCH_SIZE_OPT; +import static org.icij.datashare.cli.DatashareCliOptions.NLP_MAX_TEXT_LENGTH_OPT; +import static org.icij.datashare.cli.DatashareCliOptions.NLP_PIPELINE_OPT; +import static org.icij.datashare.cli.DatashareCliOptions.SCROLL_DURATION_OPT; +import static org.icij.datashare.cli.DatashareCliOptions.SCROLL_SIZE_OPT; +import static org.icij.datashare.cli.DatashareCliOptions.SEARCH_QUERY_OPT; +import static org.icij.datashare.nlp.NlpHelper.pipelineExtras; +import static org.icij.datashare.tasks.GroupHelper.JAVA_GROUP; +import static org.icij.datashare.tasks.GroupHelper.nlpGroup; + +import com.google.inject.Inject; +import com.google.inject.assistedinject.Assisted; +import java.io.IOException; +import java.util.ArrayList; +import java.util.Comparator; +import java.util.HashMap; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.function.Function; +import org.icij.datashare.Entity; +import org.icij.datashare.asynctasks.CancellableTask; +import org.icij.datashare.asynctasks.Task; +import org.icij.datashare.asynctasks.TaskGroup; +import org.icij.datashare.asynctasks.TaskManager; +import org.icij.datashare.text.Document; +import org.icij.datashare.text.Language; +import org.icij.datashare.text.indexing.Indexer; +import org.icij.datashare.text.indexing.SearchQuery; +import org.icij.datashare.text.nlp.Pipeline; +import org.icij.datashare.time.DatashareTime; +import org.icij.datashare.user.User; +import org.icij.datashare.user.UserTask; +import org.icij.task.DefaultTask; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +@TaskGroup(JAVA_GROUP) +public class CreateNlpBatchesFromIndex extends DefaultTask> implements UserTask, CancellableTask { + Logger logger = LoggerFactory.getLogger(getClass()); + + private final User user; + private volatile Thread taskThread; + private final TaskManager taskManager; + private final String searchQuery; + private final Map batchTaskArgs; + private final Pipeline.Type nlpPipeline; + private final int batchSize; + private final int maxTextLength; + private final String projectName; + private final Indexer indexer; + private final String scrollDuration; + private final int scrollSize; + private Language currentLanguage = null; + + public record BatchDocument(String id, String rootDocument, String project, Language language) { + public static BatchDocument fromDocument(Document document) { + return new BatchDocument(document.getId(), document.getRootDocument(), document.getProjectId(), document.getLanguage()); + } + } + + @Inject + public CreateNlpBatchesFromIndex( + final TaskManager taskManager, final Indexer indexer, @Assisted Task taskView, + @Assisted final Function ignored + ) { + this.user = taskView.getUser(); + this.taskManager = taskManager; + this.indexer = indexer; + this.nlpPipeline = Pipeline.Type.parse((String) taskView.args.getOrDefault(NLP_PIPELINE_OPT, Pipeline.Type.CORENLP.name())); + this.batchTaskArgs = batchTaskArgs(); + this.batchSize = (int) taskView.args.getOrDefault(NLP_BATCH_SIZE_OPT, DEFAULT_NLP_BATCH_SIZE); + this.maxTextLength = (int) taskView.args.getOrDefault(NLP_MAX_TEXT_LENGTH_OPT, DEFAULT_NLP_MAX_TEXT_LENGTH); + this.projectName = (String) taskView.args.getOrDefault(DEFAULT_PROJECT_OPT, DEFAULT_DEFAULT_PROJECT); + this.scrollDuration = (String) taskView.args.getOrDefault(SCROLL_DURATION_OPT, DEFAULT_SCROLL_DURATION); + this.scrollSize = (int) taskView.args.getOrDefault(SCROLL_SIZE_OPT, DEFAULT_SCROLL_SIZE); + this.searchQuery = (String) taskView.args.get(SEARCH_QUERY_OPT); + } + + @Override + public List call() throws Exception { + ArrayList taskIds = new ArrayList<>(); + taskThread = Thread.currentThread(); + Indexer.Searcher searcher; + if (searchQuery == null) { + searcher = indexer.search(singletonList(projectName), Document.class).without(nlpPipeline); + } else { + searcher = indexer.search(singletonList(projectName), Document.class, new SearchQuery(searchQuery)); + } + searcher = searcher.limit(scrollSize) + .withoutSource("language", "rootDocument") + .withoutSource("content", "contentTranslated") + .sort("language", Indexer.Searcher.SortOrder.ASC); + Map> scrolledDocsByLanguage = searcher + .scroll(scrollDuration) + .collect(groupingBy(d -> ((Document) d).getLanguage())); + ArrayList batch = new ArrayList<>(this.batchSize); + long totalHits = searcher.totalHits(); + logger.info( + "pushing batches of {} docs ids for index {}, pipeline {} with {} scroll and size of {}", + totalHits, projectName, nlpPipeline, scrollDuration, scrollSize + ); + do { + // For each scrolled page, we fill the batch... + taskIds.addAll(this.enqueueScrollBatches(scrolledDocsByLanguage, batch)); + // and keep scrolling... + scrolledDocsByLanguage = searcher + .scroll(scrollDuration) + .collect(groupingBy(d -> ((Document) d).getLanguage())); + // until we reach a page smaller than the scroll size aka the last page of the scroll + } while (scrolledDocsByLanguage.values().stream().map(List::size).mapToInt(Integer::intValue).sum() >= scrollSize); + // Let's fill the batches for that last page + taskIds.addAll(this.enqueueScrollBatches(scrolledDocsByLanguage, batch)); + // ... and enqueue that last batch if not done yet + if (!batch.isEmpty()) { + taskIds.add(this.enqueueBatch(batch)); + } + logger.info("queued batches for {} docs", totalHits); + searcher.clearScroll(); + return taskIds; + } + + private List enqueueScrollBatches(Map> docsByLanguage, ArrayList batch) { + ArrayList batchTaskIds = new ArrayList<>(); + // Make sure we consume the languages in order + Iterator>> docsIt = docsByLanguage.entrySet() + .stream().sorted(Comparator.comparing(e -> e.getKey().name())).iterator(); + while (docsIt.hasNext()) { + Map.Entry> entry = docsIt.next(); + Language language = entry.getKey(); + // If we switch language, we need to queue the batch + if (!language.equals(currentLanguage)) { + if (!batch.isEmpty()) { + batchTaskIds.add(this.enqueueBatch(batch)); + } + currentLanguage = language; + } + // and then we fill the current batch which can already be partially filled + List languageDocs = (List) entry.getValue(); + int start = 0; + int end = 0; + while (end < languageDocs.size()) { + end = start + Integer.min(batchSize - batch.size(), languageDocs.size() - start); + batch.addAll(languageDocs.subList(start, end)); + if (batch.size() >= batchSize) { + batchTaskIds.add(this.enqueueBatch(batch)); + } + start = end; + } + } + return batchTaskIds; + } + + protected String enqueueBatch(List batch) { + String taskId; + HashMap args = new HashMap<>(this.batchTaskArgs); + args.put("docs", batch.stream().map(BatchDocument::fromDocument).toList()); + try { + // TODO: here we bind the task name to the Java class name which is not ideal since it leaks Java inners + // bolts to Python, it could be nice to decouple task names from class names since they can change and + // are bound to languages + logger.info("{} - {}", DatashareTime.getNow().getTime(), ((List)args.get("docs")).get(0).language()); + taskId = this.taskManager.startTask(BatchNlpTask.class, this.user, args); + } catch (IOException e) { + throw new RuntimeException("failed to queue task " + args, e); + } + batch.clear(); + return taskId; + } + + @Override + public void cancel(boolean requeue) { + ofNullable(taskThread).ifPresent(Thread::interrupt); + } + + @Override + public User getUser() { + return user; + } + + private Map batchTaskArgs() { + Map args = new HashMap<>(Map.of( + "pipeline", this.nlpPipeline.name(), + "maxLength", this.maxTextLength, + GROUP_KEY, nlpGroup(this.nlpPipeline) + )); + args.putAll(pipelineExtras(this.nlpPipeline)); + return args; + } + +} diff --git a/datashare-app/src/main/java/org/icij/datashare/tasks/DatashareTaskFactory.java b/datashare-app/src/main/java/org/icij/datashare/tasks/DatashareTaskFactory.java index d59343609..c3635aeb6 100644 --- a/datashare-app/src/main/java/org/icij/datashare/tasks/DatashareTaskFactory.java +++ b/datashare-app/src/main/java/org/icij/datashare/tasks/DatashareTaskFactory.java @@ -1,5 +1,6 @@ package org.icij.datashare.tasks; +import java.util.List; import org.icij.datashare.asynctasks.Task; import org.icij.datashare.user.User; @@ -15,6 +16,8 @@ public interface DatashareTaskFactory extends org.icij.datashare.asynctasks.Task ScanIndexTask createScanIndexTask(Task taskView, Function updateCallback); ExtractNlpTask createExtractNlpTask(Task taskView, Function updateCallback); EnqueueFromIndexTask createEnqueueFromIndexTask(Task taskView, Function updateCallback); + CreateNlpBatchesFromIndex createBatchEnqueueFromIndexTask(Task> taskView, Function updateCallback); + BatchNlpTask createBatchNlpTask(Task taskView, Function updateCallback); DeduplicateTask createDeduplicateTask(Task taskView, Function updateCallback); ArtifactTask createArtifactTask(Task taskView, Function updateCallback); diff --git a/datashare-app/src/main/java/org/icij/datashare/tasks/DeduplicateTask.java b/datashare-app/src/main/java/org/icij/datashare/tasks/DeduplicateTask.java index 481dfaaef..19fe0475a 100644 --- a/datashare-app/src/main/java/org/icij/datashare/tasks/DeduplicateTask.java +++ b/datashare-app/src/main/java/org/icij/datashare/tasks/DeduplicateTask.java @@ -1,5 +1,7 @@ package org.icij.datashare.tasks; +import static org.icij.datashare.tasks.GroupHelper.JAVA_GROUP; + import com.google.inject.Inject; import com.google.inject.assistedinject.Assisted; import org.icij.datashare.PropertiesProvider; @@ -18,7 +20,7 @@ /** * filters the document queue with extracted docs */ -@TaskGroup("Java") +@TaskGroup(JAVA_GROUP) public class DeduplicateTask extends PipelineTask { private final Logger logger = LoggerFactory.getLogger(getClass()); private final DocumentCollectionFactory factory; diff --git a/datashare-app/src/main/java/org/icij/datashare/tasks/DelApiKeyTask.java b/datashare-app/src/main/java/org/icij/datashare/tasks/DelApiKeyTask.java index b42a31291..9445af920 100644 --- a/datashare-app/src/main/java/org/icij/datashare/tasks/DelApiKeyTask.java +++ b/datashare-app/src/main/java/org/icij/datashare/tasks/DelApiKeyTask.java @@ -1,5 +1,7 @@ package org.icij.datashare.tasks; +import static org.icij.datashare.tasks.GroupHelper.JAVA_GROUP; + import com.google.inject.Inject; import com.google.inject.assistedinject.Assisted; import org.icij.datashare.asynctasks.TaskGroup; @@ -10,7 +12,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -@TaskGroup("Java") +@TaskGroup(JAVA_GROUP) public class DelApiKeyTask extends DefaultTask implements UserTask { private final Logger logger = LoggerFactory.getLogger(getClass()); private final ApiKeyRepository apiKeyRepository; diff --git a/datashare-app/src/main/java/org/icij/datashare/tasks/EnqueueFromIndexTask.java b/datashare-app/src/main/java/org/icij/datashare/tasks/EnqueueFromIndexTask.java index 7f41538ff..e32113803 100644 --- a/datashare-app/src/main/java/org/icij/datashare/tasks/EnqueueFromIndexTask.java +++ b/datashare-app/src/main/java/org/icij/datashare/tasks/EnqueueFromIndexTask.java @@ -3,7 +3,6 @@ import com.google.inject.Inject; import com.google.inject.assistedinject.Assisted; -import java.util.Optional; import java.util.function.Function; import org.icij.datashare.Entity; import org.icij.datashare.PropertiesProvider; @@ -12,7 +11,6 @@ import org.icij.datashare.asynctasks.TaskGroup; import org.icij.datashare.extract.DocumentCollectionFactory; import org.icij.datashare.text.Document; -import org.icij.datashare.text.ProjectProxy; import org.icij.datashare.text.indexing.Indexer; import org.icij.datashare.text.indexing.SearchQuery; import org.icij.datashare.text.nlp.Pipeline; @@ -33,8 +31,9 @@ import static org.icij.datashare.cli.DatashareCliOptions.SCROLL_DURATION_OPT; import static org.icij.datashare.cli.DatashareCliOptions.SCROLL_SIZE_OPT; import static org.icij.datashare.cli.DatashareCliOptions.SEARCH_QUERY_OPT; +import static org.icij.datashare.tasks.GroupHelper.JAVA_GROUP; -@TaskGroup("Java") +@TaskGroup(JAVA_GROUP) public class EnqueueFromIndexTask extends PipelineTask { private final DocumentCollectionFactory factory; private final String searchQuery; @@ -47,7 +46,7 @@ public class EnqueueFromIndexTask extends PipelineTask { @Inject public EnqueueFromIndexTask(final DocumentCollectionFactory factory, final Indexer indexer, - @Assisted Task taskView, @Assisted final Function updateCallback) { + @Assisted Task taskView, @Assisted final Function ignored) { super(Stage.ENQUEUEIDX, taskView.getUser(), factory, new PropertiesProvider(taskView.args), String.class); this.factory = factory; this.indexer = indexer; @@ -69,6 +68,7 @@ public Long call() throws Exception { searcher = indexer.search(singletonList(projectName), Document.class, new SearchQuery(searchQuery)) .withoutSource("content", "contentTranslated").limit(scrollSize); } + searcher.sort("language", Indexer.Searcher.SortOrder.ASC); logger.info("enqueuing doc ids finding for index {} and {} with {} scroll and size of {} : {} documents found", projectName, nlpPipeline, scrollDuration, scrollSize, searcher.totalHits()); List docsToProcess = searcher.scroll(scrollDuration).collect(toList()); diff --git a/datashare-app/src/main/java/org/icij/datashare/tasks/ExtractNlpTask.java b/datashare-app/src/main/java/org/icij/datashare/tasks/ExtractNlpTask.java index a182a163e..54bfd7a92 100644 --- a/datashare-app/src/main/java/org/icij/datashare/tasks/ExtractNlpTask.java +++ b/datashare-app/src/main/java/org/icij/datashare/tasks/ExtractNlpTask.java @@ -29,9 +29,10 @@ import static org.icij.datashare.cli.DatashareCliOptions.DEFAULT_PROJECT_OPT; import static org.icij.datashare.cli.DatashareCliOptions.MAX_CONTENT_LENGTH_OPT; import static org.icij.datashare.cli.DatashareCliOptions.NLP_PIPELINE_OPT; +import static org.icij.datashare.tasks.GroupHelper.JAVA_GROUP; import static org.icij.extract.document.Identifier.shorten; -@TaskGroup("Java") +@TaskGroup(JAVA_GROUP) public class ExtractNlpTask extends PipelineTask implements Monitorable { private static final int DEFAULT_MAX_CONTENT_LENGTH = 1024 * 1024; private final Logger logger = LoggerFactory.getLogger(getClass()); diff --git a/datashare-app/src/main/java/org/icij/datashare/tasks/GenApiKeyTask.java b/datashare-app/src/main/java/org/icij/datashare/tasks/GenApiKeyTask.java index e2e2e4a7b..1287a9dd1 100644 --- a/datashare-app/src/main/java/org/icij/datashare/tasks/GenApiKeyTask.java +++ b/datashare-app/src/main/java/org/icij/datashare/tasks/GenApiKeyTask.java @@ -1,5 +1,7 @@ package org.icij.datashare.tasks; +import static org.icij.datashare.tasks.GroupHelper.JAVA_GROUP; + import com.google.inject.Inject; import com.google.inject.assistedinject.Assisted; import org.icij.datashare.asynctasks.TaskGroup; @@ -13,7 +15,7 @@ import javax.crypto.SecretKey; -@TaskGroup("Java") +@TaskGroup(JAVA_GROUP) public class GenApiKeyTask extends DefaultTask implements UserTask { private final Logger logger = LoggerFactory.getLogger(getClass()); private final ApiKeyRepository apiKeyRepository; diff --git a/datashare-app/src/main/java/org/icij/datashare/tasks/GetApiKeyTask.java b/datashare-app/src/main/java/org/icij/datashare/tasks/GetApiKeyTask.java index c8cecf7b0..837c39fae 100644 --- a/datashare-app/src/main/java/org/icij/datashare/tasks/GetApiKeyTask.java +++ b/datashare-app/src/main/java/org/icij/datashare/tasks/GetApiKeyTask.java @@ -1,5 +1,7 @@ package org.icij.datashare.tasks; +import static org.icij.datashare.tasks.GroupHelper.JAVA_GROUP; + import com.google.inject.assistedinject.Assisted; import org.icij.datashare.asynctasks.TaskGroup; import org.icij.datashare.user.ApiKey; @@ -10,7 +12,7 @@ import javax.inject.Inject; -@TaskGroup("Java") +@TaskGroup(JAVA_GROUP) public class GetApiKeyTask extends DefaultTask implements UserTask { private final ApiKeyRepository apiKeyRepository; private final User user; diff --git a/datashare-app/src/main/java/org/icij/datashare/tasks/GroupHelper.java b/datashare-app/src/main/java/org/icij/datashare/tasks/GroupHelper.java new file mode 100644 index 000000000..e7a9c616f --- /dev/null +++ b/datashare-app/src/main/java/org/icij/datashare/tasks/GroupHelper.java @@ -0,0 +1,17 @@ +package org.icij.datashare.tasks; + +import java.util.Objects; +import org.icij.datashare.text.nlp.Pipeline; + +public class GroupHelper { + // Later we could use enums if we have more groups + public static final String PYTHON_GROUP = "Python"; + public static final String JAVA_GROUP = "Java"; + + public static String nlpGroup(Pipeline.Type pipeline) { + if (Objects.requireNonNull(pipeline) == Pipeline.Type.SPACY) { + return PYTHON_GROUP; + } + return JAVA_GROUP; + } +} diff --git a/datashare-app/src/main/java/org/icij/datashare/tasks/IndexTask.java b/datashare-app/src/main/java/org/icij/datashare/tasks/IndexTask.java index d9d6c7980..b45f9c608 100644 --- a/datashare-app/src/main/java/org/icij/datashare/tasks/IndexTask.java +++ b/datashare-app/src/main/java/org/icij/datashare/tasks/IndexTask.java @@ -28,13 +28,14 @@ import static java.util.concurrent.TimeUnit.MINUTES; import static java.util.concurrent.TimeUnit.SECONDS; import static org.icij.datashare.cli.DatashareCliOptions.*; +import static org.icij.datashare.tasks.GroupHelper.JAVA_GROUP; @OptionsClass(Extractor.class) @OptionsClass(DocumentFactory.class) @OptionsClass(DocumentQueueDrainer.class) @Option(name = DEFAULT_PROJECT_OPT, description = "the default project name") @Option(name = "projectName", description = "task project name") -@TaskGroup("Java") +@TaskGroup(JAVA_GROUP) public class IndexTask extends PipelineTask implements Monitorable{ private final Logger logger = LoggerFactory.getLogger(getClass()); private final DocumentQueueDrainer drainer; diff --git a/datashare-app/src/main/java/org/icij/datashare/tasks/PipelineTask.java b/datashare-app/src/main/java/org/icij/datashare/tasks/PipelineTask.java index 640649c84..02e079288 100644 --- a/datashare-app/src/main/java/org/icij/datashare/tasks/PipelineTask.java +++ b/datashare-app/src/main/java/org/icij/datashare/tasks/PipelineTask.java @@ -1,5 +1,6 @@ package org.icij.datashare.tasks; +import java.util.List; import org.icij.datashare.PipelineHelper; import org.icij.datashare.PropertiesProvider; import org.icij.datashare.Stage; diff --git a/datashare-app/src/main/java/org/icij/datashare/tasks/ScanIndexTask.java b/datashare-app/src/main/java/org/icij/datashare/tasks/ScanIndexTask.java index 08e87dc75..d466cf7bd 100644 --- a/datashare-app/src/main/java/org/icij/datashare/tasks/ScanIndexTask.java +++ b/datashare-app/src/main/java/org/icij/datashare/tasks/ScanIndexTask.java @@ -39,9 +39,10 @@ import static org.icij.datashare.cli.DatashareCliOptions.SCROLL_DURATION_OPT; import static org.icij.datashare.cli.DatashareCliOptions.SCROLL_SIZE_OPT; import static org.icij.datashare.cli.DatashareCliOptions.SCROLL_SLICES_OPT; +import static org.icij.datashare.tasks.GroupHelper.JAVA_GROUP; import static org.icij.datashare.text.indexing.ScrollQueryBuilder.createScrollQuery; -@TaskGroup("Java") +@TaskGroup(JAVA_GROUP) public class ScanIndexTask extends PipelineTask { private final Logger logger = LoggerFactory.getLogger(getClass()); private final Indexer indexer; @@ -53,7 +54,7 @@ public class ScanIndexTask extends PipelineTask { @Inject public ScanIndexTask(DocumentCollectionFactory factory, final Indexer indexer, - @Assisted Task taskView, @Assisted Function updateCallback) { + @Assisted Task taskView, @Assisted Function ignored) { super(Stage.SCANIDX, taskView.getUser(), factory, new PropertiesProvider(taskView.args), Path.class); this.scrollDuration = propertiesProvider.get(SCROLL_DURATION_OPT).orElse(DEFAULT_SCROLL_DURATION); this.scrollSize = parseInt(propertiesProvider.get(SCROLL_SIZE_OPT).orElse(valueOf(DEFAULT_SCROLL_SIZE))); diff --git a/datashare-app/src/main/java/org/icij/datashare/tasks/ScanTask.java b/datashare-app/src/main/java/org/icij/datashare/tasks/ScanTask.java index ed76db18e..537d06f0a 100644 --- a/datashare-app/src/main/java/org/icij/datashare/tasks/ScanTask.java +++ b/datashare-app/src/main/java/org/icij/datashare/tasks/ScanTask.java @@ -1,5 +1,7 @@ package org.icij.datashare.tasks; +import static org.icij.datashare.tasks.GroupHelper.JAVA_GROUP; + import com.google.inject.Inject; import com.google.inject.assistedinject.Assisted; import java.util.function.Function; @@ -18,7 +20,7 @@ import java.nio.file.Paths; @OptionsClass(Scanner.class) -@TaskGroup("Java") +@TaskGroup(JAVA_GROUP) public class ScanTask extends PipelineTask { private final Scanner scanner; private final Path path; diff --git a/datashare-app/src/test/java/org/icij/datashare/tasks/BatchNlpTest.java b/datashare-app/src/test/java/org/icij/datashare/tasks/BatchNlpTest.java new file mode 100644 index 000000000..c1d500234 --- /dev/null +++ b/datashare-app/src/test/java/org/icij/datashare/tasks/BatchNlpTest.java @@ -0,0 +1,76 @@ +package org.icij.datashare.tasks; + +import static org.icij.datashare.test.ElasticsearchRule.TEST_INDEX; +import static org.icij.datashare.text.DocumentBuilder.createDoc; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +import static org.mockito.MockitoAnnotations.openMocks; + +import java.util.List; +import java.util.Map; +import org.icij.datashare.asynctasks.Task; +import org.icij.datashare.text.Document; +import org.icij.datashare.text.Language; +import org.icij.datashare.text.indexing.Indexer; +import org.icij.datashare.text.nlp.AbstractPipeline; +import org.icij.datashare.text.nlp.Pipeline; +import org.icij.datashare.user.User; +import org.junit.Before; +import org.junit.Test; +import org.mockito.Mock; + +public class BatchNlpTest { + @Mock + private Indexer indexer; + @Mock + private AbstractPipeline pipeline; + private AutoCloseable mocks; + + @Before + public void setUp() { + this.mocks = openMocks(this); + } + + @Before + public void tearDow() throws Exception { + this.mocks.close(); + } + + @Test(timeout = 2000) + public void test_batch_nlp() throws Exception { + // Given + int maxLength = 20; + String rootId = "rootId"; + Language language = Language.ENGLISH; + Document doc0 = createDoc("doc0").with(language).withRootId(rootId) + .with("hello world").build(); + Document doc1 = createDoc("doc1").with(language).withRootId(rootId) + .with("this is too long to be processed all at once").build(); + when(pipeline.getType()).thenReturn(Pipeline.Type.CORENLP); + when(pipeline.initialize(any())).thenReturn(true); + + when(indexer.get(anyString(), anyString(), any(List.class))).thenReturn(doc0, doc1); + List batchDocs = List.of( + new CreateNlpBatchesFromIndex.BatchDocument(doc0.getId(), doc0.getRootDocument(), TEST_INDEX, language), + new CreateNlpBatchesFromIndex.BatchDocument(doc1.getId(), doc1.getRootDocument(), TEST_INDEX, language) + ); + Map properties = Map.of( + "docs", batchDocs, + "pipeline", "OPENNLP", + "maxLength", maxLength, + "group", "JAVA" + ); + BatchNlpTask nlpTask = new BatchNlpTask( + indexer, pipeline, new Task<>(BatchNlpTask.class.getName(), new User("test"), properties), null + ); + // When + nlpTask.call(); + // Then + verify(pipeline).process(eq(doc0)); + verify(pipeline).process(eq(doc1), eq(maxLength), eq(0)); + verify(pipeline).process(eq(doc1), eq(maxLength), eq(maxLength)); + } +} diff --git a/datashare-app/src/test/java/org/icij/datashare/tasks/CreateNlpBatchesFromIndexParametrizedTest.java b/datashare-app/src/test/java/org/icij/datashare/tasks/CreateNlpBatchesFromIndexParametrizedTest.java new file mode 100644 index 000000000..065380d35 --- /dev/null +++ b/datashare-app/src/test/java/org/icij/datashare/tasks/CreateNlpBatchesFromIndexParametrizedTest.java @@ -0,0 +1,146 @@ +package org.icij.datashare.tasks; + +import static org.fest.assertions.Assertions.assertThat; +import static org.icij.datashare.test.ElasticsearchRule.TEST_INDEX; +import static org.icij.datashare.text.DocumentBuilder.createDoc; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import co.elastic.clients.elasticsearch._types.Refresh; +import java.io.IOException; +import java.util.Collection; +import java.util.Comparator; +import java.util.List; +import java.util.Map; +import java.util.function.Function; +import org.icij.datashare.PropertiesProvider; +import org.icij.datashare.asynctasks.Task; +import org.icij.datashare.asynctasks.TaskManager; +import org.icij.datashare.test.DatashareTimeRule; +import org.icij.datashare.test.ElasticsearchRule; +import org.icij.datashare.text.Document; +import org.icij.datashare.text.Language; +import org.icij.datashare.text.indexing.Indexer; +import org.icij.datashare.text.indexing.elasticsearch.ElasticsearchIndexer; +import org.icij.datashare.text.nlp.Pipeline; +import org.icij.datashare.time.DatashareTime; +import org.icij.datashare.user.User; +import org.junit.After; +import org.junit.Before; +import org.junit.ClassRule; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +@RunWith(Parameterized.class) +public class CreateNlpBatchesFromIndexParametrizedTest { + private final int batchSize; + private final int scrollSize; + private final List> expectedLanguages; + + @Rule + public DatashareTimeRule time = new DatashareTimeRule(); + + static class TestableCreateNlpBatchesFromIndex extends CreateNlpBatchesFromIndex { + public TestableCreateNlpBatchesFromIndex( + TaskManager taskManager, Indexer indexer, Task taskView, Function ignored) { + super(taskManager, indexer, taskView, ignored); + } + + protected String enqueueBatch(List batch) { + DatashareTime.getInstance().addMilliseconds(1); + return super.enqueueBatch(batch); + } + } + + @ClassRule + public static ElasticsearchRule es = new ElasticsearchRule(); + private static final ElasticsearchIndexer indexer = new ElasticsearchIndexer(es.client, new PropertiesProvider()) + .withRefresh(Refresh.True); + private static TaskManager taskManager; + + @Before + public void setUp() { + DatashareTaskFactory factory = mock(DatashareTaskFactory.class); + when(factory.createBatchNlpTask(any(), any())).thenReturn(mock(BatchNlpTask.class)); + taskManager = new TaskManagerMemory(factory, new PropertiesProvider()); + } + + @After + public void tearDown() throws IOException { + es.removeAll(); + taskManager.close(); + } + + + public CreateNlpBatchesFromIndexParametrizedTest(int batchSize, int scrollSize, + List> expectedLanguages) { + this.batchSize = batchSize; + this.scrollSize = scrollSize; + this.expectedLanguages = expectedLanguages; + } + + @Parameterized.Parameters + public static Collection taskParams() { + return List.of( + new Object[] {7, 3, List.of( + List.of(Language.ENGLISH, Language.ENGLISH, Language.ENGLISH, Language.ENGLISH, Language.ENGLISH, + Language.ENGLISH, Language.ENGLISH), + List.of(Language.ENGLISH, Language.ENGLISH, Language.ENGLISH), + List.of(Language.FRENCH, Language.FRENCH, Language.FRENCH, Language.FRENCH, Language.FRENCH), + List.of(Language.SPANISH, Language.SPANISH, Language.SPANISH, Language.SPANISH, Language.SPANISH) + )}, + new Object[] {3, 7, List.of( + List.of(Language.ENGLISH, Language.ENGLISH, Language.ENGLISH), + List.of(Language.ENGLISH, Language.ENGLISH, Language.ENGLISH), + List.of(Language.ENGLISH, Language.ENGLISH, Language.ENGLISH), + List.of(Language.ENGLISH), + List.of(Language.FRENCH, Language.FRENCH, Language.FRENCH), + List.of(Language.FRENCH, Language.FRENCH), + List.of(Language.SPANISH, Language.SPANISH, Language.SPANISH), + List.of(Language.SPANISH, Language.SPANISH) + )} + ); + } + + @Test + public void test_queue_for_batch_nlp_by_batch() throws Exception { + // Given + int numDocs = 20; + for (int i = 0; i < numDocs; i++) { + Language language = switch (i % 4) { + case 2 -> Language.FRENCH; + case 3 -> Language.SPANISH; + default -> Language.ENGLISH; + }; + indexer.add(TEST_INDEX, createDoc("doc" + i).with(language).with(Pipeline.Type.OPENNLP).build()); + } + // Already processed + indexer.add(TEST_INDEX, + createDoc("docAlreadyProcessed").with(Language.ITALIAN).with(Pipeline.Type.CORENLP).build()); + indexer.add(TEST_INDEX, + createDoc("docAlsoAlreadyProcessed").with(Language.ITALIAN).with(Pipeline.Type.CORENLP).build()); + Map properties = Map.of( + "defaultProject", "test-datashare", + "stages", "BATCHENQUEUEIDX", + "queueName", "test:queue", + "nlpPipeline", "CORENLP", + "batchSize", this.batchSize, + "scrollSize", this.scrollSize + ); + TestableCreateNlpBatchesFromIndex enqueueFromIndex = new TestableCreateNlpBatchesFromIndex(taskManager, indexer, + new Task<>(CreateNlpBatchesFromIndex.class.getName(), new User("test"), properties), null); + // When + List taskIds = enqueueFromIndex.call(); + List> queued = taskManager.getTasks().stream() + .sorted(Comparator.comparing(t -> t.createdAt)) + .map(t -> ((List) t.args.get("docs")).stream().map( + CreateNlpBatchesFromIndex.BatchDocument::language).toList()) + .toList(); + // Then + assertThat(queued).isEqualTo(this.expectedLanguages); + assertThat(taskIds.size()).isEqualTo(expectedLanguages.size()); + } +} diff --git a/datashare-app/src/test/java/org/icij/datashare/tasks/CreateNlpBatchesFromIndexTest.java b/datashare-app/src/test/java/org/icij/datashare/tasks/CreateNlpBatchesFromIndexTest.java new file mode 100644 index 000000000..aa0a68585 --- /dev/null +++ b/datashare-app/src/test/java/org/icij/datashare/tasks/CreateNlpBatchesFromIndexTest.java @@ -0,0 +1,89 @@ +package org.icij.datashare.tasks; + +import static org.fest.assertions.Assertions.assertThat; +import static org.icij.datashare.test.ElasticsearchRule.TEST_INDEX; +import static org.icij.datashare.text.DocumentBuilder.createDoc; +import static org.icij.datashare.text.Project.project; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import co.elastic.clients.elasticsearch._types.Refresh; +import java.io.IOException; +import java.util.List; +import java.util.Map; +import org.icij.datashare.PropertiesProvider; +import org.icij.datashare.asynctasks.Task; +import org.icij.datashare.asynctasks.TaskManager; +import org.icij.datashare.test.DatashareTimeRule; +import org.icij.datashare.test.ElasticsearchRule; +import org.icij.datashare.text.indexing.elasticsearch.ElasticsearchIndexer; +import org.icij.datashare.text.nlp.Pipeline; +import org.icij.datashare.user.User; +import org.junit.After; +import org.junit.Before; +import org.junit.ClassRule; +import org.junit.Rule; +import org.junit.Test; + +public class CreateNlpBatchesFromIndexTest { + @Rule + public DatashareTimeRule time = new DatashareTimeRule(); + + @ClassRule + public static ElasticsearchRule es = new ElasticsearchRule(); + private static final ElasticsearchIndexer indexer = new ElasticsearchIndexer(es.client, new PropertiesProvider()) + .withRefresh(Refresh.True); + private static TaskManager taskManager; + + @Before + public void setUp() { + DatashareTaskFactory factory = mock(DatashareTaskFactory.class); + when(factory.createBatchNlpTask(any(), any())).thenReturn(mock(BatchNlpTask.class)); + taskManager = new TaskManagerMemory(factory, new PropertiesProvider()); + } + + @After + public void tearDown() throws IOException { + es.removeAll(); + taskManager.close(); + } + + @Test + public void test_queue_for_batch_nlp_by_batch_with_body() throws Exception { + // Given + int batchSize = 3; + int scrollSize = 5; + indexer.add(TEST_INDEX, createDoc("my_id").with("this is my precious doc") + .with(Pipeline.Type.CORENLP).with(project(TEST_INDEX)).build()); + indexer.add(TEST_INDEX, createDoc("my_other_id").with("this is not my precious doc") + .withExtractionLevel((short) 1) + .with(Pipeline.Type.CORENLP).with(project(TEST_INDEX)).build()); + Map properties = Map.of( + "defaultProject", "test-datashare", + "stages", "BATCHENQUEUEIDX", + "queueName", "test:queue", + "nlpPipeline", "OPENNLP", + "batchSize", batchSize, + "scrollSize", scrollSize, + "searchQuery", """ + { + "match": { + "extractionLevel": 0 + } + } + """ + ); + CreateNlpBatchesFromIndex enqueueFromIndex = new CreateNlpBatchesFromIndex(taskManager, indexer, + new Task<>(CreateNlpBatchesFromIndex.class.getName(), new User("test"), properties), null); + // When + enqueueFromIndex.call(); + List> queued = taskManager.getTasks().stream() + .map(t -> ((List) t.args.get("docs")).stream().map( + CreateNlpBatchesFromIndex.BatchDocument::id).toList()) + .toList(); + // Then + List> expected = List.of(List.of("my_id")); + assertThat(queued).isEqualTo(expected); + } +} diff --git a/datashare-cli/src/main/java/org/icij/datashare/cli/DatashareCli.java b/datashare-cli/src/main/java/org/icij/datashare/cli/DatashareCli.java index 883eba473..102474810 100644 --- a/datashare-cli/src/main/java/org/icij/datashare/cli/DatashareCli.java +++ b/datashare-cli/src/main/java/org/icij/datashare/cli/DatashareCli.java @@ -125,6 +125,8 @@ OptionParser createParser() { DatashareCliOptions.language(parser); DatashareCliOptions.ocrLanguage(parser); DatashareCliOptions.nlpPipeline(parser); + DatashareCliOptions.nlpMaxTextLength(parser); + DatashareCliOptions.nlpBatchSize(parser); DatashareCliOptions.resume(parser); DatashareCliOptions.scroll(parser); DatashareCliOptions.scrollSize(parser); diff --git a/datashare-cli/src/main/java/org/icij/datashare/cli/DatashareCliOptions.java b/datashare-cli/src/main/java/org/icij/datashare/cli/DatashareCliOptions.java index 973440e26..d720182b1 100644 --- a/datashare-cli/src/main/java/org/icij/datashare/cli/DatashareCliOptions.java +++ b/datashare-cli/src/main/java/org/icij/datashare/cli/DatashareCliOptions.java @@ -77,7 +77,9 @@ public final class DatashareCliOptions { public static final String MESSAGE_BUS_OPT = "messageBusAddress"; public static final String MODE_ABBR_OPT = "m"; public static final String MODE_OPT = "mode"; + public static final String NLP_BATCH_SIZE_OPT = "batchSize"; public static final String NLP_PARALLELISM_ABBR_OPT = "np"; + public static final String NLP_MAX_TEXT_LENGTH_OPT = "maxTextLength"; public static final String NLP_PARALLELISM_OPT = "nlpParallelism"; public static final String NLP_PIPELINE_ABBR_OPT = "nlpp"; public static final String NLP_PIPELINE_OPT = "nlpPipeline"; @@ -156,6 +158,8 @@ public final class DatashareCliOptions { public static final String DEFAULT_LOG_LEVEL = Level.INFO.toString(); public static final String DEFAULT_MESSAGE_BUS_ADDRESS = "redis://redis:6379"; public static final String DEFAULT_NLP_PIPELINE = "CORENLP"; + public static final int DEFAULT_NLP_BATCH_SIZE = 1024; + public static final int DEFAULT_NLP_MAX_TEXT_LENGTH = 1024; public static final String DEFAULT_PROTECTED_URI_PREFIX = "/api/"; public static final String DEFAULT_QUEUE_NAME = "extract:queue"; public static final String DEFAULT_REDIS_ADDRESS = "redis://redis:6379"; @@ -229,7 +233,7 @@ static void followSymlinks(OptionParser parser) { singletonList(FOLLOW_SYMLINKS_OPT), "Follow symlinks while scanning documents") .withRequiredArg() .ofType(Boolean.class) - .defaultsTo(DEFAULT_FOLLOW_SYMLINKS);; + .defaultsTo(DEFAULT_FOLLOW_SYMLINKS); } static void cors(OptionParser parser) { @@ -439,6 +443,24 @@ static void nlpParallelism(OptionParser parser) { .defaultsTo(DEFAULT_NLP_PARALLELISM); } + static void nlpBatchSize(OptionParser parser) { + parser.acceptsAll( + List.of(NLP_BATCH_SIZE_OPT), + "Batch size of NLP extraction task in number of documents.") + .withRequiredArg() + .ofType( Integer.class ) + .defaultsTo(DEFAULT_NLP_BATCH_SIZE); + } + + static void nlpMaxTextLength(OptionParser parser) { + parser.acceptsAll( + asList(NLP_PARALLELISM_ABBR_OPT, NLP_PARALLELISM_OPT), + "Number of NLP extraction threads per pipeline.") + .withRequiredArg() + .ofType( Integer.class ) + .defaultsTo(DEFAULT_NLP_PARALLELISM); + } + public static void batchSearchMaxTime(OptionParser parser) { parser.acceptsAll( singletonList(BATCH_SEARCH_MAX_TIME_OPT), "Max time for batch search in seconds") @@ -829,7 +851,7 @@ public static void searchQuery(OptionParser parser) { } public static ValueConverter toAbsolute() { - return new ValueConverter() { + return new ValueConverter<>() { @Override public String convert(String value) { Path path = Paths.get(value); diff --git a/datashare-index/src/main/java/org/icij/datashare/text/indexing/elasticsearch/ElasticsearchSearcher.java b/datashare-index/src/main/java/org/icij/datashare/text/indexing/elasticsearch/ElasticsearchSearcher.java index ff2cb795d..2254bf048 100644 --- a/datashare-index/src/main/java/org/icij/datashare/text/indexing/elasticsearch/ElasticsearchSearcher.java +++ b/datashare-index/src/main/java/org/icij/datashare/text/indexing/elasticsearch/ElasticsearchSearcher.java @@ -9,10 +9,10 @@ import co.elastic.clients.elasticsearch.core.SearchResponse; import co.elastic.clients.elasticsearch.core.search.Hit; import co.elastic.clients.elasticsearch.core.search.ResponseBody; -import co.elastic.clients.json.JsonpMappingException; import com.fasterxml.jackson.databind.JsonNode; import com.fasterxml.jackson.databind.node.ObjectNode; import jakarta.json.JsonException; +import java.util.Objects; import org.icij.datashare.Entity; import org.icij.datashare.json.JsonObjectMapper; import org.icij.datashare.text.indexing.Indexer; @@ -64,7 +64,7 @@ static Stream resultStream(Class cls, Iterable T hitToObject(Hit searchHit, Class cls) { - return (T) JsonObjectMapper.getObject(searchHit.id(), searchHit.index(), JsonUtils.nodeToMap(searchHit.source()), cls); + return JsonObjectMapper.getObject(searchHit.id(), searchHit.index(), JsonUtils.nodeToMap(searchHit.source()), cls); } @Override @@ -123,7 +123,7 @@ public Stream scroll(ScrollQuery scrollQuery) throws IOExcepti } scrollSearchRequest = sourceBuilder.scroll(Time.of(t -> t.time(scrollQuery.getDuration()))).build(); response = client.search(scrollSearchRequest, ObjectNode.class); - totalHits = response.hits().total().value(); + totalHits = Objects.requireNonNull(response.hits().total()).value(); } else if (scrollQuery.getStringQuery() == null) { response = client.scroll(ScrollRequest.of(s -> s.scroll(Time.of(t -> t.time(scrollQuery.getDuration()))) .scrollId(ofNullable(scrollId) @@ -177,6 +177,11 @@ public Indexer.Searcher limit(int maxCount) { return this; } + @Override + public Searcher sort(String field, SortOrder order) { + sourceBuilder.sort(builder -> builder.field(fieldBuilder -> fieldBuilder.field(field).order(esSortOrder(order)))); + return this; + } @Override public void clearScroll() throws IOException { @@ -194,4 +199,11 @@ public long totalHits() { public String toString() { return "query : " + jsonBoolQuery; } + + private co.elastic.clients.elasticsearch._types.SortOrder esSortOrder(SortOrder sortOrder) { + return switch (sortOrder) { + case ASC -> co.elastic.clients.elasticsearch._types.SortOrder.Asc; + case DESC -> co.elastic.clients.elasticsearch._types.SortOrder.Desc; + }; + } } \ No newline at end of file diff --git a/datashare-tasks/src/main/java/org/icij/datashare/asynctasks/Task.java b/datashare-tasks/src/main/java/org/icij/datashare/asynctasks/Task.java index 3ff984cf2..d81056b4e 100644 --- a/datashare-tasks/src/main/java/org/icij/datashare/asynctasks/Task.java +++ b/datashare-tasks/src/main/java/org/icij/datashare/asynctasks/Task.java @@ -12,7 +12,6 @@ import org.icij.datashare.asynctasks.bus.amqp.UriResult; import org.icij.datashare.user.User; -import java.io.Serial; import java.io.Serializable; import java.util.Collections; import java.util.HashMap;