Skip to content

Commit

Permalink
feature: return created task IDs
Browse files Browse the repository at this point in the history
  • Loading branch information
ClemDoum committed Dec 10, 2024
1 parent 0123dc4 commit e7f5239
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
import org.slf4j.LoggerFactory;

@TaskGroup(JAVA_GROUP)
public class CreateNlpBatchesFromIndex extends DefaultTask<Long> implements UserTask, CancellableTask {
public class CreateNlpBatchesFromIndex extends DefaultTask<List<String>> implements UserTask, CancellableTask {
Logger logger = LoggerFactory.getLogger(getClass());

private final User user;
Expand All @@ -63,6 +63,7 @@ public class CreateNlpBatchesFromIndex extends DefaultTask<Long> implements User
private final Indexer indexer;
private final String scrollDuration;
private final int scrollSize;
private Language currentLanguage = null;

public record BatchDocument(String id, String rootDocument, String project, Language language) {
public static BatchDocument fromDocument(Document document) {
Expand All @@ -89,7 +90,8 @@ public CreateNlpBatchesFromIndex(
}

@Override
public Long call() throws Exception {
public List<String> call() throws Exception {
ArrayList<String> taskIds = new ArrayList<>();
taskThread = Thread.currentThread();
Indexer.Searcher searcher;
if (searchQuery == null) {
Expand All @@ -110,32 +112,28 @@ public Long call() throws Exception {
"pushing batches of {} docs ids for index {}, pipeline {} with {} scroll and size of {}",
totalHits, projectName, nlpPipeline, scrollDuration, scrollSize
);
Language currentLanguage = null;
do {
// For each scrolled page, we fill the batch...
currentLanguage = this.enqueueScrollBatches(scrolledDocsByLanguage, currentLanguage, batch);
taskIds.addAll(this.enqueueScrollBatches(scrolledDocsByLanguage, batch));
// and keep scrolling...
scrolledDocsByLanguage = searcher
.scroll(scrollDuration)
.collect(groupingBy(d -> ((Document) d).getLanguage()));
// until we reach a page smaller than the scroll size aka the last page of the scrol
// until we reach a page smaller than the scroll size aka the last page of the scroll
} while (scrolledDocsByLanguage.values().stream().map(List::size).mapToInt(Integer::intValue).sum() >= scrollSize);
// Let's fill the batches for that last page
this.enqueueScrollBatches(scrolledDocsByLanguage, currentLanguage, batch);
taskIds.addAll(this.enqueueScrollBatches(scrolledDocsByLanguage, batch));
// ... and enqueue that last batch if not done yet
if (!batch.isEmpty()) {
this.enqueueBatch(batch);
taskIds.add(this.enqueueBatch(batch));
}
logger.info("queued batches for {} docs", totalHits);
searcher.clearScroll();
return totalHits;
return taskIds;
}

private Language enqueueScrollBatches(
Map<Language, ? extends List<? extends Entity>> docsByLanguage,
Language currentLanguage,
ArrayList<Document> batch
) {
private List<String> enqueueScrollBatches(Map<Language, ? extends List<? extends Entity>> docsByLanguage, ArrayList<Document> batch) {
ArrayList<String> batchTaskIds = new ArrayList<>();
// Make sure we consume the languages in order
Iterator<? extends Map.Entry<Language, ? extends List<? extends Entity>>> docsIt = docsByLanguage.entrySet()
.stream().sorted(Comparator.comparing(e -> e.getKey().name())).iterator();
Expand All @@ -145,7 +143,7 @@ private Language enqueueScrollBatches(
// If we switch language, we need to queue the batch
if (!language.equals(currentLanguage)) {
if (!batch.isEmpty()) {
this.enqueueBatch(batch);
batchTaskIds.add(this.enqueueBatch(batch));
}
currentLanguage = language;
}
Expand All @@ -157,27 +155,29 @@ private Language enqueueScrollBatches(
end = start + Integer.min(batchSize - batch.size(), languageDocs.size() - start);
batch.addAll(languageDocs.subList(start, end));
if (batch.size() >= batchSize) {
this.enqueueBatch(batch);
batchTaskIds.add(this.enqueueBatch(batch));
}
start = end;
}
}
return currentLanguage;
return batchTaskIds;
}

void enqueueBatch(List<Document> batch) {
protected String enqueueBatch(List<Document> batch) {
String taskId;
HashMap<String, Object> args = new HashMap<>(this.batchTaskArgs);
args.put("docs", batch.stream().map(BatchDocument::fromDocument).toList());
try {
// TODO: here we bind the task name to the Java class name which is not ideal since it leaks Java inners
// bolts to Python, it could be nice to decouple task names from class names since they can change and
// are bound to languages
logger.info("{} - {}", DatashareTime.getNow().getTime(), ((List<BatchDocument>)args.get("docs")).get(0).language());
this.taskManager.startTask(BatchNlpTask.class, this.user, args);
taskId = this.taskManager.startTask(BatchNlpTask.class, this.user, args);
} catch (IOException e) {
throw new RuntimeException("failed to queue task " + args, e);
}
batch.clear();
return taskId;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package org.icij.datashare.tasks;

import java.util.List;
import org.icij.datashare.asynctasks.Task;
import org.icij.datashare.user.User;

Expand All @@ -15,7 +16,7 @@ public interface DatashareTaskFactory extends org.icij.datashare.asynctasks.Task
ScanIndexTask createScanIndexTask(Task<Long> taskView, Function<Double, Void> updateCallback);
ExtractNlpTask createExtractNlpTask(Task<Long> taskView, Function<Double, Void> updateCallback);
EnqueueFromIndexTask createEnqueueFromIndexTask(Task<Long> taskView, Function<Double, Void> updateCallback);
CreateNlpBatchesFromIndex createBatchEnqueueFromIndexTask(Task<Long> taskView, Function<Double, Void> updateCallback);
CreateNlpBatchesFromIndex createBatchEnqueueFromIndexTask(Task<List<String>> taskView, Function<Double, Void> updateCallback);
BatchNlpTask createBatchNlpTask(Task<Long> taskView, Function<Double, Void> updateCallback);
DeduplicateTask createDeduplicateTask(Task<Long> taskView, Function<Double, Void> updateCallback);
ArtifactTask createArtifactTask(Task<Long> taskView, Function<Double, Void> updateCallback);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
import java.util.Comparator;
import java.util.List;
import java.util.Map;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.function.Function;
import org.icij.datashare.PropertiesProvider;
import org.icij.datashare.asynctasks.Task;
Expand Down Expand Up @@ -50,9 +49,9 @@ public TestableCreateNlpBatchesFromIndex(
super(taskManager, indexer, taskView, ignored);
}

void enqueueBatch(List<Document> batch) {
protected String enqueueBatch(List<Document> batch) {
DatashareTime.getInstance().addMilliseconds(1);
super.enqueueBatch(batch);
return super.enqueueBatch(batch);
}
}

Expand All @@ -66,7 +65,7 @@ void enqueueBatch(List<Document> batch) {
public void setUp() {
DatashareTaskFactory factory = mock(DatashareTaskFactory.class);
when(factory.createBatchNlpTask(any(), any())).thenReturn(mock(BatchNlpTask.class));
taskManager = new TaskManagerMemory(new LinkedBlockingQueue<>(), factory, new PropertiesProvider());
taskManager = new TaskManagerMemory(factory, new PropertiesProvider());
}

@After
Expand Down Expand Up @@ -131,17 +130,17 @@ public void test_queue_for_batch_nlp_by_batch() throws Exception {
"batchSize", this.batchSize,
"scrollSize", this.scrollSize
);
TestableCreateNlpBatchesFromIndex enqueueFromIndex =
new TestableCreateNlpBatchesFromIndex(taskManager, indexer,
TestableCreateNlpBatchesFromIndex enqueueFromIndex = new TestableCreateNlpBatchesFromIndex(taskManager, indexer,
new Task<>(CreateNlpBatchesFromIndex.class.getName(), new User("test"), properties), null);
// When
enqueueFromIndex.call();
List<String> taskIds = enqueueFromIndex.call();
List<List<Language>> queued = taskManager.getTasks().stream()
.sorted(Comparator.comparing(t -> t.createdAt))
.map(t -> ((List<CreateNlpBatchesFromIndex.BatchDocument>) t.args.get("docs")).stream().map(
CreateNlpBatchesFromIndex.BatchDocument::language).toList())
.toList();
// Then
assertThat(queued).isEqualTo(this.expectedLanguages);
assertThat(taskIds.size()).isEqualTo(expectedLanguages.size());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import java.io.IOException;
import java.util.List;
import java.util.Map;
import java.util.concurrent.LinkedBlockingQueue;
import org.icij.datashare.PropertiesProvider;
import org.icij.datashare.asynctasks.Task;
import org.icij.datashare.asynctasks.TaskManager;
Expand Down Expand Up @@ -41,7 +40,7 @@ public class CreateNlpBatchesFromIndexTest {
public void setUp() {
DatashareTaskFactory factory = mock(DatashareTaskFactory.class);
when(factory.createBatchNlpTask(any(), any())).thenReturn(mock(BatchNlpTask.class));
taskManager = new TaskManagerMemory(new LinkedBlockingQueue<>(), factory, new PropertiesProvider());
taskManager = new TaskManagerMemory(factory, new PropertiesProvider());
}

@After
Expand Down

0 comments on commit e7f5239

Please sign in to comment.