Skip to content

Commit

Permalink
Added model-group step to workflow-steps.json, rebasing with main
Browse files Browse the repository at this point in the history
Signed-off-by: Joshua Palis <[email protected]>
  • Loading branch information
joshpalis committed Oct 30, 2023
2 parents becc510 + bcd53e1 commit cbac423
Show file tree
Hide file tree
Showing 10 changed files with 331 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,14 @@ private CommonValue() {}
public static final String GLOBAL_CONTEXT_INDEX_MAPPING = "mappings/global-context.json";
/** Global Context index mapping version */
public static final Integer GLOBAL_CONTEXT_INDEX_VERSION = 1;
/** The template field name for template use case */
public static final String USE_CASE_FIELD = "use_case";
/** The template field name for template version */
public static final String TEMPLATE_FIELD = "template";
/** The template field name for template compatibility with OpenSearch versions */
public static final String COMPATIBILITY_FIELD = "compatibility";
/** The template field name for template workflows */
public static final String WORKFLOWS_FIELD = "workflows";

/** The transport action name prefix */
public static final String TRANSPORT_ACION_NAME_PREFIX = "cluster:admin/opensearch/flow_framework/";
Expand Down Expand Up @@ -55,7 +63,7 @@ private CommonValue() {}
/** Model Group Id field */
public static final String MODEL_GROUP_ID = "model_group_id";
/** Description field */
public static final String DESCRIPTION = "description";
public static final String DESCRIPTION_FIELD = "description";
/** Connector Id field */
public static final String CONNECTOR_ID = "connector_id";
/** Model format field */
Expand All @@ -72,4 +80,10 @@ private CommonValue() {}
public static final String CREDENTIALS_FIELD = "credentials";
/** Connector actions field */
public static final String ACTIONS_FIELD = "actions";
/** Backend roles for the model */
public static final String BACKEND_ROLES_FIELD = "backend_roles";
/** Access mode for the model */
public static final String MODEL_ACCESS_MODE = "access_mode";
/** Add all backend roles */
public static final String ADD_ALL_BACKEND_ROLES = "add_all_backend_roles";
}
22 changes: 7 additions & 15 deletions src/main/java/org/opensearch/flowframework/model/Template.java
Original file line number Diff line number Diff line change
Expand Up @@ -25,27 +25,19 @@
import java.util.Map.Entry;

import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.opensearch.flowframework.common.CommonValue.COMPATIBILITY_FIELD;
import static org.opensearch.flowframework.common.CommonValue.DESCRIPTION_FIELD;
import static org.opensearch.flowframework.common.CommonValue.NAME_FIELD;
import static org.opensearch.flowframework.common.CommonValue.TEMPLATE_FIELD;
import static org.opensearch.flowframework.common.CommonValue.USE_CASE_FIELD;
import static org.opensearch.flowframework.common.CommonValue.VERSION_FIELD;
import static org.opensearch.flowframework.common.CommonValue.WORKFLOWS_FIELD;

