diff --git a/src/main/java/org/opensearch/flowframework/common/CommonValue.java b/src/main/java/org/opensearch/flowframework/common/CommonValue.java index 5be7482e0..ef54addff 100644 --- a/src/main/java/org/opensearch/flowframework/common/CommonValue.java +++ b/src/main/java/org/opensearch/flowframework/common/CommonValue.java @@ -188,4 +188,6 @@ private CommonValue() {} public static final String RESOURCE_TYPE = "resource_type"; /** The field name for the resource id */ public static final String RESOURCE_ID = "resource_id"; + /** The field name for the opensearch-ml plugin */ + public static final String OPENSEARCH_ML = "opensearch-ml"; } diff --git a/src/main/java/org/opensearch/flowframework/workflow/ToolStep.java b/src/main/java/org/opensearch/flowframework/workflow/ToolStep.java index 9e019eb47..2a5536638 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/ToolStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/ToolStep.java @@ -31,7 +31,6 @@ /** * Step to register a tool for an agent */ -@SuppressWarnings("unchecked") public class ToolStep implements WorkflowStep { private static final Logger logger = LogManager.getLogger(ToolStep.class); @@ -107,6 +106,7 @@ private Map getToolsParametersMap( Map previousNodeInputs, Map outputs ) { + @SuppressWarnings("unchecked") Map parametersMap = (Map) parameters; Optional previousNodeModel = previousNodeInputs.entrySet() .stream() diff --git a/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java b/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java index 7ccba94af..7f54b17cc 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java +++ b/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java @@ -13,6 +13,7 @@ import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.unit.TimeValue; +import org.opensearch.core.common.Strings; import org.opensearch.core.rest.RestStatus; import org.opensearch.flowframework.common.FlowFrameworkSettings; import org.opensearch.flowframework.exception.FlowFrameworkException; @@ -22,16 +23,36 @@ import org.opensearch.ml.client.MachineLearningNodeClient; import org.opensearch.threadpool.ThreadPool; -import java.util.Arrays; import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.function.Supplier; -import java.util.stream.Collectors; -import java.util.stream.Stream; -import static java.util.concurrent.TimeUnit.SECONDS; +import static org.opensearch.flowframework.common.CommonValue.ACTIONS_FIELD; +import static org.opensearch.flowframework.common.CommonValue.CREDENTIAL_FIELD; +import static org.opensearch.flowframework.common.CommonValue.DESCRIPTION_FIELD; +import static org.opensearch.flowframework.common.CommonValue.EMBEDDING_DIMENSION; +import static org.opensearch.flowframework.common.CommonValue.FRAMEWORK_TYPE; +import static org.opensearch.flowframework.common.CommonValue.FUNCTION_NAME; +import static org.opensearch.flowframework.common.CommonValue.MODEL_CONTENT_HASH_VALUE; +import static org.opensearch.flowframework.common.CommonValue.MODEL_FORMAT; +import static org.opensearch.flowframework.common.CommonValue.MODEL_GROUP_STATUS; +import static org.opensearch.flowframework.common.CommonValue.MODEL_TYPE; +import static org.opensearch.flowframework.common.CommonValue.NAME_FIELD; +import static org.opensearch.flowframework.common.CommonValue.OPENSEARCH_ML; +import static org.opensearch.flowframework.common.CommonValue.PARAMETERS_FIELD; +import static org.opensearch.flowframework.common.CommonValue.PROTOCOL_FIELD; +import static org.opensearch.flowframework.common.CommonValue.REGISTER_MODEL_STATUS; +import static org.opensearch.flowframework.common.CommonValue.SUCCESS; +import static org.opensearch.flowframework.common.CommonValue.TOOLS_FIELD; +import static org.opensearch.flowframework.common.CommonValue.TYPE; +import static org.opensearch.flowframework.common.CommonValue.URL; +import static org.opensearch.flowframework.common.CommonValue.VERSION_FIELD; +import static org.opensearch.flowframework.common.WorkflowResources.AGENT_ID; +import static org.opensearch.flowframework.common.WorkflowResources.CONNECTOR_ID; +import static org.opensearch.flowframework.common.WorkflowResources.MODEL_GROUP_ID; +import static org.opensearch.flowframework.common.WorkflowResources.MODEL_ID; /** * Generates instances implementing {@link WorkflowStep}. @@ -40,6 +61,10 @@ public class WorkflowStepFactory { private final Map> stepMap = new HashMap<>(); private static final Logger logger = LogManager.getLogger(WorkflowStepFactory.class); + private static ThreadPool threadPool; + private static MachineLearningNodeClient mlClient; + private static FlowFrameworkIndicesHandler flowFrameworkIndicesHandler; + private static FlowFrameworkSettings flowFrameworkSettings; /** * Instantiate this class. @@ -59,32 +84,14 @@ public WorkflowStepFactory( FlowFrameworkIndicesHandler flowFrameworkIndicesHandler, FlowFrameworkSettings flowFrameworkSettings ) { - stepMap.put(NoOpStep.NAME, NoOpStep::new); - stepMap.put( - RegisterLocalCustomModelStep.NAME, - () -> new RegisterLocalCustomModelStep(threadPool, mlClient, flowFrameworkIndicesHandler, flowFrameworkSettings) - ); - stepMap.put( - RegisterLocalSparseEncodingModelStep.NAME, - () -> new RegisterLocalSparseEncodingModelStep(threadPool, mlClient, flowFrameworkIndicesHandler, flowFrameworkSettings) - ); - stepMap.put( - RegisterLocalPretrainedModelStep.NAME, - () -> new RegisterLocalPretrainedModelStep(threadPool, mlClient, flowFrameworkIndicesHandler, flowFrameworkSettings) - ); - stepMap.put(RegisterRemoteModelStep.NAME, () -> new RegisterRemoteModelStep(mlClient, flowFrameworkIndicesHandler)); - stepMap.put(DeleteModelStep.NAME, () -> new DeleteModelStep(mlClient)); - stepMap.put( - DeployModelStep.NAME, - () -> new DeployModelStep(threadPool, mlClient, flowFrameworkIndicesHandler, flowFrameworkSettings) - ); - stepMap.put(UndeployModelStep.NAME, () -> new UndeployModelStep(mlClient)); - stepMap.put(CreateConnectorStep.NAME, () -> new CreateConnectorStep(mlClient, flowFrameworkIndicesHandler)); - stepMap.put(DeleteConnectorStep.NAME, () -> new DeleteConnectorStep(mlClient)); - stepMap.put(RegisterModelGroupStep.NAME, () -> new RegisterModelGroupStep(mlClient, flowFrameworkIndicesHandler)); - stepMap.put(ToolStep.NAME, ToolStep::new); - stepMap.put(RegisterAgentStep.NAME, () -> new RegisterAgentStep(mlClient, flowFrameworkIndicesHandler)); - stepMap.put(DeleteAgentStep.NAME, () -> new DeleteAgentStep(mlClient)); + this.threadPool = threadPool; + this.mlClient = mlClient; + this.flowFrameworkIndicesHandler = flowFrameworkIndicesHandler; + this.flowFrameworkSettings = flowFrameworkSettings; + // Initialize the WorkflowSteps enum inside the constructor + for (WorkflowSteps workflowStep : WorkflowSteps.values()) { + stepMap.put(workflowStep.getWorkflowStepName(), workflowStep.step()); + } } /** @@ -93,135 +100,219 @@ public WorkflowStepFactory( public enum WorkflowSteps { - NOOP("noop", Collections.emptyList(), Collections.emptyList(), Collections.emptyList(), null), + /** Noop Step */ + NOOP("noop", Collections.emptyList(), Collections.emptyList(), Collections.emptyList(), null, NoOpStep::new), + /** Create Connector Step */ CREATE_CONNECTOR( - "create_connector", - Arrays.asList("name", "description", "version", "protocol", "parameters", "credential", "actions"), - Arrays.asList("connector_id"), - Arrays.asList("opensearch-ml"), - new TimeValue(60, SECONDS) + CreateConnectorStep.NAME, + List.of(NAME_FIELD, DESCRIPTION_FIELD, VERSION_FIELD, PROTOCOL_FIELD, PARAMETERS_FIELD, CREDENTIAL_FIELD, ACTIONS_FIELD), + List.of(CONNECTOR_ID), + List.of(OPENSEARCH_ML), + TimeValue.timeValueSeconds(60), + () -> new CreateConnectorStep(mlClient, flowFrameworkIndicesHandler) ), + /** Register Local Custom Model Step */ REGISTER_LOCAL_CUSTOM_MODEL( - "register_local_custom_model", - Arrays.asList( - "name", - "version", - "model_format", - "function_name", - "model_content_hash_value", - "url", - "model_type", - "embedding_dimension", - "framework_type" + RegisterLocalCustomModelStep.NAME, + List.of( + NAME_FIELD, + VERSION_FIELD, + MODEL_FORMAT, + FUNCTION_NAME, + MODEL_CONTENT_HASH_VALUE, + URL, + MODEL_TYPE, + EMBEDDING_DIMENSION, + FRAMEWORK_TYPE ), - Arrays.asList("model_id", "register_model_status"), - Arrays.asList("opensearch-ml"), - new TimeValue(60, SECONDS) + List.of(MODEL_ID, REGISTER_MODEL_STATUS), + List.of(OPENSEARCH_ML), + TimeValue.timeValueSeconds(60), + () -> new RegisterLocalCustomModelStep(threadPool, mlClient, flowFrameworkIndicesHandler, flowFrameworkSettings) ), + /** Register Local Sparse Encoding Model Step */ REGISTER_LOCAL_SPARSE_ENCODING_MODEL( - "register_local_sparse_encoding_model", - Arrays.asList("name", "version", "model_format", "function_name", "model_content_hash_value", "url"), - Arrays.asList("model_id", "register_model_status"), - Arrays.asList("opensearch-ml"), - new TimeValue(60, SECONDS) + RegisterLocalSparseEncodingModelStep.NAME, + List.of(NAME_FIELD, VERSION_FIELD, MODEL_FORMAT, FUNCTION_NAME, MODEL_CONTENT_HASH_VALUE, URL), + List.of(MODEL_ID, REGISTER_MODEL_STATUS), + List.of(OPENSEARCH_ML), + TimeValue.timeValueSeconds(60), + () -> new RegisterLocalSparseEncodingModelStep(threadPool, mlClient, flowFrameworkIndicesHandler, flowFrameworkSettings) ), + + /** Register Local Pretrained Model Step */ REGISTER_LOCAL_PRETRAINED_MODEL( - "register_local_pretrained_model", - Arrays.asList("name", "version", "model_format"), - Arrays.asList("model_id", "register_model_status"), - Arrays.asList("opensearch-ml"), - new TimeValue(60, SECONDS) + RegisterLocalPretrainedModelStep.NAME, + List.of(NAME_FIELD, VERSION_FIELD, MODEL_FORMAT), + List.of(MODEL_ID, REGISTER_MODEL_STATUS), + List.of(OPENSEARCH_ML), + TimeValue.timeValueSeconds(60), + () -> new RegisterLocalPretrainedModelStep(threadPool, mlClient, flowFrameworkIndicesHandler, flowFrameworkSettings) ), + /** Register Remote Model Step */ REGISTER_REMOTE_MODEL( - "register_remote_model", - Arrays.asList("name", "connector_id"), - Arrays.asList("model_id", "register_model_status"), - Arrays.asList("opensearch-ml"), - null + RegisterRemoteModelStep.NAME, + List.of(NAME_FIELD, CONNECTOR_ID), + List.of(MODEL_ID, REGISTER_MODEL_STATUS), + List.of(OPENSEARCH_ML), + null, + () -> new RegisterRemoteModelStep(mlClient, flowFrameworkIndicesHandler) ), + /** Register Model Group Step */ REGISTER_MODEL_GROUP( - "register_model_group", - Arrays.asList("name"), - Arrays.asList("model_group_id", "model_group_status"), - Arrays.asList("opensearch-ml"), - null + RegisterModelGroupStep.NAME, + List.of(NAME_FIELD), + List.of(MODEL_GROUP_ID, MODEL_GROUP_STATUS), + List.of(OPENSEARCH_ML), + null, + () -> new RegisterModelGroupStep(mlClient, flowFrameworkIndicesHandler) ), + /** Deploy Model Step */ DEPLOY_MODEL( - "deploy_model", - Arrays.asList("model_id"), - Arrays.asList("deploy_model_status"), - Arrays.asList("opensearch-ml"), - new TimeValue(15, SECONDS) + DeployModelStep.NAME, + List.of(MODEL_ID), + List.of(MODEL_ID), + List.of(OPENSEARCH_ML), + TimeValue.timeValueSeconds(15), + () -> new DeployModelStep(threadPool, mlClient, flowFrameworkIndicesHandler, flowFrameworkSettings) ), - UNDEPLOY_MODEL("undeploy_model", Arrays.asList("model_id"), Arrays.asList("success"), Arrays.asList("opensearch-ml"), null), + /** Undeploy Model Step */ + UNDEPLOY_MODEL( + UndeployModelStep.NAME, + List.of(MODEL_ID), + List.of(SUCCESS), + List.of(OPENSEARCH_ML), + null, + () -> new UndeployModelStep(mlClient) + ), - DELETE_MODEL("delete_model", Arrays.asList("model_id"), Arrays.asList("model_id"), Arrays.asList("opensearch-ml"), null), + /** Delete Model Step */ + DELETE_MODEL( + DeleteModelStep.NAME, + List.of(MODEL_ID), + List.of(MODEL_ID), + List.of(OPENSEARCH_ML), + null, + () -> new DeleteModelStep(mlClient) + ), + /** Delete Connector Step */ DELETE_CONNECTOR( - "delete_connector", - Arrays.asList("connector_id"), - Arrays.asList("connector_id"), - Arrays.asList("opensearch-ml"), - null + DeleteConnectorStep.NAME, + List.of(CONNECTOR_ID), + List.of(CONNECTOR_ID), + List.of(OPENSEARCH_ML), + null, + () -> new DeleteConnectorStep(mlClient) ), - REGISTER_AGENT("register_agent", Arrays.asList("name", "type"), Arrays.asList("agent_id"), Arrays.asList("opensearch-ml"), null), + /** Register Agent Step */ + REGISTER_AGENT( + RegisterAgentStep.NAME, + List.of(NAME_FIELD, TYPE), + List.of(AGENT_ID), + List.of(OPENSEARCH_ML), + null, + () -> new RegisterAgentStep(mlClient, flowFrameworkIndicesHandler) + ), - DELETE_AGENT("delete_agent", Arrays.asList("agent_id"), Arrays.asList("agent_id"), Arrays.asList("opensearch-ml"), null), + /** Delete Agent Step */ + DELETE_AGENT( + DeleteAgentStep.NAME, + List.of(AGENT_ID), + List.of(AGENT_ID), + List.of(OPENSEARCH_ML), + null, + () -> new DeleteAgentStep(mlClient) + ), - CREATE_TOOL("create_tool", Arrays.asList("type"), Arrays.asList("tools"), Arrays.asList("opensearch-ml"), null); + /** Create Tool Step */ + CREATE_TOOL(ToolStep.NAME, List.of(TYPE), List.of(TOOLS_FIELD), List.of(OPENSEARCH_ML), null, ToolStep::new); - private final String workflowStep; + private final String workflowStepName; private final List inputs; private final List outputs; private final List requiredPlugins; private final TimeValue timeout; - - private static final List allWorkflowSteps = Stream.of(values()) - .map(WorkflowSteps::getWorkflowStep) - .collect(Collectors.toList()); - - WorkflowSteps(String workflowStep, List inputs, List outputs, List requiredPlugins, TimeValue timeout) { - this.workflowStep = workflowStep; + private final Supplier workflowStep; + + WorkflowSteps( + String workflowStepName, + List inputs, + List outputs, + List requiredPlugins, + TimeValue timeout, + Supplier workflowStep + ) { + this.workflowStepName = workflowStepName; this.inputs = List.copyOf(inputs); this.outputs = List.copyOf(outputs); this.requiredPlugins = requiredPlugins; this.timeout = timeout; + this.workflowStep = workflowStep; } /** * Returns the workflowStep for the given enum Constant * @return the workflowStep of this data. */ - public String getWorkflowStep() { - return workflowStep; + public String getWorkflowStepName() { + return workflowStepName; } - public List getInputs() { + /** + * Get the required inputs + * @return the inputs + */ + public List inputs() { return inputs; } - public List getOutputs() { + /** + * Get the required outputs + * @return the outputs + */ + public List outputs() { return outputs; } - public List getRequiredPlugins() { + /** + * Get the required plugins + * @return the required plugins + */ + public List requiredPlugins() { return requiredPlugins; } - public TimeValue getTimeout() { + /** + * Get the timeout + * @return the timeout + */ + public TimeValue timeout() { return timeout; } + /** + * Get the step + * @return the step + */ + public Supplier step() { + return workflowStep; + } + + /** + * Get the workflow step validator object + * @return the WorkflowStepValidator + */ public WorkflowStepValidator getWorkflowStepValidator() { - return new WorkflowStepValidator(workflowStep, inputs, outputs, requiredPlugins, timeout); + return new WorkflowStepValidator(workflowStepName, inputs, outputs, requiredPlugins, timeout); }; /** @@ -231,15 +322,16 @@ public WorkflowStepValidator getWorkflowStepValidator() { * @throws FlowFrameworkException if workflow step doesn't exist in enum */ public static TimeValue getTimeoutByWorkflowType(String workflowStep) throws FlowFrameworkException { - if (workflowStep != null && !workflowStep.isEmpty()) { + if (!Strings.isNullOrEmpty(workflowStep)) { for (WorkflowSteps mapping : values()) { - if (workflowStep.equals(mapping.getWorkflowStep())) { - return mapping.getTimeout(); + if (workflowStep.equals(mapping.getWorkflowStepName())) { + return mapping.timeout(); } } } - logger.error("Unable to find workflow timeout for step: {}", workflowStep); - throw new FlowFrameworkException("Unable to find workflow timeout for step: " + workflowStep, RestStatus.BAD_REQUEST); + String errorMessage = "Unable to find workflow timeout for step: " + workflowStep; + logger.error(errorMessage); + throw new FlowFrameworkException(errorMessage, RestStatus.BAD_REQUEST); } /** @@ -249,15 +341,16 @@ public static TimeValue getTimeoutByWorkflowType(String workflowStep) throws Flo * @throws FlowFrameworkException if workflow step doesn't exist in enum */ public static List getRequiredPluginsByWorkflowType(String workflowStep) throws FlowFrameworkException { - if (workflowStep != null && !workflowStep.isEmpty()) { + if (!Strings.isNullOrEmpty(workflowStep)) { for (WorkflowSteps mapping : values()) { - if (workflowStep.equals(mapping.getWorkflowStep())) { - return mapping.getRequiredPlugins(); + if (workflowStep.equals(mapping.getWorkflowStepName())) { + return mapping.requiredPlugins(); } } } - logger.error("Unable to find workflow required plugins for step: {}", workflowStep); - throw new FlowFrameworkException("Unable to find workflow required plugins for step: " + workflowStep, RestStatus.BAD_REQUEST); + String errorMessage = "Unable to find workflow required plugins for step: " + workflowStep; + logger.error(errorMessage); + throw new FlowFrameworkException(errorMessage, RestStatus.BAD_REQUEST); } /** @@ -267,15 +360,16 @@ public static List getRequiredPluginsByWorkflowType(String workflowStep) * @throws FlowFrameworkException if workflow step doesn't exist in enum */ public static List getOutputByWorkflowType(String workflowStep) throws FlowFrameworkException { - if (workflowStep != null && !workflowStep.isEmpty()) { + if (!Strings.isNullOrEmpty(workflowStep)) { for (WorkflowSteps mapping : values()) { - if (workflowStep.equals(mapping.getWorkflowStep())) { - return mapping.getOutputs(); + if (workflowStep.equals(mapping.getWorkflowStepName())) { + return mapping.outputs(); } } } - logger.error("Unable to find workflow output for step: {}", workflowStep); - throw new FlowFrameworkException("Unable to find workflow output for step: " + workflowStep, RestStatus.BAD_REQUEST); + String errorMessage = "Unable to find workflow output for step " + workflowStep; + logger.error(errorMessage); + throw new FlowFrameworkException(errorMessage, RestStatus.BAD_REQUEST); } /** @@ -285,15 +379,16 @@ public static List getOutputByWorkflowType(String workflowStep) throws F * @throws FlowFrameworkException if workflow step doesn't exist in enum */ public static List getInputByWorkflowType(String workflowStep) throws FlowFrameworkException { - if (workflowStep != null && !workflowStep.isEmpty()) { + if (!Strings.isNullOrEmpty(workflowStep)) { for (WorkflowSteps mapping : values()) { - if (workflowStep.equals(mapping.getWorkflowStep())) { - return mapping.getInputs(); + if (workflowStep.equals(mapping.getWorkflowStepName())) { + return mapping.inputs(); } } } - logger.error("Unable to find workflow input for step: {}", workflowStep); - throw new FlowFrameworkException("Unable to find workflow input for step: " + workflowStep, RestStatus.BAD_REQUEST); + String errorMessage = "Unable to find workflow input for step: " + workflowStep; + logger.error(errorMessage); + throw new FlowFrameworkException(errorMessage, RestStatus.BAD_REQUEST); } } @@ -306,7 +401,7 @@ public WorkflowValidator getWorkflowValidator() { Map workflowStepValidators = new HashMap<>(); for (WorkflowSteps mapping : WorkflowSteps.values()) { - workflowStepValidators.put(mapping.getWorkflowStep(), mapping.getWorkflowStepValidator()); + workflowStepValidators.put(mapping.getWorkflowStepName(), mapping.getWorkflowStepValidator()); } return new WorkflowValidator(workflowStepValidators); diff --git a/src/test/java/org/opensearch/flowframework/model/WorkflowStepValidatorTests.java b/src/test/java/org/opensearch/flowframework/model/WorkflowStepValidatorTests.java index d02cee91b..f155e1f90 100644 --- a/src/test/java/org/opensearch/flowframework/model/WorkflowStepValidatorTests.java +++ b/src/test/java/org/opensearch/flowframework/model/WorkflowStepValidatorTests.java @@ -17,29 +17,24 @@ public class WorkflowStepValidatorTests extends OpenSearchTestCase { - private String validValidator; - private String invalidValidator; - @Override public void setUp() throws Exception { super.setUp(); - validValidator = "{\"inputs\":[\"input_value\"],\"outputs\":[\"output_value\"]}"; - invalidValidator = "{\"inputs\":[\"input_value\"],\"invalid_field\":[\"output_value\"]}"; } public void testParseWorkflowStepValidator() throws IOException { Map workflowStepValidators = new HashMap<>(); workflowStepValidators.put( - WorkflowStepFactory.WorkflowSteps.CREATE_CONNECTOR.getWorkflowStep(), + WorkflowStepFactory.WorkflowSteps.CREATE_CONNECTOR.getWorkflowStepName(), WorkflowStepFactory.WorkflowSteps.CREATE_CONNECTOR.getWorkflowStepValidator() ); - assertEquals(7, WorkflowStepFactory.WorkflowSteps.CREATE_CONNECTOR.getInputs().size()); - assertEquals(1, WorkflowStepFactory.WorkflowSteps.CREATE_CONNECTOR.getOutputs().size()); + assertEquals(7, WorkflowStepFactory.WorkflowSteps.CREATE_CONNECTOR.inputs().size()); + assertEquals(1, WorkflowStepFactory.WorkflowSteps.CREATE_CONNECTOR.outputs().size()); - assertEquals("name", WorkflowStepFactory.WorkflowSteps.CREATE_CONNECTOR.getInputs().get(0)); - assertEquals("connector_id", WorkflowStepFactory.WorkflowSteps.CREATE_CONNECTOR.getOutputs().get(0)); + assertEquals("name", WorkflowStepFactory.WorkflowSteps.CREATE_CONNECTOR.inputs().get(0)); + assertEquals("connector_id", WorkflowStepFactory.WorkflowSteps.CREATE_CONNECTOR.outputs().get(0)); } } diff --git a/src/test/java/org/opensearch/flowframework/model/WorkflowValidatorTests.java b/src/test/java/org/opensearch/flowframework/model/WorkflowValidatorTests.java index 6dc0e02d5..73fb3186c 100644 --- a/src/test/java/org/opensearch/flowframework/model/WorkflowValidatorTests.java +++ b/src/test/java/org/opensearch/flowframework/model/WorkflowValidatorTests.java @@ -40,17 +40,11 @@ public class WorkflowValidatorTests extends OpenSearchTestCase { - private String validWorkflowStepJson; - private String invalidWorkflowStepJson; private FlowFrameworkSettings flowFrameworkSettings; @Override public void setUp() throws Exception { super.setUp(); - validWorkflowStepJson = - "{\"workflow_step_1\":{\"inputs\":[\"input_1\",\"input_2\"],\"outputs\":[\"output_1\"]},\"workflow_step_2\":{\"inputs\":[\"input_1\",\"input_2\",\"input_3\"],\"outputs\":[\"output_1\",\"output_2\",\"output_3\"]}}"; - invalidWorkflowStepJson = - "{\"workflow_step_1\":{\"bad_field\":[\"input_1\",\"input_2\"],\"outputs\":[\"output_1\"]},\"workflow_step_2\":{\"inputs\":[\"input_1\",\"input_2\",\"input_3\"],\"outputs\":[\"output_1\",\"output_2\",\"output_3\"]}}"; flowFrameworkSettings = mock(FlowFrameworkSettings.class); when(flowFrameworkSettings.isFlowFrameworkEnabled()).thenReturn(true); @@ -59,23 +53,122 @@ public void setUp() throws Exception { public void testParseWorkflowValidator() throws IOException { Map workflowStepValidators = new HashMap<>(); workflowStepValidators.put( - WorkflowStepFactory.WorkflowSteps.CREATE_CONNECTOR.getWorkflowStep(), + WorkflowStepFactory.WorkflowSteps.CREATE_CONNECTOR.getWorkflowStepName(), WorkflowStepFactory.WorkflowSteps.CREATE_CONNECTOR.getWorkflowStepValidator() ); workflowStepValidators.put( - WorkflowStepFactory.WorkflowSteps.DELETE_MODEL.getWorkflowStep(), + WorkflowStepFactory.WorkflowSteps.DELETE_MODEL.getWorkflowStepName(), WorkflowStepFactory.WorkflowSteps.DELETE_MODEL.getWorkflowStepValidator() ); + workflowStepValidators.put( + WorkflowStepFactory.WorkflowSteps.DEPLOY_MODEL.getWorkflowStepName(), + WorkflowStepFactory.WorkflowSteps.DEPLOY_MODEL.getWorkflowStepValidator() + ); + workflowStepValidators.put( + WorkflowStepFactory.WorkflowSteps.REGISTER_LOCAL_CUSTOM_MODEL.getWorkflowStepName(), + WorkflowStepFactory.WorkflowSteps.REGISTER_LOCAL_CUSTOM_MODEL.getWorkflowStepValidator() + ); + workflowStepValidators.put( + WorkflowStepFactory.WorkflowSteps.REGISTER_LOCAL_PRETRAINED_MODEL.getWorkflowStepName(), + WorkflowStepFactory.WorkflowSteps.REGISTER_LOCAL_PRETRAINED_MODEL.getWorkflowStepValidator() + ); + workflowStepValidators.put( + WorkflowStepFactory.WorkflowSteps.REGISTER_LOCAL_SPARSE_ENCODING_MODEL.getWorkflowStepName(), + WorkflowStepFactory.WorkflowSteps.REGISTER_LOCAL_SPARSE_ENCODING_MODEL.getWorkflowStepValidator() + ); + workflowStepValidators.put( + WorkflowStepFactory.WorkflowSteps.REGISTER_REMOTE_MODEL.getWorkflowStepName(), + WorkflowStepFactory.WorkflowSteps.REGISTER_REMOTE_MODEL.getWorkflowStepValidator() + ); + workflowStepValidators.put( + WorkflowStepFactory.WorkflowSteps.REGISTER_MODEL_GROUP.getWorkflowStepName(), + WorkflowStepFactory.WorkflowSteps.REGISTER_MODEL_GROUP.getWorkflowStepValidator() + ); + workflowStepValidators.put( + WorkflowStepFactory.WorkflowSteps.REGISTER_AGENT.getWorkflowStepName(), + WorkflowStepFactory.WorkflowSteps.REGISTER_AGENT.getWorkflowStepValidator() + ); + workflowStepValidators.put( + WorkflowStepFactory.WorkflowSteps.CREATE_TOOL.getWorkflowStepName(), + WorkflowStepFactory.WorkflowSteps.CREATE_TOOL.getWorkflowStepValidator() + ); + workflowStepValidators.put( + WorkflowStepFactory.WorkflowSteps.UNDEPLOY_MODEL.getWorkflowStepName(), + WorkflowStepFactory.WorkflowSteps.UNDEPLOY_MODEL.getWorkflowStepValidator() + ); + workflowStepValidators.put( + WorkflowStepFactory.WorkflowSteps.DELETE_CONNECTOR.getWorkflowStepName(), + WorkflowStepFactory.WorkflowSteps.DELETE_CONNECTOR.getWorkflowStepValidator() + ); + workflowStepValidators.put( + WorkflowStepFactory.WorkflowSteps.DELETE_AGENT.getWorkflowStepName(), + WorkflowStepFactory.WorkflowSteps.DELETE_AGENT.getWorkflowStepValidator() + ); + workflowStepValidators.put( + WorkflowStepFactory.WorkflowSteps.NOOP.getWorkflowStepName(), + WorkflowStepFactory.WorkflowSteps.NOOP.getWorkflowStepValidator() + ); WorkflowValidator validator = new WorkflowValidator(workflowStepValidators); - assertEquals(2, validator.getWorkflowStepValidators().size()); + assertEquals(14, validator.getWorkflowStepValidators().size()); + assertTrue(validator.getWorkflowStepValidators().keySet().contains("create_connector")); assertEquals(7, validator.getWorkflowStepValidators().get("create_connector").getInputs().size()); assertEquals(1, validator.getWorkflowStepValidators().get("create_connector").getOutputs().size()); + assertTrue(validator.getWorkflowStepValidators().keySet().contains("delete_model")); assertEquals(1, validator.getWorkflowStepValidators().get("delete_model").getInputs().size()); assertEquals(1, validator.getWorkflowStepValidators().get("delete_model").getOutputs().size()); + + assertTrue(validator.getWorkflowStepValidators().keySet().contains("deploy_model")); + assertEquals(1, validator.getWorkflowStepValidators().get("deploy_model").getInputs().size()); + assertEquals(1, validator.getWorkflowStepValidators().get("deploy_model").getOutputs().size()); + + assertTrue(validator.getWorkflowStepValidators().keySet().contains("register_remote_model")); + assertEquals(2, validator.getWorkflowStepValidators().get("register_remote_model").getInputs().size()); + assertEquals(2, validator.getWorkflowStepValidators().get("register_remote_model").getOutputs().size()); + + assertTrue(validator.getWorkflowStepValidators().keySet().contains("register_model_group")); + assertEquals(1, validator.getWorkflowStepValidators().get("register_model_group").getInputs().size()); + assertEquals(2, validator.getWorkflowStepValidators().get("register_model_group").getOutputs().size()); + + assertTrue(validator.getWorkflowStepValidators().keySet().contains("register_local_custom_model")); + assertEquals(9, validator.getWorkflowStepValidators().get("register_local_custom_model").getInputs().size()); + assertEquals(2, validator.getWorkflowStepValidators().get("register_local_custom_model").getOutputs().size()); + + assertTrue(validator.getWorkflowStepValidators().keySet().contains("register_local_sparse_encoding_model")); + assertEquals(6, validator.getWorkflowStepValidators().get("register_local_sparse_encoding_model").getInputs().size()); + assertEquals(2, validator.getWorkflowStepValidators().get("register_local_sparse_encoding_model").getOutputs().size()); + + assertTrue(validator.getWorkflowStepValidators().keySet().contains("register_local_pretrained_model")); + assertEquals(3, validator.getWorkflowStepValidators().get("register_local_pretrained_model").getInputs().size()); + assertEquals(2, validator.getWorkflowStepValidators().get("register_local_pretrained_model").getOutputs().size()); + + assertTrue(validator.getWorkflowStepValidators().keySet().contains("undeploy_model")); + assertEquals(1, validator.getWorkflowStepValidators().get("undeploy_model").getInputs().size()); + assertEquals(1, validator.getWorkflowStepValidators().get("undeploy_model").getOutputs().size()); + + assertTrue(validator.getWorkflowStepValidators().keySet().contains("delete_connector")); + assertEquals(1, validator.getWorkflowStepValidators().get("delete_connector").getInputs().size()); + assertEquals(1, validator.getWorkflowStepValidators().get("delete_connector").getOutputs().size()); + + assertTrue(validator.getWorkflowStepValidators().keySet().contains("register_agent")); + assertEquals(2, validator.getWorkflowStepValidators().get("register_agent").getInputs().size()); + assertEquals(1, validator.getWorkflowStepValidators().get("register_agent").getOutputs().size()); + + assertTrue(validator.getWorkflowStepValidators().keySet().contains("delete_agent")); + assertEquals(1, validator.getWorkflowStepValidators().get("delete_agent").getInputs().size()); + assertEquals(1, validator.getWorkflowStepValidators().get("delete_agent").getOutputs().size()); + + assertTrue(validator.getWorkflowStepValidators().keySet().contains("create_tool")); + assertEquals(1, validator.getWorkflowStepValidators().get("create_tool").getInputs().size()); + assertEquals(1, validator.getWorkflowStepValidators().get("create_tool").getOutputs().size()); + + assertTrue(validator.getWorkflowStepValidators().keySet().contains("noop")); + assertEquals(0, validator.getWorkflowStepValidators().get("noop").getInputs().size()); + assertEquals(0, validator.getWorkflowStepValidators().get("noop").getOutputs().size()); + } public void testWorkflowStepFactoryHasValidators() throws IOException {