From 78e54308fbaf00af3190695e9edea3f2b3d39ca3 Mon Sep 17 00:00:00 2001 From: Amit Galitzky Date: Tue, 12 Mar 2024 11:40:43 -0700 Subject: [PATCH] Adding create ingest pipeline step (#558) * adding create ingest pipeline step Signed-off-by: Amit Galitzky * adding IT and move configurations parsing to input parsing Signed-off-by: Amit Galitzky * cleaning up comments Signed-off-by: Amit Galitzky --------- Signed-off-by: Amit Galitzky --- CHANGELOG.md | 2 + build.gradle | 2 + .../flowframework/FlowFrameworkPlugin.java | 3 +- .../flowframework/common/CommonValue.java | 4 + .../flowframework/model/WorkflowNode.java | 20 +- .../flowframework/util/ParseUtils.java | 41 +++- .../workflow/CreateConnectorStep.java | 2 + .../workflow/CreateIngestPipelineStep.java | 181 ++++-------------- .../workflow/WorkflowStepFactory.java | 19 +- .../FlowFrameworkRestTestCase.java | 24 +++ .../model/WorkflowValidatorTests.java | 7 +- .../rest/FlowFrameworkRestApiIT.java | 47 +++++ .../rest/RestGetWorkflowStepActionTests.java | 12 +- .../CreateIngestPipelineStepTests.java | 21 +- .../workflow/WorkflowProcessSorterTests.java | 8 +- .../template/ingest-pipeline-template.json | 88 +++++++++ 16 files changed, 313 insertions(+), 168 deletions(-) create mode 100644 src/test/resources/template/ingest-pipeline-template.json diff --git a/CHANGELOG.md b/CHANGELOG.md index c0fa14f60..66a7e67b7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,8 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.1.0/) ## [Unreleased 2.x](https://github.com/opensearch-project/flow-framework/compare/2.12...2.x) ### Features +- Adding create ingest pipeline step ([#558](https://github.com/opensearch-project/flow-framework/pull/558)) + ### Enhancements - Substitute REST path or body parameters in Workflow Steps ([#525](https://github.com/opensearch-project/flow-framework/pull/525)) - Added an optional workflow_step param to the get workflow steps API ([#538](https://github.com/opensearch-project/flow-framework/pull/538)) diff --git a/build.gradle b/build.gradle index 74b450d5f..a04e5bc1b 100644 --- a/build.gradle +++ b/build.gradle @@ -174,6 +174,8 @@ dependencies { implementation "org.opensearch:common-utils:${common_utils_version}" implementation 'com.amazonaws:aws-encryption-sdk-java:2.4.1' implementation 'org.bouncycastle:bcprov-jdk18on:1.77' + implementation("com.fasterxml.jackson.core:jackson-annotations:${versions.jackson}") + implementation("com.fasterxml.jackson.core:jackson-databind:${versions.jackson_databind}") // ZipArchive dependencies used for integration tests zipArchive group: 'org.opensearch.plugin', name:'opensearch-ml-plugin', version: "${opensearch_build}" diff --git a/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java b/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java index 49f3bcce9..4c8486b7e 100644 --- a/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java +++ b/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java @@ -123,7 +123,8 @@ public Collection createComponents( threadPool, mlClient, flowFrameworkIndicesHandler, - flowFrameworkSettings + flowFrameworkSettings, + client ); WorkflowProcessSorter workflowProcessSorter = new WorkflowProcessSorter( workflowStepFactory, diff --git a/src/main/java/org/opensearch/flowframework/common/CommonValue.java b/src/main/java/org/opensearch/flowframework/common/CommonValue.java index 5acc34b36..bde91b55d 100644 --- a/src/main/java/org/opensearch/flowframework/common/CommonValue.java +++ b/src/main/java/org/opensearch/flowframework/common/CommonValue.java @@ -162,6 +162,10 @@ private CommonValue() {} public static final String APP_TYPE_FIELD = "app_type"; /** To include field for an agent response */ public static final String INCLUDE_OUTPUT_IN_AGENT_RESPONSE = "include_output_in_agent_response"; + /** Pipeline ID, also corresponds to pipeline name */ + public static final String PIPELINE_ID = "pipeline_id"; + /** Pipeline Configurations */ + public static final String CONFIGURATIONS = "configurations"; /* * Constants associated with resource provisioning / state diff --git a/src/main/java/org/opensearch/flowframework/model/WorkflowNode.java b/src/main/java/org/opensearch/flowframework/model/WorkflowNode.java index 55122f9e5..15d52ccd1 100644 --- a/src/main/java/org/opensearch/flowframework/model/WorkflowNode.java +++ b/src/main/java/org/opensearch/flowframework/model/WorkflowNode.java @@ -8,12 +8,15 @@ */ package org.opensearch.flowframework.model; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; import org.opensearch.common.unit.TimeValue; import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.flowframework.exception.FlowFrameworkException; +import org.opensearch.flowframework.util.ParseUtils; import org.opensearch.flowframework.workflow.ProcessNode; import org.opensearch.flowframework.workflow.WorkflowData; import org.opensearch.flowframework.workflow.WorkflowStep; @@ -28,6 +31,7 @@ import static java.util.concurrent.TimeUnit.SECONDS; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.flowframework.common.CommonValue.CONFIGURATIONS; import static org.opensearch.flowframework.common.CommonValue.TOOLS_ORDER_FIELD; import static org.opensearch.flowframework.util.ParseUtils.buildStringToObjectMap; import static org.opensearch.flowframework.util.ParseUtils.buildStringToStringMap; @@ -60,6 +64,7 @@ public class WorkflowNode implements ToXContentObject { private final String type; // maps to a WorkflowStep private final Map previousNodeInputs; private final Map userInputs; // maps to WorkflowData + private static final Logger logger = LogManager.getLogger(WorkflowNode.class); /** * Create this node with the id and type, and any user input. @@ -151,7 +156,20 @@ public static WorkflowNode parse(XContentParser parser) throws IOException { userInputs.put(inputFieldName, parser.text()); break; case START_OBJECT: - userInputs.put(inputFieldName, parseStringToStringMap(parser)); + if (CONFIGURATIONS.equals(inputFieldName)) { + Map configurationsMap = parser.map(); + try { + String configurationsString = ParseUtils.parseArbitraryStringToObjectMapToString(configurationsMap); + userInputs.put(inputFieldName, configurationsString); + } catch (Exception ex) { + String errorMessage = "Failed to parse configuration map"; + logger.error(errorMessage, ex); + throw new FlowFrameworkException(errorMessage, RestStatus.BAD_REQUEST); + } + break; + } else { + userInputs.put(inputFieldName, parseStringToStringMap(parser)); + } break; case START_ARRAY: if (PROCESSORS_FIELD.equals(inputFieldName)) { diff --git a/src/main/java/org/opensearch/flowframework/util/ParseUtils.java b/src/main/java/org/opensearch/flowframework/util/ParseUtils.java index 6192d8e6d..140f0a4af 100644 --- a/src/main/java/org/opensearch/flowframework/util/ParseUtils.java +++ b/src/main/java/org/opensearch/flowframework/util/ParseUtils.java @@ -8,6 +8,9 @@ */ package org.opensearch.flowframework.util; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; + import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.opensearch.client.Client; @@ -53,7 +56,8 @@ public class ParseUtils { private static final Logger logger = LogManager.getLogger(ParseUtils.class); // Matches ${{ foo.bar }} (whitespace optional) with capturing groups 1=foo, 2=bar - private static final Pattern SUBSTITUTION_PATTERN = Pattern.compile("\\$\\{\\{\\s*(.+)\\.(.+?)\\s*\\}\\}"); + // private static final Pattern SUBSTITUTION_PATTERN = Pattern.compile("\\$\\{\\{\\s*(.+)\\.(.+?)\\s*\\}\\}"); + private static final Pattern SUBSTITUTION_PATTERN = Pattern.compile("\\$\\{\\{\\s*([\\w_]+)\\.([\\w_]+)\\s*\\}\\}"); private ParseUtils() {} @@ -341,13 +345,25 @@ public static Map getInputsFromPreviousSteps( private static Object conditionallySubstitute(Object value, Map outputs, Map params) { if (value instanceof String) { Matcher m = SUBSTITUTION_PATTERN.matcher((String) value); - if (m.matches()) { - // Try matching a previous step+value pair - WorkflowData data = outputs.get(m.group(1)); - if (data != null && data.getContent().containsKey(m.group(2))) { - return data.getContent().get(m.group(2)); + StringBuilder result = new StringBuilder(); + while (m.find()) { + // outputs content map contains values for previous node input (e.g: deploy_openai_model.model_id) + // Check first if the substitution is looking for the same key, value pair and if yes + // then replace it with the key value pair in the inputs map + String replacement = m.group(0); + if (outputs.containsKey(m.group(1)) && outputs.get(m.group(1)).getContent().containsKey(m.group(2))) { + // Extract the key for the inputs (e.g., "model_id" from ${{deploy_openai_model.model_id}}) + String key = m.group(2); + if (outputs.get(m.group(1)).getContent().get(key) instanceof String) { + replacement = (String) outputs.get(m.group(1)).getContent().get(key); + // Replace the whole sequence with the value from the map + m.appendReplacement(result, Matcher.quoteReplacement(replacement)); + } } } + m.appendTail(result); + value = result.toString(); + // Replace all params if present for (Entry e : params.entrySet()) { String regex = "\\$\\{\\{\\s*" + Pattern.quote(e.getKey()) + "\\s*\\}\\}"; @@ -356,4 +372,17 @@ private static Object conditionallySubstitute(Object value, Map map) throws JsonProcessingException { + ObjectMapper mapper = new ObjectMapper(); + // Convert the map to a JSON string + String mappedString = mapper.writeValueAsString(map); + return mappedString; + } } diff --git a/src/main/java/org/opensearch/flowframework/workflow/CreateConnectorStep.java b/src/main/java/org/opensearch/flowframework/workflow/CreateConnectorStep.java index a6b0b6a40..403e26063 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/CreateConnectorStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/CreateConnectorStep.java @@ -161,8 +161,10 @@ public void onFailure(Exception e) { credentials = getStringToStringMap(inputs.get(CREDENTIAL_FIELD), CREDENTIAL_FIELD); actions = getConnectorActionList(inputs.get(ACTIONS_FIELD)); } catch (IllegalArgumentException iae) { + logger.error("IllegalArgumentException in connector configuration", iae); throw new FlowFrameworkException("IllegalArgumentException in connector configuration", RestStatus.BAD_REQUEST); } catch (PrivilegedActionException pae) { + logger.error("PrivilegedActionException in connector configuration", pae); throw new FlowFrameworkException("PrivilegedActionException in connector configuration", RestStatus.UNAUTHORIZED); } diff --git a/src/main/java/org/opensearch/flowframework/workflow/CreateIngestPipelineStep.java b/src/main/java/org/opensearch/flowframework/workflow/CreateIngestPipelineStep.java index 1ddfa65f6..9d840573c 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/CreateIngestPipelineStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/CreateIngestPipelineStep.java @@ -15,37 +15,27 @@ import org.opensearch.action.support.PlainActionFuture; import org.opensearch.client.Client; import org.opensearch.client.ClusterAdminClient; -import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.common.xcontent.XContentType; import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.bytes.BytesArray; import org.opensearch.core.common.bytes.BytesReference; -import org.opensearch.core.rest.RestStatus; -import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.flowframework.exception.FlowFrameworkException; import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; +import org.opensearch.flowframework.util.ParseUtils; -import java.io.IOException; -import java.util.ArrayList; -import java.util.List; +import java.nio.charset.StandardCharsets; import java.util.Map; -import java.util.Map.Entry; -import java.util.stream.Stream; - -import static org.opensearch.flowframework.common.CommonValue.DESCRIPTION_FIELD; -import static org.opensearch.flowframework.common.CommonValue.FIELD_MAP; -import static org.opensearch.flowframework.common.CommonValue.ID; -import static org.opensearch.flowframework.common.CommonValue.INPUT_FIELD_NAME; -import static org.opensearch.flowframework.common.CommonValue.OUTPUT_FIELD_NAME; -import static org.opensearch.flowframework.common.CommonValue.PROCESSORS; -import static org.opensearch.flowframework.common.CommonValue.TYPE; +import java.util.Set; + +import static org.opensearch.flowframework.common.CommonValue.CONFIGURATIONS; import static org.opensearch.flowframework.common.WorkflowResources.MODEL_ID; +import static org.opensearch.flowframework.common.WorkflowResources.PIPELINE_ID; import static org.opensearch.flowframework.common.WorkflowResources.getResourceByWorkflowStep; /** - * Workflow step to create an ingest pipeline + * Step to create an ingest pipeline */ public class CreateIngestPipelineStep implements WorkflowStep { - private static final Logger logger = LogManager.getLogger(CreateIngestPipelineStep.class); /** The name of this step, used as a key in the template and the {@link WorkflowStepFactory} */ @@ -77,93 +67,44 @@ public PlainActionFuture execute( PlainActionFuture createIngestPipelineFuture = PlainActionFuture.newFuture(); - String pipelineId = null; - String description = null; - String type = null; - String modelId = null; - String inputFieldName = null; - String outputFieldName = null; - BytesReference configuration = null; - - // TODO: Recreating the list to get this compiling - // Need to refactor the below iteration to pull directly from the maps - List data = new ArrayList<>(); - data.add(currentNodeInputs); - data.addAll(outputs.values()); - - // Extract required content from workflow data and generate the ingest pipeline configuration - for (WorkflowData workflowData : data) { - - Map content = workflowData.getContent(); - - for (Entry entry : content.entrySet()) { - switch (entry.getKey()) { - case ID: - pipelineId = (String) content.get(ID); - break; - case DESCRIPTION_FIELD: - description = (String) content.get(DESCRIPTION_FIELD); - break; - case TYPE: - type = (String) content.get(TYPE); - break; - case MODEL_ID: - modelId = (String) content.get(MODEL_ID); - break; - case INPUT_FIELD_NAME: - inputFieldName = (String) content.get(INPUT_FIELD_NAME); - break; - case OUTPUT_FIELD_NAME: - outputFieldName = (String) content.get(OUTPUT_FIELD_NAME); - break; - default: - break; - } - } + Set requiredKeys = Set.of(PIPELINE_ID, CONFIGURATIONS); - // Determine if fields have been populated, else iterate over remaining workflow data - if (Stream.of(pipelineId, description, modelId, type, inputFieldName, outputFieldName).allMatch(x -> x != null)) { - try { - configuration = BytesReference.bytes( - buildIngestPipelineRequestContent(description, modelId, type, inputFieldName, outputFieldName) - ); - } catch (IOException e) { - String errorMessage = "Failed to create ingest pipeline configuration for " + currentNodeId; - logger.error(errorMessage, e); - createIngestPipelineFuture.onFailure(new FlowFrameworkException(errorMessage, ExceptionsHelper.status(e))); - } - break; - } - } + // currently, we are supporting an optional param of model ID into the various processors + Set optionalKeys = Set.of(MODEL_ID); - if (configuration == null) { - // Required workflow data not found - createIngestPipelineFuture.onFailure( - new FlowFrameworkException( - "Failed to create ingest pipeline for " + currentNodeId + ", required inputs not found", - RestStatus.BAD_REQUEST - ) + try { + Map inputs = ParseUtils.getInputsFromPreviousSteps( + requiredKeys, + optionalKeys, + currentNodeInputs, + outputs, + previousNodeInputs, + params ); - } else { - // Create PutPipelineRequest and execute - PutPipelineRequest putPipelineRequest = new PutPipelineRequest(pipelineId, configuration, XContentType.JSON); - clusterAdminClient.putPipeline(putPipelineRequest, ActionListener.wrap(response -> { - logger.info("Created ingest pipeline : " + putPipelineRequest.getId()); + String pipelineId = (String) inputs.get(PIPELINE_ID); + String configurations = (String) inputs.get(CONFIGURATIONS); + + byte[] byteArr = configurations.getBytes(StandardCharsets.UTF_8); + BytesReference configurationsBytes = new BytesArray(byteArr); + + // Create PutPipelineRequest and execute + PutPipelineRequest putPipelineRequest = new PutPipelineRequest(pipelineId, configurationsBytes, XContentType.JSON); + clusterAdminClient.putPipeline(putPipelineRequest, ActionListener.wrap(acknowledgedResponse -> { + String resourceName = getResourceByWorkflowStep(getName()); try { - String resourceName = getResourceByWorkflowStep(getName()); flowFrameworkIndicesHandler.updateResourceInStateIndex( currentNodeInputs.getWorkflowId(), currentNodeId, getName(), - putPipelineRequest.getId(), + pipelineId, ActionListener.wrap(updateResponse -> { logger.info("successfully updated resources created in state index: {}", updateResponse.getIndex()); - // PutPipelineRequest returns only an AcknowledgeResponse, returning pipelineId instead + // PutPipelineRequest returns only an AcknowledgeResponse, saving pipelineId instead // TODO: revisit this concept of pipeline_id to be consistent with what makes most sense to end user here createIngestPipelineFuture.onResponse( new WorkflowData( - Map.of(resourceName, putPipelineRequest.getId()), + Map.of(resourceName, pipelineId), currentNodeInputs.getWorkflowId(), currentNodeInputs.getNodeId() ) @@ -174,7 +115,7 @@ public PlainActionFuture execute( + " resource " + getName() + " id " - + putPipelineRequest.getId(); + + pipelineId; logger.error(errorMessage, exception); createIngestPipelineFuture.onFailure( new FlowFrameworkException(errorMessage, ExceptionsHelper.status(exception)) @@ -187,12 +128,14 @@ public PlainActionFuture execute( logger.error(errorMessage, e); createIngestPipelineFuture.onFailure(new FlowFrameworkException(errorMessage, ExceptionsHelper.status(e))); } - - }, exception -> { + }, e -> { String errorMessage = "Failed to create ingest pipeline"; - logger.error(errorMessage, exception); - createIngestPipelineFuture.onFailure(new FlowFrameworkException(errorMessage, ExceptionsHelper.status(exception))); + logger.error(errorMessage, e); + createIngestPipelineFuture.onFailure(new FlowFrameworkException(errorMessage, ExceptionsHelper.status(e))); })); + + } catch (FlowFrameworkException e) { + createIngestPipelineFuture.onFailure(e); } return createIngestPipelineFuture; @@ -202,50 +145,4 @@ public PlainActionFuture execute( public String getName() { return NAME; } - - /** - * Temporary, generates the ingest pipeline request content for text_embedding processor from workflow data - * { - * "description" : "", - * "processors" : [ - * { - * "" : { - * "model_id" : "", - * "field_map" : { - * "" : "" - * } - * } - * ] - * } - * - * @param description The description of the ingest pipeline configuration - * @param modelId The ID of the model that will be used in the embedding interface - * @param type The processor type - * @param inputFieldName The field name used to cache text for text embeddings - * @param outputFieldName The field name in which output text is stored - * @throws IOException if the request content fails to be generated - * @return the xcontent builder with the formatted ingest pipeline configuration - */ - private XContentBuilder buildIngestPipelineRequestContent( - String description, - String modelId, - String type, - String inputFieldName, - String outputFieldName - ) throws IOException { - return XContentFactory.jsonBuilder() - .startObject() - .field(DESCRIPTION_FIELD, description) - .startArray(PROCESSORS) - .startObject() - .startObject(type) - .field(MODEL_ID, modelId) - .startObject(FIELD_MAP) - .field(inputFieldName, outputFieldName) - .endObject() - .endObject() - .endObject() - .endArray() - .endObject(); - } } diff --git a/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java b/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java index df7b3aa35..b8b736890 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java +++ b/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java @@ -10,6 +10,7 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; +import org.opensearch.client.Client; import org.opensearch.common.unit.TimeValue; import org.opensearch.core.common.Strings; import org.opensearch.core.rest.RestStatus; @@ -30,6 +31,7 @@ import java.util.function.Supplier; import static org.opensearch.flowframework.common.CommonValue.ACTIONS_FIELD; +import static org.opensearch.flowframework.common.CommonValue.CONFIGURATIONS; import static org.opensearch.flowframework.common.CommonValue.CREDENTIAL_FIELD; import static org.opensearch.flowframework.common.CommonValue.DESCRIPTION_FIELD; import static org.opensearch.flowframework.common.CommonValue.EMBEDDING_DIMENSION; @@ -42,6 +44,7 @@ import static org.opensearch.flowframework.common.CommonValue.NAME_FIELD; import static org.opensearch.flowframework.common.CommonValue.OPENSEARCH_ML; import static org.opensearch.flowframework.common.CommonValue.PARAMETERS_FIELD; +import static org.opensearch.flowframework.common.CommonValue.PIPELINE_ID; import static org.opensearch.flowframework.common.CommonValue.PROTOCOL_FIELD; import static org.opensearch.flowframework.common.CommonValue.REGISTER_MODEL_STATUS; import static org.opensearch.flowframework.common.CommonValue.SUCCESS; @@ -69,12 +72,14 @@ public class WorkflowStepFactory { * @param mlClient Machine Learning client to perform ml operations * @param flowFrameworkIndicesHandler FlowFrameworkIndicesHandler class to update system indices * @param flowFrameworkSettings common settings of the plugin + * @param client The OpenSearch Client */ public WorkflowStepFactory( ThreadPool threadPool, MachineLearningNodeClient mlClient, FlowFrameworkIndicesHandler flowFrameworkIndicesHandler, - FlowFrameworkSettings flowFrameworkSettings + FlowFrameworkSettings flowFrameworkSettings, + Client client ) { stepMap.put(NoOpStep.NAME, NoOpStep::new); stepMap.put( @@ -102,6 +107,7 @@ public WorkflowStepFactory( stepMap.put(ToolStep.NAME, ToolStep::new); stepMap.put(RegisterAgentStep.NAME, () -> new RegisterAgentStep(mlClient, flowFrameworkIndicesHandler)); stepMap.put(DeleteAgentStep.NAME, () -> new DeleteAgentStep(mlClient)); + stepMap.put(CreateIngestPipelineStep.NAME, () -> new CreateIngestPipelineStep(client, flowFrameworkIndicesHandler)); } /** @@ -196,7 +202,16 @@ public enum WorkflowSteps { DELETE_AGENT(DeleteAgentStep.NAME, List.of(AGENT_ID), List.of(AGENT_ID), List.of(OPENSEARCH_ML), null), /** Create Tool Step */ - CREATE_TOOL(ToolStep.NAME, List.of(TYPE), List.of(TOOLS_FIELD), List.of(OPENSEARCH_ML), null); + CREATE_TOOL(ToolStep.NAME, List.of(TYPE), List.of(TOOLS_FIELD), List.of(OPENSEARCH_ML), null), + + /** Create Ingest Pipeline Step */ + CREATE_INGEST_PIPELINE( + CreateIngestPipelineStep.NAME, + List.of(PIPELINE_ID, CONFIGURATIONS), + List.of(PIPELINE_ID), + Collections.emptyList(), + null + ); private final String workflowStepName; private final List inputs; diff --git a/src/test/java/org/opensearch/flowframework/FlowFrameworkRestTestCase.java b/src/test/java/org/opensearch/flowframework/FlowFrameworkRestTestCase.java index 3de22e3db..c29fa0798 100644 --- a/src/test/java/org/opensearch/flowframework/FlowFrameworkRestTestCase.java +++ b/src/test/java/org/opensearch/flowframework/FlowFrameworkRestTestCase.java @@ -639,4 +639,28 @@ protected Response deleteUser(String user) throws IOException { List.of(new BasicHeader(HttpHeaders.USER_AGENT, "Kibana")) ); } + + protected GetPipelineResponse getPipelines() throws IOException { + Response getPipelinesResponse = TestHelpers.makeRequest( + client(), + "GET", + "_ingest/pipeline", + null, + "", + ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, "")) + ); + + // Parse entity content into SearchResponse + MediaType mediaType = MediaType.fromMediaType(getPipelinesResponse.getEntity().getContentType()); + try ( + XContentParser parser = mediaType.xContent() + .createParser( + NamedXContentRegistry.EMPTY, + DeprecationHandler.THROW_UNSUPPORTED_OPERATION, + getPipelinesResponse.getEntity().getContent() + ) + ) { + return GetPipelineResponse.fromXContent(parser); + } + } } diff --git a/src/test/java/org/opensearch/flowframework/model/WorkflowValidatorTests.java b/src/test/java/org/opensearch/flowframework/model/WorkflowValidatorTests.java index c52e6b9cf..6b1841708 100644 --- a/src/test/java/org/opensearch/flowframework/model/WorkflowValidatorTests.java +++ b/src/test/java/org/opensearch/flowframework/model/WorkflowValidatorTests.java @@ -8,6 +8,7 @@ */ package org.opensearch.flowframework.model; +import org.opensearch.client.Client; import org.opensearch.flowframework.common.FlowFrameworkSettings; import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; import org.opensearch.flowframework.workflow.WorkflowStepFactory; @@ -29,6 +30,7 @@ public class WorkflowValidatorTests extends OpenSearchTestCase { private FlowFrameworkSettings flowFrameworkSettings; + private static Client client = mock(Client.class); @Override public void setUp() throws Exception { @@ -44,7 +46,7 @@ public void testParseWorkflowValidator() throws IOException { WorkflowValidator validator = new WorkflowValidator(workflowStepValidators); - assertEquals(14, validator.getWorkflowStepValidators().size()); + assertEquals(15, validator.getWorkflowStepValidators().size()); assertTrue(validator.getWorkflowStepValidators().keySet().contains("create_connector")); assertEquals(7, validator.getWorkflowStepValidators().get("create_connector").getInputs().size()); @@ -113,7 +115,8 @@ public void testWorkflowStepFactoryHasValidators() throws IOException { threadPool, mlClient, flowFrameworkIndicesHandler, - flowFrameworkSettings + flowFrameworkSettings, + client ); WorkflowValidator workflowValidator = workflowStepFactory.getWorkflowValidator(); diff --git a/src/test/java/org/opensearch/flowframework/rest/FlowFrameworkRestApiIT.java b/src/test/java/org/opensearch/flowframework/rest/FlowFrameworkRestApiIT.java index edab14abf..e957ce271 100644 --- a/src/test/java/org/opensearch/flowframework/rest/FlowFrameworkRestApiIT.java +++ b/src/test/java/org/opensearch/flowframework/rest/FlowFrameworkRestApiIT.java @@ -11,6 +11,7 @@ import org.apache.http.util.EntityUtils; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; +import org.opensearch.action.ingest.GetPipelineResponse; import org.opensearch.action.search.SearchResponse; import org.opensearch.client.Response; import org.opensearch.client.ResponseException; @@ -344,4 +345,50 @@ public void testTimestamps() throws Exception { response = deleteWorkflow(client(), workflowId); assertEquals(RestStatus.OK.getStatus(), response.getStatusLine().getStatusCode()); } + + public void testCreateAndProvisionIngestPipeline() throws Exception { + + // Using a 3 step template to create a connector, register remote model and deploy model + Template template = TestHelpers.createTemplateFromFile("ingest-pipeline-template.json"); + + // Hit Create Workflow API with original template + Response response = createWorkflow(client(), template); + assertEquals(RestStatus.CREATED, TestHelpers.restStatus(response)); + + Map responseMap = entityAsMap(response); + String workflowId = (String) responseMap.get(WORKFLOW_ID); + getAndAssertWorkflowStatus(client(), workflowId, State.NOT_STARTED, ProvisioningProgress.NOT_STARTED); + + // Ensure Ml config index is initialized as creating a connector requires this, then hit Provision API and assert status + if (!indexExistsWithAdminClient(".plugins-ml-config")) { + assertBusy(() -> assertTrue(indexExistsWithAdminClient(".plugins-ml-config")), 40, TimeUnit.SECONDS); + response = provisionWorkflow(client(), workflowId); + } else { + response = provisionWorkflow(client(), workflowId); + } + + assertEquals(RestStatus.OK, TestHelpers.restStatus(response)); + getAndAssertWorkflowStatus(client(), workflowId, State.PROVISIONING, ProvisioningProgress.IN_PROGRESS); + + // Wait until provisioning has completed successfully before attempting to retrieve created resources + List resourcesCreated = getResourcesCreated(client(), workflowId, 30); + + // This template should create 4 resources, connector_id, registered model_id, deployed model_id and pipelineId + assertEquals(4, resourcesCreated.size()); + assertEquals("create_connector", resourcesCreated.get(0).workflowStepName()); + assertNotNull(resourcesCreated.get(0).resourceId()); + assertEquals("register_remote_model", resourcesCreated.get(1).workflowStepName()); + assertNotNull(resourcesCreated.get(1).resourceId()); + assertEquals("deploy_model", resourcesCreated.get(2).workflowStepName()); + assertNotNull(resourcesCreated.get(2).resourceId()); + assertEquals("create_ingest_pipeline", resourcesCreated.get(3).workflowStepName()); + assertNotNull(resourcesCreated.get(3).resourceId()); + String modelId = resourcesCreated.get(2).resourceId(); + + GetPipelineResponse getPipelinesResponse = getPipelines(); + + assertTrue(getPipelinesResponse.pipelines().get(0).toString().contains(modelId)); + + } + } diff --git a/src/test/java/org/opensearch/flowframework/rest/RestGetWorkflowStepActionTests.java b/src/test/java/org/opensearch/flowframework/rest/RestGetWorkflowStepActionTests.java index 59df28a42..c7ebbf71e 100644 --- a/src/test/java/org/opensearch/flowframework/rest/RestGetWorkflowStepActionTests.java +++ b/src/test/java/org/opensearch/flowframework/rest/RestGetWorkflowStepActionTests.java @@ -8,6 +8,7 @@ */ package org.opensearch.flowframework.rest; +import org.opensearch.client.Client; import org.opensearch.client.node.NodeClient; import org.opensearch.core.action.ActionListener; import org.opensearch.core.common.bytes.BytesArray; @@ -57,8 +58,15 @@ public void setUp() throws Exception { ThreadPool threadPool = mock(ThreadPool.class); MachineLearningNodeClient mlClient = mock(MachineLearningNodeClient.class); FlowFrameworkIndicesHandler flowFrameworkIndicesHandler = mock(FlowFrameworkIndicesHandler.class); - - this.workflowStepFactory = new WorkflowStepFactory(threadPool, mlClient, flowFrameworkIndicesHandler, flowFrameworkSettings); + Client client = mock(Client.class); + + this.workflowStepFactory = new WorkflowStepFactory( + threadPool, + mlClient, + flowFrameworkIndicesHandler, + flowFrameworkSettings, + client + ); flowFrameworkFeatureEnabledSetting = mock(FlowFrameworkSettings.class); when(flowFrameworkFeatureEnabledSetting.isFlowFrameworkEnabled()).thenReturn(true); this.restGetWorkflowStepAction = new RestGetWorkflowStepAction(flowFrameworkFeatureEnabledSetting); diff --git a/src/test/java/org/opensearch/flowframework/workflow/CreateIngestPipelineStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/CreateIngestPipelineStepTests.java index f8d7bd68f..f8c9402d0 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/CreateIngestPipelineStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/CreateIngestPipelineStepTests.java @@ -28,6 +28,7 @@ import org.mockito.ArgumentCaptor; import static org.opensearch.action.DocWriteResponse.Result.UPDATED; +import static org.opensearch.flowframework.common.CommonValue.CONFIGURATIONS; import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_STATE_INDEX; import static org.opensearch.flowframework.common.WorkflowResources.MODEL_ID; import static org.opensearch.flowframework.common.WorkflowResources.PIPELINE_ID; @@ -54,15 +55,10 @@ public void setUp() throws Exception { super.setUp(); this.flowFrameworkIndicesHandler = mock(FlowFrameworkIndicesHandler.class); + String configurations = + "{“description”:“An neural ingest pipeline”,“processors”:[{“text_embedding”:{“field_map”:{“text”:“analyzed_text”},“model_id”:“sdsadsadasd”}}]}"; inputData = new WorkflowData( - Map.ofEntries( - Map.entry("id", "pipelineId"), - Map.entry("description", "some description"), - Map.entry("type", "text_embedding"), - Map.entry(MODEL_ID, MODEL_ID), - Map.entry("input_field_name", "inputField"), - Map.entry("output_field_name", "outputField") - ), + Map.ofEntries(Map.entry(CONFIGURATIONS, configurations), Map.entry(PIPELINE_ID, "pipelineId")), "test-id", "test-node-id" ); @@ -136,7 +132,7 @@ public void testCreateIngestPipelineStepFailure() throws InterruptedException { } public void testMissingData() throws InterruptedException { - CreateIngestPipelineStep createIngestPipelineStep = new CreateIngestPipelineStep(client, flowFrameworkIndicesHandler); + CreateIngestPipelineStep CreateIngestPipelineStep = new CreateIngestPipelineStep(client, flowFrameworkIndicesHandler); // Data with missing input and output fields WorkflowData incorrectData = new WorkflowData( @@ -150,7 +146,7 @@ public void testMissingData() throws InterruptedException { "test-node-id" ); - PlainActionFuture future = createIngestPipelineStep.execute( + PlainActionFuture future = CreateIngestPipelineStep.execute( incorrectData.getNodeId(), incorrectData, Collections.emptyMap(), @@ -161,7 +157,10 @@ public void testMissingData() throws InterruptedException { ExecutionException exception = assertThrows(ExecutionException.class, () -> future.get()); assertTrue(exception.getCause() instanceof Exception); - assertEquals("Failed to create ingest pipeline for test-node-id, required inputs not found", exception.getCause().getMessage()); + assertEquals( + "Missing required inputs [configurations, pipeline_id] in workflow [test-id] node [test-node-id]", + exception.getCause().getMessage() + ); } } diff --git a/src/test/java/org/opensearch/flowframework/workflow/WorkflowProcessSorterTests.java b/src/test/java/org/opensearch/flowframework/workflow/WorkflowProcessSorterTests.java index 02488a739..df7d1732a 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/WorkflowProcessSorterTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/WorkflowProcessSorterTests.java @@ -118,7 +118,13 @@ public static void setup() throws IOException { FLOW_FRAMEWORK_THREAD_POOL_PREFIX + DEPROVISION_WORKFLOW_THREAD_POOL ) ); - WorkflowStepFactory factory = new WorkflowStepFactory(testThreadPool, mlClient, flowFrameworkIndicesHandler, flowFrameworkSettings); + WorkflowStepFactory factory = new WorkflowStepFactory( + testThreadPool, + mlClient, + flowFrameworkIndicesHandler, + flowFrameworkSettings, + client + ); workflowProcessSorter = new WorkflowProcessSorter(factory, testThreadPool, clusterService, client, flowFrameworkSettings); } diff --git a/src/test/resources/template/ingest-pipeline-template.json b/src/test/resources/template/ingest-pipeline-template.json new file mode 100644 index 000000000..b5ee4d19d --- /dev/null +++ b/src/test/resources/template/ingest-pipeline-template.json @@ -0,0 +1,88 @@ +{ + "name": "Deploy OpenAI Model", + "description": "Deploy a model using a connector to OpenAI", + "use_case": "PROVISION", + "version": { + "template": "1.0.0", + "compatibility": [ + "2.12.0", + "3.0.0" + ] + }, + "workflows": { + "provision": { + "nodes": [ + { + "id": "create_openai_connector", + "type": "create_connector", + "user_inputs": { + "name": "OpenAI Chat Connector", + "description": "The connector to public OpenAI model service for text embedding model", + "version": "1", + "protocol": "http", + "parameters": { + "endpoint": "api.openai.com", + "model": "gpt-3.5-turbo", + "response_filter": "$.choices[0].message.content" + }, + "credential": { + "openAI_key": "PUT_YOUR_API_KEY_HERE" + }, + "actions": [ + { + "action_type": "predict", + "method": "POST", + "url": "https://${parameters.endpoint}/v1/chat/completions" + } + ] + } + }, + { + "id": "register_openai_model", + "type": "register_remote_model", + "previous_node_inputs": { + "create_openai_connector": "connector_id" + }, + "user_inputs": { + "name": "openAI-gpt-3.5-turbo" + } + }, + { + "id": "deploy_openai_model", + "type": "deploy_model", + "previous_node_inputs": { + "register_openai_model": "model_id" + } + }, + { + "id": "create_ingest_pipeline", + "type": "create_ingest_pipeline", + "previous_node_inputs": { + "deploy_openai_model": "model_id" + }, + "user_inputs": { + "pipeline_id": "append-1", + "configurations": { + "description": "Pipeline that appends event type", + "processors": [ + { + "append": { + "field": "event_types", + "value": [ + "${{deploy_openai_model.model_id}}" + ] + } + }, + { + "drop": { + "if": "ctx.user_info.contains('password') || ctx.user_info.contains('credit card')" + } + } + ] + } + } + } + ] + } + } +}