diff --git a/CHANGELOG.md b/CHANGELOG.md index b8afae345c97b..981f16560c185 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,7 +17,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), - Add support for msearch API to pass search pipeline name - ([#15923](https://github.com/opensearch-project/OpenSearch/pull/15923)) - Add _list/indices API as paginated alternate to _cat/indices ([#14718](https://github.com/opensearch-project/OpenSearch/pull/14718)) - Add success and failure metrics for async shard fetch ([#15976](https://github.com/opensearch-project/OpenSearch/pull/15976)) -- Add new metric REMOTE_STORE to NodeStatNode.javas API response ([#15611](https://github.com/opensearch-project/OpenSearch/pull/15611)) +- Add new metric REMOTE_STORE to NodeStats API response ([#15611](https://github.com/opensearch-project/OpenSearch/pull/15611)) - [S3 Repository] Change default retry mechanism of s3 clients to Standard Mode ([#15978](https://github.com/opensearch-project/OpenSearch/pull/15978)) - Add changes to block calls in cat shards, indices and segments based on dynamic limit settings ([#15986](https://github.com/opensearch-project/OpenSearch/pull/15986)) - New `phone` & `phone-search` analyzer + tokenizer ([#15915](https://github.com/opensearch-project/OpenSearch/pull/15915)) diff --git a/server/src/main/java/org/opensearch/action/admin/cluster/wlm/TransportWlmStatsAction.java b/server/src/main/java/org/opensearch/action/admin/cluster/wlm/TransportWlmStatsAction.java index 55f2d179ec02a..9c2fb3f1689ec 100644 --- a/server/src/main/java/org/opensearch/action/admin/cluster/wlm/TransportWlmStatsAction.java +++ b/server/src/main/java/org/opensearch/action/admin/cluster/wlm/TransportWlmStatsAction.java @@ -70,6 +70,10 @@ protected WlmStats newNodeResponse(StreamInput in) throws IOException { @Override protected WlmStats nodeOperation(WlmStatsRequest wlmStatsRequest) { - return queryGroupService.nodeStats(wlmStatsRequest.getQueryGroupIds(), wlmStatsRequest.isBreach()); + assert transportService.getLocalNode() != null; + return new WlmStats( + transportService.getLocalNode(), + queryGroupService.nodeStats(wlmStatsRequest.getQueryGroupIds(), wlmStatsRequest.isBreach()) + ); } } diff --git a/server/src/main/java/org/opensearch/node/Node.java b/server/src/main/java/org/opensearch/node/Node.java index b969e6dbbb297..584d95b9ff6b5 100644 --- a/server/src/main/java/org/opensearch/node/Node.java +++ b/server/src/main/java/org/opensearch/node/Node.java @@ -354,12 +354,12 @@ public class Node implements Closeable { ); /** - * controls whether the node is allowed to persist things like metadata to disk - * Note that this does not control whether the node stores actual indices (see - * {@link #NODE_DATA_SETTING}). However, if this is false, {@link #NODE_DATA_SETTING} - * and {@link #NODE_MASTER_SETTING} must also be false. - * - */ + * controls whether the node is allowed to persist things like metadata to disk + * Note that this does not control whether the node stores actual indices (see + * {@link #NODE_DATA_SETTING}). However, if this is false, {@link #NODE_DATA_SETTING} + * and {@link #NODE_MASTER_SETTING} must also be false. + * + */ public static final Setting NODE_LOCAL_STORAGE_SETTING = Setting.boolSetting( "node.local_storage", true, @@ -1037,6 +1037,41 @@ protected Node( final QueryGroupsStateAccessor queryGroupsStateAccessor = new QueryGroupsStateAccessor(); + final QueryGroupService queryGroupService = new QueryGroupService( + new QueryGroupTaskCancellationService( + workloadManagementSettings, + new MaximumResourceTaskSelectionStrategy(), + queryGroupResourceUsageTrackerService, + queryGroupsStateAccessor + ), + clusterService, + threadPool, + workloadManagementSettings, + queryGroupsStateAccessor + ); + taskResourceTrackingService.addTaskCompletionListener(queryGroupService); + + final QueryGroupRequestOperationListener queryGroupRequestOperationListener = new QueryGroupRequestOperationListener( + queryGroupService, + threadPool + ); + + // register all standard SearchRequestOperationsCompositeListenerFactory to the SearchRequestOperationsCompositeListenerFactory + final SearchRequestOperationsCompositeListenerFactory searchRequestOperationsCompositeListenerFactory = + new SearchRequestOperationsCompositeListenerFactory( + Stream.concat( + Stream.of( + searchRequestStats, + searchRequestSlowLog, + searchTaskRequestOperationsListener, + queryGroupRequestOperationListener + ), + pluginComponents.stream() + .filter(p -> p instanceof SearchRequestOperationsListener) + .map(p -> (SearchRequestOperationsListener) p) + ).toArray(SearchRequestOperationsListener[]::new) + ); + ActionModule actionModule = new ActionModule( settings, clusterModule.getIndexNameExpressionResolver(), @@ -1079,11 +1114,9 @@ protected Node( admissionControlService ); - SetOnce queryGroupServiceSetOnce = new SetOnce<>(); - WorkloadManagementTransportInterceptor workloadManagementTransportInterceptor = new WorkloadManagementTransportInterceptor( threadPool, - queryGroupServiceSetOnce // We will need to replace this with actual implementation + queryGroupService ); final Collection secureSettingsFactories = pluginsService.filterPlugins(Plugin.class) @@ -1145,45 +1178,6 @@ protected Node( taskHeaders, tracer ); - - final QueryGroupService queryGroupService = new QueryGroupService( - new QueryGroupTaskCancellationService( - workloadManagementSettings, - new MaximumResourceTaskSelectionStrategy(), - queryGroupResourceUsageTrackerService, - queryGroupsStateAccessor - ), - transportService, - clusterService, - threadPool, - workloadManagementSettings, - queryGroupsStateAccessor - ); - - queryGroupServiceSetOnce.set(queryGroupService); - taskResourceTrackingService.addTaskCompletionListener(queryGroupService); - - final QueryGroupRequestOperationListener queryGroupRequestOperationListener = new QueryGroupRequestOperationListener( - queryGroupService, - threadPool - ); - - // register all standard SearchRequestOperationsCompositeListenerFactory to the SearchRequestOperationsCompositeListenerFactory - final SearchRequestOperationsCompositeListenerFactory searchRequestOperationsCompositeListenerFactory = - new SearchRequestOperationsCompositeListenerFactory( - Stream.concat( - Stream.of( - searchRequestStats, - searchRequestSlowLog, - searchTaskRequestOperationsListener, - queryGroupRequestOperationListener - ), - pluginComponents.stream() - .filter(p -> p instanceof SearchRequestOperationsListener) - .map(p -> (SearchRequestOperationsListener) p) - ).toArray(SearchRequestOperationsListener[]::new) - ); - TopNSearchTasksLogger taskConsumer = new TopNSearchTasksLogger(settings, settingsModule.getClusterSettings()); transportService.getTaskManager().registerTaskResourceConsumer(taskConsumer); this.extensionsManager.initializeServicesAndRestHandler( @@ -1509,6 +1503,7 @@ protected Node( b.bind(SegmentReplicationStatsTracker.class).toInstance(segmentReplicationStatsTracker); b.bind(SearchRequestOperationsCompositeListenerFactory.class).toInstance(searchRequestOperationsCompositeListenerFactory); b.bind(SegmentReplicator.class).toInstance(segmentReplicator); + taskManagerClientOptional.ifPresent(value -> b.bind(TaskManagerClient.class).toInstance(value)); }); injector = modules.createInjector(); diff --git a/server/src/main/java/org/opensearch/wlm/QueryGroupService.java b/server/src/main/java/org/opensearch/wlm/QueryGroupService.java index 16be62085a4c5..30d39686931be 100644 --- a/server/src/main/java/org/opensearch/wlm/QueryGroupService.java +++ b/server/src/main/java/org/opensearch/wlm/QueryGroupService.java @@ -17,7 +17,6 @@ import org.opensearch.cluster.metadata.Metadata; import org.opensearch.cluster.metadata.QueryGroup; import org.opensearch.cluster.service.ClusterService; -import org.opensearch.common.collect.Tuple; import org.opensearch.common.lifecycle.AbstractLifecycleComponent; import org.opensearch.core.concurrency.OpenSearchRejectedExecutionException; import org.opensearch.monitor.jvm.JvmStats; @@ -28,12 +27,10 @@ import org.opensearch.tasks.TaskResourceTrackingService; import org.opensearch.threadpool.Scheduler; import org.opensearch.threadpool.ThreadPool; -import org.opensearch.transport.TransportService; import org.opensearch.wlm.cancellation.QueryGroupTaskCancellationService; import org.opensearch.wlm.stats.QueryGroupState; import org.opensearch.wlm.stats.QueryGroupStats; import org.opensearch.wlm.stats.QueryGroupStats.QueryGroupStatsHolder; -import org.opensearch.wlm.stats.WlmStats; import java.io.IOException; import java.util.HashMap; @@ -63,11 +60,9 @@ public class QueryGroupService extends AbstractLifecycleComponent private final Set deletedQueryGroups; private final NodeDuressTrackers nodeDuressTrackers; private final QueryGroupsStateAccessor queryGroupsStateAccessor; - private final TransportService transportService; public QueryGroupService( QueryGroupTaskCancellationService taskCancellationService, - TransportService transportService, ClusterService clusterService, ThreadPool threadPool, WorkloadManagementSettings workloadManagementSettings, @@ -76,7 +71,6 @@ public QueryGroupService( this( taskCancellationService, - transportService, clusterService, threadPool, workloadManagementSettings, @@ -105,7 +99,6 @@ public QueryGroupService( public QueryGroupService( QueryGroupTaskCancellationService taskCancellationService, - TransportService transportService, ClusterService clusterService, ThreadPool threadPool, WorkloadManagementSettings workloadManagementSettings, @@ -115,7 +108,6 @@ public QueryGroupService( Set deletedQueryGroups ) { this.taskCancellationService = taskCancellationService; - this.transportService = transportService; this.clusterService = clusterService; this.threadPool = threadPool; this.workloadManagementSettings = workloadManagementSettings; @@ -214,53 +206,48 @@ public void incrementFailuresFor(final String queryGroupId) { /** * @return node level query group stats */ - public WlmStats nodeStats(Set queryGroupIds, Boolean requestedBreached) { + public QueryGroupStats nodeStats(Set queryGroupIds, Boolean requestedBreached) { final Map statsHolderMap = new HashMap<>(); - Map existingGroups = clusterService.state().metadata().queryGroups(); + Map existingStateMap = queryGroupsStateAccessor.getQueryGroupStateMap(); if (!queryGroupIds.contains("_all")) { for (String id : queryGroupIds) { - if (!existingGroups.containsKey(id)) { + if (!existingStateMap.containsKey(id)) { throw new ResourceNotFoundException("QueryGroup with id " + id + " does not exist"); } } } - Map stateMap = queryGroupsStateAccessor.getQueryGroupStateMap(); - if (stateMap != null) { - stateMap.forEach((queryGroupId, currentState) -> { + if (existingStateMap != null) { + existingStateMap.forEach((queryGroupId, currentState) -> { boolean shouldInclude = queryGroupIds.contains("_all") || queryGroupIds.contains(queryGroupId); if (shouldInclude) { - if (requestedBreached == null - || requestedBreached == (resourceLimitBreached(existingGroups.get(queryGroupId), currentState).v1() - .length() != 0)) { + if (requestedBreached == null || requestedBreached == resourceLimitBreached(queryGroupId, currentState)) { statsHolderMap.put(queryGroupId, QueryGroupStatsHolder.from(currentState)); } } }); } - return new WlmStats(transportService.getLocalNode(), new QueryGroupStats(statsHolderMap)); + return new QueryGroupStats(statsHolderMap); } /** * @return if the QueryGroup breaches any resource limit based on the LastRecordedUsage */ - public Tuple resourceLimitBreached(QueryGroup queryGroup, QueryGroupState queryGroupState) { - StringBuilder reason = new StringBuilder(); + public boolean resourceLimitBreached(String id, QueryGroupState currentState) { + QueryGroup queryGroup = clusterService.state().metadata().queryGroups().get(id); + if (queryGroup == null) { + throw new ResourceNotFoundException("QueryGroup with id " + id + " does not exist"); + } + for (ResourceType resourceType : TRACKED_RESOURCES) { if (queryGroup.getResourceLimits().containsKey(resourceType)) { final double threshold = getNormalisedRejectionThreshold(queryGroup.getResourceLimits().get(resourceType), resourceType); - final double lastRecordedUsage = queryGroupState.getResourceState().get(resourceType).getLastRecordedUsage(); + final double lastRecordedUsage = currentState.getResourceState().get(resourceType).getLastRecordedUsage(); if (threshold < lastRecordedUsage) { - reason.append(resourceType) - .append(" limit is breaching for ENFORCED type QueryGroup: (") - .append(threshold) - .append(" < ") - .append(lastRecordedUsage) - .append("). "); - return new Tuple<>(reason, resourceType); + return true; } } } - return new Tuple<>(reason, null); + return false; } /** @@ -287,9 +274,30 @@ public void rejectIfNeeded(String queryGroupId) { return; optionalQueryGroup.ifPresent(queryGroup -> { - Tuple reason = resourceLimitBreached(queryGroup, queryGroupState); - if (reason.v1().length() != 0) { - queryGroupState.getResourceState().get(reason.v2()).rejections.inc(); + boolean reject = false; + final StringBuilder reason = new StringBuilder(); + for (ResourceType resourceType : TRACKED_RESOURCES) { + if (queryGroup.getResourceLimits().containsKey(resourceType)) { + final double threshold = getNormalisedRejectionThreshold( + queryGroup.getResourceLimits().get(resourceType), + resourceType + ); + final double lastRecordedUsage = queryGroupState.getResourceState().get(resourceType).getLastRecordedUsage(); + if (threshold < lastRecordedUsage) { + reject = true; + reason.append(resourceType) + .append(" limit is breaching for ENFORCED type QueryGroup: (") + .append(threshold) + .append(" < ") + .append(lastRecordedUsage) + .append("). "); + queryGroupState.getResourceState().get(resourceType).rejections.inc(); + // should not double count even if both the resource limits are breaching + break; + } + } + } + if (reject) { queryGroupState.totalRejections.inc(); throw new OpenSearchRejectedExecutionException( "QueryGroup " + queryGroupId + " is already contended. " + reason.toString() diff --git a/server/src/main/java/org/opensearch/wlm/WorkloadManagementTransportInterceptor.java b/server/src/main/java/org/opensearch/wlm/WorkloadManagementTransportInterceptor.java index 1f08d8722e349..d382b4c729a38 100644 --- a/server/src/main/java/org/opensearch/wlm/WorkloadManagementTransportInterceptor.java +++ b/server/src/main/java/org/opensearch/wlm/WorkloadManagementTransportInterceptor.java @@ -8,7 +8,6 @@ package org.opensearch.wlm; -import org.opensearch.common.SetOnce; import org.opensearch.tasks.Task; import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.TransportChannel; @@ -21,11 +20,11 @@ */ public class WorkloadManagementTransportInterceptor implements TransportInterceptor { private final ThreadPool threadPool; - private final SetOnce queryGroupService; + private final QueryGroupService queryGroupService; - public WorkloadManagementTransportInterceptor(final ThreadPool threadPool, final SetOnce queryGroupServiceSetOnce) { + public WorkloadManagementTransportInterceptor(final ThreadPool threadPool, final QueryGroupService queryGroupService) { this.threadPool = threadPool; - this.queryGroupService = queryGroupServiceSetOnce; + this.queryGroupService = queryGroupService; } @Override @@ -46,13 +45,9 @@ public static class RequestHandler implements Transp private final ThreadPool threadPool; TransportRequestHandler actualHandler; - private final SetOnce queryGroupService; + private final QueryGroupService queryGroupService; - public RequestHandler( - ThreadPool threadPool, - TransportRequestHandler actualHandler, - SetOnce queryGroupService - ) { + public RequestHandler(ThreadPool threadPool, TransportRequestHandler actualHandler, QueryGroupService queryGroupService) { this.threadPool = threadPool; this.actualHandler = actualHandler; this.queryGroupService = queryGroupService; @@ -63,8 +58,7 @@ public void messageReceived(T request, TransportChannel channel, Task task) thro if (isSearchWorkloadRequest(task)) { ((QueryGroupTask) task).setQueryGroupId(threadPool.getThreadContext()); final String queryGroupId = ((QueryGroupTask) (task)).getQueryGroupId(); - assert queryGroupService.get() != null; - queryGroupService.get().rejectIfNeeded(queryGroupId); + queryGroupService.rejectIfNeeded(queryGroupId); } actualHandler.messageReceived(request, channel, task); } diff --git a/server/src/test/java/org/opensearch/wlm/QueryGroupServiceTests.java b/server/src/test/java/org/opensearch/wlm/QueryGroupServiceTests.java index 04d9ef7916626..c5cf0dac4f807 100644 --- a/server/src/test/java/org/opensearch/wlm/QueryGroupServiceTests.java +++ b/server/src/test/java/org/opensearch/wlm/QueryGroupServiceTests.java @@ -22,7 +22,6 @@ import org.opensearch.threadpool.Scheduler; import org.opensearch.threadpool.TestThreadPool; import org.opensearch.threadpool.ThreadPool; -import org.opensearch.transport.TransportService; import org.opensearch.wlm.cancellation.QueryGroupTaskCancellationService; import org.opensearch.wlm.cancellation.TaskSelectionStrategy; import org.opensearch.wlm.stats.QueryGroupState; @@ -50,7 +49,6 @@ public class QueryGroupServiceTests extends OpenSearchTestCase { private QueryGroupService queryGroupService; - private TransportService mockTransportService; private QueryGroupTaskCancellationService mockCancellationService; private ClusterService mockClusterService; private ThreadPool mockThreadPool; @@ -63,7 +61,6 @@ public class QueryGroupServiceTests extends OpenSearchTestCase { public void setUp() throws Exception { super.setUp(); mockClusterService = Mockito.mock(ClusterService.class); - mockTransportService = Mockito.mock(TransportService.class); mockThreadPool = Mockito.mock(ThreadPool.class); mockScheduledFuture = Mockito.mock(Scheduler.Cancellable.class); mockWorkloadManagementSettings = Mockito.mock(WorkloadManagementSettings.class); @@ -74,7 +71,6 @@ public void setUp() throws Exception { queryGroupService = new QueryGroupService( mockCancellationService, - mockTransportService, mockClusterService, mockThreadPool, mockWorkloadManagementSettings, @@ -191,7 +187,6 @@ public void testRejectIfNeeded_whenQueryGroupIdIsNullOrDefaultOne() { queryGroupService = new QueryGroupService( mockCancellationService, - mockTransportService, mockClusterService, mockThreadPool, mockWorkloadManagementSettings, @@ -230,7 +225,6 @@ public void testRejectIfNeeded_whenQueryGroupIsSoftMode() { queryGroupService = new QueryGroupService( mockCancellationService, - mockTransportService, mockClusterService, mockThreadPool, mockWorkloadManagementSettings, @@ -267,7 +261,6 @@ public void testRejectIfNeeded_whenQueryGroupIsEnforcedMode_andNotBreaching() { queryGroupService = new QueryGroupService( mockCancellationService, - mockTransportService, mockClusterService, mockThreadPool, mockWorkloadManagementSettings, @@ -315,7 +308,6 @@ public void testRejectIfNeeded_whenQueryGroupIsEnforcedMode_andBreaching() { queryGroupService = new QueryGroupService( mockCancellationService, - mockTransportService, mockClusterService, mockThreadPool, mockWorkloadManagementSettings, @@ -360,7 +352,6 @@ public void testRejectIfNeeded_whenFeatureIsNotEnabled() { queryGroupService = new QueryGroupService( mockCancellationService, - mockTransportService, mockClusterService, mockThreadPool, mockWorkloadManagementSettings, @@ -384,7 +375,6 @@ public void testOnTaskCompleted() { mockQueryGroupsStateAccessor = new QueryGroupsStateAccessor(mockQueryGroupStateMap); queryGroupService = new QueryGroupService( mockCancellationService, - mockTransportService, mockClusterService, mockThreadPool, mockWorkloadManagementSettings, @@ -431,7 +421,6 @@ public void testShouldSBPHandle() { mockQueryGroupsStateAccessor = new QueryGroupsStateAccessor(mockQueryGroupStateMap); queryGroupService = new QueryGroupService( mockCancellationService, - mockTransportService, mockClusterService, mockThreadPool, mockWorkloadManagementSettings, diff --git a/server/src/test/java/org/opensearch/wlm/WorkloadManagementTransportInterceptorTests.java b/server/src/test/java/org/opensearch/wlm/WorkloadManagementTransportInterceptorTests.java index 3e5a5c72aae76..d4cd7b79455a3 100644 --- a/server/src/test/java/org/opensearch/wlm/WorkloadManagementTransportInterceptorTests.java +++ b/server/src/test/java/org/opensearch/wlm/WorkloadManagementTransportInterceptorTests.java @@ -8,25 +8,56 @@ package org.opensearch.wlm; -import org.opensearch.common.SetOnce; +import org.opensearch.cluster.ClusterState; +import org.opensearch.cluster.metadata.Metadata; +import org.opensearch.cluster.service.ClusterService; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.TestThreadPool; import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.TransportRequest; import org.opensearch.transport.TransportRequestHandler; import org.opensearch.wlm.WorkloadManagementTransportInterceptor.RequestHandler; +import org.opensearch.wlm.cancellation.QueryGroupTaskCancellationService; + +import java.util.Collections; import static org.opensearch.threadpool.ThreadPool.Names.SAME; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; public class WorkloadManagementTransportInterceptorTests extends OpenSearchTestCase { + private QueryGroupTaskCancellationService mockTaskCancellationService; + private ClusterService mockClusterService; + private ThreadPool mockThreadPool; + private WorkloadManagementSettings mockWorkloadManagementSettings; private ThreadPool threadPool; private WorkloadManagementTransportInterceptor sut; + private QueryGroupsStateAccessor stateAccessor; public void setUp() throws Exception { super.setUp(); + mockTaskCancellationService = mock(QueryGroupTaskCancellationService.class); + mockClusterService = mock(ClusterService.class); + mockThreadPool = mock(ThreadPool.class); + mockWorkloadManagementSettings = mock(WorkloadManagementSettings.class); threadPool = new TestThreadPool(getTestName()); - sut = new WorkloadManagementTransportInterceptor(threadPool, new SetOnce<>(mock(QueryGroupService.class))); + stateAccessor = new QueryGroupsStateAccessor(); + + ClusterState state = mock(ClusterState.class); + Metadata metadata = mock(Metadata.class); + when(mockClusterService.state()).thenReturn(state); + when(state.metadata()).thenReturn(metadata); + when(metadata.queryGroups()).thenReturn(Collections.emptyMap()); + sut = new WorkloadManagementTransportInterceptor( + threadPool, + new QueryGroupService( + mockTaskCancellationService, + mockClusterService, + mockThreadPool, + mockWorkloadManagementSettings, + stateAccessor + ) + ); } public void tearDown() throws Exception { diff --git a/server/src/test/java/org/opensearch/wlm/WorkloadManagementTransportRequestHandlerTests.java b/server/src/test/java/org/opensearch/wlm/WorkloadManagementTransportRequestHandlerTests.java index 7fa6e4e9f067b..59818ad3dbbd2 100644 --- a/server/src/test/java/org/opensearch/wlm/WorkloadManagementTransportRequestHandlerTests.java +++ b/server/src/test/java/org/opensearch/wlm/WorkloadManagementTransportRequestHandlerTests.java @@ -9,7 +9,6 @@ package org.opensearch.wlm; import org.opensearch.action.index.IndexRequest; -import org.opensearch.common.SetOnce; import org.opensearch.core.concurrency.OpenSearchRejectedExecutionException; import org.opensearch.search.internal.ShardSearchRequest; import org.opensearch.tasks.Task; @@ -31,7 +30,7 @@ public class WorkloadManagementTransportRequestHandlerTests extends OpenSearchTestCase { private WorkloadManagementTransportInterceptor.RequestHandler sut; private ThreadPool threadPool; - private SetOnce queryGroupService; + private QueryGroupService queryGroupService; private TestTransportRequestHandler actualHandler; @@ -39,7 +38,7 @@ public void setUp() throws Exception { super.setUp(); threadPool = new TestThreadPool(getTestName()); actualHandler = new TestTransportRequestHandler<>(); - queryGroupService = new SetOnce<>(mock(QueryGroupService.class)); + queryGroupService = mock(QueryGroupService.class); sut = new WorkloadManagementTransportInterceptor.RequestHandler<>(threadPool, actualHandler, queryGroupService); } @@ -52,7 +51,7 @@ public void tearDown() throws Exception { public void testMessageReceivedForSearchWorkload_nonRejectionCase() throws Exception { ShardSearchRequest request = mock(ShardSearchRequest.class); QueryGroupTask spyTask = getSpyTask(); - doNothing().when(queryGroupService.get()).rejectIfNeeded(anyString()); + doNothing().when(queryGroupService).rejectIfNeeded(anyString()); sut.messageReceived(request, mock(TransportChannel.class), spyTask); assertTrue(sut.isSearchWorkloadRequest(spyTask)); } @@ -60,7 +59,7 @@ public void testMessageReceivedForSearchWorkload_nonRejectionCase() throws Excep public void testMessageReceivedForSearchWorkload_RejectionCase() throws Exception { ShardSearchRequest request = mock(ShardSearchRequest.class); QueryGroupTask spyTask = getSpyTask(); - doThrow(OpenSearchRejectedExecutionException.class).when(queryGroupService.get()).rejectIfNeeded(anyString()); + doThrow(OpenSearchRejectedExecutionException.class).when(queryGroupService).rejectIfNeeded(anyString()); assertThrows(OpenSearchRejectedExecutionException.class, () -> sut.messageReceived(request, mock(TransportChannel.class), spyTask)); } diff --git a/server/src/test/java/org/opensearch/wlm/listeners/QueryGroupRequestOperationListenerTests.java b/server/src/test/java/org/opensearch/wlm/listeners/QueryGroupRequestOperationListenerTests.java index 580ef2e5e5999..4d61e89b281f7 100644 --- a/server/src/test/java/org/opensearch/wlm/listeners/QueryGroupRequestOperationListenerTests.java +++ b/server/src/test/java/org/opensearch/wlm/listeners/QueryGroupRequestOperationListenerTests.java @@ -10,14 +10,12 @@ import org.opensearch.cluster.ClusterState; import org.opensearch.cluster.metadata.Metadata; -import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.core.concurrency.OpenSearchRejectedExecutionException; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.TestThreadPool; import org.opensearch.threadpool.ThreadPool; -import org.opensearch.transport.TransportService; import org.opensearch.wlm.QueryGroupService; import org.opensearch.wlm.QueryGroupTask; import org.opensearch.wlm.QueryGroupsStateAccessor; @@ -26,7 +24,6 @@ import org.opensearch.wlm.cancellation.QueryGroupTaskCancellationService; import org.opensearch.wlm.stats.QueryGroupState; import org.opensearch.wlm.stats.QueryGroupStats; -import org.opensearch.wlm.stats.WlmStats; import java.io.IOException; import java.util.ArrayList; @@ -46,10 +43,8 @@ public class QueryGroupRequestOperationListenerTests extends OpenSearchTestCase ThreadPool testThreadPool; QueryGroupService queryGroupService; private QueryGroupTaskCancellationService taskCancellationService; - private TransportService mockTransportService; private ClusterService mockClusterService; private WorkloadManagementSettings mockWorkloadManagementSettings; - private DiscoveryNode mockDiscoveryNode; Map queryGroupStateMap; String testQueryGroupId; QueryGroupRequestOperationListener sut; @@ -57,11 +52,8 @@ public class QueryGroupRequestOperationListenerTests extends OpenSearchTestCase public void setUp() throws Exception { super.setUp(); taskCancellationService = mock(QueryGroupTaskCancellationService.class); - mockTransportService = mock(TransportService.class); mockClusterService = mock(ClusterService.class); mockWorkloadManagementSettings = mock(WorkloadManagementSettings.class); - mockDiscoveryNode = mock(DiscoveryNode.class); - when(mockTransportService.getLocalNode()).thenReturn(mockDiscoveryNode); queryGroupStateMap = new HashMap<>(); testQueryGroupId = "safjgagnakg-3r3fads"; testThreadPool = new TestThreadPool("RejectionTestThreadPool"); @@ -128,7 +120,7 @@ public void testValidQueryGroupRequestFailure() throws IOException { ) ); - assertSuccess(testQueryGroupId, queryGroupStateMap, new WlmStats(mockDiscoveryNode, expectedStats), testQueryGroupId); + assertSuccess(testQueryGroupId, queryGroupStateMap, expectedStats, testQueryGroupId); } public void testMultiThreadedValidQueryGroupRequestFailures() { @@ -138,7 +130,6 @@ public void testMultiThreadedValidQueryGroupRequestFailures() { setupMockedQueryGroupsFromClusterState(); queryGroupService = new QueryGroupService( taskCancellationService, - mockTransportService, mockClusterService, testThreadPool, mockWorkloadManagementSettings, @@ -171,9 +162,9 @@ public void testMultiThreadedValidQueryGroupRequestFailures() { HashSet set = new HashSet<>(); set.add("_all"); - WlmStats actualStats = queryGroupService.nodeStats(set, null); + QueryGroupStats actualStats = queryGroupService.nodeStats(set, null); - QueryGroupStats queryGroupStats = new QueryGroupStats( + QueryGroupStats expectedStats = new QueryGroupStats( Map.of( testQueryGroupId, new QueryGroupStats.QueryGroupStatsHolder( @@ -206,7 +197,6 @@ public void testMultiThreadedValidQueryGroupRequestFailures() { ) ); - WlmStats expectedStats = new WlmStats(mockDiscoveryNode, queryGroupStats); assertEquals(expectedStats, actualStats); } @@ -244,14 +234,14 @@ public void testInvalidQueryGroupFailure() throws IOException { ) ); - assertSuccess(testQueryGroupId, queryGroupStateMap, new WlmStats(mockDiscoveryNode, expectedStats), "dummy-invalid-qg-id"); + assertSuccess(testQueryGroupId, queryGroupStateMap, expectedStats, "dummy-invalid-qg-id"); } private void assertSuccess( String testQueryGroupId, Map queryGroupStateMap, - WlmStats expectedStats, + QueryGroupStats expectedStats, String threadContextQG_Id ) { QueryGroupsStateAccessor stateAccessor = new QueryGroupsStateAccessor(queryGroupStateMap); @@ -263,7 +253,6 @@ private void assertSuccess( queryGroupService = new QueryGroupService( taskCancellationService, - mockTransportService, mockClusterService, testThreadPool, mockWorkloadManagementSettings, @@ -277,7 +266,7 @@ private void assertSuccess( HashSet set = new HashSet<>(); set.add("_all"); - WlmStats actualStats = queryGroupService.nodeStats(set, null); + QueryGroupStats actualStats = queryGroupService.nodeStats(set, null); assertEquals(expectedStats, actualStats); } diff --git a/server/src/test/java/org/opensearch/wlm/stats/WlmStatsTests.java b/server/src/test/java/org/opensearch/wlm/stats/WlmStatsTests.java index 604db8212e2aa..75b8f570b7975 100644 --- a/server/src/test/java/org/opensearch/wlm/stats/WlmStatsTests.java +++ b/server/src/test/java/org/opensearch/wlm/stats/WlmStatsTests.java @@ -27,7 +27,7 @@ import static java.util.Collections.emptyMap; import static org.mockito.Mockito.mock; -class WlmStatsTests extends AbstractWireSerializingTestCase { +public class WlmStatsTests extends AbstractWireSerializingTestCase { public void testToXContent() throws IOException { final Map stats = new HashMap<>();