Skip to content

Commit

Permalink
Added tests
Browse files Browse the repository at this point in the history
Signed-off-by: Owais Kazi <[email protected]>
  • Loading branch information
owaiskazi19 committed Apr 12, 2024
1 parent b33c37a commit 312a6a9
Show file tree
Hide file tree
Showing 6 changed files with 194 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ private CommonValue() {}
public static final String CONFIGURATIONS = "configurations";
/** Guardrails field */
public static final String GUARDRAILS_FIELD = "guardrails";

/** The reindex field for created resources */
public static final String RE_INDEX_FIELD = "reindex";
/** The source index field for reindex */
public static final String SOURCE_INDEX = "source_index";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import org.opensearch.flowframework.workflow.DeleteModelStep;
import org.opensearch.flowframework.workflow.DeployModelStep;
import org.opensearch.flowframework.workflow.NoOpStep;
import org.opensearch.flowframework.workflow.ReIndexStep;
import org.opensearch.flowframework.workflow.RegisterAgentStep;
import org.opensearch.flowframework.workflow.RegisterLocalCustomModelStep;
import org.opensearch.flowframework.workflow.RegisterLocalPretrainedModelStep;
Expand Down Expand Up @@ -58,6 +59,8 @@ public enum WorkflowResources {
CREATE_SEARCH_PIPELINE(CreateSearchPipelineStep.NAME, WorkflowResources.PIPELINE_ID, null), // TODO delete step
/** Workflow steps for creating an index and associated created resource */
CREATE_INDEX(CreateIndexStep.NAME, WorkflowResources.INDEX_NAME, NoOpStep.NAME),
/** Workflow steps for reindex a source index to destination index and associated created resource */
RE_INDEX(ReIndexStep.NAME, CommonValue.DESTINATION_INDEX, NoOpStep.NAME),
/** Workflow steps for registering/deleting an agent and the associated created resource */
REGISTER_AGENT(RegisterAgentStep.NAME, WorkflowResources.AGENT_ID, DeleteAgentStep.NAME);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@
import static org.opensearch.flowframework.common.CommonValue.RE_INDEX_FIELD;
import static org.opensearch.flowframework.common.CommonValue.SOURCE_INDEX;

/**
* Step to reindex
*/
public class ReIndexStep implements WorkflowStep {

private static final Logger logger = LogManager.getLogger(ReIndexStep.class);
Expand Down Expand Up @@ -126,7 +129,7 @@ public void onResponse(BulkByScrollResponse bulkByScrollResponse) {

@Override
public void onFailure(Exception e) {
String errorMessage = "Failed to reindex from source" + sourceIndices + "to" + destinationIndex;
String errorMessage = "Failed to reindex from source " + sourceIndices + " to " + destinationIndex;
logger.error(errorMessage, e);
reIndexFuture.onFailure(new WorkflowStepException(errorMessage, ExceptionsHelper.status(e)));
}
Expand All @@ -143,6 +146,6 @@ public void onFailure(Exception e) {

@Override
public String getName() {
return null;
return NAME;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ public WorkflowStepFactory(
) {
stepMap.put(NoOpStep.NAME, NoOpStep::new);
stepMap.put(CreateIndexStep.NAME, () -> new CreateIndexStep(client, flowFrameworkIndicesHandler));
stepMap.put(ReIndexStep.NAME, () -> new CreateIndexStep(client, flowFrameworkIndicesHandler));
stepMap.put(ReIndexStep.NAME, () -> new ReIndexStep(client, flowFrameworkIndicesHandler));
stepMap.put(
RegisterLocalCustomModelStep.NAME,
() -> new RegisterLocalCustomModelStep(threadPool, mlClient, flowFrameworkIndicesHandler, flowFrameworkSettings)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ public void testParseWorkflowValidator() throws IOException {

WorkflowValidator validator = new WorkflowValidator(workflowStepValidators);

assertEquals(17, validator.getWorkflowStepValidators().size());
assertEquals(18, validator.getWorkflowStepValidators().size());

assertTrue(validator.getWorkflowStepValidators().keySet().contains("create_connector"));
assertEquals(7, validator.getWorkflowStepValidators().get("create_connector").getInputs().size());
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/
package org.opensearch.flowframework.workflow;

import org.apache.lucene.tests.util.LuceneTestCase;
import org.opensearch.OpenSearchException;
import org.opensearch.action.support.PlainActionFuture;
import org.opensearch.action.update.UpdateResponse;
import org.opensearch.client.Client;
import org.opensearch.common.Randomness;
import org.opensearch.common.unit.TimeValue;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.index.shard.ShardId;
import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler;
import org.opensearch.index.reindex.BulkByScrollResponse;
import org.opensearch.index.reindex.BulkByScrollTask;
import org.opensearch.index.reindex.ReindexRequest;
import org.opensearch.test.OpenSearchTestCase;

import java.io.IOException;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import java.util.stream.IntStream;

import org.mockito.ArgumentCaptor;
import org.mockito.MockitoAnnotations;

import static java.lang.Math.abs;
import static java.util.stream.Collectors.toList;
import static org.opensearch.action.DocWriteResponse.Result.UPDATED;
import static org.opensearch.common.unit.TimeValue.timeValueMillis;
import static org.opensearch.flowframework.common.CommonValue.DESTINATION_INDEX;
import static org.opensearch.flowframework.common.CommonValue.RE_INDEX_FIELD;
import static org.opensearch.flowframework.common.CommonValue.SOURCE_INDEX;
import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_STATE_INDEX;
import static org.apache.lucene.tests.util.TestUtil.randomSimpleString;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;

public class ReIndexStepTests extends OpenSearchTestCase {
private WorkflowData inputData = WorkflowData.EMPTY;
private Client client;
private ReIndexStep reIndexStep;

private FlowFrameworkIndicesHandler flowFrameworkIndicesHandler;

@Override
public void setUp() throws Exception {
super.setUp();
this.flowFrameworkIndicesHandler = mock(FlowFrameworkIndicesHandler.class);
MockitoAnnotations.openMocks(this);

inputData = new WorkflowData(
Map.ofEntries(Map.entry(SOURCE_INDEX, "demo"), Map.entry(DESTINATION_INDEX, "dest")),
"test-id",
"test-node-id"
);

client = mock(Client.class);
reIndexStep = new ReIndexStep(client, flowFrameworkIndicesHandler);
}

public void testReIndexStep() throws ExecutionException, InterruptedException, IOException {

doAnswer(invocation -> {
ActionListener<UpdateResponse> updateResponseListener = invocation.getArgument(4);
updateResponseListener.onResponse(new UpdateResponse(new ShardId(WORKFLOW_STATE_INDEX, "", 1), "id", -2, 0, 0, UPDATED));
return null;
}).when(flowFrameworkIndicesHandler).updateResourceInStateIndex(anyString(), anyString(), anyString(), anyString(), any());

@SuppressWarnings({ "unchecked" })
ArgumentCaptor<ActionListener<BulkByScrollResponse>> actionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class);
PlainActionFuture<WorkflowData> future = reIndexStep.execute(
inputData.getNodeId(),
inputData,
Collections.emptyMap(),
Collections.emptyMap(),
Collections.emptyMap()
);

verify(client, times(1)).execute(any(), any(ReindexRequest.class), actionListenerCaptor.capture());
actionListenerCaptor.getValue()
.onResponse(
new BulkByScrollResponse(
timeValueMillis(randomNonNegativeLong()),
randomStatus(),
Collections.emptyList(),
Collections.emptyList(),
randomBoolean()
)
);

assertTrue(future.isDone());

Map<String, Object> outputData = Map.of(RE_INDEX_FIELD, Map.of("demo", "dest"));
assertEquals(outputData, future.get().getContent());

}

public void testReIndexStepFailure() throws ExecutionException, InterruptedException {
@SuppressWarnings({ "unchecked" })
ArgumentCaptor<ActionListener<BulkByScrollResponse>> actionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class);
PlainActionFuture<WorkflowData> future = reIndexStep.execute(
inputData.getNodeId(),
inputData,
Collections.emptyMap(),
Collections.emptyMap(),
Collections.emptyMap()
);
assertFalse(future.isDone());
verify(client, times(1)).execute(any(), any(ReindexRequest.class), actionListenerCaptor.capture());

actionListenerCaptor.getValue().onFailure(new Exception("Failed to reindex from source demo to dest"));

assertTrue(future.isDone());
ExecutionException ex = assertThrows(ExecutionException.class, () -> future.get().getContent());
assertTrue(ex.getCause() instanceof Exception);
assertEquals("Failed to reindex from source demo to dest", ex.getCause().getMessage());
}

private static BulkByScrollTask.Status randomStatus() {
if (randomBoolean()) {
return randomWorkingStatus(null);
}
boolean canHaveNullStatues = randomBoolean();
List<BulkByScrollTask.StatusOrException> statuses = IntStream.range(0, between(0, 10)).mapToObj(i -> {
if (canHaveNullStatues && LuceneTestCase.rarely()) {
return null;
}
if (randomBoolean()) {
return new BulkByScrollTask.StatusOrException(new OpenSearchException(randomAlphaOfLength(5)));
}
return new BulkByScrollTask.StatusOrException(randomWorkingStatus(i));
}).collect(toList());
return new BulkByScrollTask.Status(statuses, randomBoolean() ? "test" : null);
}

private static BulkByScrollTask.Status randomWorkingStatus(Integer sliceId) {
// These all should be believably small because we sum them if we have multiple workers
int total = between(0, 10000000);
int updated = between(0, total);
int created = between(0, total - updated);
int deleted = between(0, total - updated - created);
int noops = total - updated - created - deleted;
int batches = between(0, 10000);
long versionConflicts = between(0, total);
long bulkRetries = between(0, 10000000);
long searchRetries = between(0, 100000);
// smallest unit of time during toXContent is Milliseconds
TimeUnit[] timeUnits = { TimeUnit.MILLISECONDS, TimeUnit.SECONDS, TimeUnit.MINUTES, TimeUnit.HOURS, TimeUnit.DAYS };
TimeValue throttled = new TimeValue(randomIntBetween(0, 1000), randomFrom(timeUnits));
TimeValue throttledUntil = new TimeValue(randomIntBetween(0, 1000), randomFrom(timeUnits));
return new BulkByScrollTask.Status(
sliceId,
total,
updated,
created,
deleted,
batches,
versionConflicts,
noops,
bulkRetries,
searchRetries,
throttled,
abs(Randomness.get().nextFloat()),
randomBoolean() ? null : randomSimpleString(Randomness.get()),
throttledUntil
);
}
}

0 comments on commit 312a6a9

Please sign in to comment.