From 04f9e6972621763a574203f8d4ed1200802dcd6e Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Tue, 31 Oct 2023 21:38:07 +0000 Subject: [PATCH] Modifies use case template format and adds graph validation when provisioning (#119) * Simplifying Template format, removing operations, resources created, user outputs Signed-off-by: Joshua Palis * Initial commit, modifies use case template to seperate workflow inputs into previous_node_inputs and user_inputs, adds graph validation after topologically sorting a workflow into a list of ProcessNode Signed-off-by: Joshua Palis * Adding tests Signed-off-by: Joshua Palis * Adding validate graph test Signed-off-by: Joshua Palis * Addressing PR comments, moving sorting/validating prior to executing async, adding success test case for graph validation Signed-off-by: Joshua Palis * Adding javadocs Signed-off-by: Joshua Palis * Moving validation prior to updating workflow state to provisioning Signed-off-by: Joshua Palis * Addressing PR comments Part 1 Signed-off-by: Joshua Palis * Addressing PR comments Part 2 : Moving field names to common value class and using constants Signed-off-by: Joshua Palis * Adding definition for noop workflow step Signed-off-by: Joshua Palis * Addressing PR comments Part 3 Signed-off-by: Joshua Palis --------- Signed-off-by: Joshua Palis (cherry picked from commit ac76a44927a6e475694f563b2a033aaa498d13e8) Signed-off-by: github-actions[bot] --- .../flowframework/common/CommonValue.java | 16 ++++ .../indices/FlowFrameworkIndex.java | 3 + .../model/ProvisioningProgress.java | 3 + .../opensearch/flowframework/model/State.java | 4 + .../flowframework/model/WorkflowNode.java | 54 +++++++---- .../model/WorkflowStepValidator.java | 91 +++++++++++++++++++ .../model/WorkflowValidator.java | 78 ++++++++++++++++ .../ProvisionWorkflowTransportAction.java | 38 ++++---- .../workflow/CreateIndexStep.java | 9 +- .../workflow/CreateIngestPipelineStep.java | 48 +++++----- .../flowframework/workflow/ProcessNode.java | 13 +++ .../workflow/WorkflowProcessSorter.java | 65 ++++++++++++- .../workflow/WorkflowStepFactory.java | 8 ++ .../resources/mappings/workflow-steps.json | 71 +++++++++++++++ .../model/TemplateTestJsonUtil.java | 2 +- .../flowframework/model/TemplateTests.java | 6 +- .../model/WorkflowNodeTests.java | 20 ++-- .../model/WorkflowStepValidatorTests.java | 46 ++++++++++ .../flowframework/model/WorkflowTests.java | 8 +- .../model/WorkflowValidatorTests.java | 87 ++++++++++++++++++ .../rest/RestCreateWorkflowActionTests.java | 4 +- .../CreateWorkflowTransportActionTests.java | 4 +- ...ProvisionWorkflowTransportActionTests.java | 4 +- .../WorkflowRequestResponseTests.java | 4 +- .../workflow/CreateIndexStepTests.java | 4 +- .../CreateIngestPipelineStepTests.java | 4 +- .../workflow/ProcessNodeTests.java | 7 +- .../workflow/WorkflowProcessSorterTests.java | 67 ++++++++++++++ 28 files changed, 672 insertions(+), 96 deletions(-) create mode 100644 src/main/java/org/opensearch/flowframework/model/WorkflowStepValidator.java create mode 100644 src/main/java/org/opensearch/flowframework/model/WorkflowValidator.java create mode 100644 src/main/resources/mappings/workflow-steps.json create mode 100644 src/test/java/org/opensearch/flowframework/model/WorkflowStepValidatorTests.java create mode 100644 src/test/java/org/opensearch/flowframework/model/WorkflowValidatorTests.java diff --git a/src/main/java/org/opensearch/flowframework/common/CommonValue.java b/src/main/java/org/opensearch/flowframework/common/CommonValue.java index 32acc9a68..ecce8ec50 100644 --- a/src/main/java/org/opensearch/flowframework/common/CommonValue.java +++ b/src/main/java/org/opensearch/flowframework/common/CommonValue.java @@ -61,6 +61,22 @@ private CommonValue() {} /** The provision workflow thread pool name */ public static final String PROVISION_THREAD_POOL = "opensearch_workflow_provision"; + /** Index name field */ + public static final String INDEX_NAME = "index_name"; + /** Type field */ + public static final String TYPE = "type"; + /** ID Field */ + public static final String ID = "id"; + /** Pipeline Id field */ + public static final String PIPELINE_ID = "pipeline_id"; + /** Processors field */ + public static final String PROCESSORS = "processors"; + /** Field map field */ + public static final String FIELD_MAP = "field_map"; + /** Input Field Name field */ + public static final String INPUT_FIELD_NAME = "input_field_name"; + /** Output Field Name field */ + public static final String OUTPUT_FIELD_NAME = "output_field_name"; /** Model Id field */ public static final String MODEL_ID = "model_id"; /** Function Name field */ diff --git a/src/main/java/org/opensearch/flowframework/indices/FlowFrameworkIndex.java b/src/main/java/org/opensearch/flowframework/indices/FlowFrameworkIndex.java index e23b9ddf0..4b005e45d 100644 --- a/src/main/java/org/opensearch/flowframework/indices/FlowFrameworkIndex.java +++ b/src/main/java/org/opensearch/flowframework/indices/FlowFrameworkIndex.java @@ -29,6 +29,9 @@ public enum FlowFrameworkIndex { ThrowingSupplierWrapper.throwingSupplierWrapper(FlowFrameworkIndicesHandler::getGlobalContextMappings), GLOBAL_CONTEXT_INDEX_VERSION ), + /** + * Workflow State Index + */ WORKFLOW_STATE( WORKFLOW_STATE_INDEX, ThrowingSupplierWrapper.throwingSupplierWrapper(FlowFrameworkIndicesHandler::getWorkflowStateMappings), diff --git a/src/main/java/org/opensearch/flowframework/model/ProvisioningProgress.java b/src/main/java/org/opensearch/flowframework/model/ProvisioningProgress.java index 1aefecb4b..d5a2a5734 100644 --- a/src/main/java/org/opensearch/flowframework/model/ProvisioningProgress.java +++ b/src/main/java/org/opensearch/flowframework/model/ProvisioningProgress.java @@ -13,7 +13,10 @@ */ // TODO: transfer this to more detailed array for each step public enum ProvisioningProgress { + /** Not Started State */ NOT_STARTED, + /** In Progress State */ IN_PROGRESS, + /** Done State */ DONE } diff --git a/src/main/java/org/opensearch/flowframework/model/State.java b/src/main/java/org/opensearch/flowframework/model/State.java index 3288ed4ab..bb9540c52 100644 --- a/src/main/java/org/opensearch/flowframework/model/State.java +++ b/src/main/java/org/opensearch/flowframework/model/State.java @@ -12,8 +12,12 @@ * Enum relating to the state of a workflow */ public enum State { + /** Not Started state */ NOT_STARTED, + /** Provisioning state */ PROVISIONING, + /** Failed state */ FAILED, + /** Completed state */ COMPLETED } diff --git a/src/main/java/org/opensearch/flowframework/model/WorkflowNode.java b/src/main/java/org/opensearch/flowframework/model/WorkflowNode.java index d2046f096..7d04a5a3f 100644 --- a/src/main/java/org/opensearch/flowframework/model/WorkflowNode.java +++ b/src/main/java/org/opensearch/flowframework/model/WorkflowNode.java @@ -39,8 +39,10 @@ public class WorkflowNode implements ToXContentObject { public static final String ID_FIELD = "id"; /** The template field name for node type */ public static final String TYPE_FIELD = "type"; + /** The template field name for previous node inputs */ + public static final String PREVIOUS_NODE_INPUTS_FIELD = "previous_node_inputs"; /** The template field name for node inputs */ - public static final String INPUTS_FIELD = "inputs"; + public static final String USER_INPUTS_FIELD = "user_inputs"; /** The field defining processors in the inputs for search and ingest pipelines */ public static final String PROCESSORS_FIELD = "processors"; /** The field defining the timeout value for this node */ @@ -50,19 +52,22 @@ public class WorkflowNode implements ToXContentObject { private final String id; // unique id private final String type; // maps to a WorkflowStep - private final Map inputs; // maps to WorkflowData + private final Map previousNodeInputs; + private final Map userInputs; // maps to WorkflowData /** * Create this node with the id and type, and any user input. * * @param id A unique string identifying this node * @param type The type of {@link WorkflowStep} to create for the corresponding {@link ProcessNode} - * @param inputs Optional input to populate params in {@link WorkflowData} + * @param previousNodeInputs Optional input to identify inputs coming from predecessor nodes + * @param userInputs Optional input to populate params in {@link WorkflowData} */ - public WorkflowNode(String id, String type, Map inputs) { + public WorkflowNode(String id, String type, Map previousNodeInputs, Map userInputs) { this.id = id; this.type = type; - this.inputs = Map.copyOf(inputs); + this.previousNodeInputs = Map.copyOf(previousNodeInputs); + this.userInputs = Map.copyOf(userInputs); } @Override @@ -71,8 +76,11 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws xContentBuilder.field(ID_FIELD, this.id); xContentBuilder.field(TYPE_FIELD, this.type); - xContentBuilder.startObject(INPUTS_FIELD); - for (Entry e : inputs.entrySet()) { + xContentBuilder.field(PREVIOUS_NODE_INPUTS_FIELD); + buildStringToStringMap(xContentBuilder, previousNodeInputs); + + xContentBuilder.startObject(USER_INPUTS_FIELD); + for (Entry e : userInputs.entrySet()) { xContentBuilder.field(e.getKey()); if (e.getValue() instanceof String) { xContentBuilder.value(e.getValue()); @@ -107,7 +115,8 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws public static WorkflowNode parse(XContentParser parser) throws IOException { String id = null; String type = null; - Map inputs = new HashMap<>(); + Map previousNodeInputs = new HashMap<>(); + Map userInputs = new HashMap<>(); ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); while (parser.nextToken() != XContentParser.Token.END_OBJECT) { @@ -120,16 +129,19 @@ public static WorkflowNode parse(XContentParser parser) throws IOException { case TYPE_FIELD: type = parser.text(); break; - case INPUTS_FIELD: + case PREVIOUS_NODE_INPUTS_FIELD: + previousNodeInputs = parseStringToStringMap(parser); + break; + case USER_INPUTS_FIELD: ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); while (parser.nextToken() != XContentParser.Token.END_OBJECT) { String inputFieldName = parser.currentName(); switch (parser.nextToken()) { case VALUE_STRING: - inputs.put(inputFieldName, parser.text()); + userInputs.put(inputFieldName, parser.text()); break; case START_OBJECT: - inputs.put(inputFieldName, parseStringToStringMap(parser)); + userInputs.put(inputFieldName, parseStringToStringMap(parser)); break; case START_ARRAY: if (PROCESSORS_FIELD.equals(inputFieldName)) { @@ -137,13 +149,13 @@ public static WorkflowNode parse(XContentParser parser) throws IOException { while (parser.nextToken() != XContentParser.Token.END_ARRAY) { processorList.add(PipelineProcessor.parse(parser)); } - inputs.put(inputFieldName, processorList.toArray(new PipelineProcessor[0])); + userInputs.put(inputFieldName, processorList.toArray(new PipelineProcessor[0])); } else { List> mapList = new ArrayList<>(); while (parser.nextToken() != XContentParser.Token.END_ARRAY) { mapList.add(parseStringToStringMap(parser)); } - inputs.put(inputFieldName, mapList.toArray(new Map[0])); + userInputs.put(inputFieldName, mapList.toArray(new Map[0])); } break; default: @@ -159,7 +171,7 @@ public static WorkflowNode parse(XContentParser parser) throws IOException { throw new IOException("An node object requires both an id and type field."); } - return new WorkflowNode(id, type, inputs); + return new WorkflowNode(id, type, previousNodeInputs, userInputs); } /** @@ -179,11 +191,19 @@ public String type() { } /** - * Return this node's input data + * Return this node's user input data + * @return the inputs + */ + public Map userInputs() { + return userInputs; + } + + /** + * Return this node's predecessor inputs * @return the inputs */ - public Map inputs() { - return inputs; + public Map previousNodeInputs() { + return previousNodeInputs; } @Override diff --git a/src/main/java/org/opensearch/flowframework/model/WorkflowStepValidator.java b/src/main/java/org/opensearch/flowframework/model/WorkflowStepValidator.java new file mode 100644 index 000000000..e49d7d68a --- /dev/null +++ b/src/main/java/org/opensearch/flowframework/model/WorkflowStepValidator.java @@ -0,0 +1,91 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.model; + +import org.opensearch.core.xcontent.XContentParser; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; + +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; + +/** + * This represents the an object of workflow steps json which maps each step to expected inputs and outputs + */ +public class WorkflowStepValidator { + + /** Inputs field name */ + private static final String INPUTS_FIELD = "inputs"; + /** Outputs field name */ + private static final String OUTPUTS_FIELD = "outputs"; + + private List inputs; + private List outputs; + + /** + * Intantiate the object representing a Workflow Step validator + * @param inputs the workflow step inputs + * @param outputs the workflow step outputs + */ + public WorkflowStepValidator(List inputs, List outputs) { + this.inputs = inputs; + this.outputs = outputs; + } + + /** + * Parse raw json content into a WorkflowStepValidator instance + * @param parser json based content parser + * @return an instance of the WorkflowStepValidator + * @throws IOException if the content cannot be parsed correctly + */ + public static WorkflowStepValidator parse(XContentParser parser) throws IOException { + List parsedInputs = new ArrayList<>(); + List parsedOutputs = new ArrayList<>(); + + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + String fieldName = parser.currentName(); + parser.nextToken(); + switch (fieldName) { + case INPUTS_FIELD: + ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_ARRAY) { + parsedInputs.add(parser.text()); + } + break; + case OUTPUTS_FIELD: + ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_ARRAY) { + parsedOutputs.add(parser.text()); + } + break; + default: + throw new IOException("Unable to parse field [" + fieldName + "] in a WorkflowStepValidator object."); + } + } + return new WorkflowStepValidator(parsedInputs, parsedOutputs); + } + + /** + * Get the required inputs + * @return the inputs + */ + public List getInputs() { + return List.copyOf(inputs); + } + + /** + * Get the required outputs + * @return the outputs + */ + public List getOutputs() { + return List.copyOf(outputs); + } +} diff --git a/src/main/java/org/opensearch/flowframework/model/WorkflowValidator.java b/src/main/java/org/opensearch/flowframework/model/WorkflowValidator.java new file mode 100644 index 000000000..506b73ab8 --- /dev/null +++ b/src/main/java/org/opensearch/flowframework/model/WorkflowValidator.java @@ -0,0 +1,78 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.model; + +import com.google.common.base.Charsets; +import com.google.common.io.Resources; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.flowframework.util.ParseUtils; + +import java.io.IOException; +import java.net.URL; +import java.util.HashMap; +import java.util.Map; + +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; + +/** + * This represents the workflow steps json which maps each step to expected inputs and outputs + */ +public class WorkflowValidator { + + private Map workflowStepValidators; + + /** + * Intantiate the object representing a Workflow validator + * @param workflowStepValidators a map of {@link WorkflowStepValidator} + */ + public WorkflowValidator(Map workflowStepValidators) { + this.workflowStepValidators = workflowStepValidators; + } + + /** + * Parse raw json content into a WorkflowValidator instance + * @param parser json based content parser + * @return an instance of the WorkflowValidator + * @throws IOException if the content cannot be parsed correctly + */ + public static WorkflowValidator parse(XContentParser parser) throws IOException { + + Map workflowStepValidators = new HashMap<>(); + + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + String type = parser.currentName(); + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + workflowStepValidators.put(type, WorkflowStepValidator.parse(parser)); + } + return new WorkflowValidator(workflowStepValidators); + } + + /** + * Parse a workflow step JSON file into a WorkflowValidator object + * + * @param file the file name of the workflow step json + * @return A {@link WorkflowValidator} represented by the JSON + * @throws IOException on failure to read and parse the json file + */ + public static WorkflowValidator parse(String file) throws IOException { + URL url = WorkflowValidator.class.getClassLoader().getResource(file); + String json = Resources.toString(url, Charsets.UTF_8); + return parse(ParseUtils.jsonToParser(json)); + } + + /** + * Get the map of WorkflowStepValidators + * @return the map of WorkflowStepValidators + */ + public Map getWorkflowStepValidators() { + return Map.copyOf(this.workflowStepValidators); + } + +} diff --git a/src/main/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportAction.java b/src/main/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportAction.java index f9a9e2dd9..22ac414e5 100644 --- a/src/main/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportAction.java +++ b/src/main/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportAction.java @@ -35,9 +35,7 @@ import java.util.ArrayList; import java.util.List; import java.util.Locale; -import java.util.concurrent.CancellationException; import java.util.concurrent.CompletableFuture; -import java.util.concurrent.CompletionException; import java.util.stream.Collectors; import static org.opensearch.flowframework.common.CommonValue.GLOBAL_CONTEXT_INDEX; @@ -109,6 +107,11 @@ protected void doExecute(Task task, WorkflowRequest request, ActionListener provisionProcessSequence = workflowProcessSorter.sortProcessNodes(provisionWorkflow); + workflowProcessSorter.validateGraph(provisionProcessSequence); + flowFrameworkIndicesHandler.updateFlowFrameworkSystemIndexDoc( WORKFLOW_STATE_INDEX, workflowId, @@ -127,10 +130,16 @@ protected void doExecute(Task task, WorkflowRequest request, ActionListener { - logger.error("Failed to retrieve template from global context.", exception); - listener.onFailure(new FlowFrameworkException(exception.getMessage(), RestStatus.INTERNAL_SERVER_ERROR)); + if (exception instanceof IllegalArgumentException) { + logger.error("Workflow validation failed for workflow : " + workflowId); + listener.onFailure(new FlowFrameworkException(exception.getMessage(), RestStatus.BAD_REQUEST)); + } else { + logger.error("Failed to retrieve template from global context.", exception); + listener.onFailure(new FlowFrameworkException(exception.getMessage(), RestStatus.INTERNAL_SERVER_ERROR)); + } })); } catch (Exception e) { logger.error("Failed to retrieve template from global context.", e); @@ -141,9 +150,9 @@ protected void doExecute(Task task, WorkflowRequest request, ActionListener workflowSequence) { // TODO : Update Action listener type to State index Request ActionListener provisionWorkflowListener = ActionListener.wrap(response -> { logger.info("Provisioning completed successuflly for workflow {}", workflowId); @@ -155,25 +164,22 @@ private void executeWorkflowAsync(String workflowId, Workflow workflow) { // TODO : Create State index request to update STATE entry status to FAILED }); try { - threadPool.executor(PROVISION_THREAD_POOL).execute(() -> { executeWorkflow(workflow, provisionWorkflowListener); }); + threadPool.executor(PROVISION_THREAD_POOL).execute(() -> { executeWorkflow(workflowSequence, provisionWorkflowListener); }); } catch (Exception exception) { provisionWorkflowListener.onFailure(new FlowFrameworkException(exception.getMessage(), RestStatus.INTERNAL_SERVER_ERROR)); } } /** - * Topologically sorts a given workflow into a sequence of ProcessNodes and executes the workflow - * @param workflow The workflow to execute + * Executes the given workflow sequence + * @param workflowSequence The topologically sorted workflow to execute * @param workflowListener The listener that updates the status of a workflow execution */ - private void executeWorkflow(Workflow workflow, ActionListener workflowListener) { + private void executeWorkflow(List workflowSequence, ActionListener workflowListener) { try { - // Attempt to topologically sort the workflow graph - List processSequence = workflowProcessSorter.sortProcessNodes(workflow); List> workflowFutureList = new ArrayList<>(); - - for (ProcessNode processNode : processSequence) { + for (ProcessNode processNode : workflowSequence) { List predecessors = processNode.predecessors(); logger.info( @@ -199,7 +205,7 @@ private void executeWorkflow(Workflow workflow, ActionListener workflowL } catch (IllegalArgumentException e) { workflowListener.onFailure(new FlowFrameworkException(e.getMessage(), RestStatus.BAD_REQUEST)); - } catch (CancellationException | CompletionException ex) { + } catch (Exception ex) { workflowListener.onFailure(new FlowFrameworkException(ex.getMessage(), RestStatus.INTERNAL_SERVER_ERROR)); } } diff --git a/src/main/java/org/opensearch/flowframework/workflow/CreateIndexStep.java b/src/main/java/org/opensearch/flowframework/workflow/CreateIndexStep.java index 6ee28c82e..5fe47b2b0 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/CreateIndexStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/CreateIndexStep.java @@ -25,6 +25,9 @@ import java.util.concurrent.CompletableFuture; import java.util.concurrent.atomic.AtomicBoolean; +import static org.opensearch.flowframework.common.CommonValue.INDEX_NAME; +import static org.opensearch.flowframework.common.CommonValue.TYPE; + /** * Step to create an index */ @@ -58,7 +61,7 @@ public CompletableFuture execute(List data) { @Override public void onResponse(CreateIndexResponse createIndexResponse) { logger.info("created index: {}", createIndexResponse.index()); - future.complete(new WorkflowData(Map.of("index-name", createIndexResponse.index()))); + future.complete(new WorkflowData(Map.of(INDEX_NAME, createIndexResponse.index()))); } @Override @@ -74,8 +77,8 @@ public void onFailure(Exception e) { for (WorkflowData workflowData : data) { Map content = workflowData.getContent(); - index = (String) content.get("index-name"); - type = (String) content.get("type"); + index = (String) content.get(INDEX_NAME); + type = (String) content.get(TYPE); if (index != null && type != null && settings != null) { break; } diff --git a/src/main/java/org/opensearch/flowframework/workflow/CreateIngestPipelineStep.java b/src/main/java/org/opensearch/flowframework/workflow/CreateIngestPipelineStep.java index 4770b94a9..b8cc83651 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/CreateIngestPipelineStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/CreateIngestPipelineStep.java @@ -26,6 +26,16 @@ import java.util.concurrent.CompletableFuture; import java.util.stream.Stream; +import static org.opensearch.flowframework.common.CommonValue.DESCRIPTION_FIELD; +import static org.opensearch.flowframework.common.CommonValue.FIELD_MAP; +import static org.opensearch.flowframework.common.CommonValue.ID; +import static org.opensearch.flowframework.common.CommonValue.INPUT_FIELD_NAME; +import static org.opensearch.flowframework.common.CommonValue.MODEL_ID; +import static org.opensearch.flowframework.common.CommonValue.OUTPUT_FIELD_NAME; +import static org.opensearch.flowframework.common.CommonValue.PIPELINE_ID; +import static org.opensearch.flowframework.common.CommonValue.PROCESSORS; +import static org.opensearch.flowframework.common.CommonValue.TYPE; + /** * Workflow step to create an ingest pipeline */ @@ -36,18 +46,6 @@ public class CreateIngestPipelineStep implements WorkflowStep { /** The name of this step, used as a key in the template and the {@link WorkflowStepFactory} */ static final String NAME = "create_ingest_pipeline"; - // Common pipeline configuration fields - private static final String PIPELINE_ID_FIELD = "id"; - private static final String DESCRIPTION_FIELD = "description"; - private static final String PROCESSORS_FIELD = "processors"; - private static final String TYPE_FIELD = "type"; - - // Temporary text embedding processor fields - private static final String FIELD_MAP = "field_map"; - private static final String MODEL_ID_FIELD = "model_id"; - private static final String INPUT_FIELD = "input_field_name"; - private static final String OUTPUT_FIELD = "output_field_name"; - // Client to store a pipeline in the cluster state private final ClusterAdminClient clusterAdminClient; @@ -80,23 +78,23 @@ public CompletableFuture execute(List data) { for (Entry entry : content.entrySet()) { switch (entry.getKey()) { - case PIPELINE_ID_FIELD: - pipelineId = (String) content.get(PIPELINE_ID_FIELD); + case ID: + pipelineId = (String) content.get(ID); break; case DESCRIPTION_FIELD: description = (String) content.get(DESCRIPTION_FIELD); break; - case TYPE_FIELD: - type = (String) content.get(TYPE_FIELD); + case TYPE: + type = (String) content.get(TYPE); break; - case MODEL_ID_FIELD: - modelId = (String) content.get(MODEL_ID_FIELD); + case MODEL_ID: + modelId = (String) content.get(MODEL_ID); break; - case INPUT_FIELD: - inputFieldName = (String) content.get(INPUT_FIELD); + case INPUT_FIELD_NAME: + inputFieldName = (String) content.get(INPUT_FIELD_NAME); break; - case OUTPUT_FIELD: - outputFieldName = (String) content.get(OUTPUT_FIELD); + case OUTPUT_FIELD_NAME: + outputFieldName = (String) content.get(OUTPUT_FIELD_NAME); break; default: break; @@ -127,7 +125,7 @@ public CompletableFuture execute(List data) { logger.info("Created ingest pipeline : " + putPipelineRequest.getId()); // PutPipelineRequest returns only an AcknowledgeResponse, returning pipelineId instead - createIngestPipelineFuture.complete(new WorkflowData(Map.of("pipelineId", putPipelineRequest.getId()))); + createIngestPipelineFuture.complete(new WorkflowData(Map.of(PIPELINE_ID, putPipelineRequest.getId()))); // TODO : Use node client to index response data to global context (pending global context index implementation) @@ -178,10 +176,10 @@ private XContentBuilder buildIngestPipelineRequestContent( return XContentFactory.jsonBuilder() .startObject() .field(DESCRIPTION_FIELD, description) - .startArray(PROCESSORS_FIELD) + .startArray(PROCESSORS) .startObject() .startObject(type) - .field(MODEL_ID_FIELD, modelId) + .field(MODEL_ID, modelId) .startObject(FIELD_MAP) .field(inputFieldName, outputFieldName) .endObject() diff --git a/src/main/java/org/opensearch/flowframework/workflow/ProcessNode.java b/src/main/java/org/opensearch/flowframework/workflow/ProcessNode.java index a99e97caa..6e3a7bc6d 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/ProcessNode.java +++ b/src/main/java/org/opensearch/flowframework/workflow/ProcessNode.java @@ -16,6 +16,7 @@ import java.util.ArrayList; import java.util.List; +import java.util.Map; import java.util.concurrent.CompletableFuture; import java.util.concurrent.TimeoutException; import java.util.stream.Collectors; @@ -30,6 +31,7 @@ public class ProcessNode { private final String id; private final WorkflowStep workflowStep; + private final Map previousNodeInputs; private final WorkflowData input; private final List predecessors; private final ThreadPool threadPool; @@ -42,6 +44,7 @@ public class ProcessNode { * * @param id A string identifying the workflow step * @param workflowStep A java class implementing {@link WorkflowStep} to be executed when it's this node's turn. + * @param previousNodeInputs A map of expected inputs coming from predecessor nodes used in graph validation * @param input Input required by the node encoded in a {@link WorkflowData} instance. * @param predecessors Nodes preceding this one in the workflow * @param threadPool The OpenSearch thread pool @@ -50,6 +53,7 @@ public class ProcessNode { public ProcessNode( String id, WorkflowStep workflowStep, + Map previousNodeInputs, WorkflowData input, List predecessors, ThreadPool threadPool, @@ -57,6 +61,7 @@ public ProcessNode( ) { this.id = id; this.workflowStep = workflowStep; + this.previousNodeInputs = previousNodeInputs; this.input = input; this.predecessors = predecessors; this.threadPool = threadPool; @@ -79,6 +84,14 @@ public WorkflowStep workflowStep() { return workflowStep; } + /** + * Returns the node's expected predecessor node input + * @return the expected predecessor node inputs + */ + public Map previousNodeInputs() { + return previousNodeInputs; + } + /** * Returns the input data for this node. * @return the input data diff --git a/src/main/java/org/opensearch/flowframework/workflow/WorkflowProcessSorter.java b/src/main/java/org/opensearch/flowframework/workflow/WorkflowProcessSorter.java index 71c44514e..745de5921 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/WorkflowProcessSorter.java +++ b/src/main/java/org/opensearch/flowframework/workflow/WorkflowProcessSorter.java @@ -14,10 +14,12 @@ import org.opensearch.flowframework.model.Workflow; import org.opensearch.flowframework.model.WorkflowEdge; import org.opensearch.flowframework.model.WorkflowNode; +import org.opensearch.flowframework.model.WorkflowValidator; import org.opensearch.threadpool.ThreadPool; import java.util.ArrayDeque; import java.util.ArrayList; +import java.util.Collection; import java.util.Collections; import java.util.HashMap; import java.util.HashSet; @@ -27,10 +29,11 @@ import java.util.Set; import java.util.function.Function; import java.util.stream.Collectors; +import java.util.stream.Stream; -import static org.opensearch.flowframework.model.WorkflowNode.INPUTS_FIELD; import static org.opensearch.flowframework.model.WorkflowNode.NODE_TIMEOUT_DEFAULT_VALUE; import static org.opensearch.flowframework.model.WorkflowNode.NODE_TIMEOUT_FIELD; +import static org.opensearch.flowframework.model.WorkflowNode.USER_INPUTS_FIELD; /** * Converts a workflow of nodes and edges into a topologically sorted list of Process Nodes. @@ -65,7 +68,7 @@ public List sortProcessNodes(Workflow workflow) { Map idToNodeMap = new HashMap<>(); for (WorkflowNode node : sortedNodes) { WorkflowStep step = workflowStepFactory.createStep(node.type()); - WorkflowData data = new WorkflowData(node.inputs(), workflow.userParams()); + WorkflowData data = new WorkflowData(node.userInputs(), workflow.userParams()); List predecessorNodes = workflow.edges() .stream() .filter(e -> e.destination().equals(node.id())) @@ -74,7 +77,15 @@ public List sortProcessNodes(Workflow workflow) { .collect(Collectors.toList()); TimeValue nodeTimeout = parseTimeout(node); - ProcessNode processNode = new ProcessNode(node.id(), step, data, predecessorNodes, threadPool, nodeTimeout); + ProcessNode processNode = new ProcessNode( + node.id(), + step, + node.previousNodeInputs(), + data, + predecessorNodes, + threadPool, + nodeTimeout + ); idToNodeMap.put(processNode.id(), processNode); nodes.add(processNode); } @@ -82,9 +93,53 @@ public List sortProcessNodes(Workflow workflow) { return nodes; } + /** + * Validates a sorted workflow, determines if each process node's user inputs and predecessor outputs match the expected workflow step inputs + * @param processNodes A list of process nodes + * @throws Exception on validation failure + */ + public void validateGraph(List processNodes) throws Exception { + + WorkflowValidator validator = WorkflowValidator.parse("mappings/workflow-steps.json"); + + // Iterate through process nodes in graph + for (ProcessNode processNode : processNodes) { + + // Get predecessor nodes types of this processNode + List predecessorNodeTypes = processNode.predecessors() + .stream() + .map(x -> x.workflowStep().getName()) + .collect(Collectors.toList()); + + // Compile a list of outputs from the predecessor nodes based on type + List predecessorOutputs = predecessorNodeTypes.stream() + .map(nodeType -> validator.getWorkflowStepValidators().get(nodeType).getOutputs()) + .flatMap(Collection::stream) + .collect(Collectors.toList()); + + // Retrieve all the user input data from this node + List currentNodeUserInputs = new ArrayList(processNode.input().getContent().keySet()); + + // Combine both predecessor outputs and current node user inputs + List allInputs = Stream.concat(predecessorOutputs.stream(), currentNodeUserInputs.stream()) + .collect(Collectors.toList()); + + // Retrieve list of required inputs from the current process node and compare + List expectedInputs = new ArrayList( + validator.getWorkflowStepValidators().get(processNode.workflowStep().getName()).getInputs() + ); + + if (!allInputs.containsAll(expectedInputs)) { + expectedInputs.removeAll(allInputs); + throw new IllegalArgumentException("Invalid graph, missing the following required inputs : " + expectedInputs.toString()); + } + } + + } + private TimeValue parseTimeout(WorkflowNode node) { - String timeoutValue = (String) node.inputs().getOrDefault(NODE_TIMEOUT_FIELD, NODE_TIMEOUT_DEFAULT_VALUE); - String fieldName = String.join(".", node.id(), INPUTS_FIELD, NODE_TIMEOUT_FIELD); + String timeoutValue = (String) node.userInputs().getOrDefault(NODE_TIMEOUT_FIELD, NODE_TIMEOUT_DEFAULT_VALUE); + String fieldName = String.join(".", node.id(), USER_INPUTS_FIELD, NODE_TIMEOUT_FIELD); TimeValue timeValue = TimeValue.parseTimeValue(timeoutValue, fieldName); if (timeValue.millis() < 0) { throw new IllegalArgumentException( diff --git a/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java b/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java index 719e4ba10..c30bdf87c 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java +++ b/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java @@ -57,4 +57,12 @@ public WorkflowStep createStep(String type) { } throw new FlowFrameworkException("Workflow step type [" + type + "] is not implemented.", RestStatus.NOT_IMPLEMENTED); } + + /** + * Gets the step map + * @return the step map + */ + public Map getStepMap() { + return Map.copyOf(this.stepMap); + } } diff --git a/src/main/resources/mappings/workflow-steps.json b/src/main/resources/mappings/workflow-steps.json new file mode 100644 index 000000000..23eb81c00 --- /dev/null +++ b/src/main/resources/mappings/workflow-steps.json @@ -0,0 +1,71 @@ +{ + "noop": { + "inputs":[], + "outputs":[] + }, + "create_index": { + "inputs":[ + "index_name", + "type" + ], + "outputs":[ + "index_name" + ] + }, + "create_ingest_pipeline": { + "inputs":[ + "id", + "description", + "type", + "model_id", + "input_field_name", + "output_field_name" + ], + "outputs":[ + "pipeline_id" + ] + }, + "create_connector": { + "inputs":[ + "name", + "description", + "version", + "protocol", + "parameters", + "credentials", + "actions" + ], + "outputs":[ + "connector_id" + ] + }, + "register_model": { + "inputs":[ + "function_name", + "name", + "description", + "connector_id" + ], + "outputs":[ + "model_id", + "register_model_status" + ] + }, + "deploy_model": { + "inputs":[ + "model_id" + ], + "outputs":[ + "deploy_model_status" + ] + }, + "model_group": { + "inputs":[ + "name" + ], + "outputs":[ + "model_group_id", + "model_group_status" + ] + } +} diff --git a/src/test/java/org/opensearch/flowframework/model/TemplateTestJsonUtil.java b/src/test/java/org/opensearch/flowframework/model/TemplateTestJsonUtil.java index 26e22af9a..ca5ee7a92 100644 --- a/src/test/java/org/opensearch/flowframework/model/TemplateTestJsonUtil.java +++ b/src/test/java/org/opensearch/flowframework/model/TemplateTestJsonUtil.java @@ -45,7 +45,7 @@ public static String nodeWithTypeAndTimeout(String id, String type, String timeo + "\": \"" + type + "\", \"" - + WorkflowNode.INPUTS_FIELD + + WorkflowNode.USER_INPUTS_FIELD + "\": {\"" + WorkflowNode.NODE_TIMEOUT_FIELD + "\": \"" diff --git a/src/test/java/org/opensearch/flowframework/model/TemplateTests.java b/src/test/java/org/opensearch/flowframework/model/TemplateTests.java index 89cffaac5..9587109c0 100644 --- a/src/test/java/org/opensearch/flowframework/model/TemplateTests.java +++ b/src/test/java/org/opensearch/flowframework/model/TemplateTests.java @@ -19,7 +19,7 @@ public class TemplateTests extends OpenSearchTestCase { private String expectedTemplate = "{\"name\":\"test\",\"description\":\"a test template\",\"use_case\":\"test use case\",\"version\":{\"template\":\"1.2.3\",\"compatibility\":[\"4.5.6\",\"7.8.9\"]}," - + "\"workflows\":{\"workflow\":{\"user_params\":{\"key\":\"value\"},\"nodes\":[{\"id\":\"A\",\"type\":\"a-type\",\"inputs\":{\"foo\":\"bar\"}},{\"id\":\"B\",\"type\":\"b-type\",\"inputs\":{\"baz\":\"qux\"}}],\"edges\":[{\"source\":\"A\",\"dest\":\"B\"}]}}}"; + + "\"workflows\":{\"workflow\":{\"user_params\":{\"key\":\"value\"},\"nodes\":[{\"id\":\"A\",\"type\":\"a-type\",\"user_inputs\":{\"foo\":\"bar\"}},{\"id\":\"B\",\"type\":\"b-type\",\"user_inputs\":{\"baz\":\"qux\"}}],\"edges\":[{\"source\":\"A\",\"dest\":\"B\"}]}}}"; @Override public void setUp() throws Exception { @@ -29,8 +29,8 @@ public void setUp() throws Exception { public void testTemplate() throws IOException { Version templateVersion = Version.fromString("1.2.3"); List compatibilityVersion = List.of(Version.fromString("4.5.6"), Version.fromString("7.8.9")); - WorkflowNode nodeA = new WorkflowNode("A", "a-type", Map.of("foo", "bar")); - WorkflowNode nodeB = new WorkflowNode("B", "b-type", Map.of("baz", "qux")); + WorkflowNode nodeA = new WorkflowNode("A", "a-type", Map.of(), Map.of("foo", "bar")); + WorkflowNode nodeB = new WorkflowNode("B", "b-type", Map.of(), Map.of("baz", "qux")); WorkflowEdge edgeAB = new WorkflowEdge("A", "B"); List nodes = List.of(nodeA, nodeB); List edges = List.of(edgeAB); diff --git a/src/test/java/org/opensearch/flowframework/model/WorkflowNodeTests.java b/src/test/java/org/opensearch/flowframework/model/WorkflowNodeTests.java index 46d897b42..700e1d0d2 100644 --- a/src/test/java/org/opensearch/flowframework/model/WorkflowNodeTests.java +++ b/src/test/java/org/opensearch/flowframework/model/WorkflowNodeTests.java @@ -24,6 +24,7 @@ public void testNode() throws IOException { WorkflowNode nodeA = new WorkflowNode( "A", "a-type", + Map.of("foo", "field"), Map.ofEntries( Map.entry("foo", "a string"), Map.entry("bar", Map.of("key", "value")), @@ -33,7 +34,8 @@ public void testNode() throws IOException { ); assertEquals("A", nodeA.id()); assertEquals("a-type", nodeA.type()); - Map map = nodeA.inputs(); + assertEquals(Map.of("foo", "field"), nodeA.previousNodeInputs()); + Map map = nodeA.userInputs(); assertEquals("a string", (String) map.get("foo")); assertEquals(Map.of("key", "value"), (Map) map.get("bar")); assertArrayEquals(new Map[] { Map.of("A", "a"), Map.of("B", "b") }, (Map[]) map.get("baz")); @@ -43,14 +45,16 @@ public void testNode() throws IOException { assertEquals(Map.of("key2", "value2"), pp[0].params()); // node equality is based only on ID - WorkflowNode nodeA2 = new WorkflowNode("A", "a2-type", Map.of("bar", "baz")); + WorkflowNode nodeA2 = new WorkflowNode("A", "a2-type", Map.of(), Map.of("bar", "baz")); assertEquals(nodeA, nodeA2); - WorkflowNode nodeB = new WorkflowNode("B", "b-type", Map.of("baz", "qux")); + WorkflowNode nodeB = new WorkflowNode("B", "b-type", Map.of("A", "foo"), Map.of("baz", "qux")); assertNotEquals(nodeA, nodeB); String json = TemplateTestJsonUtil.parseToJson(nodeA); - assertTrue(json.startsWith("{\"id\":\"A\",\"type\":\"a-type\",\"inputs\":")); + logger.info("TESTING : " + json); + assertTrue(json.startsWith("{\"id\":\"A\",\"type\":\"a-type\",\"previous_node_inputs\":{\"foo\":\"field\"},")); + assertTrue(json.contains("\"user_inputs\":{")); assertTrue(json.contains("\"foo\":\"a string\"")); assertTrue(json.contains("\"baz\":[{\"A\":\"a\"},{\"B\":\"b\"}]")); assertTrue(json.contains("\"bar\":{\"key\":\"value\"}")); @@ -59,7 +63,9 @@ public void testNode() throws IOException { WorkflowNode nodeX = WorkflowNode.parse(TemplateTestJsonUtil.jsonToParser(json)); assertEquals("A", nodeX.id()); assertEquals("a-type", nodeX.type()); - Map mapX = nodeX.inputs(); + Map previousNodeInputs = nodeX.previousNodeInputs(); + assertEquals("field", previousNodeInputs.get("foo")); + Map mapX = nodeX.userInputs(); assertEquals("a string", mapX.get("foo")); assertEquals(Map.of("key", "value"), mapX.get("bar")); assertArrayEquals(new Map[] { Map.of("A", "a"), Map.of("B", "b") }, (Map[]) map.get("baz")); @@ -70,11 +76,11 @@ public void testNode() throws IOException { } public void testExceptions() throws IOException { - String badJson = "{\"badField\":\"A\",\"type\":\"a-type\",\"inputs\":{\"foo\":\"bar\"}}"; + String badJson = "{\"badField\":\"A\",\"type\":\"a-type\",\"user_inputs\":{\"foo\":\"bar\"}}"; IOException e = assertThrows(IOException.class, () -> WorkflowNode.parse(TemplateTestJsonUtil.jsonToParser(badJson))); assertEquals("Unable to parse field [badField] in a node object.", e.getMessage()); - String missingJson = "{\"id\":\"A\",\"inputs\":{\"foo\":\"bar\"}}"; + String missingJson = "{\"id\":\"A\",\"user_inputs\":{\"foo\":\"bar\"}}"; e = assertThrows(IOException.class, () -> WorkflowNode.parse(TemplateTestJsonUtil.jsonToParser(missingJson))); assertEquals("An node object requires both an id and type field.", e.getMessage()); } diff --git a/src/test/java/org/opensearch/flowframework/model/WorkflowStepValidatorTests.java b/src/test/java/org/opensearch/flowframework/model/WorkflowStepValidatorTests.java new file mode 100644 index 000000000..646e8f8af --- /dev/null +++ b/src/test/java/org/opensearch/flowframework/model/WorkflowStepValidatorTests.java @@ -0,0 +1,46 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.model; + +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.test.OpenSearchTestCase; + +import java.io.IOException; + +public class WorkflowStepValidatorTests extends OpenSearchTestCase { + + private String validValidator; + private String invalidValidator; + + @Override + public void setUp() throws Exception { + super.setUp(); + validValidator = "{\"inputs\":[\"input_value\"],\"outputs\":[\"output_value\"]}"; + invalidValidator = "{\"inputs\":[\"input_value\"],\"invalid_field\":[\"output_value\"]}"; + } + + public void testParseWorkflowStepValidator() throws IOException { + XContentParser parser = TemplateTestJsonUtil.jsonToParser(validValidator); + WorkflowStepValidator workflowStepValidator = WorkflowStepValidator.parse(parser); + + assertEquals(1, workflowStepValidator.getInputs().size()); + assertEquals(1, workflowStepValidator.getOutputs().size()); + + assertEquals("input_value", workflowStepValidator.getInputs().get(0)); + assertEquals("output_value", workflowStepValidator.getOutputs().get(0)); + } + + public void testFailedParseWorkflowStepValidator() throws IOException { + XContentParser parser = TemplateTestJsonUtil.jsonToParser(invalidValidator); + IOException ex = expectThrows(IOException.class, () -> WorkflowStepValidator.parse(parser)); + assertEquals("Unable to parse field [invalid_field] in a WorkflowStepValidator object.", ex.getMessage()); + + } + +} diff --git a/src/test/java/org/opensearch/flowframework/model/WorkflowTests.java b/src/test/java/org/opensearch/flowframework/model/WorkflowTests.java index db070da4b..03b57aaac 100644 --- a/src/test/java/org/opensearch/flowframework/model/WorkflowTests.java +++ b/src/test/java/org/opensearch/flowframework/model/WorkflowTests.java @@ -23,8 +23,8 @@ public void setUp() throws Exception { } public void testWorkflow() throws IOException { - WorkflowNode nodeA = new WorkflowNode("A", "a-type", Map.of("foo", "bar")); - WorkflowNode nodeB = new WorkflowNode("B", "b-type", Map.of("baz", "qux")); + WorkflowNode nodeA = new WorkflowNode("A", "a-type", Map.of(), Map.of("foo", "bar")); + WorkflowNode nodeB = new WorkflowNode("B", "b-type", Map.of("A", "foo"), Map.of("baz", "qux")); WorkflowEdge edgeAB = new WorkflowEdge("A", "B"); List nodes = List.of(nodeA, nodeB); List edges = List.of(edgeAB); @@ -35,8 +35,8 @@ public void testWorkflow() throws IOException { assertEquals(List.of(edgeAB), workflow.edges()); String expectedJson = "{\"user_params\":{\"key\":\"value\"}," - + "\"nodes\":[{\"id\":\"A\",\"type\":\"a-type\",\"inputs\":{\"foo\":\"bar\"}}," - + "{\"id\":\"B\",\"type\":\"b-type\",\"inputs\":{\"baz\":\"qux\"}}]," + + "\"nodes\":[{\"id\":\"A\",\"type\":\"a-type\",\"previous_node_inputs\":{},\"user_inputs\":{\"foo\":\"bar\"}}," + + "{\"id\":\"B\",\"type\":\"b-type\",\"previous_node_inputs\":{\"A\":\"foo\"},\"user_inputs\":{\"baz\":\"qux\"}}]," + "\"edges\":[{\"source\":\"A\",\"dest\":\"B\"}]}"; String json = TemplateTestJsonUtil.parseToJson(workflow); assertEquals(expectedJson, json); diff --git a/src/test/java/org/opensearch/flowframework/model/WorkflowValidatorTests.java b/src/test/java/org/opensearch/flowframework/model/WorkflowValidatorTests.java new file mode 100644 index 000000000..6c474a11e --- /dev/null +++ b/src/test/java/org/opensearch/flowframework/model/WorkflowValidatorTests.java @@ -0,0 +1,87 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.model; + +import org.opensearch.client.AdminClient; +import org.opensearch.client.Client; +import org.opensearch.client.ClusterAdminClient; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.flowframework.workflow.WorkflowStepFactory; +import org.opensearch.ml.client.MachineLearningNodeClient; +import org.opensearch.test.OpenSearchTestCase; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; + +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class WorkflowValidatorTests extends OpenSearchTestCase { + + private String validWorkflowStepJson; + private String invalidWorkflowStepJson; + + @Override + public void setUp() throws Exception { + super.setUp(); + validWorkflowStepJson = + "{\"workflow_step_1\":{\"inputs\":[\"input_1\",\"input_2\"],\"outputs\":[\"output_1\"]},\"workflow_step_2\":{\"inputs\":[\"input_1\",\"input_2\",\"input_3\"],\"outputs\":[\"output_1\",\"output_2\",\"output_3\"]}}"; + invalidWorkflowStepJson = + "{\"workflow_step_1\":{\"bad_field\":[\"input_1\",\"input_2\"],\"outputs\":[\"output_1\"]},\"workflow_step_2\":{\"inputs\":[\"input_1\",\"input_2\",\"input_3\"],\"outputs\":[\"output_1\",\"output_2\",\"output_3\"]}}"; + } + + public void testParseWorkflowValidator() throws IOException { + + XContentParser parser = TemplateTestJsonUtil.jsonToParser(validWorkflowStepJson); + WorkflowValidator validator = WorkflowValidator.parse(parser); + + assertEquals(2, validator.getWorkflowStepValidators().size()); + assertTrue(validator.getWorkflowStepValidators().keySet().contains("workflow_step_1")); + assertEquals(2, validator.getWorkflowStepValidators().get("workflow_step_1").getInputs().size()); + assertEquals(1, validator.getWorkflowStepValidators().get("workflow_step_1").getOutputs().size()); + assertTrue(validator.getWorkflowStepValidators().keySet().contains("workflow_step_2")); + assertEquals(3, validator.getWorkflowStepValidators().get("workflow_step_2").getInputs().size()); + assertEquals(3, validator.getWorkflowStepValidators().get("workflow_step_2").getOutputs().size()); + } + + public void testFailedParseWorkflowValidator() throws IOException { + XContentParser parser = TemplateTestJsonUtil.jsonToParser(invalidWorkflowStepJson); + IOException ex = expectThrows(IOException.class, () -> WorkflowValidator.parse(parser)); + assertEquals("Unable to parse field [bad_field] in a WorkflowStepValidator object.", ex.getMessage()); + } + + public void testWorkflowStepFactoryHasValidators() throws IOException { + + ClusterService clusterService = mock(ClusterService.class); + ClusterAdminClient clusterAdminClient = mock(ClusterAdminClient.class); + AdminClient adminClient = mock(AdminClient.class); + Client client = mock(Client.class); + when(client.admin()).thenReturn(adminClient); + when(adminClient.cluster()).thenReturn(clusterAdminClient); + MachineLearningNodeClient mlClient = mock(MachineLearningNodeClient.class); + + WorkflowStepFactory workflowStepFactory = new WorkflowStepFactory(clusterService, client, mlClient); + + // Read in workflow-steps.json + WorkflowValidator workflowValidator = WorkflowValidator.parse("mappings/workflow-steps.json"); + + // Get all workflow step validator types + List registeredWorkflowValidatorTypes = new ArrayList(workflowValidator.getWorkflowStepValidators().keySet()); + + // Get all registered workflow step types in the workflow step factory + List registeredWorkflowStepTypes = new ArrayList(workflowStepFactory.getStepMap().keySet()); + + // Check if each registered step has a corresponding validator definition + assertTrue(registeredWorkflowStepTypes.containsAll(registeredWorkflowValidatorTypes)); + assertTrue(registeredWorkflowValidatorTypes.containsAll(registeredWorkflowStepTypes)); + } + +} diff --git a/src/test/java/org/opensearch/flowframework/rest/RestCreateWorkflowActionTests.java b/src/test/java/org/opensearch/flowframework/rest/RestCreateWorkflowActionTests.java index ba4f0093c..d897c6756 100644 --- a/src/test/java/org/opensearch/flowframework/rest/RestCreateWorkflowActionTests.java +++ b/src/test/java/org/opensearch/flowframework/rest/RestCreateWorkflowActionTests.java @@ -44,8 +44,8 @@ public void setUp() throws Exception { Version templateVersion = Version.fromString("1.0.0"); List compatibilityVersions = List.of(Version.fromString("2.0.0"), Version.fromString("3.0.0")); - WorkflowNode nodeA = new WorkflowNode("A", "a-type", Map.of("foo", "bar")); - WorkflowNode nodeB = new WorkflowNode("B", "b-type", Map.of("baz", "qux")); + WorkflowNode nodeA = new WorkflowNode("A", "a-type", Map.of(), Map.of("foo", "bar")); + WorkflowNode nodeB = new WorkflowNode("B", "b-type", Map.of(), Map.of("baz", "qux")); WorkflowEdge edgeAB = new WorkflowEdge("A", "B"); List nodes = List.of(nodeA, nodeB); List edges = List.of(edgeAB); diff --git a/src/test/java/org/opensearch/flowframework/transport/CreateWorkflowTransportActionTests.java b/src/test/java/org/opensearch/flowframework/transport/CreateWorkflowTransportActionTests.java index 9720453f4..b6f7bea2d 100644 --- a/src/test/java/org/opensearch/flowframework/transport/CreateWorkflowTransportActionTests.java +++ b/src/test/java/org/opensearch/flowframework/transport/CreateWorkflowTransportActionTests.java @@ -70,8 +70,8 @@ public void setUp() throws Exception { Version templateVersion = Version.fromString("1.0.0"); List compatibilityVersions = List.of(Version.fromString("2.0.0"), Version.fromString("3.0.0")); - WorkflowNode nodeA = new WorkflowNode("A", "a-type", Map.of("foo", "bar")); - WorkflowNode nodeB = new WorkflowNode("B", "b-type", Map.of("baz", "qux")); + WorkflowNode nodeA = new WorkflowNode("A", "a-type", Map.of(), Map.of("foo", "bar")); + WorkflowNode nodeB = new WorkflowNode("B", "b-type", Map.of(), Map.of("baz", "qux")); WorkflowEdge edgeAB = new WorkflowEdge("A", "B"); List nodes = List.of(nodeA, nodeB); List edges = List.of(edgeAB); diff --git a/src/test/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportActionTests.java b/src/test/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportActionTests.java index d48932a57..d3f6fb6fd 100644 --- a/src/test/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportActionTests.java +++ b/src/test/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportActionTests.java @@ -73,8 +73,8 @@ public void setUp() throws Exception { Version templateVersion = Version.fromString("1.0.0"); List compatibilityVersions = List.of(Version.fromString("2.0.0"), Version.fromString("3.0.0")); - WorkflowNode nodeA = new WorkflowNode("A", "a-type", Map.of("foo", "bar")); - WorkflowNode nodeB = new WorkflowNode("B", "b-type", Map.of("baz", "qux")); + WorkflowNode nodeA = new WorkflowNode("A", "a-type", Map.of(), Map.of("foo", "bar")); + WorkflowNode nodeB = new WorkflowNode("B", "b-type", Map.of(), Map.of("baz", "qux")); WorkflowEdge edgeAB = new WorkflowEdge("A", "B"); List nodes = List.of(nodeA, nodeB); List edges = List.of(edgeAB); diff --git a/src/test/java/org/opensearch/flowframework/transport/WorkflowRequestResponseTests.java b/src/test/java/org/opensearch/flowframework/transport/WorkflowRequestResponseTests.java index 7f5a3918a..d64249e27 100644 --- a/src/test/java/org/opensearch/flowframework/transport/WorkflowRequestResponseTests.java +++ b/src/test/java/org/opensearch/flowframework/transport/WorkflowRequestResponseTests.java @@ -36,8 +36,8 @@ public void setUp() throws Exception { super.setUp(); Version templateVersion = Version.fromString("1.0.0"); List compatibilityVersions = List.of(Version.fromString("2.0.0"), Version.fromString("3.0.0")); - WorkflowNode nodeA = new WorkflowNode("A", "a-type", Map.of("foo", "bar")); - WorkflowNode nodeB = new WorkflowNode("B", "b-type", Map.of("baz", "qux")); + WorkflowNode nodeA = new WorkflowNode("A", "a-type", Map.of(), Map.of("foo", "bar")); + WorkflowNode nodeB = new WorkflowNode("B", "b-type", Map.of(), Map.of("baz", "qux")); WorkflowEdge edgeAB = new WorkflowEdge("A", "B"); List nodes = List.of(nodeA, nodeB); List edges = List.of(edgeAB); diff --git a/src/test/java/org/opensearch/flowframework/workflow/CreateIndexStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/CreateIndexStepTests.java index ab5dd476a..7a4db70a6 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/CreateIndexStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/CreateIndexStepTests.java @@ -69,7 +69,7 @@ public class CreateIndexStepTests extends OpenSearchTestCase { public void setUp() throws Exception { super.setUp(); MockitoAnnotations.openMocks(this); - inputData = new WorkflowData(Map.ofEntries(Map.entry("index-name", "demo"), Map.entry("type", "knn"))); + inputData = new WorkflowData(Map.ofEntries(Map.entry("index_name", "demo"), Map.entry("type", "knn"))); clusterService = mock(ClusterService.class); client = mock(Client.class); adminClient = mock(AdminClient.class); @@ -98,7 +98,7 @@ public void testCreateIndexStep() throws ExecutionException, InterruptedExceptio assertTrue(future.isDone() && !future.isCompletedExceptionally()); - Map outputData = Map.of("index-name", "demo"); + Map outputData = Map.of("index_name", "demo"); assertEquals(outputData, future.get().getContent()); } diff --git a/src/test/java/org/opensearch/flowframework/workflow/CreateIngestPipelineStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/CreateIngestPipelineStepTests.java index 9dab2a8d7..039b0384f 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/CreateIngestPipelineStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/CreateIngestPipelineStepTests.java @@ -54,7 +54,7 @@ public void setUp() throws Exception { ); // Set output data to returned pipelineId - outpuData = new WorkflowData(Map.ofEntries(Map.entry("pipelineId", "pipelineId"))); + outpuData = new WorkflowData(Map.ofEntries(Map.entry("pipeline_id", "pipelineId"))); client = mock(Client.class); adminClient = mock(AdminClient.class); @@ -109,7 +109,7 @@ public void testMissingData() throws InterruptedException { // Data with missing input and output fields WorkflowData incorrectData = new WorkflowData( Map.ofEntries( - Map.entry("id", "pipelineId"), + Map.entry("id", "pipeline_id"), Map.entry("description", "some description"), Map.entry("type", "text_embedding"), Map.entry("model_id", "model_id") diff --git a/src/test/java/org/opensearch/flowframework/workflow/ProcessNodeTests.java b/src/test/java/org/opensearch/flowframework/workflow/ProcessNodeTests.java index 1e421c58c..0cac95b49 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/ProcessNodeTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/ProcessNodeTests.java @@ -67,6 +67,7 @@ public String getName() { return "test"; } }, + Map.of(), new WorkflowData(Map.of("test", "input"), Map.of("foo", "bar")), List.of(successfulNode), testThreadPool, @@ -103,7 +104,7 @@ public CompletableFuture execute(List data) { public String getName() { return "test"; } - }, WorkflowData.EMPTY, Collections.emptyList(), testThreadPool, TimeValue.timeValueMillis(250)); + }, Map.of(), WorkflowData.EMPTY, Collections.emptyList(), testThreadPool, TimeValue.timeValueMillis(250)); assertEquals("B", nodeB.id()); assertEquals("test", nodeB.workflowStep().getName()); assertEquals(WorkflowData.EMPTY, nodeB.input()); @@ -129,7 +130,7 @@ public CompletableFuture execute(List data) { public String getName() { return "sleepy"; } - }, WorkflowData.EMPTY, Collections.emptyList(), testThreadPool, TimeValue.timeValueMillis(100)); + }, Map.of(), WorkflowData.EMPTY, Collections.emptyList(), testThreadPool, TimeValue.timeValueMillis(100)); assertEquals("Zzz", nodeZ.id()); assertEquals("sleepy", nodeZ.workflowStep().getName()); assertEquals(WorkflowData.EMPTY, nodeZ.input()); @@ -156,7 +157,7 @@ public CompletableFuture execute(List data) { public String getName() { return "test"; } - }, WorkflowData.EMPTY, List.of(successfulNode, failedNode), testThreadPool, TimeValue.timeValueSeconds(15)); + }, Map.of(), WorkflowData.EMPTY, List.of(successfulNode, failedNode), testThreadPool, TimeValue.timeValueSeconds(15)); assertEquals("E", nodeE.id()); assertEquals("test", nodeE.workflowStep().getName()); assertEquals(WorkflowData.EMPTY, nodeE.input()); diff --git a/src/test/java/org/opensearch/flowframework/workflow/WorkflowProcessSorterTests.java b/src/test/java/org/opensearch/flowframework/workflow/WorkflowProcessSorterTests.java index 423db8cf0..9f629ff9e 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/WorkflowProcessSorterTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/WorkflowProcessSorterTests.java @@ -16,6 +16,8 @@ import org.opensearch.flowframework.exception.FlowFrameworkException; import org.opensearch.flowframework.model.TemplateTestJsonUtil; import org.opensearch.flowframework.model.Workflow; +import org.opensearch.flowframework.model.WorkflowEdge; +import org.opensearch.flowframework.model.WorkflowNode; import org.opensearch.ml.client.MachineLearningNodeClient; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.TestThreadPool; @@ -26,6 +28,7 @@ import java.io.IOException; import java.util.Collections; import java.util.List; +import java.util.Map; import java.util.concurrent.TimeUnit; import java.util.stream.Collectors; @@ -208,4 +211,68 @@ public void testExceptions() throws IOException { assertEquals("Workflow step type [unimplemented_step] is not implemented.", ex.getMessage()); assertEquals(RestStatus.NOT_IMPLEMENTED, ((FlowFrameworkException) ex).getRestStatus()); } + + public void testSuccessfulGraphValidation() throws Exception { + WorkflowNode createConnector = new WorkflowNode( + "workflow_step_1", + CreateConnectorStep.NAME, + Map.of(), + Map.ofEntries( + Map.entry("name", ""), + Map.entry("description", ""), + Map.entry("version", ""), + Map.entry("protocol", ""), + Map.entry("parameters", ""), + Map.entry("credentials", ""), + Map.entry("actions", "") + ) + ); + WorkflowNode registerModel = new WorkflowNode( + "workflow_step_2", + RegisterModelStep.NAME, + Map.ofEntries(Map.entry("workflow_step_1", "connector_id")), + Map.ofEntries(Map.entry("name", "name"), Map.entry("function_name", "remote"), Map.entry("description", "description")) + ); + WorkflowNode deployModel = new WorkflowNode( + "workflow_step_3", + DeployModelStep.NAME, + Map.ofEntries(Map.entry("workflow_step_2", "model_id")), + Map.of() + ); + + WorkflowEdge edge1 = new WorkflowEdge(createConnector.id(), registerModel.id()); + WorkflowEdge edge2 = new WorkflowEdge(registerModel.id(), deployModel.id()); + + Workflow workflow = new Workflow(Map.of(), List.of(createConnector, registerModel, deployModel), List.of(edge1, edge2)); + + List sortedProcessNodes = workflowProcessSorter.sortProcessNodes(workflow); + workflowProcessSorter.validateGraph(sortedProcessNodes); + } + + public void testFailedGraphValidation() { + + // Create Register Model workflow node with missing connector_id field + WorkflowNode registerModel = new WorkflowNode( + "workflow_step_1", + RegisterModelStep.NAME, + Map.of(), + Map.ofEntries(Map.entry("name", "name"), Map.entry("function_name", "remote"), Map.entry("description", "description")) + ); + WorkflowNode deployModel = new WorkflowNode( + "workflow_step_2", + DeployModelStep.NAME, + Map.ofEntries(Map.entry("workflow_step_1", "model_id")), + Map.of() + ); + WorkflowEdge edge = new WorkflowEdge(registerModel.id(), deployModel.id()); + Workflow workflow = new Workflow(Map.of(), List.of(registerModel, deployModel), List.of(edge)); + + List sortedProcessNodes = workflowProcessSorter.sortProcessNodes(workflow); + IllegalArgumentException ex = expectThrows( + IllegalArgumentException.class, + () -> workflowProcessSorter.validateGraph(sortedProcessNodes) + ); + assertEquals("Invalid graph, missing the following required inputs : [connector_id]", ex.getMessage()); + + } }