Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature: implement CreateNlpBatchesFromIndexTask and BatchNlpTask #1597

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;

import static java.util.Arrays.stream;
import static java.util.stream.Collectors.joining;
Expand Down
2 changes: 2 additions & 0 deletions datashare-api/src/main/java/org/icij/datashare/Stage.java
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ public enum Stage {
INDEX(true),
ENQUEUEIDX(false),
NLP(true),
CREATENLPBATCHESFROMIDX(false),
BATCHNLP(false),
ARTIFACT(false);

public static final Comparator<Stage> comparator = Comparator.comparing(Stage::ordinal);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ public interface Indexer extends Closeable {
<T extends Entity> boolean bulkUpdate(String indexName, List<T> entities) throws IOException;
<T extends Entity> void add(String indexName, T obj) throws IOException;
<T extends Entity> void update(String indexName, T obj) throws IOException;
<T extends Entity> boolean exists(String indexName, String id) throws IOException;
boolean exists(String indexName, String id) throws IOException;

<T extends Entity> T get(String indexName, String id);
<T extends Entity> T get(String indexName, String id, List<String> sourceExcludes);
Expand Down Expand Up @@ -61,9 +61,11 @@ interface Searcher {
Searcher withoutSource(String... fields);
Searcher withSource(boolean source);
Searcher limit(int maxCount);
Searcher sort(String field, SortOrder order);
void clearScroll() throws IOException;
long totalHits();
Searcher with(int fuzziness, boolean phraseMatches);
enum SortOrder { ASC, DESC }
}

interface QueryBuilderSearcher extends Searcher {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,38 @@ public void test_get_output_queue_name_for_last_pipeline_step() {
assertThat(name).isEqualTo("extract:queue:nlp");
}

@Test
public void test_get_queue_names_for_batch_nlp_pipeline() {
PipelineHelper pipelineHelper = new PipelineHelper(new PropertiesProvider(new HashMap<>() {{
put("stages", "SCAN,INDEX,BATCHNLP");
}}));
assertThat(pipelineHelper.getQueueNameFor(Stage.BATCHNLP)).isEqualTo("extract:queue:batchnlp");
}

@Test
public void test_get_queue_names_for_batch_nlp_pipeline_from_index() {
PipelineHelper pipelineHelper = new PipelineHelper(new PropertiesProvider(new HashMap<>() {{
put("stages", "CREATENLPBATCHESFROMIDX,BATCHNLP");
}}));
assertThat(pipelineHelper.getQueueNameFor(Stage.BATCHNLP)).isEqualTo("extract:queue:batchnlp");
}

@Test
public void test_get_queue_names_for_batch_nlp() {
PipelineHelper pipelineHelper = new PipelineHelper(new PropertiesProvider(new HashMap<>() {{
put("stages", "SCAN,INDEX,NLP");
}}));
assertThat(pipelineHelper.getQueueNameFor(Stage.NLP)).isEqualTo("extract:queue:nlp");
}

@Test
public void test_get_queue_names_for_nlp_pipeline_from_index() {
PipelineHelper pipelineHelper = new PipelineHelper(new PropertiesProvider(new HashMap<>() {{
put("stages", "ENQUEUEIDX,NLP");
}}));
assertThat(pipelineHelper.getQueueNameFor(Stage.NLP)).isEqualTo("extract:queue:nlp");
}

@Test
public void test_get_queue_name_scan_index() {
PipelineHelper pipelineHelper = new PipelineHelper(new PropertiesProvider(new HashMap<>() {{
Expand Down
12 changes: 12 additions & 0 deletions datashare-app/src/main/java/org/icij/datashare/CliApp.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,16 @@
import org.icij.datashare.cli.spi.CliExtension;
import org.icij.datashare.mode.CommonMode;
import org.icij.datashare.tasks.ArtifactTask;
import org.icij.datashare.tasks.CreateNlpBatchesFromIndex;
import org.icij.datashare.tasks.BatchNlpTask;
import org.icij.datashare.tasks.DatashareTaskFactory;
import org.icij.datashare.tasks.DeduplicateTask;
import org.icij.datashare.tasks.EnqueueFromIndexTask;
import org.icij.datashare.tasks.ExtractNlpTask;
import org.icij.datashare.tasks.IndexTask;
import org.icij.datashare.tasks.ScanIndexTask;
import org.icij.datashare.tasks.ScanTask;
import org.icij.datashare.tasks.DatashareTaskFactory;
import org.icij.datashare.text.indexing.Indexer;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
Expand Down Expand Up @@ -125,6 +128,15 @@ private static void runTaskWorker(CommonMode mode, Properties properties) throws
new Task<>(EnqueueFromIndexTask.class.getName(), nullUser(), propertiesToMap(properties)));
}

if (pipeline.has(Stage.CREATENLPBATCHESFROMIDX)) {
taskManager.startTask(new Task<>(CreateNlpBatchesFromIndex.class.getName(), nullUser(), propertiesToMap(properties)));
}

if (pipeline.has(Stage.BATCHNLP)) {
taskManager.startTask(
new Task<>(BatchNlpTask.class.getName(), nullUser(), propertiesToMap(properties)));
}

if (pipeline.has(Stage.NLP)) {
taskManager.startTask(
new Task<>(ExtractNlpTask.class.getName(), nullUser(), propertiesToMap(properties)));
Expand Down
13 changes: 13 additions & 0 deletions datashare-app/src/main/java/org/icij/datashare/nlp/NlpHelper.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
package org.icij.datashare.nlp;

import java.util.Map;
import org.icij.datashare.text.nlp.Pipeline;

public class NlpHelper {
public static Map<String, Object> pipelineExtras(Pipeline.Type pipeline) {
if (pipeline == Pipeline.Type.SPACY) {
return Map.of("modelSize", "md");
}
return Map.of();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import org.icij.datashare.asynctasks.Task;
import org.icij.datashare.asynctasks.TaskGroup;
import org.icij.datashare.extract.DocumentCollectionFactory;
import org.icij.datashare.function.Pair;
import org.icij.datashare.text.Document;
import org.icij.datashare.text.Project;
import org.icij.datashare.text.indexing.Indexer;
Expand All @@ -20,12 +19,12 @@
import java.util.concurrent.TimeUnit;
import java.util.function.Function;

import static java.util.Optional.ofNullable;
import static org.icij.datashare.cli.DatashareCliOptions.ARTIFACT_DIR_OPT;
import static org.icij.datashare.cli.DatashareCliOptions.DEFAULT_DEFAULT_PROJECT;
import static org.icij.datashare.cli.DatashareCliOptions.DEFAULT_PROJECT_OPT;
import static org.icij.datashare.tasks.GroupHelper.JAVA_GROUP;

@TaskGroup("Java")
@TaskGroup(JAVA_GROUP)
public class ArtifactTask extends PipelineTask<String> {
private final Logger logger = LoggerFactory.getLogger(getClass());
private final Indexer indexer;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@
import static java.util.Arrays.stream;
import static java.util.Optional.ofNullable;
import static java.util.regex.Pattern.compile;
import static org.icij.datashare.tasks.GroupHelper.JAVA_GROUP;

@TaskGroup("Java")
@TaskGroup(JAVA_GROUP)
public class BatchDownloadCleaner implements Runnable {
private final Logger logger = LoggerFactory.getLogger(getClass());
private final Pattern filePattern = compile(BatchDownload.ZIP_FORMAT.replace("%s", "[a-z0-9\\.:|_Z\\-\\[GMT\\]]+"));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,9 @@
import static java.lang.String.valueOf;
import static java.util.stream.Collectors.toList;
import static org.icij.datashare.cli.DatashareCliOptions.*;
import static org.icij.datashare.tasks.GroupHelper.JAVA_GROUP;

@TaskGroup("Java")
@TaskGroup(JAVA_GROUP)
public class BatchDownloadRunner implements Callable<UriResult>, Monitorable, UserTask, CancellableTask {
private final static Logger logger = LoggerFactory.getLogger(BatchDownloadRunner.class);
static final int MAX_SCROLL_SIZE = 3500;
Expand Down
109 changes: 109 additions & 0 deletions datashare-app/src/main/java/org/icij/datashare/tasks/BatchNlpTask.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
package org.icij.datashare.tasks;

import static java.util.Optional.ofNullable;
import static org.icij.datashare.tasks.GroupHelper.JAVA_GROUP;

import com.google.inject.Inject;
import com.google.inject.assistedinject.Assisted;
import java.util.List;
import java.util.Optional;
import java.util.function.Function;
import org.icij.datashare.asynctasks.CancellableTask;
import org.icij.datashare.asynctasks.Task;
import org.icij.datashare.asynctasks.TaskGroup;
import org.icij.datashare.extension.PipelineRegistry;
import org.icij.datashare.text.Document;
import org.icij.datashare.text.Language;
import org.icij.datashare.text.NamedEntity;
import org.icij.datashare.text.indexing.Indexer;
import org.icij.datashare.text.nlp.Pipeline;
import org.icij.datashare.user.User;
import org.icij.datashare.user.UserTask;
import org.icij.task.DefaultTask;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

@TaskGroup(JAVA_GROUP)
public class BatchNlpTask extends DefaultTask<Long> implements UserTask, CancellableTask {
// TODO: fix the raw used of parametrized type...
private static final List<String> EXCLUDED_SOURCES = List.of("contentTranslated");
private final Logger logger = LoggerFactory.getLogger(getClass());
private final User user;
private final Function<Double, Void> progress;
private volatile Thread taskThread;
private final Indexer indexer;
private final List<CreateNlpBatchesFromIndex.BatchDocument> docs;
private final Pipeline pipeline;
private final int maxLength;

@Inject
public BatchNlpTask(Indexer indexer, PipelineRegistry registry, @Assisted Task<Long> taskView,
@Assisted final Function<Double, Void> progress) {
this(indexer, registry.get(Pipeline.Type.parse((String) taskView.args.get("pipeline"))), taskView, progress);
}


BatchNlpTask(Indexer indexer, Pipeline pipeline, @Assisted Task<Long> taskView,
@Assisted final Function<Double, Void> progress) {
this.user = taskView.getUser();
this.indexer = indexer;
this.pipeline = pipeline;
this.docs = (List<CreateNlpBatchesFromIndex.BatchDocument>) taskView.args.get("docs");
this.maxLength = (int) taskView.args.get("maxLength");
this.progress = progress;
}

@Override
public Long call() throws Exception {
taskThread = Thread.currentThread();
if (this.docs.isEmpty()) {
return 0L;
}
int batchSize = this.docs.size();
int updateRate = Integer.max(batchSize / 10, 1);
Language language = this.docs.get(0).language();
pipeline.initialize(language);
logger.info("performing NER on {} docs in {}...", batchSize, language);
// TODO: for now None of the Java NER seems to support batch processing, we just iterate docs one by one
// TODO: we could improve perfs by fetching docs and processing them concurrently...
int nProcessed = 0;
Optional.ofNullable(this.progress).ifPresent(p -> p.apply(0.0));
for (CreateNlpBatchesFromIndex.BatchDocument doc : this.docs) {
String project = doc.project();
Document indexDoc = indexer.get(doc.id(), doc.rootDocument(), EXCLUDED_SOURCES);
if (indexDoc.getContentTextLength() < this.maxLength) {
List<NamedEntity> namedEntities = pipeline.process(indexDoc);
indexer.bulkAdd(project, pipeline.getType(), namedEntities, indexDoc);
} else {
int nbChunks = indexDoc.getContentTextLength() / this.maxLength + 1;
for (int chunkIndex = 0; chunkIndex < nbChunks; chunkIndex++) {
List<NamedEntity> namedEntities =
pipeline.process(indexDoc, maxLength, chunkIndex * maxLength);
if (chunkIndex < nbChunks - 1) {
indexer.bulkAdd(project, namedEntities);
} else {
indexer.bulkAdd(project, pipeline.getType(), namedEntities, indexDoc);
}
}
}
nProcessed += 1;
if (nProcessed % updateRate == 0) {
Double prog = (double) nProcessed / (double) batchSize;
Optional.ofNullable(this.progress).ifPresent(p -> p.apply(prog));
}
}
pipeline.terminate(language);
Optional.ofNullable(this.progress).ifPresent(p -> p.apply(1.0));
return (long) batchSize;
}

@Override
public void cancel(boolean requeue) {
ofNullable(taskThread).ifPresent(Thread::interrupt);
}

@Override
public User getUser() {
return user;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,10 @@
import static org.icij.datashare.cli.DatashareCliOptions.DEFAULT_SCROLL_DURATION;
import static org.icij.datashare.cli.DatashareCliOptions.DEFAULT_SCROLL_SIZE;
import static org.icij.datashare.cli.DatashareCliOptions.SCROLL_SIZE_OPT;
import static org.icij.datashare.tasks.GroupHelper.JAVA_GROUP;
import static org.icij.datashare.text.ProjectProxy.asCommaConcatNames;

@TaskGroup("Java")
@TaskGroup(JAVA_GROUP)
public class BatchSearchRunner implements CancellableTask, UserTask, Callable<Integer> {
private final Logger logger = LoggerFactory.getLogger(getClass());

Expand Down
Loading