diff --git a/src/main/java/org/opensearch/flowframework/indices/FlowFrameworkIndicesHandler.java b/src/main/java/org/opensearch/flowframework/indices/FlowFrameworkIndicesHandler.java index 4b6a21809..546f1be73 100644 --- a/src/main/java/org/opensearch/flowframework/indices/FlowFrameworkIndicesHandler.java +++ b/src/main/java/org/opensearch/flowframework/indices/FlowFrameworkIndicesHandler.java @@ -436,13 +436,12 @@ public void updateTemplateInGlobalContext(String documentId, Template template, listener.onFailure(new FlowFrameworkException(errorMessage, RestStatus.BAD_REQUEST)); } }, listener); - } else { String errorMessage = "Failed to get template: " + documentId; logger.error(errorMessage); listener.onFailure(new FlowFrameworkException(errorMessage, RestStatus.BAD_REQUEST)); } - }); + }, listener); } /** @@ -450,13 +449,23 @@ public void updateTemplateInGlobalContext(String documentId, Template template, * * @param documentId document id * @param booleanResultConsumer boolean consumer function + * @param listener action listener + * @param action listener response type */ - public void doesTemplateExists(String documentId, Consumer booleanResultConsumer) { + public void doesTemplateExists(String documentId, Consumer booleanResultConsumer, ActionListener listener) { GetRequest getRequest = new GetRequest(GLOBAL_CONTEXT_INDEX, documentId); - client.get(getRequest, ActionListener.wrap(response -> { booleanResultConsumer.accept(response.isExists()); }, exception -> { - logger.error("Failed to get template " + documentId, exception); + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + client.get(getRequest, ActionListener.wrap(response -> { booleanResultConsumer.accept(response.isExists()); }, exception -> { + context.restore(); + logger.error("Failed to get template " + documentId, exception); + booleanResultConsumer.accept(false); + listener.onFailure(new FlowFrameworkException(exception.getMessage(), ExceptionsHelper.status(exception))); + })); + } catch (Exception e) { + logger.error("Failed to retrieve template from global context.", e); booleanResultConsumer.accept(false); - })); + listener.onFailure(new FlowFrameworkException(e.getMessage(), ExceptionsHelper.status(e))); + } } /** @@ -469,28 +478,36 @@ public void doesTemplateExists(String documentId, Consumer booleanResul */ public void isWorkflowProvisioned(String documentId, Consumer booleanResultConsumer, ActionListener listener) { GetRequest getRequest = new GetRequest(WORKFLOW_STATE_INDEX, documentId); - client.get(getRequest, ActionListener.wrap(response -> { - if (!response.isExists()) { - booleanResultConsumer.accept(false); - return; - } - try (XContentParser parser = ParseUtils.createXContentParserFromRegistry(xContentRegistry, response.getSourceAsBytesRef())) { - ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); - WorkflowState workflowState = WorkflowState.parse(parser); - if (workflowState.getProvisioningProgress().equals(ProvisioningProgress.NOT_STARTED.name())) { - booleanResultConsumer.accept(true); - } else { + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + client.get(getRequest, ActionListener.wrap(response -> { + context.restore(); + if (!response.isExists()) { booleanResultConsumer.accept(false); + return; } - } catch (Exception e) { - String message = "Failed to parse workflow state" + documentId; - logger.error(message, e); - listener.onFailure(new FlowFrameworkException(message, RestStatus.INTERNAL_SERVER_ERROR)); - } - }, exception -> { - logger.error("Failed to get workflow state " + documentId, exception); - booleanResultConsumer.accept(false); - })); + try ( + XContentParser parser = ParseUtils.createXContentParserFromRegistry(xContentRegistry, response.getSourceAsBytesRef()) + ) { + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + WorkflowState workflowState = WorkflowState.parse(parser); + if (workflowState.getProvisioningProgress().equals(ProvisioningProgress.NOT_STARTED.name())) { + booleanResultConsumer.accept(true); + } else { + booleanResultConsumer.accept(false); + } + } catch (Exception e) { + String message = "Failed to parse workflow state" + documentId; + logger.error(message, e); + listener.onFailure(new FlowFrameworkException(message, RestStatus.INTERNAL_SERVER_ERROR)); + } + }, exception -> { + logger.error("Failed to get workflow state " + documentId, exception); + booleanResultConsumer.accept(false); + })); + } catch (Exception e) { + logger.error("Failed to retrieve workflow state to check provisioning status", e); + listener.onFailure(new FlowFrameworkException(e.getMessage(), ExceptionsHelper.status(e))); + } } /** diff --git a/src/test/java/org/opensearch/flowframework/TestHelpers.java b/src/test/java/org/opensearch/flowframework/TestHelpers.java index 3b3373253..4a0a055c3 100644 --- a/src/test/java/org/opensearch/flowframework/TestHelpers.java +++ b/src/test/java/org/opensearch/flowframework/TestHelpers.java @@ -30,6 +30,9 @@ import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.flowframework.model.Template; +import org.opensearch.flowframework.model.Workflow; +import org.opensearch.flowframework.model.WorkflowEdge; +import org.opensearch.flowframework.model.WorkflowNode; import org.opensearch.flowframework.util.ParseUtils; import java.io.BufferedReader; @@ -37,6 +40,7 @@ import java.io.InputStream; import java.io.InputStreamReader; import java.util.Arrays; +import java.util.Collections; import java.util.List; import java.util.Map; import java.util.Set; @@ -153,4 +157,15 @@ public static XContentBuilder builder() throws IOException { public static Map XContentBuilderToMap(XContentBuilder builder) { return XContentHelper.convertToMap(BytesReference.bytes(builder), false, builder.contentType()).v2(); } + + public static Workflow createSampleWorkflow() { + WorkflowNode nodeA = new WorkflowNode("A", "a-type", Collections.emptyMap(), Map.of("foo", "bar")); + WorkflowNode nodeB = new WorkflowNode("B", "b-type", Collections.emptyMap(), Map.of("baz", "qux")); + WorkflowEdge edgeAB = new WorkflowEdge("A", "B"); + List nodes = List.of(nodeA, nodeB); + List edges = List.of(edgeAB); + Workflow workflow = new Workflow(Map.of("key", "value"), nodes, edges); + return workflow; + } + } diff --git a/src/test/java/org/opensearch/flowframework/indices/FlowFrameworkIndicesHandlerTests.java b/src/test/java/org/opensearch/flowframework/indices/FlowFrameworkIndicesHandlerTests.java index b02e5eca5..73622843f 100644 --- a/src/test/java/org/opensearch/flowframework/indices/FlowFrameworkIndicesHandlerTests.java +++ b/src/test/java/org/opensearch/flowframework/indices/FlowFrameworkIndicesHandlerTests.java @@ -8,8 +8,11 @@ */ package org.opensearch.flowframework.indices; +import org.opensearch.Version; import org.opensearch.action.admin.indices.create.CreateIndexRequest; import org.opensearch.action.admin.indices.mapping.put.PutMappingRequest; +import org.opensearch.action.get.GetRequest; +import org.opensearch.action.get.GetResponse; import org.opensearch.action.index.IndexResponse; import org.opensearch.action.support.master.AcknowledgedResponse; import org.opensearch.client.AdminClient; @@ -23,25 +26,38 @@ import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.flowframework.TestHelpers; import org.opensearch.flowframework.model.Template; +import org.opensearch.flowframework.model.Workflow; +import org.opensearch.flowframework.model.WorkflowState; import org.opensearch.flowframework.util.EncryptorUtils; import org.opensearch.flowframework.workflow.CreateIndexStep; +import org.opensearch.index.get.GetResult; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.ThreadPool; import java.io.IOException; +import java.time.Instant; +import java.util.Collections; import java.util.HashMap; +import java.util.List; import java.util.Map; import java.util.concurrent.atomic.AtomicBoolean; +import java.util.function.Consumer; import org.mockito.ArgumentCaptor; import org.mockito.Mock; import org.mockito.MockitoAnnotations; import static org.opensearch.flowframework.common.CommonValue.GLOBAL_CONTEXT_INDEX; +import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_STATE_INDEX; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; @@ -70,6 +86,7 @@ public class FlowFrameworkIndicesHandlerTests extends OpenSearchTestCase { private Map indexMappingUpdated = new HashMap<>(); @Mock IndexMetadata indexMetadata; + private Template template; @Override public void setUp() throws Exception { @@ -91,6 +108,20 @@ public void setUp() throws Exception { when(adminClient.indices()).thenReturn(indicesAdminClient); when(clusterService.state()).thenReturn(ClusterState.builder(new ClusterName("test cluster")).build()); when(metadata.indices()).thenReturn(Map.of(GLOBAL_CONTEXT_INDEX, indexMetadata)); + + Workflow workflow = TestHelpers.createSampleWorkflow(); + Version templateVersion = Version.fromString("1.0.0"); + List compatibilityVersions = List.of(Version.fromString("2.0.0"), Version.fromString("3.0.0")); + this.template = new Template( + "test", + "description", + "use case", + templateVersion, + compatibilityVersions, + Map.of("workflow", workflow), + Collections.emptyMap(), + TestHelpers.randomUser() + ); } public void testDoesIndexExist() { @@ -192,4 +223,54 @@ public void testInitIndexIfAbsent_IndexNotPresent() { verify(indicesAdminClient, times(1)).create(any(CreateIndexRequest.class), any()); } + + public void testIsWorkflowProvisionedFailedParsing() { + String documentId = randomAlphaOfLength(5); + Consumer function = mock(Consumer.class); + ActionListener listener = mock(ActionListener.class); + WorkflowState workFlowState = new WorkflowState( + documentId, + "test", + "PROVISIONING", + "IN_PROGRESS", + Instant.now(), + Instant.now(), + TestHelpers.randomUser(), + Collections.emptyMap(), + Collections.emptyList() + ); + doAnswer(invocation -> { + ActionListener responseListener = invocation.getArgument(1); + + XContentBuilder builder = XContentFactory.jsonBuilder(); + // workFlowState.toXContent(builder, null); + this.template.toXContent(builder, null); + BytesReference workflowBytesRef = BytesReference.bytes(builder); + GetResult getResult = new GetResult(WORKFLOW_STATE_INDEX, documentId, 1, 1, 1, true, workflowBytesRef, null, null); + responseListener.onResponse(new GetResponse(getResult)); + return null; + }).when(client).get(any(GetRequest.class), any()); + flowFrameworkIndicesHandler.isWorkflowProvisioned(documentId, function, listener); + ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(Exception.class); + verify(listener, times(1)).onFailure(exceptionCaptor.capture()); + assertTrue(exceptionCaptor.getValue().getMessage().contains("Failed to parse workflow state")); + } + + public void testDoesTemplateExist() { + String documentId = randomAlphaOfLength(5); + Consumer function = mock(Consumer.class); + ActionListener listener = mock(ActionListener.class); + doAnswer(invocation -> { + ActionListener responseListener = invocation.getArgument(1); + + XContentBuilder builder = XContentFactory.jsonBuilder(); + this.template.toXContent(builder, null); + BytesReference templateBytesRef = BytesReference.bytes(builder); + GetResult getResult = new GetResult(GLOBAL_CONTEXT_INDEX, documentId, 1, 1, 1, true, templateBytesRef, null, null); + responseListener.onResponse(new GetResponse(getResult)); + return null; + }).when(client).get(any(GetRequest.class), any()); + flowFrameworkIndicesHandler.doesTemplateExists(documentId, function, listener); + verify(function).accept(true); + } } diff --git a/src/test/java/org/opensearch/flowframework/rest/FlowFrameworkRestApiIT.java b/src/test/java/org/opensearch/flowframework/rest/FlowFrameworkRestApiIT.java index abff7ab8c..215300cf5 100644 --- a/src/test/java/org/opensearch/flowframework/rest/FlowFrameworkRestApiIT.java +++ b/src/test/java/org/opensearch/flowframework/rest/FlowFrameworkRestApiIT.java @@ -162,10 +162,10 @@ public void testCreateAndProvisionLocalModelWorkflow() throws Exception { assertNotNull(resourcesCreated.get(1).resourceId()); // // Deprovision the workflow to avoid opening circut breaker when running additional tests - // Response deprovisionResponse = deprovisionWorkflow(client(), workflowId); - // - // // wait for deprovision to complete - // Thread.sleep(5000); + Response deprovisionResponse = deprovisionWorkflow(client(), workflowId); + + // wait for deprovision to complete + Thread.sleep(5000); } public void testCreateAndProvisionCyclicalTemplate() throws Exception { @@ -230,10 +230,10 @@ public void testCreateAndProvisionRemoteModelWorkflow() throws Exception { assertNotNull(resourcesCreated.get(2).resourceId()); // Deprovision the workflow to avoid opening circut breaker when running additional tests - // Response deprovisionResponse = deprovisionWorkflow(client(), workflowId); - // - // // wait for deprovision to complete - // Thread.sleep(5000); + Response deprovisionResponse = deprovisionWorkflow(client(), workflowId); + + // wait for deprovision to complete + Thread.sleep(5000); } public void testCreateAndProvisionAgentFrameworkWorkflow() throws Exception { @@ -272,16 +272,16 @@ public void testCreateAndProvisionAgentFrameworkWorkflow() throws Exception { assertNotNull(resourcesCreated.get(0).resourceId()); // Hit Deprovision API - // Response deprovisionResponse = deprovisionWorkflow(client(), workflowId); - // assertBusy( - // () -> { getAndAssertWorkflowStatus(client(), workflowId, State.NOT_STARTED, ProvisioningProgress.NOT_STARTED); }, - // 60, - // TimeUnit.SECONDS - // ); + Response deprovisionResponse = deprovisionWorkflow(client(), workflowId); + assertBusy( + () -> { getAndAssertWorkflowStatus(client(), workflowId, State.NOT_STARTED, ProvisioningProgress.NOT_STARTED); }, + 60, + TimeUnit.SECONDS + ); // // // Hit Delete API - // Response deleteResponse = deleteWorkflow(client(), workflowId); - // assertEquals(RestStatus.OK, TestHelpers.restStatus(deleteResponse)); + Response deleteResponse = deleteWorkflow(client(), workflowId); + assertEquals(RestStatus.OK, TestHelpers.restStatus(deleteResponse)); } } diff --git a/src/test/java/org/opensearch/flowframework/transport/CreateWorkflowTransportActionTests.java b/src/test/java/org/opensearch/flowframework/transport/CreateWorkflowTransportActionTests.java index ce08cdb8a..e8c8ba4f3 100644 --- a/src/test/java/org/opensearch/flowframework/transport/CreateWorkflowTransportActionTests.java +++ b/src/test/java/org/opensearch/flowframework/transport/CreateWorkflowTransportActionTests.java @@ -113,12 +113,7 @@ public void setUp() throws Exception { Version templateVersion = Version.fromString("1.0.0"); List compatibilityVersions = List.of(Version.fromString("2.0.0"), Version.fromString("3.0.0")); - WorkflowNode nodeA = new WorkflowNode("A", "a-type", Collections.emptyMap(), Map.of("foo", "bar")); - WorkflowNode nodeB = new WorkflowNode("B", "b-type", Collections.emptyMap(), Map.of("baz", "qux")); - WorkflowEdge edgeAB = new WorkflowEdge("A", "B"); - List nodes = List.of(nodeA, nodeB); - List edges = List.of(edgeAB); - Workflow workflow = new Workflow(Map.of("key", "value"), nodes, edges); + Workflow workflow = TestHelpers.createSampleWorkflow(); this.template = new Template( "test",