/**
* The Template is the central data structure which configures workflows. This object is used to parse JSON communicated via REST API.
*/
public class Template implements ToXContentObject {

/** The template field name for template name */
public static final String NAME_FIELD = "name";
/** The template field name for template description */
public static final String DESCRIPTION_FIELD = "description";
/** The template field name for template use case */
public static final String USE_CASE_FIELD = "use_case";
/** The template field name for template version information */
public static final String VERSION_FIELD = "version";
/** The template field name for template version */
public static final String TEMPLATE_FIELD = "template";
/** The template field name for template compatibility with OpenSearch versions */
public static final String COMPATIBILITY_FIELD = "compatibility";
/** The template field name for template workflows */
public static final String WORKFLOWS_FIELD = "workflows";

private final String name;
private final String description;
private final String useCase; // probably an ENUM actually
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import java.security.AccessController;
import java.security.PrivilegedActionException;
import java.security.PrivilegedExceptionAction;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
Expand All @@ -32,7 +33,7 @@

import static org.opensearch.flowframework.common.CommonValue.ACTIONS_FIELD;
import static org.opensearch.flowframework.common.CommonValue.CREDENTIALS_FIELD;
import static org.opensearch.flowframework.common.CommonValue.DESCRIPTION;
import static org.opensearch.flowframework.common.CommonValue.DESCRIPTION_FIELD;
import static org.opensearch.flowframework.common.CommonValue.NAME_FIELD;
import static org.opensearch.flowframework.common.CommonValue.PARAMETERS_FIELD;
import static org.opensearch.flowframework.common.CommonValue.PROTOCOL_FIELD;
Expand Down Expand Up @@ -85,7 +86,7 @@ public void onFailure(Exception e) {
String protocol = null;
Map<String, String> parameters = new HashMap<>();
Map<String, String> credentials = new HashMap<>();
List<ConnectorAction> actions = null;
List<ConnectorAction> actions = new ArrayList<>();

for (WorkflowData workflowData : data) {
Map<String, Object> content = workflowData.getContent();
Expand All @@ -95,8 +96,8 @@ public void onFailure(Exception e) {
case NAME_FIELD:
name = (String) content.get(NAME_FIELD);
break;
case DESCRIPTION:
description = (String) content.get(DESCRIPTION);
case DESCRIPTION_FIELD:
description = (String) content.get(DESCRIPTION_FIELD);
break;
case VERSION_FIELD:
version = (String) content.get(VERSION_FIELD);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/
package org.opensearch.flowframework.workflow;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.ExceptionsHelper;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.flowframework.exception.FlowFrameworkException;
import org.opensearch.ml.client.MachineLearningNodeClient;
import org.opensearch.ml.common.AccessMode;
import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupInput;
import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupInput.MLRegisterModelGroupInputBuilder;
import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupResponse;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.concurrent.CompletableFuture;

import static org.opensearch.flowframework.common.CommonValue.ADD_ALL_BACKEND_ROLES;
import static org.opensearch.flowframework.common.CommonValue.BACKEND_ROLES_FIELD;
import static org.opensearch.flowframework.common.CommonValue.DESCRIPTION_FIELD;
import static org.opensearch.flowframework.common.CommonValue.MODEL_ACCESS_MODE;
import static org.opensearch.flowframework.common.CommonValue.NAME_FIELD;

/**
* Step to register a model group
*/
public class ModelGroupStep implements WorkflowStep {

private static final Logger logger = LogManager.getLogger(RegisterModelStep.class);

private MachineLearningNodeClient mlClient;

static final String NAME = "model_group";

/**
* Instantiate this class
* @param mlClient client to instantiate MLClient
*/
public ModelGroupStep(MachineLearningNodeClient mlClient) {
this.mlClient = mlClient;
}

@Override
public CompletableFuture<WorkflowData> execute(List<WorkflowData> data) throws IOException {

CompletableFuture<WorkflowData> registerModelGroupFuture = new CompletableFuture<>();

ActionListener<MLRegisterModelGroupResponse> actionListener = new ActionListener<>() {
@Override
public void onResponse(MLRegisterModelGroupResponse mlRegisterModelGroupResponse) {
logger.info("Model group registration successful");
registerModelGroupFuture.complete(
new WorkflowData(
Map.ofEntries(
Map.entry("model_group_id", mlRegisterModelGroupResponse.getModelGroupId()),
Map.entry("model_group_status", mlRegisterModelGroupResponse.getStatus())
)
)
);
}

@Override
public void onFailure(Exception e) {
logger.error("Failed to register model group");
registerModelGroupFuture.completeExceptionally(new FlowFrameworkException(e.getMessage(), ExceptionsHelper.status(e)));
}
};

String modelGroupName = null;
String description = null;
List<String> backendRoles = new ArrayList<>();
AccessMode modelAccessMode = null;
Boolean isAddAllBackendRoles = null;

for (WorkflowData workflowData : data) {
Map<String, Object> content = workflowData.getContent();

for (Entry<String, Object> entry : content.entrySet()) {
switch (entry.getKey()) {
case NAME_FIELD:
modelGroupName = (String) content.get(NAME_FIELD);
break;
case DESCRIPTION_FIELD:
description = (String) content.get(DESCRIPTION_FIELD);
break;
case BACKEND_ROLES_FIELD:
backendRoles = getBackendRoles(content);
break;
case MODEL_ACCESS_MODE:
modelAccessMode = (AccessMode) content.get(MODEL_ACCESS_MODE);
break;
case ADD_ALL_BACKEND_ROLES:
isAddAllBackendRoles = (Boolean) content.get(ADD_ALL_BACKEND_ROLES);
break;
default:
break;
}
}
}

if (modelGroupName == null) {
registerModelGroupFuture.completeExceptionally(
new FlowFrameworkException("Model group name is not provided", RestStatus.BAD_REQUEST)
);
} else {
MLRegisterModelGroupInputBuilder builder = MLRegisterModelGroupInput.builder();
builder.name(modelGroupName);
if (description != null) {
builder.description(description);
}
if (!backendRoles.isEmpty()) {
builder.backendRoles(backendRoles);
}
if (modelAccessMode != null) {
builder.modelAccessMode(modelAccessMode);
}
if (isAddAllBackendRoles != null) {
builder.isAddAllBackendRoles(isAddAllBackendRoles);
}
MLRegisterModelGroupInput mlInput = builder.build();

mlClient.registerModelGroup(mlInput, actionListener);
}

return registerModelGroupFuture;
}

@Override
public String getName() {
return NAME;
}

@SuppressWarnings("unchecked")
private List<String> getBackendRoles(Map<String, Object> content) {
return (List<String>) content.get(BACKEND_ROLES_FIELD);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
import java.util.stream.Stream;

import static org.opensearch.flowframework.common.CommonValue.CONNECTOR_ID;
import static org.opensearch.flowframework.common.CommonValue.DESCRIPTION;
import static org.opensearch.flowframework.common.CommonValue.DESCRIPTION_FIELD;
import static org.opensearch.flowframework.common.CommonValue.FUNCTION_NAME;
import static org.opensearch.flowframework.common.CommonValue.MODEL_CONFIG;
import static org.opensearch.flowframework.common.CommonValue.MODEL_FORMAT;
Expand Down Expand Up @@ -114,8 +114,8 @@ public void onFailure(Exception e) {
case MODEL_CONFIG:
modelConfig = (MLModelConfig) content.get(MODEL_CONFIG);
break;
case DESCRIPTION:
description = (String) content.get(DESCRIPTION);
case DESCRIPTION_FIELD:
description = (String) content.get(DESCRIPTION_FIELD);
break;
case CONNECTOR_ID:
connectorId = (String) content.get(CONNECTOR_ID);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ private void populateMap(ClusterService clusterService, Client client, MachineLe
stepMap.put(RegisterModelStep.NAME, new RegisterModelStep(mlClient));
stepMap.put(DeployModelStep.NAME, new DeployModelStep(mlClient));
stepMap.put(CreateConnectorStep.NAME, new CreateConnectorStep(mlClient));
stepMap.put(ModelGroupStep.NAME, new ModelGroupStep(mlClient));
}

/**
Expand Down
9 changes: 9 additions & 0 deletions src/main/resources/mappings/workflow-steps.json
Original file line number Diff line number Diff line change
Expand Up @@ -54,5 +54,14 @@
"outputs":[
"deploy_model_status"
]
},
"model_group": {
"inputs":[
"name"
],
"outputs":[
"model_group_id",
"model_group_status"
]
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import org.opensearch.core.rest.RestStatus;
import org.opensearch.flowframework.exception.FlowFrameworkException;
import org.opensearch.ml.client.MachineLearningNodeClient;
import org.opensearch.ml.common.connector.ConnectorAction;
import org.opensearch.ml.common.transport.connector.MLCreateConnectorInput;
import org.opensearch.ml.common.transport.connector.MLCreateConnectorResponse;
import org.opensearch.test.OpenSearchTestCase;
Expand All @@ -34,9 +35,6 @@
public class CreateConnectorStepTests extends OpenSearchTestCase {
private WorkflowData inputData = WorkflowData.EMPTY;

@Mock
ActionListener<MLCreateConnectorResponse> registerModelActionListener;

@Mock
MachineLearningNodeClient machineLearningNodeClient;

Expand All @@ -49,6 +47,10 @@ public void setUp() throws Exception {

MockitoAnnotations.openMocks(this);

ConnectorAction.ActionType actionType = ConnectorAction.ActionType.PREDICT;
String method = "post";
String url = "foot.test";

inputData = new WorkflowData(
Map.ofEntries(
Map.entry("name", "test"),
Expand All @@ -57,7 +59,20 @@ public void setUp() throws Exception {
Map.entry("protocol", "test"),
Map.entry("params", params),
Map.entry("credentials", credentials),
Map.entry("actions", List.of("actions"))
Map.entry(
"actions",
List.of(
new ConnectorAction(
actionType,
method,
url,
null,
"{ \"model\": \"${parameters.model}\", \"messages\": ${parameters.messages} }",
null,
null
)
)
)
)
);

Expand Down
Loading

0 comments on commit cbac423

Please sign in to comment.