diff --git a/build.gradle b/build.gradle index 2f008a1c6..7d72e8dd6 100644 --- a/build.gradle +++ b/build.gradle @@ -175,6 +175,8 @@ dependencies { implementation 'com.amazonaws:aws-encryption-sdk-java:2.4.1' implementation 'org.bouncycastle:bcprov-jdk18on:1.77' api "org.apache.httpcomponents.core5:httpcore5:5.2.2" + implementation("com.fasterxml.jackson.core:jackson-annotations:${versions.jackson}") + implementation("com.fasterxml.jackson.core:jackson-databind:${versions.jackson_databind}") // ZipArchive dependencies used for integration tests zipArchive group: 'org.opensearch.plugin', name:'opensearch-ml-plugin', version: "${opensearch_build}" diff --git a/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java b/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java index 49f3bcce9..4c8486b7e 100644 --- a/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java +++ b/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java @@ -123,7 +123,8 @@ public Collection createComponents( threadPool, mlClient, flowFrameworkIndicesHandler, - flowFrameworkSettings + flowFrameworkSettings, + client ); WorkflowProcessSorter workflowProcessSorter = new WorkflowProcessSorter( workflowStepFactory, diff --git a/src/main/java/org/opensearch/flowframework/common/CommonValue.java b/src/main/java/org/opensearch/flowframework/common/CommonValue.java index ec88a3778..94916098e 100644 --- a/src/main/java/org/opensearch/flowframework/common/CommonValue.java +++ b/src/main/java/org/opensearch/flowframework/common/CommonValue.java @@ -170,6 +170,10 @@ private CommonValue() {} public static final String HOSTNAME_FIELD = "hostname"; /** Http port */ public static final String PORT_FIELD = "port"; + /** Pipeline ID, also corresponds to pipeline name */ + public static final String PIPELINE_ID = "pipeline_id"; + /** Pipeline Configurations */ + public static final String CONFIGURATIONS = "configurations"; /* * Constants associated with resource provisioning / state diff --git a/src/main/java/org/opensearch/flowframework/model/WorkflowNode.java b/src/main/java/org/opensearch/flowframework/model/WorkflowNode.java index 55122f9e5..15d52ccd1 100644 --- a/src/main/java/org/opensearch/flowframework/model/WorkflowNode.java +++ b/src/main/java/org/opensearch/flowframework/model/WorkflowNode.java @@ -8,12 +8,15 @@ */ package org.opensearch.flowframework.model; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; import org.opensearch.common.unit.TimeValue; import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.flowframework.exception.FlowFrameworkException; +import org.opensearch.flowframework.util.ParseUtils; import org.opensearch.flowframework.workflow.ProcessNode; import org.opensearch.flowframework.workflow.WorkflowData; import org.opensearch.flowframework.workflow.WorkflowStep; @@ -28,6 +31,7 @@ import static java.util.concurrent.TimeUnit.SECONDS; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.flowframework.common.CommonValue.CONFIGURATIONS; import static org.opensearch.flowframework.common.CommonValue.TOOLS_ORDER_FIELD; import static org.opensearch.flowframework.util.ParseUtils.buildStringToObjectMap; import static org.opensearch.flowframework.util.ParseUtils.buildStringToStringMap; @@ -60,6 +64,7 @@ public class WorkflowNode implements ToXContentObject { private final String type; // maps to a WorkflowStep private final Map previousNodeInputs; private final Map userInputs; // maps to WorkflowData + private static final Logger logger = LogManager.getLogger(WorkflowNode.class); /** * Create this node with the id and type, and any user input. @@ -151,7 +156,20 @@ public static WorkflowNode parse(XContentParser parser) throws IOException { userInputs.put(inputFieldName, parser.text()); break; case START_OBJECT: - userInputs.put(inputFieldName, parseStringToStringMap(parser)); + if (CONFIGURATIONS.equals(inputFieldName)) { + Map configurationsMap = parser.map(); + try { + String configurationsString = ParseUtils.parseArbitraryStringToObjectMapToString(configurationsMap); + userInputs.put(inputFieldName, configurationsString); + } catch (Exception ex) { + String errorMessage = "Failed to parse configuration map"; + logger.error(errorMessage, ex); + throw new FlowFrameworkException(errorMessage, RestStatus.BAD_REQUEST); + } + break; + } else { + userInputs.put(inputFieldName, parseStringToStringMap(parser)); + } break; case START_ARRAY: if (PROCESSORS_FIELD.equals(inputFieldName)) { diff --git a/src/main/java/org/opensearch/flowframework/util/ParseUtils.java b/src/main/java/org/opensearch/flowframework/util/ParseUtils.java index 6192d8e6d..177e16bc5 100644 --- a/src/main/java/org/opensearch/flowframework/util/ParseUtils.java +++ b/src/main/java/org/opensearch/flowframework/util/ParseUtils.java @@ -8,6 +8,9 @@ */ package org.opensearch.flowframework.util; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; + import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.opensearch.client.Client; @@ -356,4 +359,17 @@ private static Object conditionallySubstitute(Object value, Map map) throws JsonProcessingException { + ObjectMapper mapper = new ObjectMapper(); + // Convert the map to a JSON string + String mappedString = mapper.writeValueAsString(map); + return mappedString; + } } diff --git a/src/main/java/org/opensearch/flowframework/workflow/CreateConnectorStep.java b/src/main/java/org/opensearch/flowframework/workflow/CreateConnectorStep.java index a6b0b6a40..403e26063 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/CreateConnectorStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/CreateConnectorStep.java @@ -161,8 +161,10 @@ public void onFailure(Exception e) { credentials = getStringToStringMap(inputs.get(CREDENTIAL_FIELD), CREDENTIAL_FIELD); actions = getConnectorActionList(inputs.get(ACTIONS_FIELD)); } catch (IllegalArgumentException iae) { + logger.error("IllegalArgumentException in connector configuration", iae); throw new FlowFrameworkException("IllegalArgumentException in connector configuration", RestStatus.BAD_REQUEST); } catch (PrivilegedActionException pae) { + logger.error("PrivilegedActionException in connector configuration", pae); throw new FlowFrameworkException("PrivilegedActionException in connector configuration", RestStatus.UNAUTHORIZED); } diff --git a/src/main/java/org/opensearch/flowframework/workflow/CreateIngestPipelineStep.java b/src/main/java/org/opensearch/flowframework/workflow/CreateIngestPipelineStep.java index 1ddfa65f6..2b7d81919 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/CreateIngestPipelineStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/CreateIngestPipelineStep.java @@ -13,39 +13,33 @@ import org.opensearch.ExceptionsHelper; import org.opensearch.action.ingest.PutPipelineRequest; import org.opensearch.action.support.PlainActionFuture; +import org.opensearch.action.support.master.AcknowledgedResponse; import org.opensearch.client.Client; import org.opensearch.client.ClusterAdminClient; -import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.common.xcontent.XContentType; import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.bytes.BytesArray; import org.opensearch.core.common.bytes.BytesReference; import org.opensearch.core.rest.RestStatus; -import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.flowframework.exception.FlowFrameworkException; import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; +import org.opensearch.flowframework.util.ParseUtils; -import java.io.IOException; -import java.util.ArrayList; -import java.util.List; +import java.nio.charset.StandardCharsets; import java.util.Map; -import java.util.Map.Entry; -import java.util.stream.Stream; - -import static org.opensearch.flowframework.common.CommonValue.DESCRIPTION_FIELD; -import static org.opensearch.flowframework.common.CommonValue.FIELD_MAP; -import static org.opensearch.flowframework.common.CommonValue.ID; -import static org.opensearch.flowframework.common.CommonValue.INPUT_FIELD_NAME; -import static org.opensearch.flowframework.common.CommonValue.OUTPUT_FIELD_NAME; -import static org.opensearch.flowframework.common.CommonValue.PROCESSORS; -import static org.opensearch.flowframework.common.CommonValue.TYPE; +import java.util.Set; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +import static org.opensearch.flowframework.common.CommonValue.CONFIGURATIONS; import static org.opensearch.flowframework.common.WorkflowResources.MODEL_ID; +import static org.opensearch.flowframework.common.WorkflowResources.PIPELINE_ID; import static org.opensearch.flowframework.common.WorkflowResources.getResourceByWorkflowStep; /** - * Workflow step to create an ingest pipeline + * Step to create an ingest pipeline */ public class CreateIngestPipelineStep implements WorkflowStep { - private static final Logger logger = LogManager.getLogger(CreateIngestPipelineStep.class); /** The name of this step, used as a key in the template and the {@link WorkflowStepFactory} */ @@ -57,7 +51,7 @@ public class CreateIngestPipelineStep implements WorkflowStep { private final FlowFrameworkIndicesHandler flowFrameworkIndicesHandler; /** - * Instantiates a new CreateIngestPipelineStep + * Instantiates a new CreateIngestPipelineStepDraft * @param client The client to create a pipeline and store workflow data into the global context index * @param flowFrameworkIndicesHandler FlowFrameworkIndicesHandler class to update system indices */ @@ -77,93 +71,24 @@ public PlainActionFuture execute( PlainActionFuture createIngestPipelineFuture = PlainActionFuture.newFuture(); - String pipelineId = null; - String description = null; - String type = null; - String modelId = null; - String inputFieldName = null; - String outputFieldName = null; - BytesReference configuration = null; - - // TODO: Recreating the list to get this compiling - // Need to refactor the below iteration to pull directly from the maps - List data = new ArrayList<>(); - data.add(currentNodeInputs); - data.addAll(outputs.values()); - - // Extract required content from workflow data and generate the ingest pipeline configuration - for (WorkflowData workflowData : data) { - - Map content = workflowData.getContent(); - - for (Entry entry : content.entrySet()) { - switch (entry.getKey()) { - case ID: - pipelineId = (String) content.get(ID); - break; - case DESCRIPTION_FIELD: - description = (String) content.get(DESCRIPTION_FIELD); - break; - case TYPE: - type = (String) content.get(TYPE); - break; - case MODEL_ID: - modelId = (String) content.get(MODEL_ID); - break; - case INPUT_FIELD_NAME: - inputFieldName = (String) content.get(INPUT_FIELD_NAME); - break; - case OUTPUT_FIELD_NAME: - outputFieldName = (String) content.get(OUTPUT_FIELD_NAME); - break; - default: - break; - } - } + ActionListener actionListener = new ActionListener<>() { - // Determine if fields have been populated, else iterate over remaining workflow data - if (Stream.of(pipelineId, description, modelId, type, inputFieldName, outputFieldName).allMatch(x -> x != null)) { + @Override + public void onResponse(AcknowledgedResponse acknowledgedResponse) { + String resourceName = getResourceByWorkflowStep(getName()); try { - configuration = BytesReference.bytes( - buildIngestPipelineRequestContent(description, modelId, type, inputFieldName, outputFieldName) - ); - } catch (IOException e) { - String errorMessage = "Failed to create ingest pipeline configuration for " + currentNodeId; - logger.error(errorMessage, e); - createIngestPipelineFuture.onFailure(new FlowFrameworkException(errorMessage, ExceptionsHelper.status(e))); - } - break; - } - } - - if (configuration == null) { - // Required workflow data not found - createIngestPipelineFuture.onFailure( - new FlowFrameworkException( - "Failed to create ingest pipeline for " + currentNodeId + ", required inputs not found", - RestStatus.BAD_REQUEST - ) - ); - } else { - // Create PutPipelineRequest and execute - PutPipelineRequest putPipelineRequest = new PutPipelineRequest(pipelineId, configuration, XContentType.JSON); - clusterAdminClient.putPipeline(putPipelineRequest, ActionListener.wrap(response -> { - logger.info("Created ingest pipeline : " + putPipelineRequest.getId()); - - try { - String resourceName = getResourceByWorkflowStep(getName()); flowFrameworkIndicesHandler.updateResourceInStateIndex( currentNodeInputs.getWorkflowId(), currentNodeId, getName(), - putPipelineRequest.getId(), + currentNodeInputs.getContent().get(PIPELINE_ID).toString(), ActionListener.wrap(updateResponse -> { logger.info("successfully updated resources created in state index: {}", updateResponse.getIndex()); // PutPipelineRequest returns only an AcknowledgeResponse, returning pipelineId instead // TODO: revisit this concept of pipeline_id to be consistent with what makes most sense to end user here createIngestPipelineFuture.onResponse( new WorkflowData( - Map.of(resourceName, putPipelineRequest.getId()), + Map.of(resourceName, currentNodeInputs.getContent().get(PIPELINE_ID).toString()), currentNodeInputs.getWorkflowId(), currentNodeInputs.getNodeId() ) @@ -174,7 +99,7 @@ public PlainActionFuture execute( + " resource " + getName() + " id " - + putPipelineRequest.getId(); + + currentNodeInputs.getContent().get(PIPELINE_ID).toString(); logger.error(errorMessage, exception); createIngestPipelineFuture.onFailure( new FlowFrameworkException(errorMessage, ExceptionsHelper.status(exception)) @@ -187,12 +112,75 @@ public PlainActionFuture execute( logger.error(errorMessage, e); createIngestPipelineFuture.onFailure(new FlowFrameworkException(errorMessage, ExceptionsHelper.status(e))); } + } - }, exception -> { + @Override + public void onFailure(Exception e) { String errorMessage = "Failed to create ingest pipeline"; - logger.error(errorMessage, exception); - createIngestPipelineFuture.onFailure(new FlowFrameworkException(errorMessage, ExceptionsHelper.status(exception))); - })); + logger.error(errorMessage, e); + createIngestPipelineFuture.onFailure(new FlowFrameworkException(errorMessage, ExceptionsHelper.status(e))); + } + + }; + + Set requiredKeys = Set.of(PIPELINE_ID, CONFIGURATIONS); + + // currently, we are supporting an optional param of model ID into the various processors + Set optionalKeys = Set.of(MODEL_ID); + + try { + Map inputs = ParseUtils.getInputsFromPreviousSteps( + requiredKeys, + optionalKeys, + currentNodeInputs, + outputs, + previousNodeInputs, + params + ); + + String pipelineId = (String) inputs.get(PIPELINE_ID); + String configurations = (String) inputs.get(CONFIGURATIONS); + + // Regex to find patterns like ${{deploy_openai_model.model_id}} + // We currently support one previous node input that fits the pattern of (step.input_to_look_for) + Pattern pattern = Pattern.compile("\\$\\{\\{([\\w_]+)\\.([\\w_]+)\\}\\}"); + Matcher matcher = pattern.matcher(configurations); + + StringBuffer result = new StringBuffer(); + while (matcher.find()) { + // Params map contains params for previous node input (e.g: deploy_openai_model:model_id) + // Check first if the substitution is looking for the same key, value pair and if yes + // then replace it with the key value pair in the inputs map + if (params.containsKey(matcher.group(1)) && params.get(matcher.group(1)).equals(matcher.group(2))) { + // Extract the key for the inputs (e.g., "model_id" from ${{deploy_openai_model.model_id}}) + String key = matcher.group(2); + if (inputs.containsKey(key)) { + // Replace the whole sequence with the value from the map + matcher.appendReplacement(result, (String) inputs.get(key)); + } + } + } + matcher.appendTail(result); + + if (result == null || pipelineId == null) { + // Required workflow data not found + createIngestPipelineFuture.onFailure( + new FlowFrameworkException( + "Failed to create ingest pipeline for " + currentNodeId + ", required inputs not found", + RestStatus.BAD_REQUEST + ) + ); + } + + byte[] byteArr = result.toString().getBytes(StandardCharsets.UTF_8); + BytesReference configurationsBytes = new BytesArray(byteArr); + + // Create PutPipelineRequest and execute + PutPipelineRequest putPipelineRequest = new PutPipelineRequest(pipelineId, configurationsBytes, XContentType.JSON); + clusterAdminClient.putPipeline(putPipelineRequest, actionListener); + + } catch (FlowFrameworkException e) { + createIngestPipelineFuture.onFailure(e); } return createIngestPipelineFuture; @@ -202,50 +190,4 @@ public PlainActionFuture execute( public String getName() { return NAME; } - - /** - * Temporary, generates the ingest pipeline request content for text_embedding processor from workflow data - * { - * "description" : "", - * "processors" : [ - * { - * "" : { - * "model_id" : "", - * "field_map" : { - * "" : "" - * } - * } - * ] - * } - * - * @param description The description of the ingest pipeline configuration - * @param modelId The ID of the model that will be used in the embedding interface - * @param type The processor type - * @param inputFieldName The field name used to cache text for text embeddings - * @param outputFieldName The field name in which output text is stored - * @throws IOException if the request content fails to be generated - * @return the xcontent builder with the formatted ingest pipeline configuration - */ - private XContentBuilder buildIngestPipelineRequestContent( - String description, - String modelId, - String type, - String inputFieldName, - String outputFieldName - ) throws IOException { - return XContentFactory.jsonBuilder() - .startObject() - .field(DESCRIPTION_FIELD, description) - .startArray(PROCESSORS) - .startObject() - .startObject(type) - .field(MODEL_ID, modelId) - .startObject(FIELD_MAP) - .field(inputFieldName, outputFieldName) - .endObject() - .endObject() - .endObject() - .endArray() - .endObject(); - } } diff --git a/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java b/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java index 8c72e8481..65222ed25 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java +++ b/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java @@ -10,6 +10,7 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; +import org.opensearch.client.Client; import org.opensearch.common.unit.TimeValue; import org.opensearch.core.common.Strings; import org.opensearch.core.rest.RestStatus; @@ -30,6 +31,7 @@ import java.util.function.Supplier; import static org.opensearch.flowframework.common.CommonValue.ACTIONS_FIELD; +import static org.opensearch.flowframework.common.CommonValue.CONFIGURATIONS; import static org.opensearch.flowframework.common.CommonValue.CREDENTIAL_FIELD; import static org.opensearch.flowframework.common.CommonValue.DESCRIPTION_FIELD; import static org.opensearch.flowframework.common.CommonValue.EMBEDDING_DIMENSION; @@ -44,6 +46,7 @@ import static org.opensearch.flowframework.common.CommonValue.NAME_FIELD; import static org.opensearch.flowframework.common.CommonValue.OPENSEARCH_ML; import static org.opensearch.flowframework.common.CommonValue.PARAMETERS_FIELD; +import static org.opensearch.flowframework.common.CommonValue.PIPELINE_ID; import static org.opensearch.flowframework.common.CommonValue.PORT_FIELD; import static org.opensearch.flowframework.common.CommonValue.PROTOCOL_FIELD; import static org.opensearch.flowframework.common.CommonValue.REGISTER_MODEL_STATUS; @@ -73,12 +76,14 @@ public class WorkflowStepFactory { * @param mlClient Machine Learning client to perform ml operations * @param flowFrameworkIndicesHandler FlowFrameworkIndicesHandler class to update system indices * @param flowFrameworkSettings common settings of the plugin + * @param client The OpenSearch Client */ public WorkflowStepFactory( ThreadPool threadPool, MachineLearningNodeClient mlClient, FlowFrameworkIndicesHandler flowFrameworkIndicesHandler, - FlowFrameworkSettings flowFrameworkSettings + FlowFrameworkSettings flowFrameworkSettings, + Client client ) { stepMap.put(NoOpStep.NAME, NoOpStep::new); stepMap.put( @@ -107,6 +112,7 @@ public WorkflowStepFactory( stepMap.put(RegisterAgentStep.NAME, () -> new RegisterAgentStep(mlClient, flowFrameworkIndicesHandler)); stepMap.put(DeleteAgentStep.NAME, () -> new DeleteAgentStep(mlClient)); stepMap.put(HttpHostStep.NAME, HttpHostStep::new); + stepMap.put(CreateIngestPipelineStep.NAME, () -> new CreateIngestPipelineStep(client, flowFrameworkIndicesHandler)); } /** @@ -210,6 +216,15 @@ public enum WorkflowSteps { List.of(HTTP_HOST_FIELD), Collections.emptyList(), null + ), + + /** Create Ingest Pipeline Step */ + CREATE_INGEST_PIPELINE( + CreateIngestPipelineStep.NAME, + List.of(PIPELINE_ID, CONFIGURATIONS), + List.of(PIPELINE_ID), + Collections.emptyList(), + null ); private final String workflowStepName; diff --git a/src/test/java/org/opensearch/flowframework/model/WorkflowValidatorTests.java b/src/test/java/org/opensearch/flowframework/model/WorkflowValidatorTests.java index 5ba5924f3..baf4bfc70 100644 --- a/src/test/java/org/opensearch/flowframework/model/WorkflowValidatorTests.java +++ b/src/test/java/org/opensearch/flowframework/model/WorkflowValidatorTests.java @@ -8,6 +8,7 @@ */ package org.opensearch.flowframework.model; +import org.opensearch.client.Client; import org.opensearch.flowframework.common.FlowFrameworkSettings; import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; import org.opensearch.flowframework.workflow.WorkflowStepFactory; @@ -29,6 +30,7 @@ public class WorkflowValidatorTests extends OpenSearchTestCase { private FlowFrameworkSettings flowFrameworkSettings; + private static Client client = mock(Client.class); @Override public void setUp() throws Exception { @@ -44,7 +46,7 @@ public void testParseWorkflowValidator() throws IOException { WorkflowValidator validator = new WorkflowValidator(workflowStepValidators); - assertEquals(15, validator.getWorkflowStepValidators().size()); + assertEquals(16, validator.getWorkflowStepValidators().size()); assertTrue(validator.getWorkflowStepValidators().keySet().contains("create_connector")); assertEquals(7, validator.getWorkflowStepValidators().get("create_connector").getInputs().size()); @@ -117,7 +119,8 @@ public void testWorkflowStepFactoryHasValidators() throws IOException { threadPool, mlClient, flowFrameworkIndicesHandler, - flowFrameworkSettings + flowFrameworkSettings, + client ); WorkflowValidator workflowValidator = workflowStepFactory.getWorkflowValidator(); diff --git a/src/test/java/org/opensearch/flowframework/rest/RestGetWorkflowStepActionTests.java b/src/test/java/org/opensearch/flowframework/rest/RestGetWorkflowStepActionTests.java index 59df28a42..c7ebbf71e 100644 --- a/src/test/java/org/opensearch/flowframework/rest/RestGetWorkflowStepActionTests.java +++ b/src/test/java/org/opensearch/flowframework/rest/RestGetWorkflowStepActionTests.java @@ -8,6 +8,7 @@ */ package org.opensearch.flowframework.rest; +import org.opensearch.client.Client; import org.opensearch.client.node.NodeClient; import org.opensearch.core.action.ActionListener; import org.opensearch.core.common.bytes.BytesArray; @@ -57,8 +58,15 @@ public void setUp() throws Exception { ThreadPool threadPool = mock(ThreadPool.class); MachineLearningNodeClient mlClient = mock(MachineLearningNodeClient.class); FlowFrameworkIndicesHandler flowFrameworkIndicesHandler = mock(FlowFrameworkIndicesHandler.class); - - this.workflowStepFactory = new WorkflowStepFactory(threadPool, mlClient, flowFrameworkIndicesHandler, flowFrameworkSettings); + Client client = mock(Client.class); + + this.workflowStepFactory = new WorkflowStepFactory( + threadPool, + mlClient, + flowFrameworkIndicesHandler, + flowFrameworkSettings, + client + ); flowFrameworkFeatureEnabledSetting = mock(FlowFrameworkSettings.class); when(flowFrameworkFeatureEnabledSetting.isFlowFrameworkEnabled()).thenReturn(true); this.restGetWorkflowStepAction = new RestGetWorkflowStepAction(flowFrameworkFeatureEnabledSetting); diff --git a/src/test/java/org/opensearch/flowframework/workflow/CreateIngestPipelineStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/CreateIngestPipelineStepTests.java index f8d7bd68f..f8c9402d0 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/CreateIngestPipelineStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/CreateIngestPipelineStepTests.java @@ -28,6 +28,7 @@ import org.mockito.ArgumentCaptor; import static org.opensearch.action.DocWriteResponse.Result.UPDATED; +import static org.opensearch.flowframework.common.CommonValue.CONFIGURATIONS; import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_STATE_INDEX; import static org.opensearch.flowframework.common.WorkflowResources.MODEL_ID; import static org.opensearch.flowframework.common.WorkflowResources.PIPELINE_ID; @@ -54,15 +55,10 @@ public void setUp() throws Exception { super.setUp(); this.flowFrameworkIndicesHandler = mock(FlowFrameworkIndicesHandler.class); + String configurations = + "{“description”:“An neural ingest pipeline”,“processors”:[{“text_embedding”:{“field_map”:{“text”:“analyzed_text”},“model_id”:“sdsadsadasd”}}]}"; inputData = new WorkflowData( - Map.ofEntries( - Map.entry("id", "pipelineId"), - Map.entry("description", "some description"), - Map.entry("type", "text_embedding"), - Map.entry(MODEL_ID, MODEL_ID), - Map.entry("input_field_name", "inputField"), - Map.entry("output_field_name", "outputField") - ), + Map.ofEntries(Map.entry(CONFIGURATIONS, configurations), Map.entry(PIPELINE_ID, "pipelineId")), "test-id", "test-node-id" ); @@ -136,7 +132,7 @@ public void testCreateIngestPipelineStepFailure() throws InterruptedException { } public void testMissingData() throws InterruptedException { - CreateIngestPipelineStep createIngestPipelineStep = new CreateIngestPipelineStep(client, flowFrameworkIndicesHandler); + CreateIngestPipelineStep CreateIngestPipelineStep = new CreateIngestPipelineStep(client, flowFrameworkIndicesHandler); // Data with missing input and output fields WorkflowData incorrectData = new WorkflowData( @@ -150,7 +146,7 @@ public void testMissingData() throws InterruptedException { "test-node-id" ); - PlainActionFuture future = createIngestPipelineStep.execute( + PlainActionFuture future = CreateIngestPipelineStep.execute( incorrectData.getNodeId(), incorrectData, Collections.emptyMap(), @@ -161,7 +157,10 @@ public void testMissingData() throws InterruptedException { ExecutionException exception = assertThrows(ExecutionException.class, () -> future.get()); assertTrue(exception.getCause() instanceof Exception); - assertEquals("Failed to create ingest pipeline for test-node-id, required inputs not found", exception.getCause().getMessage()); + assertEquals( + "Missing required inputs [configurations, pipeline_id] in workflow [test-id] node [test-node-id]", + exception.getCause().getMessage() + ); } } diff --git a/src/test/java/org/opensearch/flowframework/workflow/WorkflowProcessSorterTests.java b/src/test/java/org/opensearch/flowframework/workflow/WorkflowProcessSorterTests.java index 02488a739..df7d1732a 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/WorkflowProcessSorterTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/WorkflowProcessSorterTests.java @@ -118,7 +118,13 @@ public static void setup() throws IOException { FLOW_FRAMEWORK_THREAD_POOL_PREFIX + DEPROVISION_WORKFLOW_THREAD_POOL ) ); - WorkflowStepFactory factory = new WorkflowStepFactory(testThreadPool, mlClient, flowFrameworkIndicesHandler, flowFrameworkSettings); + WorkflowStepFactory factory = new WorkflowStepFactory( + testThreadPool, + mlClient, + flowFrameworkIndicesHandler, + flowFrameworkSettings, + client + ); workflowProcessSorter = new WorkflowProcessSorter(factory, testThreadPool, clusterService, client, flowFrameworkSettings); }