Skip to content

Commit

Permalink
Adding multi node IT and update API bug fixes (#416)
Browse files Browse the repository at this point in the history
* add multi node IT, fix update api bugs

Signed-off-by: Amit Galitzky <[email protected]>

* adding test and stashing context

Signed-off-by: Amit Galitzky <[email protected]>

* cleaning up comments and adding tests

Signed-off-by: Amit Galitzky <[email protected]>

---------

Signed-off-by: Amit Galitzky <[email protected]>
  • Loading branch information
amitgalitz authored Jan 20, 2024
1 parent ab5ac39 commit a6ac83a
Show file tree
Hide file tree
Showing 8 changed files with 298 additions and 48 deletions.
46 changes: 25 additions & 21 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -43,20 +43,20 @@ jobs:
name: Test JDK${{ matrix.java }}, ${{ matrix.os }}
runs-on: ${{ matrix.os }}
steps:
- uses: actions/checkout@v4
- name: Set up JDK ${{ matrix.java }}
uses: actions/setup-java@v4
with:
java-version: ${{ matrix.java }}
distribution: temurin
- name: Build and Run Tests
run: |
- uses: actions/checkout@v4
- name: Set up JDK ${{ matrix.java }}
uses: actions/setup-java@v4
with:
java-version: ${{ matrix.java }}
distribution: temurin
- name: Build and Run Tests
run: |
./gradlew check -x integTest -x yamlRestTest -x spotlessJava
- name: Upload Coverage Report
if: ${{ matrix.codecov }}
uses: codecov/codecov-action@v3
with:
file: ./build/reports/jacoco/test/jacocoTestReport.xml
- name: Upload Coverage Report
if: ${{ matrix.codecov }}
uses: codecov/codecov-action@v3
with:
file: ./build/reports/jacoco/test/jacocoTestReport.xml
integTest:
needs: [spotless, javadoc]
strategy:
Expand All @@ -69,12 +69,16 @@ jobs:
name: Integ Test JDK${{ matrix.java }}, ${{ matrix.os }}
runs-on: ${{ matrix.os }}
steps:
- uses: actions/checkout@v4
- name: Set up JDK ${{ matrix.java }}
uses: actions/setup-java@v4
with:
java-version: ${{ matrix.java }}
distribution: temurin
- name: Build and Run Tests
run: |
- uses: actions/checkout@v4
- name: Set up JDK ${{ matrix.java }}
uses: actions/setup-java@v4
with:
java-version: ${{ matrix.java }}
distribution: temurin
- name: Build and Run Tests
run: |
./gradlew integTest yamlRestTest
- name: Multi Nodes Integration Testing
if: matrix.java == 21
run: |
./gradlew integTest -PnumNodes=3
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,12 @@ public Collection<Object> createComponents(
flowFrameworkSettings = new FlowFrameworkSettings(clusterService, settings);
MachineLearningNodeClient mlClient = new MachineLearningNodeClient(client);
EncryptorUtils encryptorUtils = new EncryptorUtils(clusterService, client);
FlowFrameworkIndicesHandler flowFrameworkIndicesHandler = new FlowFrameworkIndicesHandler(client, clusterService, encryptorUtils);
FlowFrameworkIndicesHandler flowFrameworkIndicesHandler = new FlowFrameworkIndicesHandler(
client,
clusterService,
encryptorUtils,
xContentRegistry
);
WorkflowStepFactory workflowStepFactory = new WorkflowStepFactory(
threadPool,
clusterService,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import org.opensearch.action.admin.indices.create.CreateIndexResponse;
import org.opensearch.action.admin.indices.mapping.put.PutMappingRequest;
import org.opensearch.action.admin.indices.settings.put.UpdateSettingsRequest;
import org.opensearch.action.get.GetRequest;
import org.opensearch.action.index.IndexRequest;
import org.opensearch.action.index.IndexResponse;
import org.opensearch.action.support.WriteRequest;
Expand All @@ -29,8 +30,10 @@
import org.opensearch.commons.authuser.User;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.core.xcontent.ToXContent;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.flowframework.exception.FlowFrameworkException;
import org.opensearch.flowframework.model.ProvisioningProgress;
import org.opensearch.flowframework.model.ResourceCreated;
Expand All @@ -47,8 +50,10 @@
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Consumer;

import static org.opensearch.core.rest.RestStatus.INTERNAL_SERVER_ERROR;
import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.opensearch.flowframework.common.CommonValue.CONFIG_INDEX_MAPPING;
import static org.opensearch.flowframework.common.CommonValue.GLOBAL_CONTEXT_INDEX;
import static org.opensearch.flowframework.common.CommonValue.GLOBAL_CONTEXT_INDEX_MAPPING;
Expand All @@ -70,20 +75,28 @@ public class FlowFrameworkIndicesHandler {
private final EncryptorUtils encryptorUtils;
private static final Map<String, AtomicBoolean> indexMappingUpdated = new HashMap<>();
private static final Map<String, Object> indexSettings = Map.of("index.auto_expand_replicas", "0-1");
private final NamedXContentRegistry xContentRegistry;

/**
* constructor
* @param client the open search client
* @param clusterService ClusterService
* @param encryptorUtils encryption utility
* @param xContentRegistry contentRegister to parse any response
*/
public FlowFrameworkIndicesHandler(Client client, ClusterService clusterService, EncryptorUtils encryptorUtils) {
public FlowFrameworkIndicesHandler(
Client client,
ClusterService clusterService,
EncryptorUtils encryptorUtils,
NamedXContentRegistry xContentRegistry
) {
this.client = client;
this.clusterService = clusterService;
this.encryptorUtils = encryptorUtils;
for (FlowFrameworkIndex mlIndex : FlowFrameworkIndex.values()) {
indexMappingUpdated.put(mlIndex.getIndexName(), new AtomicBoolean(false));
}
this.xContentRegistry = xContentRegistry;
}

static {
Expand Down Expand Up @@ -395,21 +408,99 @@ public void updateTemplateInGlobalContext(String documentId, Template template,
+ ", global_context index does not exist.";
logger.error(exceptionMessage);
listener.onFailure(new FlowFrameworkException(exceptionMessage, RestStatus.BAD_REQUEST));
} else {
IndexRequest request = new IndexRequest(GLOBAL_CONTEXT_INDEX).id(documentId);
try (
XContentBuilder builder = XContentFactory.jsonBuilder();
ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()
) {
Template encryptedTemplate = encryptorUtils.encryptTemplateCredentials(template);
request.source(encryptedTemplate.toXContent(builder, ToXContent.EMPTY_PARAMS))
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
client.index(request, ActionListener.runBefore(listener, context::restore));
} catch (Exception e) {
String errorMessage = "Failed to update global_context entry : " + documentId;
logger.error(errorMessage, e);
listener.onFailure(new FlowFrameworkException(errorMessage + " : " + e.getMessage(), ExceptionsHelper.status(e)));
return;
}
doesTemplateExist(documentId, templateExists -> {
if (templateExists) {
isWorkflowProvisioned(documentId, workflowIsProvisioned -> {
if (workflowIsProvisioned) {
IndexRequest request = new IndexRequest(GLOBAL_CONTEXT_INDEX).id(documentId);
try (
XContentBuilder builder = XContentFactory.jsonBuilder();
ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()
) {
Template encryptedTemplate = encryptorUtils.encryptTemplateCredentials(template);
request.source(encryptedTemplate.toXContent(builder, ToXContent.EMPTY_PARAMS))
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
client.index(request, ActionListener.runBefore(listener, context::restore));
} catch (Exception e) {
String errorMessage = "Failed to update global_context entry : " + documentId;
logger.error(errorMessage, e);
listener.onFailure(
new FlowFrameworkException(errorMessage + " : " + e.getMessage(), ExceptionsHelper.status(e))
);
}
} else {
String errorMessage = "The template has already been provisioned so it can't be updated: " + documentId;
logger.error(errorMessage);
listener.onFailure(new FlowFrameworkException(errorMessage, RestStatus.BAD_REQUEST));
}
}, listener);
} else {
String errorMessage = "Failed to get template: " + documentId;
logger.error(errorMessage);
listener.onFailure(new FlowFrameworkException(errorMessage, RestStatus.BAD_REQUEST));
}
}, listener);
}

/**
* Check if the given template exists in the template index
*
* @param documentId document id
* @param booleanResultConsumer a consumer based on whether the template exist
* @param listener action listener
* @param <T> action listener response type
*/
public <T> void doesTemplateExist(String documentId, Consumer<Boolean> booleanResultConsumer, ActionListener<T> listener) {
GetRequest getRequest = new GetRequest(GLOBAL_CONTEXT_INDEX, documentId);
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
client.get(getRequest, ActionListener.wrap(response -> { booleanResultConsumer.accept(response.isExists()); }, exception -> {
context.restore();
logger.error("Failed to get template " + documentId, exception);
listener.onFailure(new FlowFrameworkException(exception.getMessage(), ExceptionsHelper.status(exception)));
}));
} catch (Exception e) {
logger.error("Failed to retrieve template from global context.", e);
listener.onFailure(new FlowFrameworkException(e.getMessage(), ExceptionsHelper.status(e)));
}
}

/**
* Check if the workflow has been provisioned and executes the consumer by passing a boolean
*
* @param documentId document id
* @param booleanResultConsumer boolean consumer function based on if workflow is provisioned or not
* @param listener action listener
* @param <T> action listener response type
*/
public <T> void isWorkflowProvisioned(String documentId, Consumer<Boolean> booleanResultConsumer, ActionListener<T> listener) {
GetRequest getRequest = new GetRequest(WORKFLOW_STATE_INDEX, documentId);
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
client.get(getRequest, ActionListener.wrap(response -> {
context.restore();
if (!response.isExists()) {
booleanResultConsumer.accept(false);
return;
}
try (
XContentParser parser = ParseUtils.createXContentParserFromRegistry(xContentRegistry, response.getSourceAsBytesRef())
) {
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser);
WorkflowState workflowState = WorkflowState.parse(parser);
booleanResultConsumer.accept(workflowState.getProvisioningProgress().equals(ProvisioningProgress.NOT_STARTED.name()));
} catch (Exception e) {
String message = "Failed to parse workflow state " + documentId;
logger.error(message, e);
listener.onFailure(new FlowFrameworkException(message, RestStatus.INTERNAL_SERVER_ERROR));
}
}, exception -> {
logger.error("Failed to get workflow state " + documentId, exception);
booleanResultConsumer.accept(false);
}));
} catch (Exception e) {
logger.error("Failed to retrieve workflow state to check provisioning status", e);
listener.onFailure(new FlowFrameworkException(e.getMessage(), ExceptionsHelper.status(e)));
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ protected void doExecute(Task task, WorkflowRequest request, ActionListener<Work
logger.info("updated workflow {} state to {}", request.getWorkflowId(), State.NOT_STARTED.name());
listener.onResponse(new WorkflowResponse(request.getWorkflowId()));
}, exception -> {
logger.error("Failed to update workflow state : {}", exception.getMessage());
logger.error("Failed to update workflow in template index: ", exception);
if (exception instanceof FlowFrameworkException) {
listener.onFailure(exception);
} else {
Expand Down
15 changes: 15 additions & 0 deletions src/test/java/org/opensearch/flowframework/TestHelpers.java
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,17 @@
import org.opensearch.core.xcontent.ToXContentObject;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.flowframework.model.Template;
import org.opensearch.flowframework.model.Workflow;
import org.opensearch.flowframework.model.WorkflowEdge;
import org.opensearch.flowframework.model.WorkflowNode;
import org.opensearch.flowframework.util.ParseUtils;

import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Set;
Expand Down Expand Up @@ -153,4 +157,15 @@ public static XContentBuilder builder() throws IOException {
public static Map<String, Object> XContentBuilderToMap(XContentBuilder builder) {
return XContentHelper.convertToMap(BytesReference.bytes(builder), false, builder.contentType()).v2();
}

public static Workflow createSampleWorkflow() {
WorkflowNode nodeA = new WorkflowNode("A", "a-type", Collections.emptyMap(), Map.of("foo", "bar"));
WorkflowNode nodeB = new WorkflowNode("B", "b-type", Collections.emptyMap(), Map.of("baz", "qux"));
WorkflowEdge edgeAB = new WorkflowEdge("A", "B");
List<WorkflowNode> nodes = List.of(nodeA, nodeB);
List<WorkflowEdge> edges = List.of(edgeAB);
Workflow workflow = new Workflow(Map.of("key", "value"), nodes, edges);
return workflow;
}

}
Loading

0 comments on commit a6ac83a

Please sign in to comment.