diff --git a/.github/workflows/test_security.yml b/.github/workflows/test_security.yml index c18b2a11a..fafcec0fa 100644 --- a/.github/workflows/test_security.yml +++ b/.github/workflows/test_security.yml @@ -1,11 +1,11 @@ name: Security test workflow for Flow Framework on: push: - branches: - - "*" + branches-ignore: + - 'whitesource-remediate/**' + - 'backport/**' pull_request: - branches: - - "*" + types: [opened, synchronize, reopened] jobs: Get-CI-Image-Tag: @@ -30,9 +30,9 @@ jobs: steps: - name: Checkout Flow Framework - uses: actions/checkout@v1 + uses: actions/checkout@v3 - name: Setup Java ${{ matrix.java }} - uses: actions/setup-java@v1 + uses: actions/setup-java@v3 with: distribution: 'temurin' java-version: ${{ matrix.java }} diff --git a/src/main/java/org/opensearch/flowframework/transport/DeprovisionWorkflowTransportAction.java b/src/main/java/org/opensearch/flowframework/transport/DeprovisionWorkflowTransportAction.java index b2a5e5028..e026053e1 100644 --- a/src/main/java/org/opensearch/flowframework/transport/DeprovisionWorkflowTransportAction.java +++ b/src/main/java/org/opensearch/flowframework/transport/DeprovisionWorkflowTransportAction.java @@ -46,6 +46,7 @@ import static org.opensearch.flowframework.common.CommonValue.PROVISION_START_TIME_FIELD; import static org.opensearch.flowframework.common.CommonValue.RESOURCES_CREATED_FIELD; import static org.opensearch.flowframework.common.CommonValue.STATE_FIELD; +import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_THREAD_POOL; import static org.opensearch.flowframework.common.WorkflowResources.getDeprovisionStepByWorkflowStep; import static org.opensearch.flowframework.common.WorkflowResources.getResourceByWorkflowStep; @@ -101,7 +102,8 @@ protected void doExecute(Task task, WorkflowRequest request, ActionListener executeDeprovisionSequence(workflowId, response.getWorkflowState().resourcesCreated(), listener)); }, exception -> { String message = "Failed to get workflow state for workflow " + workflowId; logger.error(message, exception); diff --git a/src/test/java/org/opensearch/flowframework/transport/DeprovisionWorkflowTransportActionTests.java b/src/test/java/org/opensearch/flowframework/transport/DeprovisionWorkflowTransportActionTests.java index 6259394f2..7841ed193 100644 --- a/src/test/java/org/opensearch/flowframework/transport/DeprovisionWorkflowTransportActionTests.java +++ b/src/test/java/org/opensearch/flowframework/transport/DeprovisionWorkflowTransportActionTests.java @@ -8,6 +8,7 @@ */ package org.opensearch.flowframework.transport; +import org.opensearch.action.LatchedActionListener; import org.opensearch.action.support.ActionFilters; import org.opensearch.client.Client; import org.opensearch.common.settings.Settings; @@ -19,12 +20,9 @@ import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; import org.opensearch.flowframework.model.ResourceCreated; import org.opensearch.flowframework.model.WorkflowState; -import org.opensearch.flowframework.workflow.CreateConnectorStep; import org.opensearch.flowframework.workflow.DeleteConnectorStep; -import org.opensearch.flowframework.workflow.ProcessNode; import org.opensearch.flowframework.workflow.WorkflowData; import org.opensearch.flowframework.workflow.WorkflowStepFactory; -import org.opensearch.ml.client.MachineLearningNodeClient; import org.opensearch.tasks.Task; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.FixedExecutorBuilder; @@ -33,10 +31,9 @@ import org.opensearch.transport.TransportService; import org.junit.AfterClass; -import java.io.IOException; -import java.util.Collections; import java.util.List; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; import org.mockito.ArgumentCaptor; @@ -50,6 +47,7 @@ import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.spy; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @@ -77,7 +75,15 @@ public class DeprovisionWorkflowTransportActionTests extends OpenSearchTestCase public void setUp() throws Exception { super.setUp(); this.client = mock(Client.class); + ThreadPool clientThreadPool = spy(threadPool); + when(client.threadPool()).thenReturn(clientThreadPool); + ThreadContext threadContext = new ThreadContext(Settings.EMPTY); + when(clientThreadPool.getThreadContext()).thenReturn(threadContext); + this.workflowStepFactory = mock(WorkflowStepFactory.class); + this.deleteConnectorStep = mock(DeleteConnectorStep.class); + when(this.workflowStepFactory.createStep("delete_connector")).thenReturn(deleteConnectorStep); + this.flowFrameworkIndicesHandler = mock(FlowFrameworkIndicesHandler.class); flowFrameworkSettings = mock(FlowFrameworkSettings.class); when(flowFrameworkSettings.getRequestTimeout()).thenReturn(TimeValue.timeValueSeconds(10)); @@ -85,28 +91,12 @@ public void setUp() throws Exception { this.deprovisionWorkflowTransportAction = new DeprovisionWorkflowTransportAction( mock(TransportService.class), mock(ActionFilters.class), - threadPool, + clientThreadPool, client, workflowStepFactory, flowFrameworkIndicesHandler, flowFrameworkSettings ); - - MachineLearningNodeClient mlClient = new MachineLearningNodeClient(client); - ProcessNode processNode = mock(ProcessNode.class); - when(processNode.id()).thenReturn("step_1"); - when(processNode.workflowStep()).thenReturn(new CreateConnectorStep(mlClient, flowFrameworkIndicesHandler)); - when(processNode.previousNodeInputs()).thenReturn(Collections.emptyMap()); - when(processNode.input()).thenReturn(WorkflowData.EMPTY); - when(processNode.nodeTimeout()).thenReturn(TimeValue.timeValueSeconds(5)); - this.deleteConnectorStep = mock(DeleteConnectorStep.class); - when(this.workflowStepFactory.createStep("delete_connector")).thenReturn(deleteConnectorStep); - - ThreadPool clientThreadPool = mock(ThreadPool.class); - ThreadContext threadContext = new ThreadContext(Settings.EMPTY); - - when(client.threadPool()).thenReturn(clientThreadPool); - when(clientThreadPool.getThreadContext()).thenReturn(threadContext); } @AfterClass @@ -114,10 +104,12 @@ public static void cleanup() { ThreadPool.terminate(threadPool, 500, TimeUnit.MILLISECONDS); } - public void testDeprovisionWorkflow() throws IOException { + public void testDeprovisionWorkflow() throws Exception { String workflowId = "1"; + + CountDownLatch latch = new CountDownLatch(1); @SuppressWarnings("unchecked") - ActionListener listener = mock(ActionListener.class); + ActionListener listener = spy(new LatchedActionListener(mock(ActionListener.class), latch)); WorkflowRequest workflowRequest = new WorkflowRequest(workflowId, null); doAnswer(invocation -> { @@ -137,14 +129,17 @@ public void testDeprovisionWorkflow() throws IOException { deprovisionWorkflowTransportAction.doExecute(mock(Task.class), workflowRequest, listener); ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(WorkflowResponse.class); + latch.await(5, TimeUnit.SECONDS); verify(listener, times(1)).onResponse(responseCaptor.capture()); assertEquals(workflowId, responseCaptor.getValue().getWorkflowId()); } - public void testFailToDeprovision() throws IOException { + public void testFailToDeprovision() throws Exception { String workflowId = "1"; + + CountDownLatch latch = new CountDownLatch(1); @SuppressWarnings("unchecked") - ActionListener listener = mock(ActionListener.class); + ActionListener listener = spy(new LatchedActionListener(mock(ActionListener.class), latch)); WorkflowRequest workflowRequest = new WorkflowRequest(workflowId, null); doAnswer(invocation -> { @@ -164,6 +159,7 @@ public void testFailToDeprovision() throws IOException { deprovisionWorkflowTransportAction.doExecute(mock(Task.class), workflowRequest, listener); ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(Exception.class); + latch.await(5, TimeUnit.SECONDS); verify(listener, times(1)).onFailure(exceptionCaptor.capture()); assertEquals("Failed to deprovision some resources: [model_id modelId].", exceptionCaptor.getValue().getMessage()); }