Skip to content

Commit

Permalink
Add model interface support for remote and local custom models
Browse files Browse the repository at this point in the history
Signed-off-by: Joshua Palis <[email protected]>
  • Loading branch information
joshpalis committed May 2, 2024
1 parent 40d9efc commit 05b701c
Show file tree
Hide file tree
Showing 6 changed files with 78 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,8 @@ private CommonValue() {}
public static final String GUARDRAILS_FIELD = "guardrails";
/** Delay field */
public static final String DELAY_FIELD = "delay";
/** Model Interface Field */
public static final String MODEL_INTERFACE = "interface";

/*
* Constants associated with resource provisioning / state
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.opensearch.flowframework.common.CommonValue.CONFIGURATIONS;
import static org.opensearch.flowframework.common.CommonValue.GUARDRAILS_FIELD;
import static org.opensearch.flowframework.common.CommonValue.MODEL_INTERFACE;
import static org.opensearch.flowframework.common.CommonValue.TOOLS_ORDER_FIELD;
import static org.opensearch.flowframework.util.ParseUtils.buildStringToObjectMap;
import static org.opensearch.flowframework.util.ParseUtils.buildStringToStringMap;
Expand Down Expand Up @@ -164,7 +165,7 @@ public static WorkflowNode parse(XContentParser parser) throws IOException {
if (GUARDRAILS_FIELD.equals(inputFieldName)) {
userInputs.put(inputFieldName, Guardrails.parse(parser));
break;
} else if (CONFIGURATIONS.equals(inputFieldName)) {
} else if (CONFIGURATIONS.equals(inputFieldName) || MODEL_INTERFACE.equals(inputFieldName)) {
Map<String, Object> configurationsMap = parser.map();
try {
String configurationsString = ParseUtils.parseArbitraryStringToObjectMapToString(configurationsMap);
Expand Down
15 changes: 15 additions & 0 deletions src/main/java/org/opensearch/flowframework/util/ParseUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -451,4 +451,19 @@ public static String removingBackslashesAndQuotesInArrayInJsonString(String inpu
matcher.appendTail(result);
return result.toString();
}

public static Map<String, String> convertStringToObjectMapToStringToStringMap(Map<String, Object> stringToObjectMap) throws Exception {
try (Jsonb jsonb = JsonbBuilder.create()) {
Map<String, String> stringToStringMap = new HashMap<>();
for (Map.Entry<String, Object> entry : stringToObjectMap.entrySet()) {
Object value = entry.getValue();
if (value instanceof String) {
stringToStringMap.put(entry.getKey(), (String) value);
} else {
stringToStringMap.put(entry.getKey(), jsonb.toJson(value));
}
}
return stringToStringMap;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,12 @@
import org.opensearch.ExceptionsHelper;
import org.opensearch.action.support.PlainActionFuture;
import org.opensearch.common.Booleans;
import org.opensearch.common.xcontent.XContentHelper;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.bytes.BytesArray;
import org.opensearch.core.common.bytes.BytesReference;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.core.xcontent.MediaTypeRegistry;
import org.opensearch.flowframework.common.FlowFrameworkSettings;
import org.opensearch.flowframework.exception.FlowFrameworkException;
import org.opensearch.flowframework.exception.WorkflowStepException;
Expand All @@ -30,6 +34,7 @@
import org.opensearch.ml.common.transport.register.MLRegisterModelInput.MLRegisterModelInputBuilder;
import org.opensearch.threadpool.ThreadPool;

import java.nio.charset.StandardCharsets;
import java.util.Map;
import java.util.Set;
import java.util.stream.Stream;
Expand All @@ -42,6 +47,7 @@
import static org.opensearch.flowframework.common.CommonValue.FUNCTION_NAME;
import static org.opensearch.flowframework.common.CommonValue.MODEL_CONTENT_HASH_VALUE;
import static org.opensearch.flowframework.common.CommonValue.MODEL_FORMAT;
import static org.opensearch.flowframework.common.CommonValue.MODEL_INTERFACE;
import static org.opensearch.flowframework.common.CommonValue.MODEL_TYPE;
import static org.opensearch.flowframework.common.CommonValue.NAME_FIELD;
import static org.opensearch.flowframework.common.CommonValue.REGISTER_MODEL_STATUS;
Expand Down Expand Up @@ -116,6 +122,7 @@ public PlainActionFuture<WorkflowData> execute(
String description = (String) inputs.get(DESCRIPTION_FIELD);
String modelGroupId = (String) inputs.get(MODEL_GROUP_ID);
String allConfig = (String) inputs.get(ALL_CONFIG);
String modelInterface = (String) inputs.get(MODEL_INTERFACE);
final Boolean deploy = inputs.containsKey(DEPLOY_FIELD) ? Booleans.parseBoolean(inputs.get(DEPLOY_FIELD).toString()) : null;

// Build register model input
Expand Down Expand Up @@ -149,6 +156,27 @@ public PlainActionFuture<WorkflowData> execute(
if (modelGroupId != null) {
mlInputBuilder.modelGroupId(modelGroupId);
}
if (modelInterface != null) {
try {
// Convert model interface string to map
BytesReference modelInterfaceBytes = new BytesArray(modelInterface.getBytes(StandardCharsets.UTF_8));
Map<String, Object> modelInterfaceAsMap = XContentHelper.convertToMap(
modelInterfaceBytes,
false,
MediaTypeRegistry.JSON
).v2();

// Convert to string to string map
Map<String, String> parameters = ParseUtils.convertStringToObjectMapToStringToStringMap(modelInterfaceAsMap);
mlInputBuilder.modelInterface(parameters);

} catch (Exception ex) {
String errorMessage = "Failed to create model interface";
logger.error(errorMessage, ex);
registerLocalModelFuture.onFailure(new WorkflowStepException(errorMessage, RestStatus.BAD_REQUEST));
}

}
if (deploy != null) {
mlInputBuilder.deployModel(deploy);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import static org.opensearch.flowframework.common.CommonValue.FUNCTION_NAME;
import static org.opensearch.flowframework.common.CommonValue.MODEL_CONTENT_HASH_VALUE;
import static org.opensearch.flowframework.common.CommonValue.MODEL_FORMAT;
import static org.opensearch.flowframework.common.CommonValue.MODEL_INTERFACE;
import static org.opensearch.flowframework.common.CommonValue.MODEL_TYPE;
import static org.opensearch.flowframework.common.CommonValue.NAME_FIELD;
import static org.opensearch.flowframework.common.CommonValue.URL;
Expand Down Expand Up @@ -71,7 +72,7 @@ protected Set<String> getRequiredKeys() {

@Override
protected Set<String> getOptionalKeys() {
return Set.of(DESCRIPTION_FIELD, MODEL_GROUP_ID, ALL_CONFIG, DEPLOY_FIELD);
return Set.of(DESCRIPTION_FIELD, MODEL_GROUP_ID, ALL_CONFIG, DEPLOY_FIELD, MODEL_INTERFACE);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,12 @@
import org.opensearch.action.support.PlainActionFuture;
import org.opensearch.action.update.UpdateResponse;
import org.opensearch.common.Booleans;
import org.opensearch.common.xcontent.XContentHelper;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.bytes.BytesArray;
import org.opensearch.core.common.bytes.BytesReference;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.core.xcontent.MediaTypeRegistry;
import org.opensearch.flowframework.exception.FlowFrameworkException;
import org.opensearch.flowframework.exception.WorkflowStepException;
import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler;
Expand All @@ -27,12 +31,14 @@
import org.opensearch.ml.common.transport.register.MLRegisterModelInput.MLRegisterModelInputBuilder;
import org.opensearch.ml.common.transport.register.MLRegisterModelResponse;

import java.nio.charset.StandardCharsets;
import java.util.Map;
import java.util.Set;

import static org.opensearch.flowframework.common.CommonValue.DEPLOY_FIELD;
import static org.opensearch.flowframework.common.CommonValue.DESCRIPTION_FIELD;
import static org.opensearch.flowframework.common.CommonValue.GUARDRAILS_FIELD;
import static org.opensearch.flowframework.common.CommonValue.MODEL_INTERFACE;
import static org.opensearch.flowframework.common.CommonValue.NAME_FIELD;
import static org.opensearch.flowframework.common.CommonValue.REGISTER_MODEL_STATUS;
import static org.opensearch.flowframework.common.WorkflowResources.CONNECTOR_ID;
Expand Down Expand Up @@ -76,7 +82,7 @@ public PlainActionFuture<WorkflowData> execute(
PlainActionFuture<WorkflowData> registerRemoteModelFuture = PlainActionFuture.newFuture();

Set<String> requiredKeys = Set.of(NAME_FIELD, CONNECTOR_ID);
Set<String> optionalKeys = Set.of(MODEL_GROUP_ID, DESCRIPTION_FIELD, DEPLOY_FIELD, GUARDRAILS_FIELD);
Set<String> optionalKeys = Set.of(MODEL_GROUP_ID, DESCRIPTION_FIELD, DEPLOY_FIELD, GUARDRAILS_FIELD, MODEL_INTERFACE);

try {
Map<String, Object> inputs = ParseUtils.getInputsFromPreviousSteps(
Expand All @@ -93,6 +99,7 @@ public PlainActionFuture<WorkflowData> execute(
String description = (String) inputs.get(DESCRIPTION_FIELD);
String connectorId = (String) inputs.get(CONNECTOR_ID);
Guardrails guardRails = (Guardrails) inputs.get(GUARDRAILS_FIELD);
String modelInterface = (String) inputs.get(MODEL_INTERFACE);
final Boolean deploy = inputs.containsKey(DEPLOY_FIELD) ? Booleans.parseBoolean(inputs.get(DEPLOY_FIELD).toString()) : null;

MLRegisterModelInputBuilder builder = MLRegisterModelInput.builder()
Expand All @@ -112,6 +119,27 @@ public PlainActionFuture<WorkflowData> execute(
if (guardRails != null) {
builder.guardrails(guardRails);
}
if (modelInterface != null) {
try {
// Convert model interface string to map
BytesReference modelInterfaceBytes = new BytesArray(modelInterface.getBytes(StandardCharsets.UTF_8));
Map<String, Object> modelInterfaceAsMap = XContentHelper.convertToMap(
modelInterfaceBytes,
false,
MediaTypeRegistry.JSON
).v2();

// Convert to string to string map
Map<String, String> parameters = ParseUtils.convertStringToObjectMapToStringToStringMap(modelInterfaceAsMap);
builder.modelInterface(parameters);

} catch (Exception ex) {
String errorMessage = "Failed to create model interface";
logger.error(errorMessage, ex);
registerRemoteModelFuture.onFailure(new WorkflowStepException(errorMessage, RestStatus.BAD_REQUEST));
}

}

MLRegisterModelInput mlInput = builder.build();

Expand Down

0 comments on commit 05b701c

Please sign in to comment.