From 4a12730a5449800da32def4f67e2e4e9c9a807c7 Mon Sep 17 00:00:00 2001 From: Owais Kazi Date: Wed, 27 Mar 2024 19:12:56 -0700 Subject: [PATCH] Added new Guardrail field for remote model (#622) * Added new field guarddail for remote model Signed-off-by: Owais Kazi * Fixed parsing Signed-off-by: Owais Kazi * Deserialize Signed-off-by: Owais Kazi * fixing guardrails Signed-off-by: Joshua Palis * Added break Signed-off-by: Owais Kazi --------- Signed-off-by: Owais Kazi Signed-off-by: Joshua Palis Co-authored-by: Joshua Palis --- .../opensearch/flowframework/common/CommonValue.java | 2 ++ .../opensearch/flowframework/model/WorkflowNode.java | 12 ++++++++++-- .../workflow/RegisterRemoteModelStep.java | 10 +++++++++- 3 files changed, 21 insertions(+), 3 deletions(-) diff --git a/src/main/java/org/opensearch/flowframework/common/CommonValue.java b/src/main/java/org/opensearch/flowframework/common/CommonValue.java index d3960d90b..8df5613d4 100644 --- a/src/main/java/org/opensearch/flowframework/common/CommonValue.java +++ b/src/main/java/org/opensearch/flowframework/common/CommonValue.java @@ -168,6 +168,8 @@ private CommonValue() {} public static final String PIPELINE_ID = "pipeline_id"; /** Pipeline Configurations */ public static final String CONFIGURATIONS = "configurations"; + /** Guardrails field */ + public static final String GUARDRAILS_FIELD = "guardrails"; /* * Constants associated with resource provisioning / state diff --git a/src/main/java/org/opensearch/flowframework/model/WorkflowNode.java b/src/main/java/org/opensearch/flowframework/model/WorkflowNode.java index 15d52ccd1..899167ac8 100644 --- a/src/main/java/org/opensearch/flowframework/model/WorkflowNode.java +++ b/src/main/java/org/opensearch/flowframework/model/WorkflowNode.java @@ -20,6 +20,7 @@ import org.opensearch.flowframework.workflow.ProcessNode; import org.opensearch.flowframework.workflow.WorkflowData; import org.opensearch.flowframework.workflow.WorkflowStep; +import org.opensearch.ml.common.model.Guardrails; import java.io.IOException; import java.util.ArrayList; @@ -32,6 +33,7 @@ import static java.util.concurrent.TimeUnit.SECONDS; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; import static org.opensearch.flowframework.common.CommonValue.CONFIGURATIONS; +import static org.opensearch.flowframework.common.CommonValue.GUARDRAILS_FIELD; import static org.opensearch.flowframework.common.CommonValue.TOOLS_ORDER_FIELD; import static org.opensearch.flowframework.util.ParseUtils.buildStringToObjectMap; import static org.opensearch.flowframework.util.ParseUtils.buildStringToStringMap; @@ -95,6 +97,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws xContentBuilder.field(e.getKey()); if (e.getValue() instanceof String || e.getValue() instanceof Number || e.getValue() instanceof Boolean) { xContentBuilder.value(e.getValue()); + } else if (GUARDRAILS_FIELD.equals(e.getKey())) { + Guardrails g = (Guardrails) e.getValue(); + xContentBuilder.value(g); } else if (e.getValue() instanceof Map) { buildStringToStringMap(xContentBuilder, (Map) e.getValue()); } else if (e.getValue() instanceof Object[]) { @@ -156,13 +161,16 @@ public static WorkflowNode parse(XContentParser parser) throws IOException { userInputs.put(inputFieldName, parser.text()); break; case START_OBJECT: - if (CONFIGURATIONS.equals(inputFieldName)) { + if (GUARDRAILS_FIELD.equals(inputFieldName)) { + userInputs.put(inputFieldName, Guardrails.parse(parser)); + break; + } else if (CONFIGURATIONS.equals(inputFieldName)) { Map configurationsMap = parser.map(); try { String configurationsString = ParseUtils.parseArbitraryStringToObjectMapToString(configurationsMap); userInputs.put(inputFieldName, configurationsString); } catch (Exception ex) { - String errorMessage = "Failed to parse configuration map"; + String errorMessage = "Failed to parse" + inputFieldName + "map"; logger.error(errorMessage, ex); throw new FlowFrameworkException(errorMessage, RestStatus.BAD_REQUEST); } diff --git a/src/main/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStep.java b/src/main/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStep.java index c32a7f0bd..cc3800284 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStep.java @@ -20,6 +20,7 @@ import org.opensearch.flowframework.util.ParseUtils; import org.opensearch.ml.client.MachineLearningNodeClient; import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.model.Guardrails; import org.opensearch.ml.common.transport.register.MLRegisterModelInput; import org.opensearch.ml.common.transport.register.MLRegisterModelInput.MLRegisterModelInputBuilder; import org.opensearch.ml.common.transport.register.MLRegisterModelResponse; @@ -29,6 +30,7 @@ import static org.opensearch.flowframework.common.CommonValue.DEPLOY_FIELD; import static org.opensearch.flowframework.common.CommonValue.DESCRIPTION_FIELD; +import static org.opensearch.flowframework.common.CommonValue.GUARDRAILS_FIELD; import static org.opensearch.flowframework.common.CommonValue.NAME_FIELD; import static org.opensearch.flowframework.common.CommonValue.REGISTER_MODEL_STATUS; import static org.opensearch.flowframework.common.WorkflowResources.CONNECTOR_ID; @@ -71,7 +73,7 @@ public PlainActionFuture execute( PlainActionFuture registerRemoteModelFuture = PlainActionFuture.newFuture(); Set requiredKeys = Set.of(NAME_FIELD, CONNECTOR_ID); - Set optionalKeys = Set.of(MODEL_GROUP_ID, DESCRIPTION_FIELD, DEPLOY_FIELD); + Set optionalKeys = Set.of(MODEL_GROUP_ID, DESCRIPTION_FIELD, DEPLOY_FIELD, GUARDRAILS_FIELD); try { Map inputs = ParseUtils.getInputsFromPreviousSteps( @@ -87,6 +89,7 @@ public PlainActionFuture execute( String modelGroupId = (String) inputs.get(MODEL_GROUP_ID); String description = (String) inputs.get(DESCRIPTION_FIELD); String connectorId = (String) inputs.get(CONNECTOR_ID); + Guardrails guardRails = (Guardrails) inputs.get(GUARDRAILS_FIELD); final Boolean deploy = (Boolean) inputs.get(DEPLOY_FIELD); MLRegisterModelInputBuilder builder = MLRegisterModelInput.builder() @@ -103,6 +106,11 @@ public PlainActionFuture execute( if (deploy != null) { builder.deployModel(deploy); } + + if (guardRails != null) { + builder.guardrails(guardRails); + } + MLRegisterModelInput mlInput = builder.build(); mlClient.register(mlInput, new ActionListener() {