diff --git a/datashare-app/src/main/java/org/icij/datashare/tasks/CreateNlpBatchesFromIndex.java b/datashare-app/src/main/java/org/icij/datashare/tasks/CreateNlpBatchesFromIndex.java index 562f62fb1..e1088bf04 100644 --- a/datashare-app/src/main/java/org/icij/datashare/tasks/CreateNlpBatchesFromIndex.java +++ b/datashare-app/src/main/java/org/icij/datashare/tasks/CreateNlpBatchesFromIndex.java @@ -48,7 +48,7 @@ import org.slf4j.LoggerFactory; @TaskGroup(JAVA_GROUP) -public class CreateNlpBatchesFromIndex extends DefaultTask implements UserTask, CancellableTask { +public class CreateNlpBatchesFromIndex extends DefaultTask> implements UserTask, CancellableTask { Logger logger = LoggerFactory.getLogger(getClass()); private final User user; @@ -63,6 +63,7 @@ public class CreateNlpBatchesFromIndex extends DefaultTask implements User private final Indexer indexer; private final String scrollDuration; private final int scrollSize; + private Language currentLanguage = null; public record BatchDocument(String id, String rootDocument, String project, Language language) { public static BatchDocument fromDocument(Document document) { @@ -89,7 +90,8 @@ public CreateNlpBatchesFromIndex( } @Override - public Long call() throws Exception { + public List call() throws Exception { + ArrayList taskIds = new ArrayList<>(); taskThread = Thread.currentThread(); Indexer.Searcher searcher; if (searchQuery == null) { @@ -110,32 +112,28 @@ public Long call() throws Exception { "pushing batches of {} docs ids for index {}, pipeline {} with {} scroll and size of {}", totalHits, projectName, nlpPipeline, scrollDuration, scrollSize ); - Language currentLanguage = null; do { // For each scrolled page, we fill the batch... - currentLanguage = this.enqueueScrollBatches(scrolledDocsByLanguage, currentLanguage, batch); + taskIds.addAll(this.enqueueScrollBatches(scrolledDocsByLanguage, batch)); // and keep scrolling... scrolledDocsByLanguage = searcher .scroll(scrollDuration) .collect(groupingBy(d -> ((Document) d).getLanguage())); - // until we reach a page smaller than the scroll size aka the last page of the scrol + // until we reach a page smaller than the scroll size aka the last page of the scroll } while (scrolledDocsByLanguage.values().stream().map(List::size).mapToInt(Integer::intValue).sum() >= scrollSize); // Let's fill the batches for that last page - this.enqueueScrollBatches(scrolledDocsByLanguage, currentLanguage, batch); + taskIds.addAll(this.enqueueScrollBatches(scrolledDocsByLanguage, batch)); // ... and enqueue that last batch if not done yet if (!batch.isEmpty()) { - this.enqueueBatch(batch); + taskIds.add(this.enqueueBatch(batch)); } logger.info("queued batches for {} docs", totalHits); searcher.clearScroll(); - return totalHits; + return taskIds; } - private Language enqueueScrollBatches( - Map> docsByLanguage, - Language currentLanguage, - ArrayList batch - ) { + private List enqueueScrollBatches(Map> docsByLanguage, ArrayList batch) { + ArrayList batchTaskIds = new ArrayList<>(); // Make sure we consume the languages in order Iterator>> docsIt = docsByLanguage.entrySet() .stream().sorted(Comparator.comparing(e -> e.getKey().name())).iterator(); @@ -145,7 +143,7 @@ private Language enqueueScrollBatches( // If we switch language, we need to queue the batch if (!language.equals(currentLanguage)) { if (!batch.isEmpty()) { - this.enqueueBatch(batch); + batchTaskIds.add(this.enqueueBatch(batch)); } currentLanguage = language; } @@ -157,15 +155,16 @@ private Language enqueueScrollBatches( end = start + Integer.min(batchSize - batch.size(), languageDocs.size() - start); batch.addAll(languageDocs.subList(start, end)); if (batch.size() >= batchSize) { - this.enqueueBatch(batch); + batchTaskIds.add(this.enqueueBatch(batch)); } start = end; } } - return currentLanguage; + return batchTaskIds; } - void enqueueBatch(List batch) { + protected String enqueueBatch(List batch) { + String taskId; HashMap args = new HashMap<>(this.batchTaskArgs); args.put("docs", batch.stream().map(BatchDocument::fromDocument).toList()); try { @@ -173,11 +172,12 @@ void enqueueBatch(List batch) { // bolts to Python, it could be nice to decouple task names from class names since they can change and // are bound to languages logger.info("{} - {}", DatashareTime.getNow().getTime(), ((List)args.get("docs")).get(0).language()); - this.taskManager.startTask(BatchNlpTask.class, this.user, args); + taskId = this.taskManager.startTask(BatchNlpTask.class, this.user, args); } catch (IOException e) { throw new RuntimeException("failed to queue task " + args, e); } batch.clear(); + return taskId; } @Override diff --git a/datashare-app/src/test/java/org/icij/datashare/tasks/CreateNlpBatchesFromIndexParametrizedTest.java b/datashare-app/src/test/java/org/icij/datashare/tasks/CreateNlpBatchesFromIndexParametrizedTest.java index c3ef3bd10..fa7c16c0e 100644 --- a/datashare-app/src/test/java/org/icij/datashare/tasks/CreateNlpBatchesFromIndexParametrizedTest.java +++ b/datashare-app/src/test/java/org/icij/datashare/tasks/CreateNlpBatchesFromIndexParametrizedTest.java @@ -50,9 +50,9 @@ public TestableCreateNlpBatchesFromIndex( super(taskManager, indexer, taskView, ignored); } - void enqueueBatch(List batch) { + protected String enqueueBatch(List batch) { DatashareTime.getInstance().addMilliseconds(1); - super.enqueueBatch(batch); + return super.enqueueBatch(batch); } } @@ -131,11 +131,10 @@ public void test_queue_for_batch_nlp_by_batch() throws Exception { "batchSize", this.batchSize, "scrollSize", this.scrollSize ); - TestableCreateNlpBatchesFromIndex enqueueFromIndex = - new TestableCreateNlpBatchesFromIndex(taskManager, indexer, + TestableCreateNlpBatchesFromIndex enqueueFromIndex = new TestableCreateNlpBatchesFromIndex(taskManager, indexer, new Task<>(CreateNlpBatchesFromIndex.class.getName(), new User("test"), properties), null); // When - enqueueFromIndex.call(); + List taskIds = enqueueFromIndex.call(); List> queued = taskManager.getTasks().stream() .sorted(Comparator.comparing(t -> t.createdAt)) .map(t -> ((List) t.args.get("docs")).stream().map( @@ -143,5 +142,6 @@ public void test_queue_for_batch_nlp_by_batch() throws Exception { .toList(); // Then assertThat(queued).isEqualTo(this.expectedLanguages); + assertThat(taskIds.size()).isEqualTo(expectedLanguages.size()); } }