diff --git a/src/main/java/org/opensearch/flowframework/indices/FlowFrameworkIndicesHandler.java b/src/main/java/org/opensearch/flowframework/indices/FlowFrameworkIndicesHandler.java index 546f1be73..fcb3a55de 100644 --- a/src/main/java/org/opensearch/flowframework/indices/FlowFrameworkIndicesHandler.java +++ b/src/main/java/org/opensearch/flowframework/indices/FlowFrameworkIndicesHandler.java @@ -410,7 +410,7 @@ public void updateTemplateInGlobalContext(String documentId, Template template, listener.onFailure(new FlowFrameworkException(exceptionMessage, RestStatus.BAD_REQUEST)); return; } - doesTemplateExists(documentId, templateExists -> { + doesTemplateExist(documentId, templateExists -> { if (templateExists) { isWorkflowProvisioned(documentId, workflowIsProvisioned -> { if (workflowIsProvisioned) { @@ -448,22 +448,20 @@ public void updateTemplateInGlobalContext(String documentId, Template template, * Check if the given template exists in the template index * * @param documentId document id - * @param booleanResultConsumer boolean consumer function + * @param booleanResultConsumer a consumer based on whether the template exist * @param listener action listener * @param action listener response type */ - public void doesTemplateExists(String documentId, Consumer booleanResultConsumer, ActionListener listener) { + public void doesTemplateExist(String documentId, Consumer booleanResultConsumer, ActionListener listener) { GetRequest getRequest = new GetRequest(GLOBAL_CONTEXT_INDEX, documentId); try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { client.get(getRequest, ActionListener.wrap(response -> { booleanResultConsumer.accept(response.isExists()); }, exception -> { context.restore(); logger.error("Failed to get template " + documentId, exception); - booleanResultConsumer.accept(false); listener.onFailure(new FlowFrameworkException(exception.getMessage(), ExceptionsHelper.status(exception))); })); } catch (Exception e) { logger.error("Failed to retrieve template from global context.", e); - booleanResultConsumer.accept(false); listener.onFailure(new FlowFrameworkException(e.getMessage(), ExceptionsHelper.status(e))); } } @@ -472,7 +470,7 @@ public void doesTemplateExists(String documentId, Consumer booleanR * Check if the workflow has been provisioned and executes the consumer by passing a boolean * * @param documentId document id - * @param booleanResultConsumer boolean consumer function + * @param booleanResultConsumer boolean consumer function based on if workflow is provisioned or not * @param listener action listener * @param action listener response type */ @@ -490,13 +488,9 @@ public void isWorkflowProvisioned(String documentId, Consumer boole ) { ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); WorkflowState workflowState = WorkflowState.parse(parser); - if (workflowState.getProvisioningProgress().equals(ProvisioningProgress.NOT_STARTED.name())) { - booleanResultConsumer.accept(true); - } else { - booleanResultConsumer.accept(false); - } + booleanResultConsumer.accept(workflowState.getProvisioningProgress().equals(ProvisioningProgress.NOT_STARTED.name())); } catch (Exception e) { - String message = "Failed to parse workflow state" + documentId; + String message = "Failed to parse workflow state " + documentId; logger.error(message, e); listener.onFailure(new FlowFrameworkException(message, RestStatus.INTERNAL_SERVER_ERROR)); } diff --git a/src/test/java/org/opensearch/flowframework/indices/FlowFrameworkIndicesHandlerTests.java b/src/test/java/org/opensearch/flowframework/indices/FlowFrameworkIndicesHandlerTests.java index 73622843f..9939ba0cd 100644 --- a/src/test/java/org/opensearch/flowframework/indices/FlowFrameworkIndicesHandlerTests.java +++ b/src/test/java/org/opensearch/flowframework/indices/FlowFrameworkIndicesHandlerTests.java @@ -155,6 +155,30 @@ public void testFailedUpdateTemplateInGlobalContext() throws IOException { ); } + public void testFailedUpdateTemplateInGlobalContextNotExisting() throws IOException { + Template template = mock(Template.class); + @SuppressWarnings("unchecked") + ActionListener listener = mock(ActionListener.class); + FlowFrameworkIndex index = FlowFrameworkIndex.GLOBAL_CONTEXT; + ClusterState mockClusterState = mock(ClusterState.class); + Metadata mockMetadata = mock(Metadata.class); + when(clusterService.state()).thenReturn(mockClusterState); + when(mockClusterState.metadata()).thenReturn(mockMetadata); + when(mockMetadata.hasIndex(index.getIndexName())).thenReturn(true); + when(flowFrameworkIndicesHandler.doesIndexExist(GLOBAL_CONTEXT_INDEX)).thenReturn(true); + doAnswer(invocation -> { + ActionListener responseListener = invocation.getArgument(1); + responseListener.onFailure(new Exception("Failed to get template")); + return null; + }).when(client).get(any(GetRequest.class), any()); + + flowFrameworkIndicesHandler.updateTemplateInGlobalContext("1", template, listener); + + ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(Exception.class); + verify(listener, times(1)).onFailure(exceptionCaptor.capture()); + assertTrue(exceptionCaptor.getValue().getMessage().contains("Failed to get template")); + } + public void testInitIndexIfAbsent_IndexExist() { FlowFrameworkIndex index = FlowFrameworkIndex.GLOBAL_CONTEXT; indexMappingUpdated.put(index.getIndexName(), new AtomicBoolean(false)); @@ -270,7 +294,7 @@ public void testDoesTemplateExist() { responseListener.onResponse(new GetResponse(getResult)); return null; }).when(client).get(any(GetRequest.class), any()); - flowFrameworkIndicesHandler.doesTemplateExists(documentId, function, listener); + flowFrameworkIndicesHandler.doesTemplateExist(documentId, function, listener); verify(function).accept(true); } } diff --git a/src/test/java/org/opensearch/flowframework/rest/FlowFrameworkRestApiIT.java b/src/test/java/org/opensearch/flowframework/rest/FlowFrameworkRestApiIT.java index 215300cf5..c097fce6b 100644 --- a/src/test/java/org/opensearch/flowframework/rest/FlowFrameworkRestApiIT.java +++ b/src/test/java/org/opensearch/flowframework/rest/FlowFrameworkRestApiIT.java @@ -69,9 +69,12 @@ public void testSearchWorkflows() throws Exception { } public void testFailedUpdateWorkflow() throws Exception { - // Create a Workflow that has a credential 12345 - Template template = TestHelpers.createTemplateFromFile("createconnector-registerremotemodel-deploymodel.json"); + Template templateCreation = TestHelpers.createTemplateFromFile("createconnector-registerremotemodel-deploymodel.json"); + Response responseCreate = createWorkflow(client(), templateCreation); + assertEquals(RestStatus.CREATED, TestHelpers.restStatus(responseCreate)); + Template template = TestHelpers.createTemplateFromFile("createconnector-registerremotemodel-deploymodel.json"); + Thread.sleep(1000); ResponseException exception = expectThrows(ResponseException.class, () -> updateWorkflow(client(), "123", template)); assertTrue(exception.getMessage().contains("Failed to get template: 123")); @@ -161,7 +164,7 @@ public void testCreateAndProvisionLocalModelWorkflow() throws Exception { assertEquals("deploy_model", resourcesCreated.get(1).workflowStepName()); assertNotNull(resourcesCreated.get(1).resourceId()); - // // Deprovision the workflow to avoid opening circut breaker when running additional tests + // Deprovision the workflow to avoid opening circut breaker when running additional tests Response deprovisionResponse = deprovisionWorkflow(client(), workflowId); // wait for deprovision to complete @@ -229,7 +232,7 @@ public void testCreateAndProvisionRemoteModelWorkflow() throws Exception { assertEquals("deploy_model", resourcesCreated.get(2).workflowStepName()); assertNotNull(resourcesCreated.get(2).resourceId()); - // Deprovision the workflow to avoid opening circut breaker when running additional tests + // Deprovision the workflow to avoid opening circuit breaker when running additional tests Response deprovisionResponse = deprovisionWorkflow(client(), workflowId); // wait for deprovision to complete @@ -278,8 +281,7 @@ public void testCreateAndProvisionAgentFrameworkWorkflow() throws Exception { 60, TimeUnit.SECONDS ); - // - // // Hit Delete API + // Hit Delete API Response deleteResponse = deleteWorkflow(client(), workflowId); assertEquals(RestStatus.OK, TestHelpers.restStatus(deleteResponse)); }