Skip to content

Commit

Permalink
Combined register local model and get ml task (#220)
Browse files Browse the repository at this point in the history
* Combined register local model and get ml task

Signed-off-by: Joshua Palis <[email protected]>

* Handling task failure, returning error to user, added related test

Signed-off-by: Joshua Palis <[email protected]>

* Removing unnecessary MLConfig and spy

Signed-off-by: Joshua Palis <[email protected]>

---------

Signed-off-by: Joshua Palis <[email protected]>
  • Loading branch information
joshpalis authored Nov 30, 2023
1 parent 7c06865 commit e2581a1
Show file tree
Hide file tree
Showing 7 changed files with 187 additions and 395 deletions.
149 changes: 0 additions & 149 deletions src/main/java/org/opensearch/flowframework/workflow/GetMLTaskStep.java

This file was deleted.

65 changes: 0 additions & 65 deletions src/main/java/org/opensearch/flowframework/workflow/GetTask.java

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,14 @@
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.ExceptionsHelper;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.util.concurrent.FutureUtils;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.flowframework.exception.FlowFrameworkException;
import org.opensearch.ml.client.MachineLearningNodeClient;
import org.opensearch.ml.common.MLTaskState;
import org.opensearch.ml.common.model.MLModelConfig;
import org.opensearch.ml.common.model.MLModelFormat;
import org.opensearch.ml.common.model.TextEmbeddingModelConfig;
Expand All @@ -38,17 +42,17 @@
import static org.opensearch.flowframework.common.CommonValue.MODEL_CONTENT_HASH_VALUE;
import static org.opensearch.flowframework.common.CommonValue.MODEL_FORMAT;
import static org.opensearch.flowframework.common.CommonValue.MODEL_GROUP_ID;
import static org.opensearch.flowframework.common.CommonValue.MODEL_ID;
import static org.opensearch.flowframework.common.CommonValue.MODEL_TYPE;
import static org.opensearch.flowframework.common.CommonValue.NAME_FIELD;
import static org.opensearch.flowframework.common.CommonValue.REGISTER_MODEL_STATUS;
import static org.opensearch.flowframework.common.CommonValue.TASK_ID;
import static org.opensearch.flowframework.common.CommonValue.URL;
import static org.opensearch.flowframework.common.CommonValue.VERSION_FIELD;

/**
* Step to register a local model
*/
public class RegisterLocalModelStep implements WorkflowStep {
public class RegisterLocalModelStep extends AbstractRetryableWorkflowStep {

private static final Logger logger = LogManager.getLogger(RegisterLocalModelStep.class);

Expand All @@ -58,9 +62,12 @@ public class RegisterLocalModelStep implements WorkflowStep {

/**
* Instantiate this class
* @param settings The OpenSearch settings
* @param clusterService The cluster service
* @param mlClient client to instantiate MLClient
*/
public RegisterLocalModelStep(MachineLearningNodeClient mlClient) {
public RegisterLocalModelStep(Settings settings, ClusterService clusterService, MachineLearningNodeClient mlClient) {
super(settings, clusterService);
this.mlClient = mlClient;
}

Expand All @@ -74,20 +81,21 @@ public CompletableFuture<WorkflowData> execute(

CompletableFuture<WorkflowData> registerLocalModelFuture = new CompletableFuture<>();

// TODO: Recreating the list to get this compiling
// Need to refactor the below iteration to pull directly from the maps
List<WorkflowData> data = new ArrayList<>();
data.add(currentNodeInputs);
data.addAll(outputs.values());

ActionListener<MLRegisterModelResponse> actionListener = new ActionListener<>() {
@Override
public void onResponse(MLRegisterModelResponse mlRegisterModelResponse) {
logger.info("Local Model registration task creation successful");
registerLocalModelFuture.complete(
new WorkflowData(
Map.ofEntries(
Map.entry(TASK_ID, mlRegisterModelResponse.getTaskId()),
Map.entry(REGISTER_MODEL_STATUS, mlRegisterModelResponse.getStatus())
),
currentNodeInputs.getWorkflowId(),
currentNodeInputs.getNodeId()
)
);

String taskId = mlRegisterModelResponse.getTaskId();

// Attempt to retrieve the model ID
retryableGetMlTask(currentNodeInputs.getWorkflowId(), currentNodeId, registerLocalModelFuture, taskId, 0);
}

@Override
Expand All @@ -109,12 +117,6 @@ public void onFailure(Exception e) {
String allConfig = null;
String url = null;

// TODO: Recreating the list to get this compiling
// Need to refactor the below iteration to pull directly from the maps
List<WorkflowData> data = new ArrayList<>();
data.add(currentNodeInputs);
data.addAll(outputs.values());

for (WorkflowData workflowData : data) {
Map<String, Object> content = workflowData.getContent();

Expand Down Expand Up @@ -211,4 +213,63 @@ public void onFailure(Exception e) {
public String getName() {
return NAME;
}

/**
* Retryable get ml task
* @param workflowId the workflow id
* @param nodeId the workflow node id
* @param getMLTaskFuture the workflow step future
* @param taskId the ml task id
* @param retries the current number of request retries
*/
void retryableGetMlTask(
String workflowId,
String nodeId,
CompletableFuture<WorkflowData> registerLocalModelFuture,
String taskId,
int retries
) {
mlClient.getTask(taskId, ActionListener.wrap(response -> {
MLTaskState currentState = response.getState();
if (currentState != MLTaskState.COMPLETED) {
if (Stream.of(MLTaskState.FAILED, MLTaskState.COMPLETED_WITH_ERROR).anyMatch(x -> x == currentState)) {
// Model registration failed or completed with errors
String errorMessage = "Local model registration failed with error : " + response.getError();
logger.error(errorMessage);
registerLocalModelFuture.completeExceptionally(new FlowFrameworkException(errorMessage, RestStatus.BAD_REQUEST));
} else {
// Task still in progress, attempt retry
throw new IllegalStateException("Local model registration is not yet completed");
}
} else {
logger.info("Local model registeration successful");
registerLocalModelFuture.complete(
new WorkflowData(
Map.ofEntries(
Map.entry(MODEL_ID, response.getModelId()),
Map.entry(REGISTER_MODEL_STATUS, response.getState().name())
),
workflowId,
nodeId
)
);
}
}, exception -> {
if (retries < maxRetry) {
// Sleep thread prior to retrying request
try {
Thread.sleep(5000);
} catch (Exception e) {
FutureUtils.cancel(registerLocalModelFuture);
}
final int retryAdd = retries + 1;
retryableGetMlTask(workflowId, nodeId, registerLocalModelFuture, taskId, retryAdd);
} else {
logger.error("Failed to retrieve local model registration task, maximum retries exceeded");
registerLocalModelFuture.completeExceptionally(
new FlowFrameworkException(exception.getMessage(), ExceptionsHelper.status(exception))
);
}
}));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,11 @@ private void populateMap(
stepMap.put(NoOpStep.NAME, new NoOpStep());
stepMap.put(CreateIndexStep.NAME, new CreateIndexStep(clusterService, client));
stepMap.put(CreateIngestPipelineStep.NAME, new CreateIngestPipelineStep(client));
stepMap.put(RegisterLocalModelStep.NAME, new RegisterLocalModelStep(mlClient));
stepMap.put(RegisterLocalModelStep.NAME, new RegisterLocalModelStep(settings, clusterService, mlClient));
stepMap.put(RegisterRemoteModelStep.NAME, new RegisterRemoteModelStep(mlClient));
stepMap.put(DeployModelStep.NAME, new DeployModelStep(mlClient));
stepMap.put(CreateConnectorStep.NAME, new CreateConnectorStep(mlClient, flowFrameworkIndicesHandler));
stepMap.put(ModelGroupStep.NAME, new ModelGroupStep(mlClient));
stepMap.put(GetMLTaskStep.NAME, new GetMLTaskStep(settings, clusterService, mlClient));
}

/**
Expand Down
Loading

0 comments on commit e2581a1

Please sign in to comment.