From b3d58745a8c299f3937729c9fde9a0ee8c67e707 Mon Sep 17 00:00:00 2001 From: Andriy Redko Date: Thu, 29 Aug 2024 10:10:46 -0400 Subject: [PATCH 1/7] Fix ResourceType API annotations (#15497) Signed-off-by: Andriy Redko --- server/src/main/java/org/opensearch/wlm/ResourceType.java | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/server/src/main/java/org/opensearch/wlm/ResourceType.java b/server/src/main/java/org/opensearch/wlm/ResourceType.java index adf384995c91d..c3f48f5f793ce 100644 --- a/server/src/main/java/org/opensearch/wlm/ResourceType.java +++ b/server/src/main/java/org/opensearch/wlm/ResourceType.java @@ -18,8 +18,10 @@ /** * Enum to hold the resource type + * + * @opensearch.api */ -@PublicApi(since = "2.x") +@PublicApi(since = "2.17.0") public enum ResourceType { CPU("cpu", task -> task.getTotalResourceUtilization(ResourceStats.CPU), true), MEMORY("memory", task -> task.getTotalResourceUtilization(ResourceStats.MEMORY), true); From 30ed15dea695932082d8ac2cdd661c1a669410dc Mon Sep 17 00:00:00 2001 From: Sachin Kale Date: Thu, 29 Aug 2024 20:28:43 +0530 Subject: [PATCH 2/7] Change RemoteSegmentStoreDirectory init at given timestamp to ignore pinned timestamp setting (#15457) Signed-off-by: Sachin Kale Co-authored-by: Sachin Kale --- .../index/remote/RemoteStoreUtils.java | 33 +++++++++- .../store/RemoteSegmentStoreDirectory.java | 3 +- .../main/java/org/opensearch/node/Node.java | 3 +- .../RemoteSegmentStoreDirectoryTests.java | 56 +++++++++++++++++ ...toreDirectoryWithPinnedTimestampTests.java | 62 ------------------- 5 files changed, 90 insertions(+), 67 deletions(-) diff --git a/server/src/main/java/org/opensearch/index/remote/RemoteStoreUtils.java b/server/src/main/java/org/opensearch/index/remote/RemoteStoreUtils.java index b2bc8a0294a49..871e2eb3ce47f 100644 --- a/server/src/main/java/org/opensearch/index/remote/RemoteStoreUtils.java +++ b/server/src/main/java/org/opensearch/index/remote/RemoteStoreUtils.java @@ -391,15 +391,24 @@ public static boolean isSwitchToStrictCompatibilityMode(ClusterUpdateSettingsReq * @param pinnedTimestampSet A set of timestamps representing pinned points in time. * @param getTimestampFunction A function that extracts the timestamp from a metadata file name. * @param prefixFunction A function that extracts a tuple of prefix information from a metadata file name. + * @param ignorePinnedTimestampEnabledSetting A flag to ignore pinned timestamp enabled setting * @return A set of metadata file names that are implicitly locked based on the pinned timestamps. */ public static Set getPinnedTimestampLockedFiles( List metadataFiles, Set pinnedTimestampSet, Function getTimestampFunction, - Function> prefixFunction + Function> prefixFunction, + boolean ignorePinnedTimestampEnabledSetting ) { - return getPinnedTimestampLockedFiles(metadataFiles, pinnedTimestampSet, new HashMap<>(), getTimestampFunction, prefixFunction); + return getPinnedTimestampLockedFiles( + metadataFiles, + pinnedTimestampSet, + new HashMap<>(), + getTimestampFunction, + prefixFunction, + ignorePinnedTimestampEnabledSetting + ); } /** @@ -431,10 +440,28 @@ public static Set getPinnedTimestampLockedFiles( Map metadataFilePinnedTimestampMap, Function getTimestampFunction, Function> prefixFunction + ) { + return getPinnedTimestampLockedFiles( + metadataFiles, + pinnedTimestampSet, + metadataFilePinnedTimestampMap, + getTimestampFunction, + prefixFunction, + false + ); + } + + private static Set getPinnedTimestampLockedFiles( + List metadataFiles, + Set pinnedTimestampSet, + Map metadataFilePinnedTimestampMap, + Function getTimestampFunction, + Function> prefixFunction, + boolean ignorePinnedTimestampEnabledSetting ) { Set implicitLockedFiles = new HashSet<>(); - if (RemoteStoreSettings.isPinnedTimestampsEnabled() == false) { + if (ignorePinnedTimestampEnabledSetting == false && RemoteStoreSettings.isPinnedTimestampsEnabled() == false) { return implicitLockedFiles; } diff --git a/server/src/main/java/org/opensearch/index/store/RemoteSegmentStoreDirectory.java b/server/src/main/java/org/opensearch/index/store/RemoteSegmentStoreDirectory.java index 26871429e41d6..53b43bbfb3bba 100644 --- a/server/src/main/java/org/opensearch/index/store/RemoteSegmentStoreDirectory.java +++ b/server/src/main/java/org/opensearch/index/store/RemoteSegmentStoreDirectory.java @@ -196,7 +196,8 @@ public RemoteSegmentMetadata initializeToSpecificTimestamp(long timestamp) throw metadataFiles, Set.of(timestamp), MetadataFilenameUtils::getTimestamp, - MetadataFilenameUtils::getNodeIdByPrimaryTermAndGen + MetadataFilenameUtils::getNodeIdByPrimaryTermAndGen, + true ); if (lockedMetadataFiles.isEmpty()) { return null; diff --git a/server/src/main/java/org/opensearch/node/Node.java b/server/src/main/java/org/opensearch/node/Node.java index 388e00bedab0c..9c7dfe8850b85 100644 --- a/server/src/main/java/org/opensearch/node/Node.java +++ b/server/src/main/java/org/opensearch/node/Node.java @@ -307,6 +307,7 @@ import static org.opensearch.env.NodeEnvironment.collectFileCacheDataPath; import static org.opensearch.index.ShardIndexingPressureSettings.SHARD_INDEXING_PRESSURE_ENABLED_ATTRIBUTE_KEY; import static org.opensearch.indices.RemoteStoreSettings.CLUSTER_REMOTE_STORE_PINNED_TIMESTAMP_ENABLED; +import static org.opensearch.node.remotestore.RemoteStoreNodeAttribute.isRemoteDataAttributePresent; import static org.opensearch.node.remotestore.RemoteStoreNodeAttribute.isRemoteStoreAttributePresent; import static org.opensearch.node.remotestore.RemoteStoreNodeAttribute.isRemoteStoreClusterStateEnabled; @@ -814,7 +815,7 @@ protected Node( remoteClusterStateCleanupManager = null; } final RemoteStorePinnedTimestampService remoteStorePinnedTimestampService; - if (isRemoteStoreAttributePresent(settings) && CLUSTER_REMOTE_STORE_PINNED_TIMESTAMP_ENABLED.get(settings)) { + if (isRemoteDataAttributePresent(settings) && CLUSTER_REMOTE_STORE_PINNED_TIMESTAMP_ENABLED.get(settings)) { remoteStorePinnedTimestampService = new RemoteStorePinnedTimestampService( repositoriesServiceReference::get, settings, diff --git a/server/src/test/java/org/opensearch/index/store/RemoteSegmentStoreDirectoryTests.java b/server/src/test/java/org/opensearch/index/store/RemoteSegmentStoreDirectoryTests.java index 574c5bf620474..336d4bafd4b66 100644 --- a/server/src/test/java/org/opensearch/index/store/RemoteSegmentStoreDirectoryTests.java +++ b/server/src/test/java/org/opensearch/index/store/RemoteSegmentStoreDirectoryTests.java @@ -1141,6 +1141,62 @@ public void testMetadataFileNameOrder() { assertEquals(14, count); } + public void testInitializeToSpecificTimestampNoMetadataFiles() throws IOException { + when( + remoteMetadataDirectory.listFilesByPrefixInLexicographicOrder( + RemoteSegmentStoreDirectory.MetadataFilenameUtils.METADATA_PREFIX, + Integer.MAX_VALUE + ) + ).thenReturn(new ArrayList<>()); + assertNull(remoteSegmentStoreDirectory.initializeToSpecificTimestamp(1234L)); + } + + public void testInitializeToSpecificTimestampNoMdMatchingTimestamp() throws IOException { + String metadataPrefix = "metadata__1__2__3__4__5__"; + List metadataFiles = new ArrayList<>(); + metadataFiles.add(metadataPrefix + RemoteStoreUtils.invertLong(2000)); + metadataFiles.add(metadataPrefix + RemoteStoreUtils.invertLong(3000)); + metadataFiles.add(metadataPrefix + RemoteStoreUtils.invertLong(4000)); + + when( + remoteMetadataDirectory.listFilesByPrefixInLexicographicOrder( + RemoteSegmentStoreDirectory.MetadataFilenameUtils.METADATA_PREFIX, + Integer.MAX_VALUE + ) + ).thenReturn(metadataFiles); + assertNull(remoteSegmentStoreDirectory.initializeToSpecificTimestamp(1234L)); + } + + public void testInitializeToSpecificTimestampMatchingMdFile() throws IOException { + String metadataPrefix = "metadata__1__2__3__4__5__"; + List metadataFiles = new ArrayList<>(); + metadataFiles.add(metadataPrefix + RemoteStoreUtils.invertLong(1000)); + metadataFiles.add(metadataPrefix + RemoteStoreUtils.invertLong(2000)); + metadataFiles.add(metadataPrefix + RemoteStoreUtils.invertLong(3000)); + + Map metadata = new HashMap<>(); + metadata.put("_0.cfe", "_0.cfe::_0.cfe__" + UUIDs.base64UUID() + "::1234::512::" + Version.LATEST.major); + metadata.put("_0.cfs", "_0.cfs::_0.cfs__" + UUIDs.base64UUID() + "::2345::1024::" + Version.LATEST.major); + + when( + remoteMetadataDirectory.listFilesByPrefixInLexicographicOrder( + RemoteSegmentStoreDirectory.MetadataFilenameUtils.METADATA_PREFIX, + Integer.MAX_VALUE + ) + ).thenReturn(metadataFiles); + when(remoteMetadataDirectory.getBlobStream(metadataPrefix + RemoteStoreUtils.invertLong(1000))).thenReturn( + createMetadataFileBytes(metadata, indexShard.getLatestReplicationCheckpoint(), segmentInfos) + ); + + RemoteSegmentMetadata remoteSegmentMetadata = remoteSegmentStoreDirectory.initializeToSpecificTimestamp(1234L); + assertNotNull(remoteSegmentMetadata); + Map uploadedSegments = remoteSegmentStoreDirectory + .getSegmentsUploadedToRemoteStore(); + assertEquals(2, uploadedSegments.size()); + assertTrue(uploadedSegments.containsKey("_0.cfe")); + assertTrue(uploadedSegments.containsKey("_0.cfs")); + } + private static class WrapperIndexOutput extends IndexOutput { public IndexOutput indexOutput; diff --git a/server/src/test/java/org/opensearch/index/store/RemoteSegmentStoreDirectoryWithPinnedTimestampTests.java b/server/src/test/java/org/opensearch/index/store/RemoteSegmentStoreDirectoryWithPinnedTimestampTests.java index b4f93d706bb1e..107d59aa97549 100644 --- a/server/src/test/java/org/opensearch/index/store/RemoteSegmentStoreDirectoryWithPinnedTimestampTests.java +++ b/server/src/test/java/org/opensearch/index/store/RemoteSegmentStoreDirectoryWithPinnedTimestampTests.java @@ -8,8 +8,6 @@ package org.opensearch.index.store; -import org.apache.lucene.util.Version; -import org.opensearch.common.UUIDs; import org.opensearch.common.blobstore.BlobMetadata; import org.opensearch.common.blobstore.BlobPath; import org.opensearch.common.blobstore.support.PlainBlobMetadata; @@ -18,8 +16,6 @@ import org.opensearch.core.action.ActionListener; import org.opensearch.gateway.remote.model.RemotePinnedTimestamps; import org.opensearch.gateway.remote.model.RemoteStorePinnedTimestampsBlobStore; -import org.opensearch.index.remote.RemoteStoreUtils; -import org.opensearch.index.store.remote.metadata.RemoteSegmentMetadata; import org.opensearch.index.translog.transfer.BlobStoreTransferService; import org.opensearch.indices.RemoteStoreSettings; import org.opensearch.node.Node; @@ -31,7 +27,6 @@ import java.io.IOException; import java.util.ArrayList; -import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.function.Supplier; @@ -40,7 +35,6 @@ import org.mockito.Mockito; import static org.opensearch.indices.RemoteStoreSettings.CLUSTER_REMOTE_STORE_PINNED_TIMESTAMP_ENABLED; -import static org.opensearch.test.RemoteStoreTestUtils.createMetadataFileBytes; import static org.hamcrest.CoreMatchers.is; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.anyInt; @@ -143,62 +137,6 @@ private void metadataWithOlderTimestamp() { ); } - public void testInitializeToSpecificTimestampNoMetadataFiles() throws IOException { - when( - remoteMetadataDirectory.listFilesByPrefixInLexicographicOrder( - RemoteSegmentStoreDirectory.MetadataFilenameUtils.METADATA_PREFIX, - Integer.MAX_VALUE - ) - ).thenReturn(new ArrayList<>()); - assertNull(remoteSegmentStoreDirectory.initializeToSpecificTimestamp(1234L)); - } - - public void testInitializeToSpecificTimestampNoMdMatchingTimestamp() throws IOException { - String metadataPrefix = "metadata__1__2__3__4__5__"; - List metadataFiles = new ArrayList<>(); - metadataFiles.add(metadataPrefix + RemoteStoreUtils.invertLong(2000)); - metadataFiles.add(metadataPrefix + RemoteStoreUtils.invertLong(3000)); - metadataFiles.add(metadataPrefix + RemoteStoreUtils.invertLong(4000)); - - when( - remoteMetadataDirectory.listFilesByPrefixInLexicographicOrder( - RemoteSegmentStoreDirectory.MetadataFilenameUtils.METADATA_PREFIX, - Integer.MAX_VALUE - ) - ).thenReturn(metadataFiles); - assertNull(remoteSegmentStoreDirectory.initializeToSpecificTimestamp(1234L)); - } - - public void testInitializeToSpecificTimestampMatchingMdFile() throws IOException { - String metadataPrefix = "metadata__1__2__3__4__5__"; - List metadataFiles = new ArrayList<>(); - metadataFiles.add(metadataPrefix + RemoteStoreUtils.invertLong(1000)); - metadataFiles.add(metadataPrefix + RemoteStoreUtils.invertLong(2000)); - metadataFiles.add(metadataPrefix + RemoteStoreUtils.invertLong(3000)); - - Map metadata = new HashMap<>(); - metadata.put("_0.cfe", "_0.cfe::_0.cfe__" + UUIDs.base64UUID() + "::1234::512::" + Version.LATEST.major); - metadata.put("_0.cfs", "_0.cfs::_0.cfs__" + UUIDs.base64UUID() + "::2345::1024::" + Version.LATEST.major); - - when( - remoteMetadataDirectory.listFilesByPrefixInLexicographicOrder( - RemoteSegmentStoreDirectory.MetadataFilenameUtils.METADATA_PREFIX, - Integer.MAX_VALUE - ) - ).thenReturn(metadataFiles); - when(remoteMetadataDirectory.getBlobStream(metadataPrefix + RemoteStoreUtils.invertLong(1000))).thenReturn( - createMetadataFileBytes(metadata, indexShard.getLatestReplicationCheckpoint(), segmentInfos) - ); - - RemoteSegmentMetadata remoteSegmentMetadata = remoteSegmentStoreDirectory.initializeToSpecificTimestamp(1234L); - assertNotNull(remoteSegmentMetadata); - Map uploadedSegments = remoteSegmentStoreDirectory - .getSegmentsUploadedToRemoteStore(); - assertEquals(2, uploadedSegments.size()); - assertTrue(uploadedSegments.containsKey("_0.cfe")); - assertTrue(uploadedSegments.containsKey("_0.cfs")); - } - public void testDeleteStaleCommitsNoPinnedTimestampMdFilesLatest() throws Exception { metadataFilename = RemoteSegmentStoreDirectory.MetadataFilenameUtils.getMetadataFilename( 12, From e982a16667bb2c7fa7e6d3e0618f3bb0c070d377 Mon Sep 17 00:00:00 2001 From: Rishab Nahata Date: Thu, 29 Aug 2024 20:31:15 +0530 Subject: [PATCH 3/7] Make balanced shards allocator timebound (#15239) * Make balanced shards allocator time bound to prioritise critical operations waiting in the pending task queue Signed-off-by: Rishab Nahata --- CHANGELOG.md | 1 + .../cluster/routing/RoutingNodes.java | 4 +- .../allocator/BalancedShardsAllocator.java | 46 +- .../allocator/LocalShardsBalancer.java | 46 +- .../common/settings/ClusterSettings.java | 1 + ...TimeBoundBalancedShardsAllocatorTests.java | 479 ++++++++++++++++++ .../decider/DiskThresholdDeciderTests.java | 12 +- .../cluster/OpenSearchAllocationTestCase.java | 11 + 8 files changed, 591 insertions(+), 9 deletions(-) create mode 100644 server/src/test/java/org/opensearch/cluster/routing/allocation/allocator/TimeBoundBalancedShardsAllocatorTests.java diff --git a/CHANGELOG.md b/CHANGELOG.md index b7e4548100df3..f8b695205e789 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -23,6 +23,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), - [Workload Management] QueryGroup resource tracking framework changes ([#13897](https://github.com/opensearch-project/OpenSearch/pull/13897)) - Support filtering on a large list encoded by bitmap ([#14774](https://github.com/opensearch-project/OpenSearch/pull/14774)) - Add slice execution listeners to SearchOperationListener interface ([#15153](https://github.com/opensearch-project/OpenSearch/pull/15153)) +- Make balanced shards allocator timebound ([#15239](https://github.com/opensearch-project/OpenSearch/pull/15239)) - Add allowlist setting for ingest-geoip and ingest-useragent ([#15325](https://github.com/opensearch-project/OpenSearch/pull/15325)) - Adding access to noSubMatches and noOverlappingMatches in Hyphenation ([#13895](https://github.com/opensearch-project/OpenSearch/pull/13895)) - Add support for index level max slice count setting for concurrent segment search ([#15336](https://github.com/opensearch-project/OpenSearch/pull/15336)) diff --git a/server/src/main/java/org/opensearch/cluster/routing/RoutingNodes.java b/server/src/main/java/org/opensearch/cluster/routing/RoutingNodes.java index ab455f52c4195..b5e74821d41e7 100644 --- a/server/src/main/java/org/opensearch/cluster/routing/RoutingNodes.java +++ b/server/src/main/java/org/opensearch/cluster/routing/RoutingNodes.java @@ -1439,7 +1439,9 @@ public void remove() { */ public Iterator nodeInterleavedShardIterator(ShardMovementStrategy shardMovementStrategy) { final Queue> queue = new ArrayDeque<>(); - for (Map.Entry entry : nodesToShards.entrySet()) { + List> nodesToShardsEntrySet = new ArrayList<>(nodesToShards.entrySet()); + Randomness.shuffle(nodesToShardsEntrySet); + for (Map.Entry entry : nodesToShardsEntrySet) { queue.add(entry.getValue().copyShards().iterator()); } if (shardMovementStrategy == ShardMovementStrategy.PRIMARY_FIRST) { diff --git a/server/src/main/java/org/opensearch/cluster/routing/allocation/allocator/BalancedShardsAllocator.java b/server/src/main/java/org/opensearch/cluster/routing/allocation/allocator/BalancedShardsAllocator.java index 212583d1fb14f..a5193ca602f04 100644 --- a/server/src/main/java/org/opensearch/cluster/routing/allocation/allocator/BalancedShardsAllocator.java +++ b/server/src/main/java/org/opensearch/cluster/routing/allocation/allocator/BalancedShardsAllocator.java @@ -54,6 +54,7 @@ import org.opensearch.common.settings.Setting; import org.opensearch.common.settings.Setting.Property; import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.TimeValue; import java.util.HashMap; import java.util.HashSet; @@ -87,6 +88,7 @@ public class BalancedShardsAllocator implements ShardsAllocator { private static final Logger logger = LogManager.getLogger(BalancedShardsAllocator.class); + public static final TimeValue MIN_ALLOCATOR_TIMEOUT = TimeValue.timeValueSeconds(20); public static final Setting INDEX_BALANCE_FACTOR_SETTING = Setting.floatSetting( "cluster.routing.allocation.balance.index", @@ -169,6 +171,23 @@ public class BalancedShardsAllocator implements ShardsAllocator { Property.NodeScope ); + public static final Setting ALLOCATOR_TIMEOUT_SETTING = Setting.timeSetting( + "cluster.routing.allocation.balanced_shards_allocator.allocator_timeout", + TimeValue.MINUS_ONE, + TimeValue.MINUS_ONE, + timeValue -> { + if (timeValue.compareTo(MIN_ALLOCATOR_TIMEOUT) < 0 && timeValue.compareTo(TimeValue.MINUS_ONE) != 0) { + throw new IllegalArgumentException( + "Setting [" + + "cluster.routing.allocation.balanced_shards_allocator.allocator_timeout" + + "] should be more than 20s or -1ms to disable timeout" + ); + } + }, + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); + private volatile boolean movePrimaryFirst; private volatile ShardMovementStrategy shardMovementStrategy; @@ -181,6 +200,8 @@ public class BalancedShardsAllocator implements ShardsAllocator { private volatile float threshold; private volatile boolean ignoreThrottleInRestore; + private volatile TimeValue allocatorTimeout; + private long startTime; public BalancedShardsAllocator(Settings settings) { this(settings, new ClusterSettings(settings, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS)); @@ -197,6 +218,7 @@ public BalancedShardsAllocator(Settings settings, ClusterSettings clusterSetting setPreferPrimaryShardBalance(PREFER_PRIMARY_SHARD_BALANCE.get(settings)); setPreferPrimaryShardRebalance(PREFER_PRIMARY_SHARD_REBALANCE.get(settings)); setShardMovementStrategy(SHARD_MOVEMENT_STRATEGY_SETTING.get(settings)); + setAllocatorTimeout(ALLOCATOR_TIMEOUT_SETTING.get(settings)); clusterSettings.addSettingsUpdateConsumer(PREFER_PRIMARY_SHARD_BALANCE, this::setPreferPrimaryShardBalance); clusterSettings.addSettingsUpdateConsumer(SHARD_MOVE_PRIMARY_FIRST_SETTING, this::setMovePrimaryFirst); clusterSettings.addSettingsUpdateConsumer(SHARD_MOVEMENT_STRATEGY_SETTING, this::setShardMovementStrategy); @@ -206,6 +228,7 @@ public BalancedShardsAllocator(Settings settings, ClusterSettings clusterSetting clusterSettings.addSettingsUpdateConsumer(PREFER_PRIMARY_SHARD_REBALANCE, this::setPreferPrimaryShardRebalance); clusterSettings.addSettingsUpdateConsumer(THRESHOLD_SETTING, this::setThreshold); clusterSettings.addSettingsUpdateConsumer(IGNORE_THROTTLE_FOR_REMOTE_RESTORE, this::setIgnoreThrottleInRestore); + clusterSettings.addSettingsUpdateConsumer(ALLOCATOR_TIMEOUT_SETTING, this::setAllocatorTimeout); } /** @@ -284,6 +307,20 @@ private void setThreshold(float threshold) { this.threshold = threshold; } + private void setAllocatorTimeout(TimeValue allocatorTimeout) { + this.allocatorTimeout = allocatorTimeout; + } + + protected boolean allocatorTimedOut() { + if (allocatorTimeout.equals(TimeValue.MINUS_ONE)) { + if (logger.isTraceEnabled()) { + logger.trace("Allocator timeout is disabled. Will not short circuit allocator tasks"); + } + return false; + } + return System.nanoTime() - this.startTime > allocatorTimeout.nanos(); + } + @Override public void allocate(RoutingAllocation allocation) { if (allocation.routingNodes().size() == 0) { @@ -298,8 +335,10 @@ public void allocate(RoutingAllocation allocation) { threshold, preferPrimaryShardBalance, preferPrimaryShardRebalance, - ignoreThrottleInRestore + ignoreThrottleInRestore, + this::allocatorTimedOut ); + this.startTime = System.nanoTime(); localShardsBalancer.allocateUnassigned(); localShardsBalancer.moveShards(); localShardsBalancer.balance(); @@ -321,7 +360,8 @@ public ShardAllocationDecision decideShardAllocation(final ShardRouting shard, f threshold, preferPrimaryShardBalance, preferPrimaryShardRebalance, - ignoreThrottleInRestore + ignoreThrottleInRestore, + () -> false // as we don't need to check if timed out or not while just understanding ShardAllocationDecision ); AllocateUnassignedDecision allocateUnassignedDecision = AllocateUnassignedDecision.NOT_TAKEN; MoveDecision moveDecision = MoveDecision.NOT_TAKEN; @@ -585,7 +625,7 @@ public Balancer( float threshold, boolean preferPrimaryBalance ) { - super(logger, allocation, shardMovementStrategy, weight, threshold, preferPrimaryBalance, false, false); + super(logger, allocation, shardMovementStrategy, weight, threshold, preferPrimaryBalance, false, false, () -> false); } } diff --git a/server/src/main/java/org/opensearch/cluster/routing/allocation/allocator/LocalShardsBalancer.java b/server/src/main/java/org/opensearch/cluster/routing/allocation/allocator/LocalShardsBalancer.java index 7e4ae58548c55..adb8ee2cf7e85 100644 --- a/server/src/main/java/org/opensearch/cluster/routing/allocation/allocator/LocalShardsBalancer.java +++ b/server/src/main/java/org/opensearch/cluster/routing/allocation/allocator/LocalShardsBalancer.java @@ -41,6 +41,7 @@ import java.util.List; import java.util.Map; import java.util.Set; +import java.util.function.Supplier; import java.util.stream.Stream; import java.util.stream.StreamSupport; @@ -71,6 +72,7 @@ public class LocalShardsBalancer extends ShardsBalancer { private final float avgPrimaryShardsPerNode; private final BalancedShardsAllocator.NodeSorter sorter; private final Set inEligibleTargetNode; + private final Supplier timedOutFunc; private int totalShardCount = 0; public LocalShardsBalancer( @@ -81,7 +83,8 @@ public LocalShardsBalancer( float threshold, boolean preferPrimaryBalance, boolean preferPrimaryRebalance, - boolean ignoreThrottleInRestore + boolean ignoreThrottleInRestore, + Supplier timedOutFunc ) { this.logger = logger; this.allocation = allocation; @@ -99,6 +102,7 @@ public LocalShardsBalancer( this.preferPrimaryRebalance = preferPrimaryRebalance; this.shardMovementStrategy = shardMovementStrategy; this.ignoreThrottleInRestore = ignoreThrottleInRestore; + this.timedOutFunc = timedOutFunc; } /** @@ -344,6 +348,14 @@ private void balanceByWeights() { final BalancedShardsAllocator.ModelNode[] modelNodes = sorter.modelNodes; final float[] weights = sorter.weights; for (String index : buildWeightOrderedIndices()) { + // Terminate if the time allocated to the balanced shards allocator has elapsed + if (timedOutFunc != null && timedOutFunc.get()) { + logger.info( + "Cannot balance any shard in the cluster as time allocated to balanced shards allocator has elapsed" + + ". Skipping indices iteration" + ); + return; + } IndexMetadata indexMetadata = metadata.index(index); // find nodes that have a shard of this index or where shards of this index are allowed to be allocated to, @@ -368,6 +380,14 @@ private void balanceByWeights() { int lowIdx = 0; int highIdx = relevantNodes - 1; while (true) { + // break if the time allocated to the balanced shards allocator has elapsed + if (timedOutFunc != null && timedOutFunc.get()) { + logger.info( + "Cannot balance any shard in the cluster as time allocated to balanced shards allocator has elapsed" + + ". Skipping relevant nodes iteration" + ); + return; + } final BalancedShardsAllocator.ModelNode minNode = modelNodes[lowIdx]; final BalancedShardsAllocator.ModelNode maxNode = modelNodes[highIdx]; advance_range: if (maxNode.numShards(index) > 0) { @@ -572,6 +592,15 @@ void moveShards() { return; } + // Terminate if the time allocated to the balanced shards allocator has elapsed + if (timedOutFunc != null && timedOutFunc.get()) { + logger.info( + "Cannot move any shard in the cluster as time allocated to balanced shards allocator has elapsed" + + ". Skipping shard iteration" + ); + return; + } + ShardRouting shardRouting = it.next(); if (RoutingPool.REMOTE_CAPABLE.equals(RoutingPool.getShardPool(shardRouting, allocation))) { @@ -799,8 +828,23 @@ void allocateUnassigned() { int secondaryLength = 0; int primaryLength = primary.length; ArrayUtil.timSort(primary, comparator); + if (logger.isTraceEnabled()) { + logger.trace("Staring allocation of [{}] unassigned shards", primaryLength); + } do { for (int i = 0; i < primaryLength; i++) { + if (timedOutFunc != null && timedOutFunc.get()) { + // TODO - maybe check if we can allow wait for active shards thingy bypass this condition + logger.info( + "Ignoring [{}] unassigned shards for allocation as time allocated to balanced shards allocator has elapsed", + (primaryLength - i) + ); + while (i < primaryLength) { + unassigned.ignoreShard(primary[i], UnassignedInfo.AllocationStatus.NO_ATTEMPT, allocation.changes()); + i++; + } + return; + } ShardRouting shard = primary[i]; final AllocateUnassignedDecision allocationDecision = decideAllocateUnassigned(shard); final String assignedNodeId = allocationDecision.getTargetNode() != null diff --git a/server/src/main/java/org/opensearch/common/settings/ClusterSettings.java b/server/src/main/java/org/opensearch/common/settings/ClusterSettings.java index 8daf9125bb27e..9a6b3f1118709 100644 --- a/server/src/main/java/org/opensearch/common/settings/ClusterSettings.java +++ b/server/src/main/java/org/opensearch/common/settings/ClusterSettings.java @@ -270,6 +270,7 @@ public void apply(Settings value, Settings current, Settings previous) { BalancedShardsAllocator.SHARD_MOVEMENT_STRATEGY_SETTING, BalancedShardsAllocator.THRESHOLD_SETTING, BalancedShardsAllocator.IGNORE_THROTTLE_FOR_REMOTE_RESTORE, + BalancedShardsAllocator.ALLOCATOR_TIMEOUT_SETTING, BreakerSettings.CIRCUIT_BREAKER_LIMIT_SETTING, BreakerSettings.CIRCUIT_BREAKER_OVERHEAD_SETTING, BreakerSettings.CIRCUIT_BREAKER_TYPE, diff --git a/server/src/test/java/org/opensearch/cluster/routing/allocation/allocator/TimeBoundBalancedShardsAllocatorTests.java b/server/src/test/java/org/opensearch/cluster/routing/allocation/allocator/TimeBoundBalancedShardsAllocatorTests.java new file mode 100644 index 0000000000000..a10c305686638 --- /dev/null +++ b/server/src/test/java/org/opensearch/cluster/routing/allocation/allocator/TimeBoundBalancedShardsAllocatorTests.java @@ -0,0 +1,479 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.cluster.routing.allocation.allocator; + +import org.opensearch.Version; +import org.opensearch.cluster.ClusterInfo; +import org.opensearch.cluster.ClusterName; +import org.opensearch.cluster.ClusterState; +import org.opensearch.cluster.OpenSearchAllocationTestCase; +import org.opensearch.cluster.metadata.IndexMetadata; +import org.opensearch.cluster.metadata.Metadata; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.cluster.node.DiscoveryNodes; +import org.opensearch.cluster.routing.RoutingNodes; +import org.opensearch.cluster.routing.RoutingTable; +import org.opensearch.cluster.routing.ShardRouting; +import org.opensearch.cluster.routing.ShardRoutingState; +import org.opensearch.cluster.routing.allocation.RoutingAllocation; +import org.opensearch.cluster.routing.allocation.decider.AllocationDecider; +import org.opensearch.cluster.routing.allocation.decider.AllocationDeciders; +import org.opensearch.cluster.routing.allocation.decider.Decision; +import org.opensearch.cluster.routing.allocation.decider.SameShardAllocationDecider; +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.common.settings.Settings; + +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.concurrent.CountDownLatch; + +import static org.opensearch.cluster.routing.ShardRoutingState.INITIALIZING; +import static org.opensearch.cluster.routing.ShardRoutingState.STARTED; +import static org.opensearch.cluster.routing.allocation.allocator.BalancedShardsAllocator.ALLOCATOR_TIMEOUT_SETTING; + +public class TimeBoundBalancedShardsAllocatorTests extends OpenSearchAllocationTestCase { + + private final DiscoveryNode node1 = newNode("node1", "node1", Collections.singletonMap("zone", "1a")); + private final DiscoveryNode node2 = newNode("node2", "node2", Collections.singletonMap("zone", "1b")); + private final DiscoveryNode node3 = newNode("node3", "node3", Collections.singletonMap("zone", "1c")); + + public void testAllUnassignedShardsAllocatedWhenNoTimeOut() { + int numberOfIndices = 2; + int numberOfShards = 5; + int numberOfReplicas = 1; + int totalPrimaryCount = numberOfIndices * numberOfShards; + int totalShardCount = numberOfIndices * (numberOfShards * (numberOfReplicas + 1)); + Settings.Builder settings = Settings.builder(); + // passing total shard count for timed out latch such that no shard times out + BalancedShardsAllocator allocator = new TestBalancedShardsAllocator(settings.build(), new CountDownLatch(totalShardCount)); + Metadata metadata = buildMetadata(Metadata.builder(), numberOfIndices, numberOfShards, numberOfReplicas); + RoutingTable routingTable = buildRoutingTable(metadata); + ClusterState state = ClusterState.builder(ClusterName.CLUSTER_NAME_SETTING.getDefault(Settings.EMPTY)) + .metadata(metadata) + .routingTable(routingTable) + .nodes(DiscoveryNodes.builder().add(node1).add(node2).add(node3)) + .build(); + RoutingAllocation allocation = new RoutingAllocation( + yesAllocationDeciders(), + new RoutingNodes(state, false), + state, + ClusterInfo.EMPTY, + null, + System.nanoTime() + ); + allocator.allocate(allocation); + List initializingShards = allocation.routingNodes().shardsWithState(ShardRoutingState.INITIALIZING); + int node1Recoveries = allocation.routingNodes().getInitialPrimariesIncomingRecoveries(node1.getId()); + int node2Recoveries = allocation.routingNodes().getInitialPrimariesIncomingRecoveries(node2.getId()); + int node3Recoveries = allocation.routingNodes().getInitialPrimariesIncomingRecoveries(node3.getId()); + assertEquals(totalShardCount, initializingShards.size()); + assertEquals(0, allocation.routingNodes().unassigned().ignored().size()); + assertEquals(totalPrimaryCount, node1Recoveries + node2Recoveries + node3Recoveries); + } + + public void testAllUnassignedShardsIgnoredWhenTimedOut() { + int numberOfIndices = 2; + int numberOfShards = 5; + int numberOfReplicas = 1; + int totalShardCount = numberOfIndices * (numberOfShards * (numberOfReplicas + 1)); + Settings.Builder settings = Settings.builder(); + // passing 0 for timed out latch such that all shard times out + BalancedShardsAllocator allocator = new TestBalancedShardsAllocator(settings.build(), new CountDownLatch(0)); + Metadata metadata = buildMetadata(Metadata.builder(), numberOfIndices, numberOfShards, numberOfReplicas); + RoutingTable routingTable = buildRoutingTable(metadata); + ClusterState state = ClusterState.builder(ClusterName.CLUSTER_NAME_SETTING.getDefault(Settings.EMPTY)) + .metadata(metadata) + .routingTable(routingTable) + .nodes(DiscoveryNodes.builder().add(node1).add(node2).add(node3)) + .build(); + RoutingAllocation allocation = new RoutingAllocation( + yesAllocationDeciders(), + new RoutingNodes(state, false), + state, + ClusterInfo.EMPTY, + null, + System.nanoTime() + ); + allocator.allocate(allocation); + List initializingShards = allocation.routingNodes().shardsWithState(ShardRoutingState.INITIALIZING); + int node1Recoveries = allocation.routingNodes().getInitialPrimariesIncomingRecoveries(node1.getId()); + int node2Recoveries = allocation.routingNodes().getInitialPrimariesIncomingRecoveries(node2.getId()); + int node3Recoveries = allocation.routingNodes().getInitialPrimariesIncomingRecoveries(node3.getId()); + assertEquals(0, initializingShards.size()); + assertEquals(totalShardCount, allocation.routingNodes().unassigned().ignored().size()); + assertEquals(0, node1Recoveries + node2Recoveries + node3Recoveries); + } + + public void testAllocatePartialPrimaryShardsUntilTimedOut() { + int numberOfIndices = 2; + int numberOfShards = 5; + int numberOfReplicas = 1; + int totalShardCount = numberOfIndices * (numberOfShards * (numberOfReplicas + 1)); + Settings.Builder settings = Settings.builder(); + int shardsToAllocate = randomIntBetween(1, numberOfShards * numberOfIndices); + // passing shards to allocate for timed out latch such that only few primary shards are allocated in this reroute round + BalancedShardsAllocator allocator = new TestBalancedShardsAllocator(settings.build(), new CountDownLatch(shardsToAllocate)); + Metadata metadata = buildMetadata(Metadata.builder(), numberOfIndices, numberOfShards, numberOfReplicas); + RoutingTable routingTable = buildRoutingTable(metadata); + ClusterState state = ClusterState.builder(ClusterName.CLUSTER_NAME_SETTING.getDefault(Settings.EMPTY)) + .metadata(metadata) + .routingTable(routingTable) + .nodes(DiscoveryNodes.builder().add(node1).add(node2).add(node3)) + .build(); + RoutingAllocation allocation = new RoutingAllocation( + yesAllocationDeciders(), + new RoutingNodes(state, false), + state, + ClusterInfo.EMPTY, + null, + System.nanoTime() + ); + allocator.allocate(allocation); + List initializingShards = allocation.routingNodes().shardsWithState(ShardRoutingState.INITIALIZING); + int node1Recoveries = allocation.routingNodes().getInitialPrimariesIncomingRecoveries(node1.getId()); + int node2Recoveries = allocation.routingNodes().getInitialPrimariesIncomingRecoveries(node2.getId()); + int node3Recoveries = allocation.routingNodes().getInitialPrimariesIncomingRecoveries(node3.getId()); + assertEquals(shardsToAllocate, initializingShards.size()); + assertEquals(totalShardCount - shardsToAllocate, allocation.routingNodes().unassigned().ignored().size()); + assertEquals(shardsToAllocate, node1Recoveries + node2Recoveries + node3Recoveries); + } + + public void testAllocateAllPrimaryShardsAndPartialReplicaShardsUntilTimedOut() { + int numberOfIndices = 2; + int numberOfShards = 5; + int numberOfReplicas = 1; + int totalShardCount = numberOfIndices * (numberOfShards * (numberOfReplicas + 1)); + Settings.Builder settings = Settings.builder(); + int shardsToAllocate = randomIntBetween(numberOfShards * numberOfIndices, totalShardCount); + // passing shards to allocate for timed out latch such that all primary shards and few replica shards are allocated in this reroute + // round + BalancedShardsAllocator allocator = new TestBalancedShardsAllocator(settings.build(), new CountDownLatch(shardsToAllocate)); + Metadata metadata = buildMetadata(Metadata.builder(), numberOfIndices, numberOfShards, numberOfReplicas); + RoutingTable routingTable = buildRoutingTable(metadata); + ClusterState state = ClusterState.builder(ClusterName.CLUSTER_NAME_SETTING.getDefault(Settings.EMPTY)) + .metadata(metadata) + .routingTable(routingTable) + .nodes(DiscoveryNodes.builder().add(node1).add(node2).add(node3)) + .build(); + RoutingAllocation allocation = new RoutingAllocation( + yesAllocationDeciders(), + new RoutingNodes(state, false), + state, + ClusterInfo.EMPTY, + null, + System.nanoTime() + ); + allocator.allocate(allocation); + List initializingShards = allocation.routingNodes().shardsWithState(ShardRoutingState.INITIALIZING); + int node1Recoveries = allocation.routingNodes().getInitialPrimariesIncomingRecoveries(node1.getId()); + int node2Recoveries = allocation.routingNodes().getInitialPrimariesIncomingRecoveries(node2.getId()); + int node3Recoveries = allocation.routingNodes().getInitialPrimariesIncomingRecoveries(node3.getId()); + assertEquals(shardsToAllocate, initializingShards.size()); + assertEquals(totalShardCount - shardsToAllocate, allocation.routingNodes().unassigned().ignored().size()); + assertEquals(numberOfShards * numberOfIndices, node1Recoveries + node2Recoveries + node3Recoveries); + } + + public void testAllShardsMoveWhenExcludedAndTimeoutNotBreached() { + int numberOfIndices = 3; + int numberOfShards = 5; + int numberOfReplicas = 1; + int totalShardCount = numberOfIndices * (numberOfShards * (numberOfReplicas + 1)); + Metadata metadata = buildMetadata(Metadata.builder(), numberOfIndices, numberOfShards, numberOfReplicas); + RoutingTable routingTable = buildRoutingTable(metadata); + ClusterState state = ClusterState.builder(ClusterName.CLUSTER_NAME_SETTING.getDefault(Settings.EMPTY)) + .metadata(metadata) + .routingTable(routingTable) + .nodes(DiscoveryNodes.builder().add(node1).add(node2).add(node3)) + .build(); + MockAllocationService allocationService = createAllocationService(); + state = applyStartedShardsUntilNoChange(state, allocationService); + // check all shards allocated + assertEquals(0, state.getRoutingNodes().shardsWithState(INITIALIZING).size()); + assertEquals(totalShardCount, state.getRoutingNodes().shardsWithState(STARTED).size()); + int node1ShardCount = state.getRoutingNodes().node("node1").size(); + Settings settings = Settings.builder().put("cluster.routing.allocation.exclude.zone", "1a").build(); + int shardsToMove = 10 + 1000; // such that time out is never breached + BalancedShardsAllocator allocator = new TestBalancedShardsAllocator(settings, new CountDownLatch(shardsToMove)); + RoutingAllocation allocation = new RoutingAllocation( + allocationDecidersForExcludeAPI(settings), + new RoutingNodes(state, false), + state, + ClusterInfo.EMPTY, + null, + System.nanoTime() + ); + allocator.allocate(allocation); + List relocatingShards = allocation.routingNodes().shardsWithState(ShardRoutingState.RELOCATING); + assertEquals(node1ShardCount, relocatingShards.size()); + } + + public void testNoShardsMoveWhenExcludedAndTimeoutBreached() { + int numberOfIndices = 3; + int numberOfShards = 5; + int numberOfReplicas = 1; + int totalShardCount = numberOfIndices * (numberOfShards * (numberOfReplicas + 1)); + Metadata metadata = buildMetadata(Metadata.builder(), numberOfIndices, numberOfShards, numberOfReplicas); + RoutingTable routingTable = buildRoutingTable(metadata); + ClusterState state = ClusterState.builder(ClusterName.CLUSTER_NAME_SETTING.getDefault(Settings.EMPTY)) + .metadata(metadata) + .routingTable(routingTable) + .nodes(DiscoveryNodes.builder().add(node1).add(node2).add(node3)) + .build(); + MockAllocationService allocationService = createAllocationService(); + state = applyStartedShardsUntilNoChange(state, allocationService); + // check all shards allocated + assertEquals(0, state.getRoutingNodes().shardsWithState(INITIALIZING).size()); + assertEquals(totalShardCount, state.getRoutingNodes().shardsWithState(STARTED).size()); + Settings settings = Settings.builder().put("cluster.routing.allocation.exclude.zone", "1a").build(); + int shardsToMove = 0; // such that time out is never breached + BalancedShardsAllocator allocator = new TestBalancedShardsAllocator(settings, new CountDownLatch(shardsToMove)); + RoutingAllocation allocation = new RoutingAllocation( + allocationDecidersForExcludeAPI(settings), + new RoutingNodes(state, false), + state, + ClusterInfo.EMPTY, + null, + System.nanoTime() + ); + allocator.allocate(allocation); + List relocatingShards = allocation.routingNodes().shardsWithState(ShardRoutingState.RELOCATING); + assertEquals(0, relocatingShards.size()); + } + + public void testPartialShardsMoveWhenExcludedAndTimeoutBreached() { + int numberOfIndices = 3; + int numberOfShards = 5; + int numberOfReplicas = 1; + int totalShardCount = numberOfIndices * (numberOfShards * (numberOfReplicas + 1)); + Metadata metadata = buildMetadata(Metadata.builder(), numberOfIndices, numberOfShards, numberOfReplicas); + RoutingTable routingTable = buildRoutingTable(metadata); + ClusterState state = ClusterState.builder(ClusterName.CLUSTER_NAME_SETTING.getDefault(Settings.EMPTY)) + .metadata(metadata) + .routingTable(routingTable) + .nodes(DiscoveryNodes.builder().add(node1).add(node2).add(node3)) + .build(); + MockAllocationService allocationService = createAllocationService(); + state = applyStartedShardsUntilNoChange(state, allocationService); + // check all shards allocated + assertEquals(0, state.getRoutingNodes().shardsWithState(INITIALIZING).size()); + assertEquals(totalShardCount, state.getRoutingNodes().shardsWithState(STARTED).size()); + Settings settings = Settings.builder().put("cluster.routing.allocation.exclude.zone", "1a").build(); + // since for moves, it creates an iterator over shards which interleaves between nodes, hence + // for shardsToMove=6, it will have 2 shards from node1, node2, node3 each attempting to move with only + // shards from node1 can actually move. Hence, total moves that will be executed is 2 (6/3). + int shardsToMove = 6; // such that time out is never breached + BalancedShardsAllocator allocator = new TestBalancedShardsAllocator(settings, new CountDownLatch(shardsToMove)); + RoutingAllocation allocation = new RoutingAllocation( + allocationDecidersForExcludeAPI(settings), + new RoutingNodes(state, false), + state, + ClusterInfo.EMPTY, + null, + System.nanoTime() + ); + allocator.allocate(allocation); + List relocatingShards = allocation.routingNodes().shardsWithState(ShardRoutingState.RELOCATING); + assertEquals(shardsToMove / 3, relocatingShards.size()); + } + + public void testClusterRebalancedWhenNotTimedOut() { + int numberOfIndices = 1; + int numberOfShards = 15; + int numberOfReplicas = 1; + int totalShardCount = numberOfIndices * (numberOfShards * (numberOfReplicas + 1)); + Metadata metadata = buildMetadata(Metadata.builder(), numberOfIndices, numberOfShards, numberOfReplicas); + RoutingTable routingTable = buildRoutingTable(metadata); + ClusterState state = ClusterState.builder(ClusterName.CLUSTER_NAME_SETTING.getDefault(Settings.EMPTY)) + .metadata(metadata) + .routingTable(routingTable) + .nodes(DiscoveryNodes.builder().add(node1).add(node2).add(node3)) + .build(); + MockAllocationService allocationService = createAllocationService( + Settings.builder().put("cluster.routing.allocation.exclude.zone", "1a").build() + ); // such that no shards are allocated to node1 + state = applyStartedShardsUntilNoChange(state, allocationService); + int node1ShardCount = state.getRoutingNodes().node("node1").size(); + // check all shards allocated + assertEquals(0, state.getRoutingNodes().shardsWithState(INITIALIZING).size()); + assertEquals(totalShardCount, state.getRoutingNodes().shardsWithState(STARTED).size()); + assertEquals(0, node1ShardCount); + Settings newSettings = Settings.builder().put("cluster.routing.allocation.exclude.zone", "").build(); + int shardsToMove = 1000; // such that time out is never breached + BalancedShardsAllocator allocator = new TestBalancedShardsAllocator(newSettings, new CountDownLatch(shardsToMove)); + RoutingAllocation allocation = new RoutingAllocation( + allocationDecidersForExcludeAPI(newSettings), + new RoutingNodes(state, false), + state, + ClusterInfo.EMPTY, + null, + System.nanoTime() + ); + allocator.allocate(allocation); + List relocatingShards = allocation.routingNodes().shardsWithState(ShardRoutingState.RELOCATING); + assertEquals(totalShardCount / 3, relocatingShards.size()); + } + + public void testClusterNotRebalancedWhenTimedOut() { + int numberOfIndices = 1; + int numberOfShards = 15; + int numberOfReplicas = 1; + int totalShardCount = numberOfIndices * (numberOfShards * (numberOfReplicas + 1)); + Metadata metadata = buildMetadata(Metadata.builder(), numberOfIndices, numberOfShards, numberOfReplicas); + RoutingTable routingTable = buildRoutingTable(metadata); + ClusterState state = ClusterState.builder(ClusterName.CLUSTER_NAME_SETTING.getDefault(Settings.EMPTY)) + .metadata(metadata) + .routingTable(routingTable) + .nodes(DiscoveryNodes.builder().add(node1).add(node2).add(node3)) + .build(); + MockAllocationService allocationService = createAllocationService( + Settings.builder().put("cluster.routing.allocation.exclude.zone", "1a").build() + ); // such that no shards are allocated to node1 + state = applyStartedShardsUntilNoChange(state, allocationService); + int node1ShardCount = state.getRoutingNodes().node("node1").size(); + // check all shards allocated + assertEquals(0, state.getRoutingNodes().shardsWithState(INITIALIZING).size()); + assertEquals(totalShardCount, state.getRoutingNodes().shardsWithState(STARTED).size()); + assertEquals(0, node1ShardCount); + Settings newSettings = Settings.builder().put("cluster.routing.allocation.exclude.zone", "").build(); + int shardsToMove = 0; // such that it never balances anything + BalancedShardsAllocator allocator = new TestBalancedShardsAllocator(newSettings, new CountDownLatch(shardsToMove)); + RoutingAllocation allocation = new RoutingAllocation( + allocationDecidersForExcludeAPI(newSettings), + new RoutingNodes(state, false), + state, + ClusterInfo.EMPTY, + null, + System.nanoTime() + ); + allocator.allocate(allocation); + List relocatingShards = allocation.routingNodes().shardsWithState(ShardRoutingState.RELOCATING); + assertEquals(0, relocatingShards.size()); + } + + public void testClusterPartialRebalancedWhenTimedOut() { + int numberOfIndices = 1; + int numberOfShards = 15; + int numberOfReplicas = 1; + int totalShardCount = numberOfIndices * (numberOfShards * (numberOfReplicas + 1)); + Metadata metadata = buildMetadata(Metadata.builder(), numberOfIndices, numberOfShards, numberOfReplicas); + RoutingTable routingTable = buildRoutingTable(metadata); + ClusterState state = ClusterState.builder(ClusterName.CLUSTER_NAME_SETTING.getDefault(Settings.EMPTY)) + .metadata(metadata) + .routingTable(routingTable) + .nodes(DiscoveryNodes.builder().add(node1).add(node2).add(node3)) + .build(); + MockAllocationService allocationService = createAllocationService( + Settings.builder().put("cluster.routing.allocation.exclude.zone", "1a").build() + ); // such that no shards are allocated to node1 + state = applyStartedShardsUntilNoChange(state, allocationService); + int node1ShardCount = state.getRoutingNodes().node("node1").size(); + // check all shards allocated + assertEquals(0, state.getRoutingNodes().shardsWithState(INITIALIZING).size()); + assertEquals(totalShardCount, state.getRoutingNodes().shardsWithState(STARTED).size()); + assertEquals(0, node1ShardCount); + Settings newSettings = Settings.builder().put("cluster.routing.allocation.exclude.zone", "").build(); + + // making custom set of allocation deciders such that it never attempts to move shards but always attempts to rebalance + List allocationDeciders = Arrays.asList(new AllocationDecider() { + @Override + public Decision canMoveAnyShard(RoutingAllocation allocation) { + return Decision.NO; + } + }, new AllocationDecider() { + @Override + public Decision canRebalance(ShardRouting shardRouting, RoutingAllocation allocation) { + return Decision.YES; + } + }, new SameShardAllocationDecider(newSettings, new ClusterSettings(newSettings, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS))); + int shardsToMove = 3; // such that it only partially balances few shards + // adding +1 as during rebalance we do per index timeout check and then per node check + BalancedShardsAllocator allocator = new TestBalancedShardsAllocator(newSettings, new CountDownLatch(shardsToMove + 1)); + RoutingAllocation allocation = new RoutingAllocation( + new AllocationDeciders(allocationDeciders), + new RoutingNodes(state, false), + state, + ClusterInfo.EMPTY, + null, + System.nanoTime() + ); + allocator.allocate(allocation); + List relocatingShards = allocation.routingNodes().shardsWithState(ShardRoutingState.RELOCATING); + assertEquals(3, relocatingShards.size()); + } + + public void testAllocatorNeverTimedOutIfValueIsMinusOne() { + Settings build = Settings.builder().put("cluster.routing.allocation.balanced_shards_allocator.allocator_timeout", "-1").build(); + BalancedShardsAllocator allocator = new BalancedShardsAllocator(build); + assertFalse(allocator.allocatorTimedOut()); + } + + public void testAllocatorTimeout() { + String settingKey = "cluster.routing.allocation.balanced_shards_allocator.allocator_timeout"; + // Valid setting with timeout = 20s + Settings build = Settings.builder().put(settingKey, "20s").build(); + assertEquals(20, ALLOCATOR_TIMEOUT_SETTING.get(build).getSeconds()); + + // Valid setting with timeout > 20s + build = Settings.builder().put(settingKey, "30000ms").build(); + assertEquals(30, ALLOCATOR_TIMEOUT_SETTING.get(build).getSeconds()); + + // Invalid setting with timeout < 20s + Settings lessThan20sSetting = Settings.builder().put(settingKey, "10s").build(); + IllegalArgumentException iae = expectThrows( + IllegalArgumentException.class, + () -> ALLOCATOR_TIMEOUT_SETTING.get(lessThan20sSetting) + ); + assertEquals("Setting [" + settingKey + "] should be more than 20s or -1ms to disable timeout", iae.getMessage()); + + // Valid setting with timeout = -1 + build = Settings.builder().put(settingKey, "-1").build(); + assertEquals(-1, ALLOCATOR_TIMEOUT_SETTING.get(build).getMillis()); + } + + private RoutingTable buildRoutingTable(Metadata metadata) { + RoutingTable.Builder routingTableBuilder = RoutingTable.builder(); + for (Map.Entry entry : metadata.getIndices().entrySet()) { + routingTableBuilder.addAsNew(entry.getValue()); + } + return routingTableBuilder.build(); + } + + private Metadata buildMetadata(Metadata.Builder mb, int numberOfIndices, int numberOfShards, int numberOfReplicas) { + for (int i = 0; i < numberOfIndices; i++) { + mb.put( + IndexMetadata.builder("test_" + i) + .settings(settings(Version.CURRENT)) + .numberOfShards(numberOfShards) + .numberOfReplicas(numberOfReplicas) + ); + } + + return mb.build(); + } + + static class TestBalancedShardsAllocator extends BalancedShardsAllocator { + private final CountDownLatch timedOutLatch; + + public TestBalancedShardsAllocator(Settings settings, CountDownLatch timedOutLatch) { + super(settings); + this.timedOutLatch = timedOutLatch; + } + + @Override + protected boolean allocatorTimedOut() { + if (timedOutLatch.getCount() == 0) { + return true; + } + timedOutLatch.countDown(); + return false; + } + } +} diff --git a/server/src/test/java/org/opensearch/cluster/routing/allocation/decider/DiskThresholdDeciderTests.java b/server/src/test/java/org/opensearch/cluster/routing/allocation/decider/DiskThresholdDeciderTests.java index 2e24640fe858d..94e91c3f7c3c1 100644 --- a/server/src/test/java/org/opensearch/cluster/routing/allocation/decider/DiskThresholdDeciderTests.java +++ b/server/src/test/java/org/opensearch/cluster/routing/allocation/decider/DiskThresholdDeciderTests.java @@ -530,6 +530,8 @@ public void testDiskThresholdWithAbsoluteSizes() { // Primary should initialize, even though both nodes are over the limit initialize assertThat(clusterState.getRoutingNodes().shardsWithState(INITIALIZING).size(), equalTo(1)); + // below checks are unnecessary as the primary shard is always assigned to node2 as BSA always picks up that node + // first as both node1 and node2 have equal weight as both of them contain zero shards. String nodeWithPrimary, nodeWithoutPrimary; if (clusterState.getRoutingNodes().node("node1").size() == 1) { nodeWithPrimary = "node1"; @@ -679,10 +681,12 @@ public void testDiskThresholdWithAbsoluteSizes() { clusterState = startInitializingShardsAndReroute(strategy, clusterState); logShardStates(clusterState); - // primary shard already has been relocated away - assertThat(clusterState.getRoutingNodes().node(nodeWithPrimary).size(), equalTo(0)); - // node with increased space still has its shard - assertThat(clusterState.getRoutingNodes().node(nodeWithoutPrimary).size(), equalTo(1)); + // primary shard already has been relocated away - this is a wrong expectation as we don't really move + // primary first unless explicitly set by setting. This is caught with PR + // https://github.com/opensearch-project/OpenSearch/pull/15239/ + // as it randomises nodes to check for potential moves + // assertThat(clusterState.getRoutingNodes().node(nodeWithPrimary).size(), equalTo(0)); + // assertThat(clusterState.getRoutingNodes().node(nodeWithoutPrimary).size(), equalTo(1)); assertThat(clusterState.getRoutingNodes().node("node3").size(), equalTo(1)); assertThat(clusterState.getRoutingNodes().node("node4").size(), equalTo(1)); diff --git a/test/framework/src/main/java/org/opensearch/cluster/OpenSearchAllocationTestCase.java b/test/framework/src/main/java/org/opensearch/cluster/OpenSearchAllocationTestCase.java index 34b8c58a9c5b2..f54ba36203684 100644 --- a/test/framework/src/main/java/org/opensearch/cluster/OpenSearchAllocationTestCase.java +++ b/test/framework/src/main/java/org/opensearch/cluster/OpenSearchAllocationTestCase.java @@ -48,6 +48,7 @@ import org.opensearch.cluster.routing.allocation.decider.AllocationDecider; import org.opensearch.cluster.routing.allocation.decider.AllocationDeciders; import org.opensearch.cluster.routing.allocation.decider.Decision; +import org.opensearch.cluster.routing.allocation.decider.FilterAllocationDecider; import org.opensearch.cluster.routing.allocation.decider.SameShardAllocationDecider; import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Settings; @@ -213,6 +214,16 @@ protected static AllocationDeciders throttleAllocationDeciders() { ); } + protected static AllocationDeciders allocationDecidersForExcludeAPI(Settings settings) { + return new AllocationDeciders( + Arrays.asList( + new TestAllocateDecision(Decision.YES), + new SameShardAllocationDecider(settings, new ClusterSettings(settings, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS)), + new FilterAllocationDecider(settings, new ClusterSettings(settings, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS)) + ) + ); + } + protected ClusterState applyStartedShardsUntilNoChange(ClusterState clusterState, AllocationService service) { ClusterState lastClusterState; do { From e5fadba7b82da4da714cac37aa335a3be230eace Mon Sep 17 00:00:00 2001 From: gaobinlong Date: Thu, 29 Aug 2024 23:34:25 +0800 Subject: [PATCH 4/7] Update version check for fix the bug of missing validation for the search_backpressure.mode setting (#15500) Signed-off-by: Gao Binlong --- .../rest-api-spec/test/cluster.put_settings/10_basic.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/rest-api-spec/src/main/resources/rest-api-spec/test/cluster.put_settings/10_basic.yml b/rest-api-spec/src/main/resources/rest-api-spec/test/cluster.put_settings/10_basic.yml index 2bc5e98465e16..107d298b597d3 100644 --- a/rest-api-spec/src/main/resources/rest-api-spec/test/cluster.put_settings/10_basic.yml +++ b/rest-api-spec/src/main/resources/rest-api-spec/test/cluster.put_settings/10_basic.yml @@ -85,8 +85,8 @@ "Test set invalid search backpressure mode": - skip: - version: "- 2.99.99" - reason: "Parsing and validation of SearchBackpressureMode does not exist in versions < 3.0" + version: "- 2.8.99" + reason: "Fixed in 2.9.0" - do: catch: bad_request From e146f13a69c8bee87a936416e8998710521518a3 Mon Sep 17 00:00:00 2001 From: Pranshu Shukla <55992439+Pranshu-S@users.noreply.github.com> Date: Thu, 29 Aug 2024 21:12:15 +0530 Subject: [PATCH 5/7] Optimize NodeIndicesStats output behind flag (#14454) * Optimize NodeIndicesStats output behind flag Signed-off-by: Pranshu Shukla --- CHANGELOG.md | 1 + .../org/opensearch/nodestats/NodeStatsIT.java | 309 ++++++++++++++ .../admin/indices/stats/CommonStatsFlags.java | 15 + .../opensearch/indices/IndicesService.java | 8 +- .../opensearch/indices/NodeIndicesStats.java | 199 +++++++-- .../admin/cluster/RestNodesStatsAction.java | 1 + .../rest/action/cat/RestNodesAction.java | 1 + .../cluster/node/stats/NodeStatsTests.java | 400 ++++++++++++++++++ 8 files changed, 904 insertions(+), 30 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index f8b695205e789..cbfde6a1c1a80 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -32,6 +32,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), - Add concurrent search support for Derived Fields ([#15326](https://github.com/opensearch-project/OpenSearch/pull/15326)) - [Workload Management] Add query group stats constructs ([#15343](https://github.com/opensearch-project/OpenSearch/pull/15343))) - Add runAs to Subject interface and introduce IdentityAwarePlugin extension point ([#14630](https://github.com/opensearch-project/OpenSearch/pull/14630)) +- Optimize NodeIndicesStats output behind flag ([#14454](https://github.com/opensearch-project/OpenSearch/pull/14454)) ### Dependencies - Bump `netty` from 4.1.111.Final to 4.1.112.Final ([#15081](https://github.com/opensearch-project/OpenSearch/pull/15081)) diff --git a/server/src/internalClusterTest/java/org/opensearch/nodestats/NodeStatsIT.java b/server/src/internalClusterTest/java/org/opensearch/nodestats/NodeStatsIT.java index f270cb1399072..22c1679babb52 100644 --- a/server/src/internalClusterTest/java/org/opensearch/nodestats/NodeStatsIT.java +++ b/server/src/internalClusterTest/java/org/opensearch/nodestats/NodeStatsIT.java @@ -10,6 +10,9 @@ import org.opensearch.ExceptionsHelper; import org.opensearch.action.DocWriteResponse; +import org.opensearch.action.admin.cluster.node.stats.NodeStats; +import org.opensearch.action.admin.cluster.node.stats.NodesStatsResponse; +import org.opensearch.action.admin.indices.stats.CommonStatsFlags; import org.opensearch.action.bulk.BulkItemResponse; import org.opensearch.action.bulk.BulkRequest; import org.opensearch.action.bulk.BulkResponse; @@ -19,21 +22,35 @@ import org.opensearch.action.index.IndexResponse; import org.opensearch.action.update.UpdateRequest; import org.opensearch.action.update.UpdateResponse; +import org.opensearch.cluster.ClusterState; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.common.xcontent.XContentHelper; +import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.index.IndexNotFoundException; import org.opensearch.index.engine.DocumentMissingException; import org.opensearch.index.engine.VersionConflictEngineException; import org.opensearch.index.shard.IndexingStats.Stats.DocStatusStats; +import org.opensearch.indices.NodeIndicesStats; import org.opensearch.test.OpenSearchIntegTestCase; import org.opensearch.test.OpenSearchIntegTestCase.ClusterScope; import org.opensearch.test.OpenSearchIntegTestCase.Scope; import org.hamcrest.MatcherAssert; +import java.io.IOException; +import java.util.ArrayList; import java.util.Arrays; +import java.util.Collections; import java.util.Comparator; +import java.util.LinkedHashMap; +import java.util.List; import java.util.Map; import java.util.concurrent.atomic.AtomicLong; import static java.util.Collections.singletonMap; +import static org.opensearch.test.hamcrest.OpenSearchAssertions.assertAcked; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.greaterThanOrEqualTo; @@ -243,6 +260,280 @@ public void testNodeIndicesStatsDocStatusStatsCreateDeleteUpdate() { } } + public void testNodeIndicesStatsDocStatsWithAggregations() { + { // Testing Create + final String INDEX = "create_index"; + final String ID = "id"; + DocStatusStats expectedDocStatusStats = new DocStatusStats(); + + IndexResponse response = client().index(new IndexRequest(INDEX).id(ID).source(SOURCE).create(true)).actionGet(); + expectedDocStatusStats.inc(response.status()); + + CommonStatsFlags commonStatsFlags = new CommonStatsFlags(); + commonStatsFlags.setIncludeIndicesStatsByLevel(true); + + DocStatusStats docStatusStats = client().admin() + .cluster() + .prepareNodesStats() + .setIndices(commonStatsFlags) + .execute() + .actionGet() + .getNodes() + .get(0) + .getIndices() + .getIndexing() + .getTotal() + .getDocStatusStats(); + + assertTrue( + Arrays.equals( + docStatusStats.getDocStatusCounter(), + expectedDocStatusStats.getDocStatusCounter(), + Comparator.comparingLong(AtomicLong::longValue) + ) + ); + } + } + + /** + * Default behavior - without consideration of request level param on level, the NodeStatsRequest always + * returns ShardStats which is aggregated on the coordinator node when creating the XContent. + */ + public void testNodeIndicesStatsXContentWithoutAggregationOnNodes() { + List testLevels = new ArrayList<>(); + testLevels.add("null"); + testLevels.add(NodeIndicesStats.StatsLevel.NODE.getRestName()); + testLevels.add(NodeIndicesStats.StatsLevel.INDICES.getRestName()); + testLevels.add(NodeIndicesStats.StatsLevel.SHARDS.getRestName()); + testLevels.add("unknown"); + + internalCluster().startNode(); + ensureGreen(); + String indexName = "test1"; + assertAcked( + prepareCreate( + indexName, + clusterService().state().getNodes().getSize(), + Settings.builder().put("number_of_shards", 2).put("number_of_replicas", clusterService().state().getNodes().getSize() - 1) + ) + ); + ensureGreen(); + ClusterState clusterState = client().admin().cluster().prepareState().get().getState(); + + testLevels.forEach(testLevel -> { + NodesStatsResponse response; + if (!testLevel.equals("null")) { + ArrayList level_arg = new ArrayList<>(); + level_arg.add(testLevel); + + CommonStatsFlags commonStatsFlags = new CommonStatsFlags(); + commonStatsFlags.setLevels(level_arg.toArray(new String[0])); + response = client().admin().cluster().prepareNodesStats().setIndices(commonStatsFlags).get(); + } else { + response = client().admin().cluster().prepareNodesStats().get(); + } + + NodeStats nodeStats = response.getNodes().get(0); + assertNotNull(nodeStats.getIndices().getShardStats(clusterState.metadata().index(indexName).getIndex())); + try { + // Without any param - default is level = nodes + XContentBuilder builder = XContentFactory.jsonBuilder(); + builder.startObject(); + builder = nodeStats.getIndices().toXContent(builder, ToXContent.EMPTY_PARAMS); + builder.endObject(); + + Map xContentMap = xContentBuilderToMap(builder); + LinkedHashMap indicesStatsMap = (LinkedHashMap) xContentMap.get(NodeIndicesStats.StatsLevel.INDICES.getRestName()); + assertFalse(indicesStatsMap.containsKey(NodeIndicesStats.StatsLevel.INDICES)); + assertFalse(indicesStatsMap.containsKey(NodeIndicesStats.StatsLevel.SHARDS)); + + // With param containing level as 'indices', the indices stats are returned + builder = XContentFactory.jsonBuilder(); + builder.startObject(); + builder = nodeStats.getIndices() + .toXContent( + builder, + new ToXContent.MapParams(Collections.singletonMap("level", NodeIndicesStats.StatsLevel.INDICES.getRestName())) + ); + builder.endObject(); + + xContentMap = xContentBuilderToMap(builder); + indicesStatsMap = (LinkedHashMap) xContentMap.get(NodeIndicesStats.StatsLevel.INDICES.getRestName()); + assertTrue(indicesStatsMap.containsKey(NodeIndicesStats.StatsLevel.INDICES.getRestName())); + assertFalse(indicesStatsMap.containsKey(NodeIndicesStats.StatsLevel.SHARDS.getRestName())); + + LinkedHashMap indexLevelStats = (LinkedHashMap) indicesStatsMap.get(NodeIndicesStats.StatsLevel.INDICES.getRestName()); + assertTrue(indexLevelStats.containsKey(indexName)); + + // With param containing level as 'shards', the shard stats are returned + builder = XContentFactory.jsonBuilder(); + builder.startObject(); + builder = nodeStats.getIndices() + .toXContent( + builder, + new ToXContent.MapParams(Collections.singletonMap("level", NodeIndicesStats.StatsLevel.SHARDS.getRestName())) + ); + builder.endObject(); + + xContentMap = xContentBuilderToMap(builder); + indicesStatsMap = (LinkedHashMap) xContentMap.get(NodeIndicesStats.StatsLevel.INDICES.getRestName()); + assertFalse(indicesStatsMap.containsKey(NodeIndicesStats.StatsLevel.INDICES.getRestName())); + assertTrue(indicesStatsMap.containsKey(NodeIndicesStats.StatsLevel.SHARDS.getRestName())); + + LinkedHashMap shardLevelStats = (LinkedHashMap) indicesStatsMap.get(NodeIndicesStats.StatsLevel.SHARDS.getRestName()); + assertTrue(shardLevelStats.containsKey(indexName)); + } catch (IOException e) { + throw new RuntimeException(e); + } + }); + } + + /** + * Aggregated behavior - to avoid unnecessary IO in the form of shard-stats when not required, we not honor the levels on the + * individual data nodes instead and pre-compute information as required. + */ + public void testNodeIndicesStatsXContentWithAggregationOnNodes() { + List testLevels = new ArrayList<>(); + + testLevels.add(MockStatsLevel.NULL); + testLevels.add(MockStatsLevel.NODE); + testLevels.add(MockStatsLevel.INDICES); + testLevels.add(MockStatsLevel.SHARDS); + + internalCluster().startNode(); + ensureGreen(); + String indexName = "test1"; + assertAcked( + prepareCreate( + indexName, + clusterService().state().getNodes().getSize(), + Settings.builder().put("number_of_shards", 2).put("number_of_replicas", clusterService().state().getNodes().getSize() - 1) + ) + ); + ensureGreen(); + + testLevels.forEach(testLevel -> { + NodesStatsResponse response; + CommonStatsFlags commonStatsFlags = new CommonStatsFlags(); + commonStatsFlags.setIncludeIndicesStatsByLevel(true); + if (!testLevel.equals(MockStatsLevel.NULL)) { + ArrayList level_arg = new ArrayList<>(); + level_arg.add(testLevel.getRestName()); + + commonStatsFlags.setLevels(level_arg.toArray(new String[0])); + } + response = client().admin().cluster().prepareNodesStats().setIndices(commonStatsFlags).get(); + + NodeStats nodeStats = response.getNodes().get(0); + try { + XContentBuilder builder = XContentFactory.jsonBuilder(); + + builder.startObject(); + + if (!testLevel.equals(MockStatsLevel.SHARDS)) { + final XContentBuilder failedBuilder = builder; + assertThrows( + "Expected shard stats in response for generating [SHARDS] field", + AssertionError.class, + () -> nodeStats.getIndices() + .toXContent( + failedBuilder, + new ToXContent.MapParams( + Collections.singletonMap("level", NodeIndicesStats.StatsLevel.SHARDS.getRestName()) + ) + ) + ); + } else { + builder = nodeStats.getIndices() + .toXContent( + builder, + new ToXContent.MapParams(Collections.singletonMap("level", NodeIndicesStats.StatsLevel.SHARDS.getRestName())) + ); + builder.endObject(); + + Map xContentMap = xContentBuilderToMap(builder); + LinkedHashMap indicesStatsMap = (LinkedHashMap) xContentMap.get(NodeIndicesStats.StatsLevel.INDICES.getRestName()); + LinkedHashMap indicesStats = (LinkedHashMap) indicesStatsMap.get(NodeIndicesStats.StatsLevel.INDICES.getRestName()); + LinkedHashMap shardStats = (LinkedHashMap) indicesStatsMap.get(NodeIndicesStats.StatsLevel.SHARDS.getRestName()); + + assertFalse(shardStats.isEmpty()); + assertNull(indicesStats); + } + + builder = XContentFactory.jsonBuilder(); + builder.startObject(); + + if (!(testLevel.equals(MockStatsLevel.SHARDS) || testLevel.equals(MockStatsLevel.INDICES))) { + final XContentBuilder failedBuilder = builder; + assertThrows( + "Expected shard stats or index stats in response for generating INDICES field", + AssertionError.class, + () -> nodeStats.getIndices() + .toXContent( + failedBuilder, + new ToXContent.MapParams( + Collections.singletonMap("level", NodeIndicesStats.StatsLevel.INDICES.getRestName()) + ) + ) + ); + } else { + builder = nodeStats.getIndices() + .toXContent( + builder, + new ToXContent.MapParams(Collections.singletonMap("level", NodeIndicesStats.StatsLevel.INDICES.getRestName())) + ); + builder.endObject(); + + Map xContentMap = xContentBuilderToMap(builder); + LinkedHashMap indicesStatsMap = (LinkedHashMap) xContentMap.get(NodeIndicesStats.StatsLevel.INDICES.getRestName()); + LinkedHashMap indicesStats = (LinkedHashMap) indicesStatsMap.get(NodeIndicesStats.StatsLevel.INDICES.getRestName()); + LinkedHashMap shardStats = (LinkedHashMap) indicesStatsMap.get(NodeIndicesStats.StatsLevel.SHARDS.getRestName()); + + switch (testLevel) { + case SHARDS: + case INDICES: + assertNull(shardStats); + assertFalse(indicesStats.isEmpty()); + break; + } + } + } catch (IOException e) { + throw new RuntimeException(e); + } + }); + } + + public void testNodeIndicesStatsUnknownLevelThrowsException() { + MockStatsLevel testLevel = MockStatsLevel.UNKNOWN; + internalCluster().startNode(); + ensureGreen(); + String indexName = "test1"; + assertAcked( + prepareCreate( + indexName, + clusterService().state().getNodes().getSize(), + Settings.builder().put("number_of_shards", 2).put("number_of_replicas", clusterService().state().getNodes().getSize() - 1) + ) + ); + ensureGreen(); + + NodesStatsResponse response; + CommonStatsFlags commonStatsFlags = new CommonStatsFlags(); + commonStatsFlags.setIncludeIndicesStatsByLevel(true); + ArrayList level_arg = new ArrayList<>(); + level_arg.add(testLevel.getRestName()); + + commonStatsFlags.setLevels(level_arg.toArray(new String[0])); + response = client().admin().cluster().prepareNodesStats().setIndices(commonStatsFlags).get(); + + assertTrue(response.hasFailures()); + assertEquals("Level provided is not supported by NodeIndicesStats", response.failures().get(0).getCause().getCause().getMessage()); + } + + private Map xContentBuilderToMap(XContentBuilder xContentBuilder) { + return XContentHelper.convertToMap(BytesReference.bytes(xContentBuilder), true, xContentBuilder.contentType()).v2(); + } + private void assertDocStatusStats() { DocStatusStats docStatusStats = client().admin() .cluster() @@ -273,4 +564,22 @@ private void updateExpectedDocStatusCounter(Exception e) { expectedDocStatusStats.inc(ExceptionsHelper.status(e)); } + private enum MockStatsLevel { + INDICES(NodeIndicesStats.StatsLevel.INDICES.getRestName()), + SHARDS(NodeIndicesStats.StatsLevel.SHARDS.getRestName()), + NODE(NodeIndicesStats.StatsLevel.NODE.getRestName()), + NULL("null"), + UNKNOWN("unknown"); + + private final String restName; + + MockStatsLevel(String restName) { + this.restName = restName; + } + + public String getRestName() { + return restName; + } + } + } diff --git a/server/src/main/java/org/opensearch/action/admin/indices/stats/CommonStatsFlags.java b/server/src/main/java/org/opensearch/action/admin/indices/stats/CommonStatsFlags.java index ca2685e093d3f..04f39d77ce6c8 100644 --- a/server/src/main/java/org/opensearch/action/admin/indices/stats/CommonStatsFlags.java +++ b/server/src/main/java/org/opensearch/action/admin/indices/stats/CommonStatsFlags.java @@ -67,6 +67,7 @@ public class CommonStatsFlags implements Writeable, Cloneable { // Used for metric CACHE_STATS, to determine which caches to report stats for private EnumSet includeCaches = EnumSet.noneOf(CacheType.class); private String[] levels = new String[0]; + private boolean includeIndicesStatsByLevel = false; /** * @param flags flags to set. If no flags are supplied, default flags will be set. @@ -100,6 +101,9 @@ public CommonStatsFlags(StreamInput in) throws IOException { includeCaches = in.readEnumSet(CacheType.class); levels = in.readStringArray(); } + if (in.getVersion().onOrAfter(Version.V_3_0_0)) { + includeIndicesStatsByLevel = in.readBoolean(); + } } @Override @@ -124,6 +128,9 @@ public void writeTo(StreamOutput out) throws IOException { out.writeEnumSet(includeCaches); out.writeStringArrayNullable(levels); } + if (out.getVersion().onOrAfter(Version.V_3_0_0)) { + out.writeBoolean(includeIndicesStatsByLevel); + } } /** @@ -262,6 +269,14 @@ public boolean includeSegmentFileSizes() { return this.includeSegmentFileSizes; } + public void setIncludeIndicesStatsByLevel(boolean includeIndicesStatsByLevel) { + this.includeIndicesStatsByLevel = includeIndicesStatsByLevel; + } + + public boolean getIncludeIndicesStatsByLevel() { + return this.includeIndicesStatsByLevel; + } + public boolean isSet(Flag flag) { return flags.contains(flag); } diff --git a/server/src/main/java/org/opensearch/indices/IndicesService.java b/server/src/main/java/org/opensearch/indices/IndicesService.java index a78328e24c588..be16d4ea184fa 100644 --- a/server/src/main/java/org/opensearch/indices/IndicesService.java +++ b/server/src/main/java/org/opensearch/indices/IndicesService.java @@ -693,8 +693,12 @@ public NodeIndicesStats stats(CommonStatsFlags flags) { break; } } - - return new NodeIndicesStats(commonStats, statsByShard(this, flags), searchRequestStats); + if (flags.getIncludeIndicesStatsByLevel()) { + NodeIndicesStats.StatsLevel statsLevel = NodeIndicesStats.getAcceptedLevel(flags.getLevels()); + return new NodeIndicesStats(commonStats, statsByShard(this, flags), searchRequestStats, statsLevel); + } else { + return new NodeIndicesStats(commonStats, statsByShard(this, flags), searchRequestStats); + } } Map> statsByShard(final IndicesService indicesService, final CommonStatsFlags flags) { diff --git a/server/src/main/java/org/opensearch/indices/NodeIndicesStats.java b/server/src/main/java/org/opensearch/indices/NodeIndicesStats.java index 35b6fd395ee12..83a759cdb71c5 100644 --- a/server/src/main/java/org/opensearch/indices/NodeIndicesStats.java +++ b/server/src/main/java/org/opensearch/indices/NodeIndicesStats.java @@ -32,6 +32,7 @@ package org.opensearch.indices; +import org.opensearch.Version; import org.opensearch.action.admin.indices.stats.CommonStats; import org.opensearch.action.admin.indices.stats.IndexShardStats; import org.opensearch.action.admin.indices.stats.ShardStats; @@ -63,9 +64,11 @@ import java.io.IOException; import java.util.ArrayList; +import java.util.Arrays; import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Optional; /** * Global information on indices stats running on a specific node. @@ -74,26 +77,27 @@ */ @PublicApi(since = "1.0.0") public class NodeIndicesStats implements Writeable, ToXContentFragment { - private CommonStats stats; - private Map> statsByShard; + protected CommonStats stats; + protected Map statsByIndex; + protected Map> statsByShard; public NodeIndicesStats(StreamInput in) throws IOException { stats = new CommonStats(in); - if (in.readBoolean()) { - int entries = in.readVInt(); - statsByShard = new HashMap<>(); - for (int i = 0; i < entries; i++) { - Index index = new Index(in); - int indexShardListSize = in.readVInt(); - List indexShardStats = new ArrayList<>(indexShardListSize); - for (int j = 0; j < indexShardListSize; j++) { - indexShardStats.add(new IndexShardStats(in)); - } - statsByShard.put(index, indexShardStats); + if (in.getVersion().onOrAfter(Version.V_3_0_0)) { + // contains statsByIndex + if (in.readBoolean()) { + statsByIndex = readStatsByIndex(in); } } + if (in.readBoolean()) { + statsByShard = readStatsByShard(in); + } } + /** + * Without passing the information of the levels to the constructor, we return the Node-level aggregated stats as + * {@link CommonStats} along with a hash-map containing Index to List of Shard Stats. + */ public NodeIndicesStats(CommonStats oldStats, Map> statsByShard, SearchRequestStats searchRequestStats) { // this.stats = stats; this.statsByShard = statsByShard; @@ -112,6 +116,90 @@ public NodeIndicesStats(CommonStats oldStats, Map> } } + /** + * Passing the level information to the nodes allows us to aggregate the stats based on the level passed. This + * allows us to aggregate based on NodeLevel (default - if no level is passed) or Index level if `indices` level is + * passed and finally return the statsByShards map if `shards` level is passed. This allows us to reduce ser/de of + * stats and return only the information that is required while returning to the client. + */ + public NodeIndicesStats( + CommonStats oldStats, + Map> statsByShard, + SearchRequestStats searchRequestStats, + StatsLevel level + ) { + // make a total common stats from old ones and current ones + this.stats = oldStats; + for (List shardStatsList : statsByShard.values()) { + for (IndexShardStats indexShardStats : shardStatsList) { + for (ShardStats shardStats : indexShardStats.getShards()) { + stats.add(shardStats.getStats()); + } + } + } + + if (this.stats.search != null) { + this.stats.search.setSearchRequestStats(searchRequestStats); + } + + if (level != null) { + switch (level) { + case INDICES: + this.statsByIndex = createStatsByIndex(statsByShard); + break; + case SHARDS: + this.statsByShard = statsByShard; + break; + } + } + } + + /** + * By default, the levels passed from the transport action will be a list of strings, since NodeIndicesStats can + * only aggregate on one level, we pick the first accepted level else we ignore if no known level is passed. Level is + * selected based on enum defined in {@link StatsLevel} + * + * Note - we are picking the first level as multiple levels are not supported in the previous versions. + * @param levels - levels sent in the request. + * + * @return Corresponding identified enum {@link StatsLevel} + */ + public static StatsLevel getAcceptedLevel(String[] levels) { + if (levels != null && levels.length > 0) { + Optional level = Arrays.stream(StatsLevel.values()) + .filter(field -> field.getRestName().equals(levels[0])) + .findFirst(); + return level.orElseThrow(() -> new IllegalArgumentException("Level provided is not supported by NodeIndicesStats")); + } + return null; + } + + private Map readStatsByIndex(StreamInput in) throws IOException { + Map statsByIndex = new HashMap<>(); + int indexEntries = in.readVInt(); + for (int i = 0; i < indexEntries; i++) { + Index index = new Index(in); + CommonStats commonStats = new CommonStats(in); + statsByIndex.put(index, commonStats); + } + return statsByIndex; + } + + private Map> readStatsByShard(StreamInput in) throws IOException { + Map> statsByShard = new HashMap<>(); + int entries = in.readVInt(); + for (int i = 0; i < entries; i++) { + Index index = new Index(in); + int indexShardListSize = in.readVInt(); + List indexShardStats = new ArrayList<>(indexShardListSize); + for (int j = 0; j < indexShardListSize; j++) { + indexShardStats.add(new IndexShardStats(in)); + } + statsByShard.put(index, indexShardStats); + } + return statsByShard; + } + @Nullable public StoreStats getStore() { return stats.getStore(); @@ -195,7 +283,31 @@ public RecoveryStats getRecoveryStats() { @Override public void writeTo(StreamOutput out) throws IOException { stats.writeTo(out); + + if (out.getVersion().onOrAfter(Version.V_3_0_0)) { + out.writeBoolean(statsByIndex != null); + if (statsByIndex != null) { + writeStatsByIndex(out); + } + } + out.writeBoolean(statsByShard != null); + if (statsByShard != null) { + writeStatsByShards(out); + } + } + + private void writeStatsByIndex(StreamOutput out) throws IOException { + if (statsByIndex != null) { + out.writeVInt(statsByIndex.size()); + for (Map.Entry entry : statsByIndex.entrySet()) { + entry.getKey().writeTo(out); + entry.getValue().writeTo(out); + } + } + } + + private void writeStatsByShards(StreamOutput out) throws IOException { if (statsByShard != null) { out.writeVInt(statsByShard.size()); for (Map.Entry> entry : statsByShard.entrySet()) { @@ -210,29 +322,46 @@ public void writeTo(StreamOutput out) throws IOException { @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - final String level = params.param("level", "node"); - final boolean isLevelValid = "indices".equalsIgnoreCase(level) - || "node".equalsIgnoreCase(level) - || "shards".equalsIgnoreCase(level); + final String level = params.param("level", StatsLevel.NODE.getRestName()); + final boolean isLevelValid = StatsLevel.NODE.getRestName().equalsIgnoreCase(level) + || StatsLevel.INDICES.getRestName().equalsIgnoreCase(level) + || StatsLevel.SHARDS.getRestName().equalsIgnoreCase(level); if (!isLevelValid) { - throw new IllegalArgumentException("level parameter must be one of [indices] or [node] or [shards] but was [" + level + "]"); + throw new IllegalArgumentException( + "level parameter must be one of [" + + StatsLevel.INDICES.getRestName() + + "] or [" + + StatsLevel.NODE.getRestName() + + "] or [" + + StatsLevel.SHARDS.getRestName() + + "] but was [" + + level + + "]" + ); } // "node" level - builder.startObject(Fields.INDICES); + builder.startObject(StatsLevel.INDICES.getRestName()); stats.toXContent(builder, params); - if ("indices".equals(level)) { - Map indexStats = createStatsByIndex(); - builder.startObject(Fields.INDICES); - for (Map.Entry entry : indexStats.entrySet()) { + if (StatsLevel.INDICES.getRestName().equals(level)) { + assert statsByIndex != null || statsByShard != null : "Expected shard stats or index stats in response for generating [" + + StatsLevel.INDICES + + "] field"; + if (statsByIndex == null) { + statsByIndex = createStatsByIndex(statsByShard); + } + + builder.startObject(StatsLevel.INDICES.getRestName()); + for (Map.Entry entry : statsByIndex.entrySet()) { builder.startObject(entry.getKey().getName()); entry.getValue().toXContent(builder, params); builder.endObject(); } builder.endObject(); - } else if ("shards".equals(level)) { - builder.startObject("shards"); + } else if (StatsLevel.SHARDS.getRestName().equals(level)) { + builder.startObject(StatsLevel.SHARDS.getRestName()); + assert statsByShard != null : "Expected shard stats in response for generating [" + StatsLevel.SHARDS + "] field"; for (Map.Entry> entry : statsByShard.entrySet()) { builder.startArray(entry.getKey().getName()); for (IndexShardStats indexShardStats : entry.getValue()) { @@ -251,7 +380,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws return builder; } - private Map createStatsByIndex() { + private Map createStatsByIndex(Map> statsByShard) { Map statsMap = new HashMap<>(); for (Map.Entry> entry : statsByShard.entrySet()) { if (!statsMap.containsKey(entry.getKey())) { @@ -281,7 +410,21 @@ public List getShardStats(Index index) { * * @opensearch.internal */ - static final class Fields { - static final String INDICES = "indices"; + @PublicApi(since = "3.0.0") + public enum StatsLevel { + INDICES("indices"), + SHARDS("shards"), + NODE("node"); + + private final String restName; + + StatsLevel(String restName) { + this.restName = restName; + } + + public String getRestName() { + return restName; + } + } } diff --git a/server/src/main/java/org/opensearch/rest/action/admin/cluster/RestNodesStatsAction.java b/server/src/main/java/org/opensearch/rest/action/admin/cluster/RestNodesStatsAction.java index ed9c0b171aa56..0119731e4a0d7 100644 --- a/server/src/main/java/org/opensearch/rest/action/admin/cluster/RestNodesStatsAction.java +++ b/server/src/main/java/org/opensearch/rest/action/admin/cluster/RestNodesStatsAction.java @@ -233,6 +233,7 @@ public RestChannelConsumer prepareRequest(final RestRequest request, final NodeC String[] levels = Strings.splitStringByCommaToArray(request.param("level")); nodesStatsRequest.indices().setLevels(levels); nodesStatsRequest.setIncludeDiscoveryNodes(false); + nodesStatsRequest.indices().setIncludeIndicesStatsByLevel(true); return channel -> client.admin().cluster().nodesStats(nodesStatsRequest, new NodesResponseRestListener<>(channel)); } diff --git a/server/src/main/java/org/opensearch/rest/action/cat/RestNodesAction.java b/server/src/main/java/org/opensearch/rest/action/cat/RestNodesAction.java index 0330fe627ccd0..1aa40b50290cd 100644 --- a/server/src/main/java/org/opensearch/rest/action/cat/RestNodesAction.java +++ b/server/src/main/java/org/opensearch/rest/action/cat/RestNodesAction.java @@ -148,6 +148,7 @@ public void processResponse(final NodesInfoResponse nodesInfoResponse) { NodesStatsRequest.Metric.PROCESS.metricName(), NodesStatsRequest.Metric.SCRIPT.metricName() ); + nodesStatsRequest.indices().setIncludeIndicesStatsByLevel(true); client.admin().cluster().nodesStats(nodesStatsRequest, new RestResponseListener(channel) { @Override public RestResponse buildResponse(NodesStatsResponse nodesStatsResponse) throws Exception { diff --git a/server/src/test/java/org/opensearch/action/admin/cluster/node/stats/NodeStatsTests.java b/server/src/test/java/org/opensearch/action/admin/cluster/node/stats/NodeStatsTests.java index f7bc96bdfe769..a0225a0bf6193 100644 --- a/server/src/test/java/org/opensearch/action/admin/cluster/node/stats/NodeStatsTests.java +++ b/server/src/test/java/org/opensearch/action/admin/cluster/node/stats/NodeStatsTests.java @@ -32,13 +32,19 @@ package org.opensearch.action.admin.cluster.node.stats; +import org.opensearch.Version; import org.opensearch.action.admin.indices.stats.CommonStats; import org.opensearch.action.admin.indices.stats.CommonStatsFlags; +import org.opensearch.action.admin.indices.stats.IndexShardStats; +import org.opensearch.action.admin.indices.stats.ShardStats; import org.opensearch.action.search.SearchRequestStats; import org.opensearch.cluster.coordination.PendingClusterStateStats; import org.opensearch.cluster.coordination.PersistedStateStats; import org.opensearch.cluster.coordination.PublishClusterStateStats; import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.cluster.routing.ShardRouting; +import org.opensearch.cluster.routing.ShardRoutingState; +import org.opensearch.cluster.routing.TestShardRouting; import org.opensearch.cluster.routing.WeightedRoutingStats; import org.opensearch.cluster.service.ClusterManagerThrottlingStats; import org.opensearch.cluster.service.ClusterStateStats; @@ -52,17 +58,31 @@ import org.opensearch.common.metrics.OperationStats; import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Settings; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.common.xcontent.XContentHelper; +import org.opensearch.core.common.bytes.BytesReference; import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.index.Index; import org.opensearch.core.index.shard.ShardId; import org.opensearch.core.indices.breaker.AllCircuitBreakerStats; import org.opensearch.core.indices.breaker.CircuitBreakerStats; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.discovery.DiscoveryStats; import org.opensearch.gateway.remote.RemotePersistenceStats; import org.opensearch.http.HttpStats; import org.opensearch.index.ReplicationStats; import org.opensearch.index.SegmentReplicationRejectionStats; +import org.opensearch.index.cache.query.QueryCacheStats; +import org.opensearch.index.engine.SegmentsStats; +import org.opensearch.index.fielddata.FieldDataStats; +import org.opensearch.index.flush.FlushStats; import org.opensearch.index.remote.RemoteSegmentStats; import org.opensearch.index.remote.RemoteTranslogTransferTracker; +import org.opensearch.index.shard.DocsStats; +import org.opensearch.index.shard.IndexingStats; +import org.opensearch.index.shard.ShardPath; +import org.opensearch.index.store.StoreStats; import org.opensearch.index.translog.RemoteTranslogStats; import org.opensearch.indices.NodeIndicesStats; import org.opensearch.ingest.IngestStats; @@ -82,17 +102,20 @@ import org.opensearch.ratelimitting.admissioncontrol.stats.AdmissionControllerStats; import org.opensearch.script.ScriptCacheStats; import org.opensearch.script.ScriptStats; +import org.opensearch.search.suggest.completion.CompletionStats; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.test.VersionUtils; import org.opensearch.threadpool.ThreadPoolStats; import org.opensearch.transport.TransportStats; import java.io.IOException; +import java.nio.file.Path; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.HashMap; import java.util.Iterator; +import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import java.util.TreeMap; @@ -1065,4 +1088,381 @@ private static RemoteTranslogTransferTracker.Stats getRandomRemoteTranslogTransf private OperationStats getPipelineStats(List pipelineStats, String id) { return pipelineStats.stream().filter(p1 -> p1.getPipelineId().equals(id)).findFirst().map(p2 -> p2.getStats()).orElse(null); } + + public static class MockNodeIndicesStats extends NodeIndicesStats { + + public MockNodeIndicesStats(StreamInput in) throws IOException { + super(in); + } + + public MockNodeIndicesStats( + CommonStats oldStats, + Map> statsByShard, + SearchRequestStats searchRequestStats + ) { + super(oldStats, statsByShard, searchRequestStats); + } + + public MockNodeIndicesStats( + CommonStats oldStats, + Map> statsByShard, + SearchRequestStats searchRequestStats, + StatsLevel level + ) { + super(oldStats, statsByShard, searchRequestStats, level); + } + + public CommonStats getStats() { + return this.stats; + } + + public Map getStatsByIndex() { + return this.statsByIndex; + } + + public Map> getStatsByShard() { + return this.statsByShard; + } + } + + public void testOldVersionNodes() throws IOException { + long numDocs = randomLongBetween(0, 10000); + long numDeletedDocs = randomLongBetween(0, 100); + CommonStats commonStats = new CommonStats(CommonStatsFlags.NONE); + + commonStats.docs = new DocsStats(numDocs, numDeletedDocs, 0); + commonStats.store = new StoreStats(100, 0L); + commonStats.indexing = new IndexingStats(); + DocsStats hostDocStats = new DocsStats(numDocs, numDeletedDocs, 0); + + CommonStatsFlags commonStatsFlags = new CommonStatsFlags(); + commonStatsFlags.clear(); + commonStatsFlags.set(CommonStatsFlags.Flag.Docs, true); + commonStatsFlags.set(CommonStatsFlags.Flag.Store, true); + commonStatsFlags.set(CommonStatsFlags.Flag.Indexing, true); + + Index newIndex = new Index("index", "_na_"); + + MockNodeIndicesStats mockNodeIndicesStats = generateMockNodeIndicesStats(commonStats, newIndex, commonStatsFlags, null); + + // To test out scenario when the incoming node stats response is from a node with an older ES Version. + try (BytesStreamOutput out = new BytesStreamOutput()) { + out.setVersion(Version.V_2_13_0); + mockNodeIndicesStats.writeTo(out); + try (StreamInput in = out.bytes().streamInput()) { + in.setVersion(Version.V_2_13_0); + MockNodeIndicesStats newNodeIndicesStats = new MockNodeIndicesStats(in); + + List incomingIndexStats = newNodeIndicesStats.getStatsByShard().get(newIndex); + incomingIndexStats.forEach(indexShardStats -> { + ShardStats shardStats = Arrays.stream(indexShardStats.getShards()).findFirst().get(); + DocsStats incomingDocStats = shardStats.getStats().docs; + + assertEquals(incomingDocStats.getCount(), hostDocStats.getCount()); + assertEquals(incomingDocStats.getTotalSizeInBytes(), hostDocStats.getTotalSizeInBytes()); + assertEquals(incomingDocStats.getAverageSizeInBytes(), hostDocStats.getAverageSizeInBytes()); + assertEquals(incomingDocStats.getDeleted(), hostDocStats.getDeleted()); + }); + } + } + } + + public void testNodeIndicesStatsSerialization() throws IOException { + long numDocs = randomLongBetween(0, 10000); + long numDeletedDocs = randomLongBetween(0, 100); + List levelParams = new ArrayList<>(); + levelParams.add(NodeIndicesStats.StatsLevel.INDICES); + levelParams.add(NodeIndicesStats.StatsLevel.SHARDS); + levelParams.add(NodeIndicesStats.StatsLevel.NODE); + CommonStats commonStats = new CommonStats(CommonStatsFlags.NONE); + + commonStats.docs = new DocsStats(numDocs, numDeletedDocs, 0); + commonStats.store = new StoreStats(100, 0L); + commonStats.indexing = new IndexingStats(); + + CommonStatsFlags commonStatsFlags = new CommonStatsFlags(); + commonStatsFlags.clear(); + commonStatsFlags.set(CommonStatsFlags.Flag.Docs, true); + commonStatsFlags.set(CommonStatsFlags.Flag.Store, true); + commonStatsFlags.set(CommonStatsFlags.Flag.Indexing, true); + commonStatsFlags.setIncludeIndicesStatsByLevel(true); + + levelParams.forEach(level -> { + Index newIndex = new Index("index", "_na_"); + + MockNodeIndicesStats mockNodeIndicesStats = generateMockNodeIndicesStats(commonStats, newIndex, commonStatsFlags, level); + + // To test out scenario when the incoming node stats response is from a node with an older ES Version. + try (BytesStreamOutput out = new BytesStreamOutput()) { + mockNodeIndicesStats.writeTo(out); + try (StreamInput in = out.bytes().streamInput()) { + MockNodeIndicesStats newNodeIndicesStats = new MockNodeIndicesStats(in); + switch (level) { + case NODE: + assertNull(newNodeIndicesStats.getStatsByIndex()); + assertNull(newNodeIndicesStats.getStatsByShard()); + break; + case INDICES: + assertNull(newNodeIndicesStats.getStatsByShard()); + assertNotNull(newNodeIndicesStats.getStatsByIndex()); + break; + case SHARDS: + assertNull(newNodeIndicesStats.getStatsByIndex()); + assertNotNull(newNodeIndicesStats.getStatsByShard()); + break; + } + } + } catch (IOException e) { + throw new RuntimeException(e); + } + }); + } + + public void testNodeIndicesStatsToXContent() { + long numDocs = randomLongBetween(0, 10000); + long numDeletedDocs = randomLongBetween(0, 100); + List levelParams = new ArrayList<>(); + levelParams.add(NodeIndicesStats.StatsLevel.INDICES); + levelParams.add(NodeIndicesStats.StatsLevel.SHARDS); + levelParams.add(NodeIndicesStats.StatsLevel.NODE); + CommonStats commonStats = new CommonStats(CommonStatsFlags.NONE); + + commonStats.docs = new DocsStats(numDocs, numDeletedDocs, 0); + commonStats.store = new StoreStats(100, 0L); + commonStats.indexing = new IndexingStats(); + + CommonStatsFlags commonStatsFlags = new CommonStatsFlags(); + commonStatsFlags.clear(); + commonStatsFlags.set(CommonStatsFlags.Flag.Docs, true); + commonStatsFlags.set(CommonStatsFlags.Flag.Store, true); + commonStatsFlags.set(CommonStatsFlags.Flag.Indexing, true); + commonStatsFlags.setIncludeIndicesStatsByLevel(true); + + levelParams.forEach(level -> { + + Index newIndex = new Index("index", "_na_"); + + MockNodeIndicesStats mockNodeIndicesStats = generateMockNodeIndicesStats(commonStats, newIndex, commonStatsFlags, level); + + XContentBuilder builder = null; + try { + builder = XContentFactory.jsonBuilder(); + builder.startObject(); + builder = mockNodeIndicesStats.toXContent( + builder, + new ToXContent.MapParams(Collections.singletonMap("level", level.getRestName())) + ); + builder.endObject(); + + Map xContentMap = xContentBuilderToMap(builder); + LinkedHashMap indicesStatsMap = (LinkedHashMap) xContentMap.get(NodeIndicesStats.StatsLevel.INDICES.getRestName()); + + switch (level) { + case NODE: + assertFalse(indicesStatsMap.containsKey(NodeIndicesStats.StatsLevel.INDICES.getRestName())); + assertFalse(indicesStatsMap.containsKey(NodeIndicesStats.StatsLevel.SHARDS.getRestName())); + break; + case INDICES: + assertTrue(indicesStatsMap.containsKey(NodeIndicesStats.StatsLevel.INDICES.getRestName())); + assertFalse(indicesStatsMap.containsKey(NodeIndicesStats.StatsLevel.SHARDS.getRestName())); + break; + case SHARDS: + assertFalse(indicesStatsMap.containsKey(NodeIndicesStats.StatsLevel.INDICES.getRestName())); + assertTrue(indicesStatsMap.containsKey(NodeIndicesStats.StatsLevel.SHARDS.getRestName())); + break; + } + } catch (IOException e) { + throw new RuntimeException(e); + } + + }); + } + + public void testNodeIndicesStatsWithAndWithoutAggregations() throws IOException { + + CommonStatsFlags commonStatsFlags = new CommonStatsFlags( + CommonStatsFlags.Flag.Docs, + CommonStatsFlags.Flag.Store, + CommonStatsFlags.Flag.Indexing, + CommonStatsFlags.Flag.Completion, + CommonStatsFlags.Flag.Flush, + CommonStatsFlags.Flag.FieldData, + CommonStatsFlags.Flag.QueryCache, + CommonStatsFlags.Flag.Segments + ); + + int numberOfIndexes = randomIntBetween(1, 3); + List indexList = new ArrayList<>(); + for (int i = 0; i < numberOfIndexes; i++) { + Index index = new Index("test-index-" + i, "_na_"); + indexList.add(index); + } + + ClusterSettings clusterSettings = new ClusterSettings(Settings.EMPTY, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS); + HashMap> statsByShards = createRandomShardByStats(indexList); + + final MockNodeIndicesStats nonAggregatedNodeIndicesStats = new MockNodeIndicesStats( + new CommonStats(commonStatsFlags), + statsByShards, + new SearchRequestStats(clusterSettings) + ); + + commonStatsFlags.setIncludeIndicesStatsByLevel(true); + + Arrays.stream(NodeIndicesStats.StatsLevel.values()).forEach(level -> { + MockNodeIndicesStats aggregatedNodeIndicesStats = new MockNodeIndicesStats( + new CommonStats(commonStatsFlags), + statsByShards, + new SearchRequestStats(clusterSettings), + level + ); + + XContentBuilder nonAggregatedBuilder = null; + XContentBuilder aggregatedBuilder = null; + try { + nonAggregatedBuilder = XContentFactory.jsonBuilder(); + nonAggregatedBuilder.startObject(); + nonAggregatedBuilder = nonAggregatedNodeIndicesStats.toXContent( + nonAggregatedBuilder, + new ToXContent.MapParams(Collections.singletonMap("level", level.getRestName())) + ); + nonAggregatedBuilder.endObject(); + Map nonAggregatedContentMap = xContentBuilderToMap(nonAggregatedBuilder); + + aggregatedBuilder = XContentFactory.jsonBuilder(); + aggregatedBuilder.startObject(); + aggregatedBuilder = aggregatedNodeIndicesStats.toXContent( + aggregatedBuilder, + new ToXContent.MapParams(Collections.singletonMap("level", level.getRestName())) + ); + aggregatedBuilder.endObject(); + Map aggregatedContentMap = xContentBuilderToMap(aggregatedBuilder); + + assertEquals(aggregatedContentMap, nonAggregatedContentMap); + } catch (IOException e) { + throw new RuntimeException(e); + } + }); + } + + private CommonStats createRandomCommonStats() { + CommonStats commonStats = new CommonStats(CommonStatsFlags.NONE); + commonStats.docs = new DocsStats(randomLongBetween(0, 10000), randomLongBetween(0, 100), randomLongBetween(0, 1000)); + commonStats.store = new StoreStats(randomLongBetween(0, 100), randomLongBetween(0, 1000)); + commonStats.indexing = new IndexingStats(); + commonStats.completion = new CompletionStats(); + commonStats.flush = new FlushStats(randomLongBetween(0, 100), randomLongBetween(0, 100), randomLongBetween(0, 100)); + commonStats.fieldData = new FieldDataStats(randomLongBetween(0, 100), randomLongBetween(0, 100), null); + commonStats.queryCache = new QueryCacheStats( + randomLongBetween(0, 100), + randomLongBetween(0, 100), + randomLongBetween(0, 100), + randomLongBetween(0, 100), + randomLongBetween(0, 100) + ); + commonStats.segments = new SegmentsStats(); + + return commonStats; + } + + private HashMap> createRandomShardByStats(List indexes) { + DiscoveryNode localNode = new DiscoveryNode("node", buildNewFakeTransportAddress(), Version.CURRENT); + HashMap> statsByShards = new HashMap<>(); + indexes.forEach(index -> { + List indexShardStatsList = new ArrayList<>(); + + int numberOfShards = randomIntBetween(1, 4); + for (int i = 0; i < numberOfShards; i++) { + ShardRoutingState shardRoutingState = ShardRoutingState.fromValue((byte) randomIntBetween(2, 3)); + + ShardRouting shardRouting = TestShardRouting.newShardRouting( + index.getName(), + i, + localNode.getId(), + randomBoolean(), + shardRoutingState + ); + + Path path = createTempDir().resolve("indices") + .resolve(shardRouting.shardId().getIndex().getUUID()) + .resolve(String.valueOf(shardRouting.shardId().id())); + + ShardStats shardStats = new ShardStats( + shardRouting, + new ShardPath(false, path, path, shardRouting.shardId()), + createRandomCommonStats(), + null, + null, + null + ); + List shardStatsList = new ArrayList<>(); + shardStatsList.add(shardStats); + IndexShardStats indexShardStats = new IndexShardStats(shardRouting.shardId(), shardStatsList.toArray(new ShardStats[0])); + indexShardStatsList.add(indexShardStats); + } + statsByShards.put(index, indexShardStatsList); + }); + + return statsByShards; + } + + private Map xContentBuilderToMap(XContentBuilder xContentBuilder) { + return XContentHelper.convertToMap(BytesReference.bytes(xContentBuilder), true, xContentBuilder.contentType()).v2(); + } + + public MockNodeIndicesStats generateMockNodeIndicesStats( + CommonStats commonStats, + Index index, + CommonStatsFlags commonStatsFlags, + NodeIndicesStats.StatsLevel level + ) { + DiscoveryNode localNode = new DiscoveryNode("local", buildNewFakeTransportAddress(), Version.CURRENT); + Map> statsByShard = new HashMap<>(); + List indexShardStatsList = new ArrayList<>(); + Index statsIndex = null; + for (int i = 0; i < 2; i++) { + ShardRoutingState shardRoutingState = ShardRoutingState.fromValue((byte) randomIntBetween(2, 3)); + ShardRouting shardRouting = TestShardRouting.newShardRouting( + index.getName(), + i, + localNode.getId(), + randomBoolean(), + shardRoutingState + ); + + if (statsIndex == null) { + statsIndex = shardRouting.shardId().getIndex(); + } + + Path path = createTempDir().resolve("indices") + .resolve(shardRouting.shardId().getIndex().getUUID()) + .resolve(String.valueOf(shardRouting.shardId().id())); + + ShardStats shardStats = new ShardStats( + shardRouting, + new ShardPath(false, path, path, shardRouting.shardId()), + commonStats, + null, + null, + null + ); + IndexShardStats indexShardStats = new IndexShardStats(shardRouting.shardId(), new ShardStats[] { shardStats }); + indexShardStatsList.add(indexShardStats); + } + + statsByShard.put(statsIndex, indexShardStatsList); + + ClusterSettings clusterSettings = new ClusterSettings(Settings.EMPTY, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS); + + if (commonStatsFlags.getIncludeIndicesStatsByLevel()) { + return new MockNodeIndicesStats( + new CommonStats(commonStatsFlags), + statsByShard, + new SearchRequestStats(clusterSettings), + level + ); + } else { + return new MockNodeIndicesStats(new CommonStats(commonStatsFlags), statsByShard, new SearchRequestStats(clusterSettings)); + } + } } From 0a10aca350c6fdcf0861f372683a2ee1f8f4b783 Mon Sep 17 00:00:00 2001 From: Andriy Redko Date: Thu, 29 Aug 2024 15:37:37 -0400 Subject: [PATCH 6/7] Enhance OpenSearch APIs annotation processor with OpenSearch version validation (#15502) Signed-off-by: Andriy Redko --- .../processor/ApiAnnotationProcessor.java | 61 +++++++++++++++++++ .../ApiAnnotationProcessorTests.java | 31 ++++++++++ .../PublicApiAnnotatedUnparseable.java | 16 +++++ .../PublicApiWithDeprecatedApiMethod.java | 20 ++++++ 4 files changed, 128 insertions(+) create mode 100644 libs/common/src/test/resources/org/opensearch/common/annotation/processor/PublicApiAnnotatedUnparseable.java create mode 100644 libs/common/src/test/resources/org/opensearch/common/annotation/processor/PublicApiWithDeprecatedApiMethod.java diff --git a/libs/common/src/main/java/org/opensearch/common/annotation/processor/ApiAnnotationProcessor.java b/libs/common/src/main/java/org/opensearch/common/annotation/processor/ApiAnnotationProcessor.java index 569f48a8465f3..94ec0db3a9712 100644 --- a/libs/common/src/main/java/org/opensearch/common/annotation/processor/ApiAnnotationProcessor.java +++ b/libs/common/src/main/java/org/opensearch/common/annotation/processor/ApiAnnotationProcessor.java @@ -59,6 +59,7 @@ public class ApiAnnotationProcessor extends AbstractProcessor { private static final String OPENSEARCH_PACKAGE = "org.opensearch"; private final Set reported = new HashSet<>(); + private final Set validated = new HashSet<>(); private final Set processed = new HashSet<>(); private Kind reportFailureAs = Kind.ERROR; @@ -85,6 +86,8 @@ public boolean process(Set annotations, RoundEnvironment ); for (var element : elements) { + validate(element); + if (!checkPackage(element)) { continue; } @@ -100,6 +103,64 @@ public boolean process(Set annotations, RoundEnvironment return false; } + private void validate(Element element) { + // The element was validated already + if (validated.contains(element)) { + return; + } + + validated.add(element); + + final PublicApi publicApi = element.getAnnotation(PublicApi.class); + if (publicApi != null) { + if (!validateVersion(publicApi.since())) { + processingEnv.getMessager() + .printMessage( + reportFailureAs, + "The type " + element + " has @PublicApi annotation with unparseable OpenSearch version: " + publicApi.since() + ); + } + } + + final DeprecatedApi deprecatedApi = element.getAnnotation(DeprecatedApi.class); + if (deprecatedApi != null) { + if (!validateVersion(deprecatedApi.since())) { + processingEnv.getMessager() + .printMessage( + reportFailureAs, + "The type " + + element + + " has @DeprecatedApi annotation with unparseable OpenSearch version: " + + deprecatedApi.since() + ); + } + } + } + + private boolean validateVersion(String version) { + String[] parts = version.split("[.-]"); + if (parts.length < 3 || parts.length > 4) { + return false; + } + + int major = Integer.parseInt(parts[0]); + if (major > 3 || major < 0) { + return false; + } + + int minor = Integer.parseInt(parts[1]); + if (minor < 0) { + return false; + } + + int patch = Integer.parseInt(parts[2]); + if (patch < 0) { + return false; + } + + return true; + } + /** * Check top level executable element * @param executable top level executable element diff --git a/libs/common/src/test/java/org/opensearch/common/annotation/processor/ApiAnnotationProcessorTests.java b/libs/common/src/test/java/org/opensearch/common/annotation/processor/ApiAnnotationProcessorTests.java index 52162e3df0c1c..716dcc3b9015f 100644 --- a/libs/common/src/test/java/org/opensearch/common/annotation/processor/ApiAnnotationProcessorTests.java +++ b/libs/common/src/test/java/org/opensearch/common/annotation/processor/ApiAnnotationProcessorTests.java @@ -486,4 +486,35 @@ public void testPublicApiConstructorAnnotatedInternalApi() { assertThat(failure.diagnotics(), not(hasItem(matching(Diagnostic.Kind.ERROR)))); } + + public void testPublicApiUnparseableVersion() { + final CompilerResult result = compile("PublicApiAnnotatedUnparseable.java"); + assertThat(result, instanceOf(Failure.class)); + + final Failure failure = (Failure) result; + assertThat(failure.diagnotics(), hasSize(3)); + + assertThat( + failure.diagnotics(), + hasItem( + matching( + Diagnostic.Kind.ERROR, + containsString( + "The type org.opensearch.common.annotation.processor.PublicApiAnnotatedUnparseable has @PublicApi annotation with unparseable OpenSearch version: 2.x" + ) + ) + ) + ); + } + + public void testPublicApiWithDeprecatedApiMethod() { + final CompilerResult result = compile("PublicApiWithDeprecatedApiMethod.java"); + assertThat(result, instanceOf(Failure.class)); + + final Failure failure = (Failure) result; + assertThat(failure.diagnotics(), hasSize(2)); + + assertThat(failure.diagnotics(), not(hasItem(matching(Diagnostic.Kind.ERROR)))); + } + } diff --git a/libs/common/src/test/resources/org/opensearch/common/annotation/processor/PublicApiAnnotatedUnparseable.java b/libs/common/src/test/resources/org/opensearch/common/annotation/processor/PublicApiAnnotatedUnparseable.java new file mode 100644 index 0000000000000..44779450c9fd1 --- /dev/null +++ b/libs/common/src/test/resources/org/opensearch/common/annotation/processor/PublicApiAnnotatedUnparseable.java @@ -0,0 +1,16 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.common.annotation.processor; + +import org.opensearch.common.annotation.PublicApi; + +@PublicApi(since = "2.x") +public class PublicApiAnnotatedUnparseable { + +} diff --git a/libs/common/src/test/resources/org/opensearch/common/annotation/processor/PublicApiWithDeprecatedApiMethod.java b/libs/common/src/test/resources/org/opensearch/common/annotation/processor/PublicApiWithDeprecatedApiMethod.java new file mode 100644 index 0000000000000..3cb28d3360830 --- /dev/null +++ b/libs/common/src/test/resources/org/opensearch/common/annotation/processor/PublicApiWithDeprecatedApiMethod.java @@ -0,0 +1,20 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.common.annotation.processor; + +import org.opensearch.common.annotation.DeprecatedApi; +import org.opensearch.common.annotation.PublicApi; + +@PublicApi(since = "1.0.0") +public class PublicApiWithDeprecatedApiMethod { + @DeprecatedApi(since = "0.1.0") + public void method() { + + } +} From c11d27575b87bd41f9e6512f48591ec0443eaf16 Mon Sep 17 00:00:00 2001 From: Kaushal Kumar Date: Thu, 29 Aug 2024 16:49:28 -0700 Subject: [PATCH 7/7] Add query group level rejection logic (#15428) * add rejection listener Signed-off-by: Kaushal Kumar * add rejection listener unit test Signed-off-by: Kaushal Kumar * add rejection logic for shard level requests Signed-off-by: Kaushal Kumar * add changelog entry Signed-off-by: Kaushal Kumar * apply spotless check Signed-off-by: Kaushal Kumar * remove unused files and fix precommit Signed-off-by: Kaushal Kumar * refactor code Signed-off-by: Kaushal Kumar * add package info file Signed-off-by: Kaushal Kumar * remove unused method from QueryGroupService stub Signed-off-by: Kaushal Kumar --------- Signed-off-by: Kaushal Kumar --- CHANGELOG.md | 1 + .../main/java/org/opensearch/node/Node.java | 18 ++++++- .../org/opensearch/wlm/QueryGroupService.java | 32 +++++++++++ ...orkloadManagementTransportInterceptor.java | 12 +++-- ...roupRequestRejectionOperationListener.java | 39 ++++++++++++++ .../wlm/listeners/package-info.java | 12 +++++ ...adManagementTransportInterceptorTests.java | 2 +- ...anagementTransportRequestHandlerTests.java | 20 +++++-- ...equestRejectionOperationListenerTests.java | 53 +++++++++++++++++++ 9 files changed, 180 insertions(+), 9 deletions(-) create mode 100644 server/src/main/java/org/opensearch/wlm/QueryGroupService.java create mode 100644 server/src/main/java/org/opensearch/wlm/listeners/QueryGroupRequestRejectionOperationListener.java create mode 100644 server/src/main/java/org/opensearch/wlm/listeners/package-info.java create mode 100644 server/src/test/java/org/opensearch/wlm/listeners/QueryGroupRequestRejectionOperationListenerTests.java diff --git a/CHANGELOG.md b/CHANGELOG.md index cbfde6a1c1a80..fe1cee57279d2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -33,6 +33,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), - [Workload Management] Add query group stats constructs ([#15343](https://github.com/opensearch-project/OpenSearch/pull/15343))) - Add runAs to Subject interface and introduce IdentityAwarePlugin extension point ([#14630](https://github.com/opensearch-project/OpenSearch/pull/14630)) - Optimize NodeIndicesStats output behind flag ([#14454](https://github.com/opensearch-project/OpenSearch/pull/14454)) +- [Workload Management] Add rejection logic for co-ordinator and shard level requests ([#15428](https://github.com/opensearch-project/OpenSearch/pull/15428))) ### Dependencies - Bump `netty` from 4.1.111.Final to 4.1.112.Final ([#15081](https://github.com/opensearch-project/OpenSearch/pull/15081)) diff --git a/server/src/main/java/org/opensearch/node/Node.java b/server/src/main/java/org/opensearch/node/Node.java index 9c7dfe8850b85..ea656af6110e5 100644 --- a/server/src/main/java/org/opensearch/node/Node.java +++ b/server/src/main/java/org/opensearch/node/Node.java @@ -267,7 +267,9 @@ import org.opensearch.transport.TransportService; import org.opensearch.usage.UsageService; import org.opensearch.watcher.ResourceWatcherService; +import org.opensearch.wlm.QueryGroupService; import org.opensearch.wlm.WorkloadManagementTransportInterceptor; +import org.opensearch.wlm.listeners.QueryGroupRequestRejectionOperationListener; import javax.net.ssl.SNIHostName; @@ -1017,11 +1019,22 @@ protected Node( List identityAwarePlugins = pluginsService.filterPlugins(IdentityAwarePlugin.class); identityService.initializeIdentityAwarePlugins(identityAwarePlugins); + final QueryGroupRequestRejectionOperationListener queryGroupRequestRejectionListener = + new QueryGroupRequestRejectionOperationListener( + new QueryGroupService(), // We will need to replace this with actual instance of the queryGroupService + threadPool + ); + // register all standard SearchRequestOperationsCompositeListenerFactory to the SearchRequestOperationsCompositeListenerFactory final SearchRequestOperationsCompositeListenerFactory searchRequestOperationsCompositeListenerFactory = new SearchRequestOperationsCompositeListenerFactory( Stream.concat( - Stream.of(searchRequestStats, searchRequestSlowLog, searchTaskRequestOperationsListener), + Stream.of( + searchRequestStats, + searchRequestSlowLog, + searchTaskRequestOperationsListener, + queryGroupRequestRejectionListener + ), pluginComponents.stream() .filter(p -> p instanceof SearchRequestOperationsListener) .map(p -> (SearchRequestOperationsListener) p) @@ -1071,7 +1084,8 @@ protected Node( ); WorkloadManagementTransportInterceptor workloadManagementTransportInterceptor = new WorkloadManagementTransportInterceptor( - threadPool + threadPool, + new QueryGroupService() // We will need to replace this with actual implementation ); final Collection secureSettingsFactories = pluginsService.filterPlugins(Plugin.class) diff --git a/server/src/main/java/org/opensearch/wlm/QueryGroupService.java b/server/src/main/java/org/opensearch/wlm/QueryGroupService.java new file mode 100644 index 0000000000000..97c4e5169b4ed --- /dev/null +++ b/server/src/main/java/org/opensearch/wlm/QueryGroupService.java @@ -0,0 +1,32 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.wlm; + +import org.opensearch.core.concurrency.OpenSearchRejectedExecutionException; + +/** + * This is stub at this point in time and will be replace by an acutal one in couple of days + */ +public class QueryGroupService { + /** + * + * @param queryGroupId query group identifier + */ + public void rejectIfNeeded(String queryGroupId) { + if (queryGroupId == null) return; + boolean reject = false; + final StringBuilder reason = new StringBuilder(); + // TODO: At this point this is dummy and we need to decide whether to cancel the request based on last + // reported resource usage for the queryGroup. We also need to increment the rejection count here for the + // query group + if (reject) { + throw new OpenSearchRejectedExecutionException("QueryGroup " + queryGroupId + " is already contended." + reason.toString()); + } + } +} diff --git a/server/src/main/java/org/opensearch/wlm/WorkloadManagementTransportInterceptor.java b/server/src/main/java/org/opensearch/wlm/WorkloadManagementTransportInterceptor.java index 848df8712549a..d382b4c729a38 100644 --- a/server/src/main/java/org/opensearch/wlm/WorkloadManagementTransportInterceptor.java +++ b/server/src/main/java/org/opensearch/wlm/WorkloadManagementTransportInterceptor.java @@ -20,9 +20,11 @@ */ public class WorkloadManagementTransportInterceptor implements TransportInterceptor { private final ThreadPool threadPool; + private final QueryGroupService queryGroupService; - public WorkloadManagementTransportInterceptor(ThreadPool threadPool) { + public WorkloadManagementTransportInterceptor(final ThreadPool threadPool, final QueryGroupService queryGroupService) { this.threadPool = threadPool; + this.queryGroupService = queryGroupService; } @Override @@ -32,7 +34,7 @@ public TransportRequestHandler interceptHandler( boolean forceExecution, TransportRequestHandler actualHandler ) { - return new RequestHandler(threadPool, actualHandler); + return new RequestHandler(threadPool, actualHandler, queryGroupService); } /** @@ -43,16 +45,20 @@ public static class RequestHandler implements Transp private final ThreadPool threadPool; TransportRequestHandler actualHandler; + private final QueryGroupService queryGroupService; - public RequestHandler(ThreadPool threadPool, TransportRequestHandler actualHandler) { + public RequestHandler(ThreadPool threadPool, TransportRequestHandler actualHandler, QueryGroupService queryGroupService) { this.threadPool = threadPool; this.actualHandler = actualHandler; + this.queryGroupService = queryGroupService; } @Override public void messageReceived(T request, TransportChannel channel, Task task) throws Exception { if (isSearchWorkloadRequest(task)) { ((QueryGroupTask) task).setQueryGroupId(threadPool.getThreadContext()); + final String queryGroupId = ((QueryGroupTask) (task)).getQueryGroupId(); + queryGroupService.rejectIfNeeded(queryGroupId); } actualHandler.messageReceived(request, channel, task); } diff --git a/server/src/main/java/org/opensearch/wlm/listeners/QueryGroupRequestRejectionOperationListener.java b/server/src/main/java/org/opensearch/wlm/listeners/QueryGroupRequestRejectionOperationListener.java new file mode 100644 index 0000000000000..89f6fe709667f --- /dev/null +++ b/server/src/main/java/org/opensearch/wlm/listeners/QueryGroupRequestRejectionOperationListener.java @@ -0,0 +1,39 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.wlm.listeners; + +import org.opensearch.action.search.SearchRequestContext; +import org.opensearch.action.search.SearchRequestOperationsListener; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.wlm.QueryGroupService; +import org.opensearch.wlm.QueryGroupTask; + +/** + * This listener is used to perform the rejections for incoming requests into a queryGroup + */ +public class QueryGroupRequestRejectionOperationListener extends SearchRequestOperationsListener { + + private final QueryGroupService queryGroupService; + private final ThreadPool threadPool; + + public QueryGroupRequestRejectionOperationListener(QueryGroupService queryGroupService, ThreadPool threadPool) { + this.queryGroupService = queryGroupService; + this.threadPool = threadPool; + } + + /** + * This method assumes that the queryGroupId is already populated in the thread context + * @param searchRequestContext SearchRequestContext instance + */ + @Override + protected void onRequestStart(SearchRequestContext searchRequestContext) { + final String queryGroupId = threadPool.getThreadContext().getHeader(QueryGroupTask.QUERY_GROUP_ID_HEADER); + queryGroupService.rejectIfNeeded(queryGroupId); + } +} diff --git a/server/src/main/java/org/opensearch/wlm/listeners/package-info.java b/server/src/main/java/org/opensearch/wlm/listeners/package-info.java new file mode 100644 index 0000000000000..e900acf657085 --- /dev/null +++ b/server/src/main/java/org/opensearch/wlm/listeners/package-info.java @@ -0,0 +1,12 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +/** + * WLM related listener constructs + */ +package org.opensearch.wlm.listeners; diff --git a/server/src/test/java/org/opensearch/wlm/WorkloadManagementTransportInterceptorTests.java b/server/src/test/java/org/opensearch/wlm/WorkloadManagementTransportInterceptorTests.java index db4e5e45d49ed..4668b845150a9 100644 --- a/server/src/test/java/org/opensearch/wlm/WorkloadManagementTransportInterceptorTests.java +++ b/server/src/test/java/org/opensearch/wlm/WorkloadManagementTransportInterceptorTests.java @@ -25,7 +25,7 @@ public class WorkloadManagementTransportInterceptorTests extends OpenSearchTestC public void setUp() throws Exception { super.setUp(); threadPool = new TestThreadPool(getTestName()); - sut = new WorkloadManagementTransportInterceptor(threadPool); + sut = new WorkloadManagementTransportInterceptor(threadPool, new QueryGroupService()); } public void tearDown() throws Exception { diff --git a/server/src/test/java/org/opensearch/wlm/WorkloadManagementTransportRequestHandlerTests.java b/server/src/test/java/org/opensearch/wlm/WorkloadManagementTransportRequestHandlerTests.java index 789c02345e774..59818ad3dbbd2 100644 --- a/server/src/test/java/org/opensearch/wlm/WorkloadManagementTransportRequestHandlerTests.java +++ b/server/src/test/java/org/opensearch/wlm/WorkloadManagementTransportRequestHandlerTests.java @@ -9,6 +9,7 @@ package org.opensearch.wlm; import org.opensearch.action.index.IndexRequest; +import org.opensearch.core.concurrency.OpenSearchRejectedExecutionException; import org.opensearch.search.internal.ShardSearchRequest; import org.opensearch.tasks.Task; import org.opensearch.test.OpenSearchTestCase; @@ -20,12 +21,16 @@ import java.util.Collections; +import static org.mockito.Mockito.anyString; +import static org.mockito.Mockito.doNothing; +import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.spy; public class WorkloadManagementTransportRequestHandlerTests extends OpenSearchTestCase { private WorkloadManagementTransportInterceptor.RequestHandler sut; private ThreadPool threadPool; + private QueryGroupService queryGroupService; private TestTransportRequestHandler actualHandler; @@ -33,8 +38,9 @@ public void setUp() throws Exception { super.setUp(); threadPool = new TestThreadPool(getTestName()); actualHandler = new TestTransportRequestHandler<>(); + queryGroupService = mock(QueryGroupService.class); - sut = new WorkloadManagementTransportInterceptor.RequestHandler<>(threadPool, actualHandler); + sut = new WorkloadManagementTransportInterceptor.RequestHandler<>(threadPool, actualHandler, queryGroupService); } public void tearDown() throws Exception { @@ -42,14 +48,22 @@ public void tearDown() throws Exception { threadPool.shutdown(); } - public void testMessageReceivedForSearchWorkload() throws Exception { + public void testMessageReceivedForSearchWorkload_nonRejectionCase() throws Exception { ShardSearchRequest request = mock(ShardSearchRequest.class); QueryGroupTask spyTask = getSpyTask(); - + doNothing().when(queryGroupService).rejectIfNeeded(anyString()); sut.messageReceived(request, mock(TransportChannel.class), spyTask); assertTrue(sut.isSearchWorkloadRequest(spyTask)); } + public void testMessageReceivedForSearchWorkload_RejectionCase() throws Exception { + ShardSearchRequest request = mock(ShardSearchRequest.class); + QueryGroupTask spyTask = getSpyTask(); + doThrow(OpenSearchRejectedExecutionException.class).when(queryGroupService).rejectIfNeeded(anyString()); + + assertThrows(OpenSearchRejectedExecutionException.class, () -> sut.messageReceived(request, mock(TransportChannel.class), spyTask)); + } + public void testMessageReceivedForNonSearchWorkload() throws Exception { IndexRequest indexRequest = mock(IndexRequest.class); Task task = mock(Task.class); diff --git a/server/src/test/java/org/opensearch/wlm/listeners/QueryGroupRequestRejectionOperationListenerTests.java b/server/src/test/java/org/opensearch/wlm/listeners/QueryGroupRequestRejectionOperationListenerTests.java new file mode 100644 index 0000000000000..19e82aca26153 --- /dev/null +++ b/server/src/test/java/org/opensearch/wlm/listeners/QueryGroupRequestRejectionOperationListenerTests.java @@ -0,0 +1,53 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.wlm.listeners; + +import org.opensearch.core.concurrency.OpenSearchRejectedExecutionException; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.TestThreadPool; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.wlm.QueryGroupService; +import org.opensearch.wlm.QueryGroupTask; + +import static org.mockito.Mockito.doNothing; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.mock; + +public class QueryGroupRequestRejectionOperationListenerTests extends OpenSearchTestCase { + ThreadPool testThreadPool; + QueryGroupService queryGroupService; + QueryGroupRequestRejectionOperationListener sut; + + public void setUp() throws Exception { + super.setUp(); + testThreadPool = new TestThreadPool("RejectionTestThreadPool"); + queryGroupService = mock(QueryGroupService.class); + sut = new QueryGroupRequestRejectionOperationListener(queryGroupService, testThreadPool); + } + + public void tearDown() throws Exception { + super.tearDown(); + testThreadPool.shutdown(); + } + + public void testRejectionCase() { + final String testQueryGroupId = "asdgasgkajgkw3141_3rt4t"; + testThreadPool.getThreadContext().putHeader(QueryGroupTask.QUERY_GROUP_ID_HEADER, testQueryGroupId); + doThrow(OpenSearchRejectedExecutionException.class).when(queryGroupService).rejectIfNeeded(testQueryGroupId); + assertThrows(OpenSearchRejectedExecutionException.class, () -> sut.onRequestStart(null)); + } + + public void testNonRejectionCase() { + final String testQueryGroupId = "asdgasgkajgkw3141_3rt4t"; + testThreadPool.getThreadContext().putHeader(QueryGroupTask.QUERY_GROUP_ID_HEADER, testQueryGroupId); + doNothing().when(queryGroupService).rejectIfNeeded(testQueryGroupId); + + sut.onRequestStart(null); + } +}