Skip to content

Commit

Permalink
adding IT and move configurations parsing to input parsing
Browse files Browse the repository at this point in the history
Signed-off-by: Amit Galitzky <[email protected]>
  • Loading branch information
amitgalitz committed Mar 12, 2024
1 parent 3dcae38 commit 29f66cd
Show file tree
Hide file tree
Showing 5 changed files with 233 additions and 93 deletions.
25 changes: 19 additions & 6 deletions src/main/java/org/opensearch/flowframework/util/ParseUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@ public class ParseUtils {
private static final Logger logger = LogManager.getLogger(ParseUtils.class);

// Matches ${{ foo.bar }} (whitespace optional) with capturing groups 1=foo, 2=bar
private static final Pattern SUBSTITUTION_PATTERN = Pattern.compile("\\$\\{\\{\\s*(.+)\\.(.+?)\\s*\\}\\}");
// private static final Pattern SUBSTITUTION_PATTERN = Pattern.compile("\\$\\{\\{\\s*(.+)\\.(.+?)\\s*\\}\\}");
private static final Pattern SUBSTITUTION_PATTERN = Pattern.compile("\\$\\{\\{\\s*([\\w_]+)\\.([\\w_]+)\\s*\\}\\}");

private ParseUtils() {}

Expand Down Expand Up @@ -344,13 +345,25 @@ public static Map<String, Object> getInputsFromPreviousSteps(
private static Object conditionallySubstitute(Object value, Map<String, WorkflowData> outputs, Map<String, String> params) {
if (value instanceof String) {
Matcher m = SUBSTITUTION_PATTERN.matcher((String) value);
if (m.matches()) {
// Try matching a previous step+value pair
WorkflowData data = outputs.get(m.group(1));
if (data != null && data.getContent().containsKey(m.group(2))) {
return data.getContent().get(m.group(2));
StringBuilder result = new StringBuilder();
while (m.find()) {
// outputs content map contains values for previous node input (e.g: deploy_openai_model.model_id)
// Check first if the substitution is looking for the same key, value pair and if yes
// then replace it with the key value pair in the inputs map
String replacement = m.group(0);
if (outputs.containsKey(m.group(1)) && outputs.get(m.group(1)).getContent().containsKey(m.group(2))) {
// Extract the key for the inputs (e.g., "model_id" from ${{deploy_openai_model.model_id}})
String key = m.group(2);
if (outputs.get(m.group(1)).getContent().get(key) instanceof String) {
replacement = (String) outputs.get(m.group(1)).getContent().get(key);
// Replace the whole sequence with the value from the map
m.appendReplacement(result, Matcher.quoteReplacement(replacement));
}
}
}
m.appendTail(result);
value = result.toString();

// Replace all params if present
for (Entry<String, String> e : params.entrySet()) {
String regex = "\\$\\{\\{\\s*" + Pattern.quote(e.getKey()) + "\\s*\\}\\}";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,13 @@
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.bytes.BytesArray;
import org.opensearch.core.common.bytes.BytesReference;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.flowframework.exception.FlowFrameworkException;
import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler;
import org.opensearch.flowframework.util.ParseUtils;

import java.nio.charset.StandardCharsets;
import java.util.Map;
import java.util.Set;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

import static org.opensearch.flowframework.common.CommonValue.CONFIGURATIONS;
import static org.opensearch.flowframework.common.WorkflowResources.MODEL_ID;
Expand All @@ -51,7 +48,7 @@ public class CreateIngestPipelineStep implements WorkflowStep {
private final FlowFrameworkIndicesHandler flowFrameworkIndicesHandler;

/**
* Instantiates a new CreateIngestPipelineStepDraft
* Instantiates a new CreateIngestPipelineStep
* @param client The client to create a pipeline and store workflow data into the global context index
* @param flowFrameworkIndicesHandler FlowFrameworkIndicesHandler class to update system indices
*/
Expand All @@ -71,58 +68,6 @@ public PlainActionFuture<WorkflowData> execute(

PlainActionFuture<WorkflowData> createIngestPipelineFuture = PlainActionFuture.newFuture();

ActionListener<AcknowledgedResponse> actionListener = new ActionListener<>() {

@Override
public void onResponse(AcknowledgedResponse acknowledgedResponse) {
String resourceName = getResourceByWorkflowStep(getName());
try {
flowFrameworkIndicesHandler.updateResourceInStateIndex(
currentNodeInputs.getWorkflowId(),
currentNodeId,
getName(),
currentNodeInputs.getContent().get(PIPELINE_ID).toString(),
ActionListener.wrap(updateResponse -> {
logger.info("successfully updated resources created in state index: {}", updateResponse.getIndex());
// PutPipelineRequest returns only an AcknowledgeResponse, returning pipelineId instead
// TODO: revisit this concept of pipeline_id to be consistent with what makes most sense to end user here
createIngestPipelineFuture.onResponse(
new WorkflowData(
Map.of(resourceName, currentNodeInputs.getContent().get(PIPELINE_ID).toString()),
currentNodeInputs.getWorkflowId(),
currentNodeInputs.getNodeId()
)
);
}, exception -> {
String errorMessage = "Failed to update new created "
+ currentNodeId
+ " resource "
+ getName()
+ " id "
+ currentNodeInputs.getContent().get(PIPELINE_ID).toString();
logger.error(errorMessage, exception);
createIngestPipelineFuture.onFailure(
new FlowFrameworkException(errorMessage, ExceptionsHelper.status(exception))
);
})
);

} catch (Exception e) {
String errorMessage = "Failed to parse and update new created resource";
logger.error(errorMessage, e);
createIngestPipelineFuture.onFailure(new FlowFrameworkException(errorMessage, ExceptionsHelper.status(e)));
}
}

@Override
public void onFailure(Exception e) {
String errorMessage = "Failed to create ingest pipeline";
logger.error(errorMessage, e);
createIngestPipelineFuture.onFailure(new FlowFrameworkException(errorMessage, ExceptionsHelper.status(e)));
}

};

Set<String> requiredKeys = Set.of(PIPELINE_ID, CONFIGURATIONS);

// currently, we are supporting an optional param of model ID into the various processors
Expand All @@ -141,39 +86,60 @@ public void onFailure(Exception e) {
String pipelineId = (String) inputs.get(PIPELINE_ID);
String configurations = (String) inputs.get(CONFIGURATIONS);

// Regex to find patterns like ${{deploy_openai_model.model_id}}
// We currently support one previous node input that fits the pattern of (step.input_to_look_for)
Pattern pattern = Pattern.compile("\\$\\{\\{([\\w_]+)\\.([\\w_]+)\\}\\}");
Matcher matcher = pattern.matcher(configurations);

StringBuffer result = new StringBuffer();
while (matcher.find()) {
// Params map contains params for previous node input (e.g: deploy_openai_model:model_id)
// Check first if the substitution is looking for the same key, value pair and if yes
// then replace it with the key value pair in the inputs map
if (params.containsKey(matcher.group(1)) && params.get(matcher.group(1)).equals(matcher.group(2))) {
// Extract the key for the inputs (e.g., "model_id" from ${{deploy_openai_model.model_id}})
String key = matcher.group(2);
if (inputs.containsKey(key)) {
// Replace the whole sequence with the value from the map
matcher.appendReplacement(result, (String) inputs.get(key));
byte[] byteArr = configurations.getBytes(StandardCharsets.UTF_8);
BytesReference configurationsBytes = new BytesArray(byteArr);

ActionListener<AcknowledgedResponse> actionListener = new ActionListener<>() {

@Override
public void onResponse(AcknowledgedResponse acknowledgedResponse) {
String resourceName = getResourceByWorkflowStep(getName());
try {
flowFrameworkIndicesHandler.updateResourceInStateIndex(
currentNodeInputs.getWorkflowId(),
currentNodeId,
getName(),
pipelineId,
ActionListener.wrap(updateResponse -> {
logger.info("successfully updated resources created in state index: {}", updateResponse.getIndex());
// PutPipelineRequest returns only an AcknowledgeResponse, saving pipelineId instead
// TODO: revisit this concept of pipeline_id to be consistent with what makes most sense to end user here
createIngestPipelineFuture.onResponse(
new WorkflowData(
Map.of(resourceName, pipelineId),
currentNodeInputs.getWorkflowId(),
currentNodeInputs.getNodeId()
)
);
}, exception -> {
String errorMessage = "Failed to update new created "
+ currentNodeId
+ " resource "
+ getName()
+ " id "
+ pipelineId;
logger.error(errorMessage, exception);
createIngestPipelineFuture.onFailure(
new FlowFrameworkException(errorMessage, ExceptionsHelper.status(exception))
);
})
);

} catch (Exception e) {
String errorMessage = "Failed to parse and update new created resource";
logger.error(errorMessage, e);
createIngestPipelineFuture.onFailure(new FlowFrameworkException(errorMessage, ExceptionsHelper.status(e)));
}
}
}
matcher.appendTail(result);

if (result == null || pipelineId == null) {
// Required workflow data not found
createIngestPipelineFuture.onFailure(
new FlowFrameworkException(
"Failed to create ingest pipeline for " + currentNodeId + ", required inputs not found",
RestStatus.BAD_REQUEST
)
);
}

byte[] byteArr = result.toString().getBytes(StandardCharsets.UTF_8);
BytesReference configurationsBytes = new BytesArray(byteArr);

@Override
public void onFailure(Exception e) {
String errorMessage = "Failed to create ingest pipeline";
logger.error(errorMessage, e);
createIngestPipelineFuture.onFailure(new FlowFrameworkException(errorMessage, ExceptionsHelper.status(e)));
}

};

// Create PutPipelineRequest and execute
PutPipelineRequest putPipelineRequest = new PutPipelineRequest(pipelineId, configurationsBytes, XContentType.JSON);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import org.apache.hc.core5.reactor.ssl.TlsDetails;
import org.apache.hc.core5.ssl.SSLContextBuilder;
import org.apache.hc.core5.util.Timeout;
import org.opensearch.action.ingest.GetPipelineResponse;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.client.Request;
import org.opensearch.client.Response;
Expand All @@ -44,6 +45,7 @@
import org.opensearch.flowframework.model.State;
import org.opensearch.flowframework.model.Template;
import org.opensearch.flowframework.model.WorkflowState;
import org.opensearch.ml.repackage.com.google.common.collect.ImmutableList;
import org.opensearch.test.rest.OpenSearchRestTestCase;
import org.junit.After;
import org.junit.Before;
Expand Down Expand Up @@ -648,4 +650,28 @@ protected Response deleteUser(String user) throws IOException {
List.of(new BasicHeader(HttpHeaders.USER_AGENT, "Kibana"))
);
}

protected GetPipelineResponse getPipelines() throws IOException {
Response getPipelinesResponse = TestHelpers.makeRequest(
client(),
"GET",
"_ingest/pipeline",
null,
"",
ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, ""))
);

// Parse entity content into SearchResponse
MediaType mediaType = MediaType.fromMediaType(getPipelinesResponse.getEntity().getContentType());
try (
XContentParser parser = mediaType.xContent()
.createParser(
NamedXContentRegistry.EMPTY,
DeprecationHandler.THROW_UNSUPPORTED_OPERATION,
getPipelinesResponse.getEntity().getContent()
)
) {
return GetPipelineResponse.fromXContent(parser);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import org.apache.hc.core5.http.io.entity.EntityUtils;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.action.ingest.GetPipelineResponse;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.client.Response;
import org.opensearch.client.ResponseException;
Expand Down Expand Up @@ -344,4 +345,50 @@ public void testTimestamps() throws Exception {
response = deleteWorkflow(client(), workflowId);
assertEquals(RestStatus.OK.getStatus(), response.getStatusLine().getStatusCode());
}

public void testCreateAndProvisionIngestPipeline() throws Exception {

// Using a 3 step template to create a connector, register remote model and deploy model
Template template = TestHelpers.createTemplateFromFile("ingest-pipeline-template.json");

// Hit Create Workflow API with original template
Response response = createWorkflow(client(), template);
assertEquals(RestStatus.CREATED, TestHelpers.restStatus(response));

Map<String, Object> responseMap = entityAsMap(response);
String workflowId = (String) responseMap.get(WORKFLOW_ID);
getAndAssertWorkflowStatus(client(), workflowId, State.NOT_STARTED, ProvisioningProgress.NOT_STARTED);

// Ensure Ml config index is initialized as creating a connector requires this, then hit Provision API and assert status
if (!indexExistsWithAdminClient(".plugins-ml-config")) {
assertBusy(() -> assertTrue(indexExistsWithAdminClient(".plugins-ml-config")), 40, TimeUnit.SECONDS);
response = provisionWorkflow(client(), workflowId);
} else {
response = provisionWorkflow(client(), workflowId);
}

assertEquals(RestStatus.OK, TestHelpers.restStatus(response));
getAndAssertWorkflowStatus(client(), workflowId, State.PROVISIONING, ProvisioningProgress.IN_PROGRESS);

// Wait until provisioning has completed successfully before attempting to retrieve created resources
List<ResourceCreated> resourcesCreated = getResourcesCreated(client(), workflowId, 30);

// This template should create 4 resources, connector_id, registered model_id, deployed model_id and pipelineId
assertEquals(4, resourcesCreated.size());
assertEquals("create_connector", resourcesCreated.get(0).workflowStepName());
assertNotNull(resourcesCreated.get(0).resourceId());
assertEquals("register_remote_model", resourcesCreated.get(1).workflowStepName());
assertNotNull(resourcesCreated.get(1).resourceId());
assertEquals("deploy_model", resourcesCreated.get(2).workflowStepName());
assertNotNull(resourcesCreated.get(2).resourceId());
assertEquals("create_ingest_pipeline", resourcesCreated.get(3).workflowStepName());
assertNotNull(resourcesCreated.get(3).resourceId());
String modelId = resourcesCreated.get(2).resourceId();

GetPipelineResponse getPipelinesResponse = getPipelines();

assertTrue(getPipelinesResponse.pipelines().get(0).toString().contains(modelId));

}

}
Loading

0 comments on commit 29f66cd

Please sign in to comment.