-
Notifications
You must be signed in to change notification settings - Fork 54
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
2 changed files
with
172 additions
and
0 deletions.
There are no files selected for viewing
96 changes: 96 additions & 0 deletions
96
datashare-app/src/main/java/org/icij/datashare/tasks/BatchNlpTask.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
package org.icij.datashare.tasks; | ||
|
||
import static java.util.Optional.ofNullable; | ||
import static org.icij.datashare.tasks.GroupHelper.JAVA_GROUP; | ||
|
||
import com.google.inject.Inject; | ||
import com.google.inject.assistedinject.Assisted; | ||
import java.util.List; | ||
import java.util.function.Function; | ||
import org.icij.datashare.asynctasks.CancellableTask; | ||
import org.icij.datashare.asynctasks.Task; | ||
import org.icij.datashare.asynctasks.TaskGroup; | ||
import org.icij.datashare.extension.PipelineRegistry; | ||
import org.icij.datashare.text.Document; | ||
import org.icij.datashare.text.Language; | ||
import org.icij.datashare.text.NamedEntity; | ||
import org.icij.datashare.text.indexing.Indexer; | ||
import org.icij.datashare.text.nlp.Pipeline; | ||
import org.icij.datashare.user.User; | ||
import org.icij.datashare.user.UserTask; | ||
import org.icij.task.DefaultTask; | ||
import org.slf4j.Logger; | ||
import org.slf4j.LoggerFactory; | ||
|
||
@TaskGroup(JAVA_GROUP) | ||
public class BatchNlpTask extends DefaultTask<Long> implements UserTask, CancellableTask { | ||
// TODO: fix the raw used of parametrized type... | ||
private static final List<String> EXCLUDED_SOURCES = List.of("contentTranslated"); | ||
private final Logger logger = LoggerFactory.getLogger(getClass()); | ||
private final User user; | ||
private volatile Thread taskThread; | ||
private final Indexer indexer; | ||
private final List<BatchEnqueueFromIndexTask.BatchDocument> docs; | ||
private final Pipeline pipeline; | ||
private final int maxLength; | ||
|
||
@Inject | ||
public BatchNlpTask(Indexer indexer, PipelineRegistry registry, @Assisted Task<Long> taskView, | ||
@Assisted final Function<Double, Void> updateCallback) { | ||
this(indexer, registry.get(Pipeline.Type.parse((String) taskView.args.get("pipeline"))), taskView, updateCallback); | ||
} | ||
|
||
|
||
BatchNlpTask(Indexer indexer, Pipeline pipeline, @Assisted Task<Long> taskView, | ||
@Assisted final Function<Double, Void> ignored) { | ||
this.user = taskView.getUser(); | ||
this.indexer = indexer; | ||
this.pipeline = pipeline; | ||
this.docs = (List<BatchEnqueueFromIndexTask.BatchDocument>) taskView.args.get("docs"); | ||
this.maxLength = (int) taskView.args.get("maxLength"); | ||
} | ||
|
||
@Override | ||
public Long call() throws Exception { | ||
taskThread = Thread.currentThread(); | ||
if (this.docs.isEmpty()) { | ||
return 0L; | ||
} | ||
Language language = this.docs.get(0).language(); | ||
pipeline.initialize(language); | ||
logger.info("performing NER on {} docs in {}...", this.docs.size(), language); | ||
// TODO: for now None of the Java NER seems to support batch processing, we just iterate docs one by one | ||
// TODO: we could improve perfs by fetching docs and processing them concurrently... | ||
for (BatchEnqueueFromIndexTask.BatchDocument doc : this.docs) { | ||
String project = doc.project(); | ||
Document indexDoc = indexer.get(doc.id(), doc.rootDocument(), EXCLUDED_SOURCES); | ||
if (indexDoc.getContentTextLength() < this.maxLength) { | ||
List<NamedEntity> namedEntities = pipeline.process(indexDoc); | ||
indexer.bulkAdd(project, pipeline.getType(), namedEntities, indexDoc); | ||
} else { | ||
int nbChunks = indexDoc.getContentTextLength() / this.maxLength + 1; | ||
for (int chunkIndex = 0; chunkIndex < nbChunks; chunkIndex++) { | ||
List<NamedEntity> namedEntities = | ||
pipeline.process(indexDoc, maxLength, chunkIndex * maxLength); | ||
if (chunkIndex < nbChunks - 1) { | ||
indexer.bulkAdd(project, namedEntities); | ||
} else { | ||
indexer.bulkAdd(project, pipeline.getType(), namedEntities, indexDoc); | ||
} | ||
} | ||
} | ||
} | ||
pipeline.terminate(language); | ||
return (long) this.docs.size(); | ||
} | ||
|
||
@Override | ||
public void cancel(boolean requeue) { | ||
ofNullable(taskThread).ifPresent(Thread::interrupt); | ||
} | ||
|
||
@Override | ||
public User getUser() { | ||
return user; | ||
} | ||
} |
76 changes: 76 additions & 0 deletions
76
datashare-app/src/test/java/org/icij/datashare/tasks/BatchNlpTaskTest.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
package org.icij.datashare.tasks; | ||
|
||
import static org.icij.datashare.test.ElasticsearchRule.TEST_INDEX; | ||
import static org.icij.datashare.text.DocumentBuilder.createDoc; | ||
import static org.mockito.ArgumentMatchers.any; | ||
import static org.mockito.ArgumentMatchers.anyString; | ||
import static org.mockito.ArgumentMatchers.eq; | ||
import static org.mockito.Mockito.verify; | ||
import static org.mockito.Mockito.when; | ||
import static org.mockito.MockitoAnnotations.openMocks; | ||
|
||
import java.util.List; | ||
import java.util.Map; | ||
import org.icij.datashare.asynctasks.Task; | ||
import org.icij.datashare.text.Document; | ||
import org.icij.datashare.text.Language; | ||
import org.icij.datashare.text.indexing.Indexer; | ||
import org.icij.datashare.text.nlp.AbstractPipeline; | ||
import org.icij.datashare.text.nlp.Pipeline; | ||
import org.icij.datashare.user.User; | ||
import org.junit.Before; | ||
import org.junit.Test; | ||
import org.mockito.Mock; | ||
|
||
public class BatchNlpTaskTest { | ||
@Mock | ||
private Indexer indexer; | ||
@Mock | ||
private AbstractPipeline pipeline; | ||
private AutoCloseable mocks; | ||
|
||
@Before | ||
public void setUp() { | ||
this.mocks = openMocks(this); | ||
} | ||
|
||
@Before | ||
public void tearDow() throws Exception { | ||
this.mocks.close(); | ||
} | ||
|
||
@Test(timeout = 2000) | ||
public void test_batch_nlp() throws Exception { | ||
// Given | ||
int maxLength = 20; | ||
String rootId = "rootId"; | ||
Language language = Language.ENGLISH; | ||
Document doc0 = createDoc("doc0").with(language).withRootId(rootId) | ||
.with("hello world").build(); | ||
Document doc1 = createDoc("doc1").with(language).withRootId(rootId) | ||
.with("this is too long to be processed all at once").build(); | ||
when(pipeline.getType()).thenReturn(Pipeline.Type.CORENLP); | ||
when(pipeline.initialize(any())).thenReturn(true); | ||
|
||
when(indexer.get(anyString(), anyString(), any(List.class))).thenReturn(doc0, doc1); | ||
List<BatchEnqueueFromIndexTask.BatchDocument> batchDocs = List.of( | ||
new BatchEnqueueFromIndexTask.BatchDocument(doc0.getId(), doc0.getRootDocument(), TEST_INDEX, language), | ||
new BatchEnqueueFromIndexTask.BatchDocument(doc1.getId(), doc1.getRootDocument(), TEST_INDEX, language) | ||
); | ||
Map<String, Object> properties = Map.of( | ||
"docs", batchDocs, | ||
"pipeline", "OPENNLP", | ||
"maxLength", maxLength, | ||
"group", "JAVA" | ||
); | ||
BatchNlpTask nlpTask = new BatchNlpTask( | ||
indexer, pipeline, new Task<>(BatchNlpTask.class.getName(), new User("test"), properties), null | ||
); | ||
// When | ||
nlpTask.call(); | ||
// Then | ||
verify(pipeline).process(eq(doc0)); | ||
verify(pipeline).process(eq(doc1), eq(maxLength), eq(0)); | ||
verify(pipeline).process(eq(doc1), eq(maxLength), eq(maxLength)); | ||
} | ||
} |