Skip to content

Commit

Permalink
Revert "[Backport 2.13] Added new Guardrail field for remote model (o…
Browse files Browse the repository at this point in the history
…pensearch-project#624)"

This reverts commit be5410b.
  • Loading branch information
amitgalitz committed Apr 4, 2024
1 parent 478bb7e commit fe0f9b4
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -168,8 +168,6 @@ private CommonValue() {}
public static final String PIPELINE_ID = "pipeline_id";
/** Pipeline Configurations */
public static final String CONFIGURATIONS = "configurations";
/** Guardrails field */
public static final String GUARDRAILS_FIELD = "guardrails";

/*
* Constants associated with resource provisioning / state
Expand Down
12 changes: 2 additions & 10 deletions src/main/java/org/opensearch/flowframework/model/WorkflowNode.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import org.opensearch.flowframework.workflow.ProcessNode;
import org.opensearch.flowframework.workflow.WorkflowData;
import org.opensearch.flowframework.workflow.WorkflowStep;
import org.opensearch.ml.common.model.Guardrails;

import java.io.IOException;
import java.util.ArrayList;
Expand All @@ -33,7 +32,6 @@
import static java.util.concurrent.TimeUnit.SECONDS;
import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.opensearch.flowframework.common.CommonValue.CONFIGURATIONS;
import static org.opensearch.flowframework.common.CommonValue.GUARDRAILS_FIELD;
import static org.opensearch.flowframework.common.CommonValue.TOOLS_ORDER_FIELD;
import static org.opensearch.flowframework.util.ParseUtils.buildStringToObjectMap;
import static org.opensearch.flowframework.util.ParseUtils.buildStringToStringMap;
Expand Down Expand Up @@ -97,9 +95,6 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
xContentBuilder.field(e.getKey());
if (e.getValue() instanceof String || e.getValue() instanceof Number || e.getValue() instanceof Boolean) {
xContentBuilder.value(e.getValue());
} else if (GUARDRAILS_FIELD.equals(e.getKey())) {
Guardrails g = (Guardrails) e.getValue();
xContentBuilder.value(g);
} else if (e.getValue() instanceof Map<?, ?>) {
buildStringToStringMap(xContentBuilder, (Map<?, ?>) e.getValue());
} else if (e.getValue() instanceof Object[]) {
Expand Down Expand Up @@ -161,16 +156,13 @@ public static WorkflowNode parse(XContentParser parser) throws IOException {
userInputs.put(inputFieldName, parser.text());
break;
case START_OBJECT:
if (GUARDRAILS_FIELD.equals(inputFieldName)) {
userInputs.put(inputFieldName, Guardrails.parse(parser));
break;
} else if (CONFIGURATIONS.equals(inputFieldName)) {
if (CONFIGURATIONS.equals(inputFieldName)) {
Map<String, Object> configurationsMap = parser.map();
try {
String configurationsString = ParseUtils.parseArbitraryStringToObjectMapToString(configurationsMap);
userInputs.put(inputFieldName, configurationsString);
} catch (Exception ex) {
String errorMessage = "Failed to parse" + inputFieldName + "map";
String errorMessage = "Failed to parse configuration map";
logger.error(errorMessage, ex);
throw new FlowFrameworkException(errorMessage, RestStatus.BAD_REQUEST);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import org.opensearch.flowframework.util.ParseUtils;
import org.opensearch.ml.client.MachineLearningNodeClient;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.model.Guardrails;
import org.opensearch.ml.common.transport.register.MLRegisterModelInput;
import org.opensearch.ml.common.transport.register.MLRegisterModelInput.MLRegisterModelInputBuilder;
import org.opensearch.ml.common.transport.register.MLRegisterModelResponse;
Expand All @@ -30,7 +29,6 @@

import static org.opensearch.flowframework.common.CommonValue.DEPLOY_FIELD;
import static org.opensearch.flowframework.common.CommonValue.DESCRIPTION_FIELD;
import static org.opensearch.flowframework.common.CommonValue.GUARDRAILS_FIELD;
import static org.opensearch.flowframework.common.CommonValue.NAME_FIELD;
import static org.opensearch.flowframework.common.CommonValue.REGISTER_MODEL_STATUS;
import static org.opensearch.flowframework.common.WorkflowResources.CONNECTOR_ID;
Expand Down Expand Up @@ -73,7 +71,7 @@ public PlainActionFuture<WorkflowData> execute(
PlainActionFuture<WorkflowData> registerRemoteModelFuture = PlainActionFuture.newFuture();

Set<String> requiredKeys = Set.of(NAME_FIELD, CONNECTOR_ID);
Set<String> optionalKeys = Set.of(MODEL_GROUP_ID, DESCRIPTION_FIELD, DEPLOY_FIELD, GUARDRAILS_FIELD);
Set<String> optionalKeys = Set.of(MODEL_GROUP_ID, DESCRIPTION_FIELD, DEPLOY_FIELD);

try {
Map<String, Object> inputs = ParseUtils.getInputsFromPreviousSteps(
Expand All @@ -89,7 +87,6 @@ public PlainActionFuture<WorkflowData> execute(
String modelGroupId = (String) inputs.get(MODEL_GROUP_ID);
String description = (String) inputs.get(DESCRIPTION_FIELD);
String connectorId = (String) inputs.get(CONNECTOR_ID);
Guardrails guardRails = (Guardrails) inputs.get(GUARDRAILS_FIELD);
final Boolean deploy = (Boolean) inputs.get(DEPLOY_FIELD);

MLRegisterModelInputBuilder builder = MLRegisterModelInput.builder()
Expand All @@ -106,11 +103,6 @@ public PlainActionFuture<WorkflowData> execute(
if (deploy != null) {
builder.deployModel(deploy);
}

if (guardRails != null) {
builder.guardrails(guardRails);
}

MLRegisterModelInput mlInput = builder.build();

mlClient.register(mlInput, new ActionListener<MLRegisterModelResponse>() {
Expand Down

0 comments on commit fe0f9b4

Please sign in to comment.