diff --git a/CHANGELOG.md b/CHANGELOG.md index ad2517357..5377c0296 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,8 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.1.0/) ## [Unreleased 2.x](https://github.com/opensearch-project/flow-framework/compare/2.12...2.x) ### Features +- Add HttpHost WorkflowStep ([#530](https://github.com/opensearch-project/flow-framework/pull/530)) + ### Enhancements - Substitute REST path or body parameters in Workflow Steps ([#525](https://github.com/opensearch-project/flow-framework/pull/525)) diff --git a/DEVELOPER_GUIDE.md b/DEVELOPER_GUIDE.md index 7b76a7dfe..b4149ae6d 100644 --- a/DEVELOPER_GUIDE.md +++ b/DEVELOPER_GUIDE.md @@ -105,7 +105,7 @@ snapshots/ ### Adding Workflow Steps To add functionality to workflows, add new Workflow Steps to the [`org.opensearch.flowframework.workflow`](https://github.com/opensearch-project/flow-framework/tree/main/src/main/java/org/opensearch/flowframework/workflow) package. -1. Implement the [Workflow](https://github.com/opensearch-project/flow-framework/blob/main/src/main/java/org/opensearch/flowframework/workflow/WorkflowStep.java) interface. See existing steps for examples for input, output, and API execution. +1. Implement the [WorkflowStep](https://github.com/opensearch-project/flow-framework/blob/main/src/main/java/org/opensearch/flowframework/workflow/WorkflowStep.java) interface. See existing steps for examples for input, output, and API execution. 2. Choose a unique name for the step which is not used by other steps. This will align with the `step_type` field in the templates and should be descriptive of what the step does. 3. Add a constructor and call it from the [WorkflowStepFactory](https://github.com/opensearch-project/flow-framework/blob/main/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java). 4. Add an entry to the [WorkflowStepFactory](https://github.com/opensearch-project/flow-framework/blob/main/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java) enum specifying required inputs, outputs, required plugins, and optionally a different timeout than the default. diff --git a/src/main/java/org/opensearch/flowframework/common/CommonValue.java b/src/main/java/org/opensearch/flowframework/common/CommonValue.java index ef54addff..c1752a018 100644 --- a/src/main/java/org/opensearch/flowframework/common/CommonValue.java +++ b/src/main/java/org/opensearch/flowframework/common/CommonValue.java @@ -158,6 +158,14 @@ private CommonValue() {} public static final String CREATED_TIME = "created_time"; /** The last updated time field for an agent */ public static final String LAST_UPDATED_TIME_FIELD = "last_updated_time"; + /** HttpHost */ + public static final String HTTP_HOST_FIELD = "http_host"; + /** Http scheme */ + public static final String SCHEME_FIELD = "scheme"; + /** Http hostname */ + public static final String HOSTNAME_FIELD = "hostname"; + /** Http port */ + public static final String PORT_FIELD = "port"; /* * Constants associated with resource provisioning / state diff --git a/src/main/java/org/opensearch/flowframework/workflow/HttpHostStep.java b/src/main/java/org/opensearch/flowframework/workflow/HttpHostStep.java new file mode 100644 index 000000000..84480b8a8 --- /dev/null +++ b/src/main/java/org/opensearch/flowframework/workflow/HttpHostStep.java @@ -0,0 +1,115 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.workflow; + +import org.apache.http.HttpHost; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.action.support.PlainActionFuture; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.flowframework.exception.FlowFrameworkException; +import org.opensearch.flowframework.util.ParseUtils; + +import java.util.Collections; +import java.util.Locale; +import java.util.Map; +import java.util.Set; + +import static org.opensearch.flowframework.common.CommonValue.HOSTNAME_FIELD; +import static org.opensearch.flowframework.common.CommonValue.HTTP_HOST_FIELD; +import static org.opensearch.flowframework.common.CommonValue.PORT_FIELD; +import static org.opensearch.flowframework.common.CommonValue.SCHEME_FIELD; + +/** + * Step to register parameters for an HTTP Connection to a Host + */ +public class HttpHostStep implements WorkflowStep { + + private static final Logger logger = LogManager.getLogger(HttpHostStep.class); + PlainActionFuture hostFuture = PlainActionFuture.newFuture(); + static final String NAME = "http_host"; + + @Override + public PlainActionFuture execute( + String currentNodeId, + WorkflowData currentNodeInputs, + Map outputs, + Map previousNodeInputs, + Map params + ) { + Set requiredKeys = Set.of(SCHEME_FIELD, HOSTNAME_FIELD, PORT_FIELD); + // TODO Possibly add credentials fields here + // See ML Commons MLConnectorInput class and its usage + Set optionalKeys = Collections.emptySet(); + + try { + Map inputs = ParseUtils.getInputsFromPreviousSteps( + requiredKeys, + optionalKeys, + currentNodeInputs, + outputs, + previousNodeInputs, + params + ); + + String scheme = validScheme(inputs.get(SCHEME_FIELD)); + String hostname = validHostName(inputs.get(HOSTNAME_FIELD)); + int port = validPort(inputs.get(PORT_FIELD)); + + HttpHost httpHost = new HttpHost(hostname, port, scheme); + + hostFuture.onResponse( + new WorkflowData( + Map.ofEntries(Map.entry(HTTP_HOST_FIELD, httpHost)), + currentNodeInputs.getWorkflowId(), + currentNodeInputs.getNodeId() + ) + ); + + logger.info("Http Host registered successfully {}", httpHost); + + } catch (FlowFrameworkException e) { + hostFuture.onFailure(e); + } + return hostFuture; + } + + private String validScheme(Object o) { + String scheme = o.toString().toLowerCase(Locale.ROOT); + if ("http".equals(scheme) || "https".equals(scheme)) { + return scheme; + } + throw new FlowFrameworkException("http_host scheme must be http or https", RestStatus.BAD_REQUEST); + } + + private String validHostName(Object o) { + // TODO Add validation: + // Prevent use of localhost or private IP address ranges + // See ML Commons MLHttpClientFactory.java methods for examples + // Possibly consider an allowlist of addresses + return o.toString(); + } + + private int validPort(Object o) { + try { + int port = Integer.parseInt(o.toString()); + if ((port & 0xffff0000) != 0) { + throw new FlowFrameworkException("http_host port number must be between 0 and 65535", RestStatus.BAD_REQUEST); + } + return port; + } catch (NumberFormatException e) { + throw new FlowFrameworkException("http_host port must be a number between 0 and 65535", RestStatus.BAD_REQUEST); + } + } + + @Override + public String getName() { + return NAME; + } +} diff --git a/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java b/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java index a580c41f7..3fb91d6c6 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java +++ b/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java @@ -33,6 +33,8 @@ import static org.opensearch.flowframework.common.CommonValue.EMBEDDING_DIMENSION; import static org.opensearch.flowframework.common.CommonValue.FRAMEWORK_TYPE; import static org.opensearch.flowframework.common.CommonValue.FUNCTION_NAME; +import static org.opensearch.flowframework.common.CommonValue.HOSTNAME_FIELD; +import static org.opensearch.flowframework.common.CommonValue.HTTP_HOST_FIELD; import static org.opensearch.flowframework.common.CommonValue.MODEL_CONTENT_HASH_VALUE; import static org.opensearch.flowframework.common.CommonValue.MODEL_FORMAT; import static org.opensearch.flowframework.common.CommonValue.MODEL_GROUP_STATUS; @@ -40,8 +42,10 @@ import static org.opensearch.flowframework.common.CommonValue.NAME_FIELD; import static org.opensearch.flowframework.common.CommonValue.OPENSEARCH_ML; import static org.opensearch.flowframework.common.CommonValue.PARAMETERS_FIELD; +import static org.opensearch.flowframework.common.CommonValue.PORT_FIELD; import static org.opensearch.flowframework.common.CommonValue.PROTOCOL_FIELD; import static org.opensearch.flowframework.common.CommonValue.REGISTER_MODEL_STATUS; +import static org.opensearch.flowframework.common.CommonValue.SCHEME_FIELD; import static org.opensearch.flowframework.common.CommonValue.SUCCESS; import static org.opensearch.flowframework.common.CommonValue.TOOLS_FIELD; import static org.opensearch.flowframework.common.CommonValue.TYPE; @@ -100,6 +104,7 @@ public WorkflowStepFactory( stepMap.put(ToolStep.NAME, ToolStep::new); stepMap.put(RegisterAgentStep.NAME, () -> new RegisterAgentStep(mlClient, flowFrameworkIndicesHandler)); stepMap.put(DeleteAgentStep.NAME, () -> new DeleteAgentStep(mlClient)); + stepMap.put(HttpHostStep.NAME, HttpHostStep::new); } /** @@ -194,7 +199,16 @@ public enum WorkflowSteps { DELETE_AGENT(DeleteAgentStep.NAME, List.of(AGENT_ID), List.of(AGENT_ID), List.of(OPENSEARCH_ML), null), /** Create Tool Step */ - CREATE_TOOL(ToolStep.NAME, List.of(TYPE), List.of(TOOLS_FIELD), List.of(OPENSEARCH_ML), null); + CREATE_TOOL(ToolStep.NAME, List.of(TYPE), List.of(TOOLS_FIELD), List.of(OPENSEARCH_ML), null), + + /** Http Host Step */ + HTTP_HOST( + HttpHostStep.NAME, + List.of(SCHEME_FIELD, HOSTNAME_FIELD, PORT_FIELD), + List.of(HTTP_HOST_FIELD), + Collections.emptyList(), + null + ); private final String workflowStepName; private final List inputs; diff --git a/src/test/java/org/opensearch/flowframework/model/WorkflowValidatorTests.java b/src/test/java/org/opensearch/flowframework/model/WorkflowValidatorTests.java index 678435707..5ba5924f3 100644 --- a/src/test/java/org/opensearch/flowframework/model/WorkflowValidatorTests.java +++ b/src/test/java/org/opensearch/flowframework/model/WorkflowValidatorTests.java @@ -11,15 +11,17 @@ import org.opensearch.flowframework.common.FlowFrameworkSettings; import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; import org.opensearch.flowframework.workflow.WorkflowStepFactory; +import org.opensearch.flowframework.workflow.WorkflowStepFactory.WorkflowSteps; import org.opensearch.ml.client.MachineLearningNodeClient; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.ThreadPool; import java.io.IOException; import java.util.ArrayList; -import java.util.HashMap; +import java.util.Arrays; import java.util.List; import java.util.Map; +import java.util.stream.Collectors; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; @@ -37,67 +39,12 @@ public void setUp() throws Exception { } public void testParseWorkflowValidator() throws IOException { - Map workflowStepValidators = new HashMap<>(); - workflowStepValidators.put( - WorkflowStepFactory.WorkflowSteps.CREATE_CONNECTOR.getWorkflowStepName(), - WorkflowStepFactory.WorkflowSteps.CREATE_CONNECTOR.getWorkflowStepValidator() - ); - workflowStepValidators.put( - WorkflowStepFactory.WorkflowSteps.DELETE_MODEL.getWorkflowStepName(), - WorkflowStepFactory.WorkflowSteps.DELETE_MODEL.getWorkflowStepValidator() - ); - workflowStepValidators.put( - WorkflowStepFactory.WorkflowSteps.DEPLOY_MODEL.getWorkflowStepName(), - WorkflowStepFactory.WorkflowSteps.DEPLOY_MODEL.getWorkflowStepValidator() - ); - workflowStepValidators.put( - WorkflowStepFactory.WorkflowSteps.REGISTER_LOCAL_CUSTOM_MODEL.getWorkflowStepName(), - WorkflowStepFactory.WorkflowSteps.REGISTER_LOCAL_CUSTOM_MODEL.getWorkflowStepValidator() - ); - workflowStepValidators.put( - WorkflowStepFactory.WorkflowSteps.REGISTER_LOCAL_PRETRAINED_MODEL.getWorkflowStepName(), - WorkflowStepFactory.WorkflowSteps.REGISTER_LOCAL_PRETRAINED_MODEL.getWorkflowStepValidator() - ); - workflowStepValidators.put( - WorkflowStepFactory.WorkflowSteps.REGISTER_LOCAL_SPARSE_ENCODING_MODEL.getWorkflowStepName(), - WorkflowStepFactory.WorkflowSteps.REGISTER_LOCAL_SPARSE_ENCODING_MODEL.getWorkflowStepValidator() - ); - workflowStepValidators.put( - WorkflowStepFactory.WorkflowSteps.REGISTER_REMOTE_MODEL.getWorkflowStepName(), - WorkflowStepFactory.WorkflowSteps.REGISTER_REMOTE_MODEL.getWorkflowStepValidator() - ); - workflowStepValidators.put( - WorkflowStepFactory.WorkflowSteps.REGISTER_MODEL_GROUP.getWorkflowStepName(), - WorkflowStepFactory.WorkflowSteps.REGISTER_MODEL_GROUP.getWorkflowStepValidator() - ); - workflowStepValidators.put( - WorkflowStepFactory.WorkflowSteps.REGISTER_AGENT.getWorkflowStepName(), - WorkflowStepFactory.WorkflowSteps.REGISTER_AGENT.getWorkflowStepValidator() - ); - workflowStepValidators.put( - WorkflowStepFactory.WorkflowSteps.CREATE_TOOL.getWorkflowStepName(), - WorkflowStepFactory.WorkflowSteps.CREATE_TOOL.getWorkflowStepValidator() - ); - workflowStepValidators.put( - WorkflowStepFactory.WorkflowSteps.UNDEPLOY_MODEL.getWorkflowStepName(), - WorkflowStepFactory.WorkflowSteps.UNDEPLOY_MODEL.getWorkflowStepValidator() - ); - workflowStepValidators.put( - WorkflowStepFactory.WorkflowSteps.DELETE_CONNECTOR.getWorkflowStepName(), - WorkflowStepFactory.WorkflowSteps.DELETE_CONNECTOR.getWorkflowStepValidator() - ); - workflowStepValidators.put( - WorkflowStepFactory.WorkflowSteps.DELETE_AGENT.getWorkflowStepName(), - WorkflowStepFactory.WorkflowSteps.DELETE_AGENT.getWorkflowStepValidator() - ); - workflowStepValidators.put( - WorkflowStepFactory.WorkflowSteps.NOOP.getWorkflowStepName(), - WorkflowStepFactory.WorkflowSteps.NOOP.getWorkflowStepValidator() - ); + Map workflowStepValidators = Arrays.stream(WorkflowSteps.values()) + .collect(Collectors.toMap(WorkflowSteps::getWorkflowStepName, WorkflowSteps::getWorkflowStepValidator)); WorkflowValidator validator = new WorkflowValidator(workflowStepValidators); - assertEquals(14, validator.getWorkflowStepValidators().size()); + assertEquals(15, validator.getWorkflowStepValidators().size()); assertTrue(validator.getWorkflowStepValidators().keySet().contains("create_connector")); assertEquals(7, validator.getWorkflowStepValidators().get("create_connector").getInputs().size()); @@ -155,6 +102,9 @@ public void testParseWorkflowValidator() throws IOException { assertEquals(0, validator.getWorkflowStepValidators().get("noop").getInputs().size()); assertEquals(0, validator.getWorkflowStepValidators().get("noop").getOutputs().size()); + assertTrue(validator.getWorkflowStepValidators().keySet().contains("http_host")); + assertEquals(3, validator.getWorkflowStepValidators().get("http_host").getInputs().size()); + assertEquals(1, validator.getWorkflowStepValidators().get("http_host").getOutputs().size()); } public void testWorkflowStepFactoryHasValidators() throws IOException { diff --git a/src/test/java/org/opensearch/flowframework/workflow/HttpHostStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/HttpHostStepTests.java new file mode 100644 index 000000000..6c80563d6 --- /dev/null +++ b/src/test/java/org/opensearch/flowframework/workflow/HttpHostStepTests.java @@ -0,0 +1,121 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.workflow; + +import org.apache.http.HttpHost; +import org.opensearch.action.support.PlainActionFuture; +import org.opensearch.flowframework.exception.FlowFrameworkException; +import org.opensearch.test.OpenSearchTestCase; + +import java.util.Collections; +import java.util.Map; +import java.util.concurrent.ExecutionException; + +import static org.opensearch.flowframework.common.CommonValue.HOSTNAME_FIELD; +import static org.opensearch.flowframework.common.CommonValue.HTTP_HOST_FIELD; +import static org.opensearch.flowframework.common.CommonValue.PORT_FIELD; +import static org.opensearch.flowframework.common.CommonValue.SCHEME_FIELD; + +public class HttpHostStepTests extends OpenSearchTestCase { + + public void testHttpHost() throws InterruptedException, ExecutionException { + HttpHostStep httpHostStep = new HttpHostStep(); + assertEquals(HttpHostStep.NAME, httpHostStep.getName()); + + WorkflowData inputData = new WorkflowData( + Map.ofEntries(Map.entry(SCHEME_FIELD, "http"), Map.entry(HOSTNAME_FIELD, "localhost"), Map.entry(PORT_FIELD, 1234)), + "test-id", + "test-node-id" + ); + + PlainActionFuture future = httpHostStep.execute( + inputData.getNodeId(), + inputData, + Collections.emptyMap(), + Collections.emptyMap(), + Collections.emptyMap() + ); + + assertTrue(future.isDone()); + assertEquals(HttpHost.class, future.get().getContent().get(HTTP_HOST_FIELD).getClass()); + HttpHost host = (HttpHost) future.get().getContent().get(HTTP_HOST_FIELD); + assertEquals("http", host.getSchemeName()); + assertEquals("localhost", host.getHostName()); + assertEquals(1234, host.getPort()); + } + + public void testBadScheme() { + HttpHostStep httpHostStep = new HttpHostStep(); + + WorkflowData badSchemeData = new WorkflowData( + Map.ofEntries(Map.entry(SCHEME_FIELD, "ftp"), Map.entry(HOSTNAME_FIELD, "localhost"), Map.entry(PORT_FIELD, 1234)), + "test-id", + "test-node-id" + ); + + PlainActionFuture future = httpHostStep.execute( + badSchemeData.getNodeId(), + badSchemeData, + Collections.emptyMap(), + Collections.emptyMap(), + Collections.emptyMap() + ); + + assertTrue(future.isDone()); + ExecutionException ex = assertThrows(ExecutionException.class, () -> future.get()); + assertEquals(FlowFrameworkException.class, ex.getCause().getClass()); + assertEquals("http_host scheme must be http or https", ex.getCause().getMessage()); + } + + public void testBadPort() { + HttpHostStep httpHostStep = new HttpHostStep(); + + WorkflowData badPortData = new WorkflowData( + Map.ofEntries(Map.entry(SCHEME_FIELD, "https"), Map.entry(HOSTNAME_FIELD, "localhost"), Map.entry(PORT_FIELD, 123456)), + "test-id", + "test-node-id" + ); + + PlainActionFuture future = httpHostStep.execute( + badPortData.getNodeId(), + badPortData, + Collections.emptyMap(), + Collections.emptyMap(), + Collections.emptyMap() + ); + + assertTrue(future.isDone()); + ExecutionException ex = assertThrows(ExecutionException.class, () -> future.get()); + assertEquals(FlowFrameworkException.class, ex.getCause().getClass()); + assertEquals("http_host port number must be between 0 and 65535", ex.getCause().getMessage()); + } + + public void testNoParsePort() { + HttpHostStep httpHostStep = new HttpHostStep(); + + WorkflowData noParsePortData = new WorkflowData( + Map.ofEntries(Map.entry(SCHEME_FIELD, "https"), Map.entry(HOSTNAME_FIELD, "localhost"), Map.entry(PORT_FIELD, "doesn't parse")), + "test-id", + "test-node-id" + ); + + PlainActionFuture future = httpHostStep.execute( + noParsePortData.getNodeId(), + noParsePortData, + Collections.emptyMap(), + Collections.emptyMap(), + Collections.emptyMap() + ); + + assertTrue(future.isDone()); + ExecutionException ex = assertThrows(ExecutionException.class, () -> future.get()); + assertEquals(FlowFrameworkException.class, ex.getCause().getClass()); + assertEquals("http_host port must be a number between 0 and 65535", ex.getCause().getMessage()); + } +}