Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DRAFT] Request labeling service in opensearch #14282

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
- Add remote routing table for remote state publication with experimental feature flag ([#13304](https://github.com/opensearch-project/OpenSearch/pull/13304))
- [Remote Store] Add support to disable flush based on translog reader count ([#14027](https://github.com/opensearch-project/OpenSearch/pull/14027))
- [Query Insights] Add exporter support for top n queries ([#12982](https://github.com/opensearch-project/OpenSearch/pull/12982))
- Support rule-based labeling for search queries ([#13374](https://github.com/opensearch-project/OpenSearch/pull/13374))

### Dependencies
- Bump `com.github.spullara.mustache.java:compiler` from 0.9.10 to 0.9.13 ([#13329](https://github.com/opensearch-project/OpenSearch/pull/13329), [#13559](https://github.com/opensearch-project/OpenSearch/pull/13559))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ public Collection<Object> createComponents(
) {
// create top n queries service
final QueryInsightsService queryInsightsService = new QueryInsightsService(clusterService.getClusterSettings(), threadPool, client);
return List.of(queryInsightsService, new QueryInsightsListener(clusterService, queryInsightsService));
return List.of(queryInsightsService, new QueryInsightsListener(threadPool, clusterService, queryInsightsService));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@
import org.opensearch.plugin.insights.rules.model.Attribute;
import org.opensearch.plugin.insights.rules.model.MetricType;
import org.opensearch.plugin.insights.rules.model.SearchQueryRecord;
import org.opensearch.search.labels.RequestLabelingService;
import org.opensearch.tasks.Task;
import org.opensearch.threadpool.ThreadPool;

import java.util.Collections;
import java.util.HashMap;
Expand All @@ -45,15 +48,21 @@ public final class QueryInsightsListener extends SearchRequestOperationsListener
private static final Logger log = LogManager.getLogger(QueryInsightsListener.class);

private final QueryInsightsService queryInsightsService;
private final ThreadPool threadPool;

/**
* Constructor for QueryInsightsListener
*
* @param threadPool the OpenSearch internal threadPool
* @param clusterService The Node's cluster service.
* @param queryInsightsService The topQueriesByLatencyService associated with this listener
*/
@Inject
public QueryInsightsListener(final ClusterService clusterService, final QueryInsightsService queryInsightsService) {
public QueryInsightsListener(
final ThreadPool threadPool,
final ClusterService clusterService,
final QueryInsightsService queryInsightsService
) {
this.queryInsightsService = queryInsightsService;
clusterService.getClusterSettings()
.addSettingsUpdateConsumer(TOP_N_LATENCY_QUERIES_ENABLED, v -> this.setEnableTopQueries(MetricType.LATENCY, v));
Expand All @@ -74,6 +83,7 @@ public QueryInsightsListener(final ClusterService clusterService, final QueryIns
.setTopNSize(clusterService.getClusterSettings().get(TOP_N_LATENCY_QUERIES_SIZE));
this.queryInsightsService.getTopQueriesService(MetricType.LATENCY)
.setWindowSize(clusterService.getClusterSettings().get(TOP_N_LATENCY_QUERIES_WINDOW_SIZE));
this.threadPool = threadPool;
}

/**
Expand Down Expand Up @@ -138,6 +148,21 @@ public void onRequestEnd(final SearchPhaseContext context, final SearchRequestCo
attributes.put(Attribute.TOTAL_SHARDS, context.getNumShards());
attributes.put(Attribute.INDICES, request.indices());
attributes.put(Attribute.PHASE_LATENCY_MAP, searchRequestContext.phaseTookMap());

// Get internal computed and user provided labels
Map<String, Object> labels = new HashMap<>();
// Retrieve user provided label if exists
String userProvidedLabel = RequestLabelingService.getUserProvidedTag(threadPool);
if (userProvidedLabel != null) {
labels.put(Task.X_OPAQUE_ID, userProvidedLabel);
}
// Retrieve computed labels if exists
Map<String, Object> computedLabels = RequestLabelingService.getRuleBasedLabels(threadPool);
if (computedLabels != null) {
labels.putAll(computedLabels);
}
attributes.put(Attribute.LABELS, labels);
// construct SearchQueryRecord from attributes and measurements
SearchQueryRecord record = new SearchQueryRecord(request.getOrCreateAbsoluteStartMillis(), measurements, attributes);
queryInsightsService.addRecord(record);
} catch (Exception e) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,11 @@ public enum Attribute {
/**
* The node id for this request
*/
NODE_ID;
NODE_ID,
/**
* Custom search request labels
*/
LABELS;

/**
* Read an Attribute from a StreamInput
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,26 +13,37 @@
import org.opensearch.action.search.SearchRequestContext;
import org.opensearch.action.search.SearchType;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.collect.Tuple;
import org.opensearch.common.settings.ClusterSettings;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.plugin.insights.core.service.QueryInsightsService;
import org.opensearch.plugin.insights.core.service.TopQueriesService;
import org.opensearch.plugin.insights.rules.model.Attribute;
import org.opensearch.plugin.insights.rules.model.MetricType;
import org.opensearch.plugin.insights.rules.model.SearchQueryRecord;
import org.opensearch.plugin.insights.settings.QueryInsightsSettings;
import org.opensearch.search.aggregations.bucket.terms.TermsAggregationBuilder;
import org.opensearch.search.aggregations.support.ValueType;
import org.opensearch.search.builder.SearchSourceBuilder;
import org.opensearch.search.labels.RequestLabelingService;
import org.opensearch.tasks.Task;
import org.opensearch.test.ClusterServiceUtils;
import org.opensearch.test.OpenSearchTestCase;
import org.opensearch.threadpool.ThreadPool;
import org.junit.Before;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.Phaser;

import org.mockito.ArgumentCaptor;

import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
Expand All @@ -48,6 +59,7 @@ public class QueryInsightsListenerTests extends OpenSearchTestCase {
private final SearchRequest searchRequest = mock(SearchRequest.class);
private final QueryInsightsService queryInsightsService = mock(QueryInsightsService.class);
private final TopQueriesService topQueriesService = mock(TopQueriesService.class);
private final ThreadPool threadPool = mock(ThreadPool.class);
private ClusterService clusterService;

@Before
Expand All @@ -61,8 +73,14 @@ public void setup() {
clusterService = ClusterServiceUtils.createClusterService(settings, clusterSettings, null);
when(queryInsightsService.isCollectionEnabled(MetricType.LATENCY)).thenReturn(true);
when(queryInsightsService.getTopQueriesService(MetricType.LATENCY)).thenReturn(topQueriesService);

ThreadContext threadContext = new ThreadContext(Settings.EMPTY);
threadContext.setHeaders(new Tuple<>(Collections.singletonMap(Task.X_OPAQUE_ID, "userLabel"), new HashMap<>()));
threadContext.putTransient(RequestLabelingService.RULE_BASED_LABELS, Map.of("labelKey", "labelValue"));
when(threadPool.getThreadContext()).thenReturn(threadContext);
}

@SuppressWarnings("unchecked")
public void testOnRequestEnd() throws InterruptedException {
Long timestamp = System.currentTimeMillis() - 100L;
SearchType searchType = SearchType.QUERY_THEN_FETCH;
Expand All @@ -80,7 +98,7 @@ public void testOnRequestEnd() throws InterruptedException {

int numberOfShards = 10;

QueryInsightsListener queryInsightsListener = new QueryInsightsListener(clusterService, queryInsightsService);
QueryInsightsListener queryInsightsListener = new QueryInsightsListener(threadPool, clusterService, queryInsightsService);

when(searchRequest.getOrCreateAbsoluteStartMillis()).thenReturn(timestamp);
when(searchRequest.searchType()).thenReturn(searchType);
Expand All @@ -89,10 +107,19 @@ public void testOnRequestEnd() throws InterruptedException {
when(searchRequestContext.phaseTookMap()).thenReturn(phaseLatencyMap);
when(searchPhaseContext.getRequest()).thenReturn(searchRequest);
when(searchPhaseContext.getNumShards()).thenReturn(numberOfShards);
ArgumentCaptor<SearchQueryRecord> captor = ArgumentCaptor.forClass(SearchQueryRecord.class);

queryInsightsListener.onRequestEnd(searchPhaseContext, searchRequestContext);

verify(queryInsightsService, times(1)).addRecord(any());
verify(queryInsightsService, times(1)).addRecord(captor.capture());
SearchQueryRecord generatedRecord = captor.getValue();
assertEquals(timestamp.longValue(), generatedRecord.getTimestamp());
assertEquals(numberOfShards, generatedRecord.getAttributes().get(Attribute.TOTAL_SHARDS));
assertEquals(searchType.toString().toLowerCase(Locale.ROOT), generatedRecord.getAttributes().get(Attribute.SEARCH_TYPE));
assertEquals(searchSourceBuilder.toString(), generatedRecord.getAttributes().get(Attribute.SOURCE));
Map<String, String> labels = (Map<String, String>) generatedRecord.getAttributes().get(Attribute.LABELS);
assertEquals("labelValue", labels.get("labelKey"));
assertEquals("userLabel", labels.get(Task.X_OPAQUE_ID));
}

public void testConcurrentOnRequestEnd() throws InterruptedException {
Expand Down Expand Up @@ -128,7 +155,7 @@ public void testConcurrentOnRequestEnd() throws InterruptedException {
CountDownLatch countDownLatch = new CountDownLatch(numRequests);

for (int i = 0; i < numRequests; i++) {
searchListenersList.add(new QueryInsightsListener(clusterService, queryInsightsService));
searchListenersList.add(new QueryInsightsListener(threadPool, clusterService, queryInsightsService));
}

for (int i = 0; i < numRequests; i++) {
Expand All @@ -149,7 +176,7 @@ public void testConcurrentOnRequestEnd() throws InterruptedException {

public void testSetEnabled() {
when(queryInsightsService.isCollectionEnabled(MetricType.LATENCY)).thenReturn(true);
QueryInsightsListener queryInsightsListener = new QueryInsightsListener(clusterService, queryInsightsService);
QueryInsightsListener queryInsightsListener = new QueryInsightsListener(threadPool, clusterService, queryInsightsService);
queryInsightsListener.setEnableTopQueries(MetricType.LATENCY, true);
assertTrue(queryInsightsListener.isEnabled());

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,10 @@ String formattedShardStats() {
);
}
}

public SearchRequest getRequest() {
return searchRequest;
}
}

enum ShardStatsFieldNames {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,11 @@ protected SearchRequestOperationsListener(final boolean enabled) {
this.enabled = enabled;
}

protected abstract void onPhaseStart(SearchPhaseContext context);
protected void onPhaseStart(SearchPhaseContext context) {};

protected abstract void onPhaseEnd(SearchPhaseContext context, SearchRequestContext searchRequestContext);
protected void onPhaseEnd(SearchPhaseContext context, SearchRequestContext searchRequestContext) {};

protected abstract void onPhaseFailure(SearchPhaseContext context, Throwable cause);
protected void onPhaseFailure(SearchPhaseContext context, Throwable cause) {};

protected void onRequestStart(SearchRequestContext searchRequestContext) {}

Expand Down
11 changes: 10 additions & 1 deletion server/src/main/java/org/opensearch/node/Node.java
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,9 @@
import org.opensearch.search.backpressure.SearchBackpressureService;
import org.opensearch.search.backpressure.settings.SearchBackpressureSettings;
import org.opensearch.search.fetch.FetchPhase;
import org.opensearch.search.labels.RequestLabelingService;
import org.opensearch.search.labels.SearchRequestLabelingListener;
import org.opensearch.search.labels.rules.Rule;
import org.opensearch.search.pipeline.SearchPipelineService;
import org.opensearch.search.query.QueryPhase;
import org.opensearch.snapshots.InternalSnapshotsInfoService;
Expand Down Expand Up @@ -962,11 +965,17 @@ protected Node(
// Add the telemetryAwarePlugin components to the existing pluginComponents collection.
pluginComponents.addAll(telemetryAwarePluginComponents);

final SearchRequestLabelingListener searchRequestLabelingListener = new SearchRequestLabelingListener(
new RequestLabelingService(
threadPool,
pluginComponents.stream().filter(p -> p instanceof Rule).map(p -> (Rule) p).collect(toList())
)
);
// register all standard SearchRequestOperationsCompositeListenerFactory to the SearchRequestOperationsCompositeListenerFactory
final SearchRequestOperationsCompositeListenerFactory searchRequestOperationsCompositeListenerFactory =
new SearchRequestOperationsCompositeListenerFactory(
Stream.concat(
Stream.of(searchRequestStats, searchRequestSlowLog),
Stream.of(searchRequestStats, searchRequestSlowLog, searchRequestLabelingListener),
pluginComponents.stream()
.filter(p -> p instanceof SearchRequestOperationsListener)
.map(p -> (SearchRequestOperationsListener) p)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/

package org.opensearch.search.labels;

import org.opensearch.action.search.SearchRequest;
import org.opensearch.search.labels.rules.Rule;
import org.opensearch.tasks.Task;
import org.opensearch.threadpool.ThreadPool;

import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.stream.Collectors;

/**
* Service to attach labels to a search request based on pre-defined rules
* It evaluate all available rules and generate labels into the thread context.
*/

// Rule: allowed user provided labels, internally supported labels. -> do we allow conflicts?
// take json string (or one key-value pair for each label?), parse and check if user-provided labels are valid based on allowed labels from different rules, if valid, put into all_labels.
// evaluate all rules, put into all_labels thread context
// for user defined labels, do we want to keep in another thread context?

// or add a label for each separate key?

// each rule define their own "allowed labels", each rule should only use their own defined "allowed labels"
// This service should be responsible to check conflicts.
public class RequestLabelingService {
/**
* Field name for computed labels
*/
public static final String RULE_BASED_LABELS = "rule_based_labels";
private final ThreadPool threadPool;
private final List<Rule> rules;

public RequestLabelingService(final ThreadPool threadPool, final List<Rule> rules) {
this.threadPool = threadPool;
this.rules = rules;
}

public void parseUserLabels() {
// parse user provided labels into a map, also validate them based on allowed_user_labels from each rule

threadPool.getThreadContext().putTransient(USER_PROVIDED_LABELS, userLabels);
}

/**
* Evaluate all labeling rules and store the computed rules into thread context
*
* @param searchRequest {@link SearchRequest}
*/
public void applyAllRules(final SearchRequest searchRequest) {
Map<String, Object> computedLabels = rules.stream()
.map(rule -> rule.evaluate(threadPool.getThreadContext(), searchRequest))
.flatMap(m -> m.entrySet().stream())
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue, (existing, replacement) -> replacement));
// String userProvidedTag = getUserProvidedTag(threadPool);

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Check if labels are valid here

threadPool.getThreadContext().putTransient(RULE_BASED_LABELS, computedLabels);
}

/**
* Get the user provided tag from the X-Opaque-Id header
*
* @return user provided tag
*/
public static String getUserProvidedTag(ThreadPool threadPool) {
return threadPool.getThreadContext().getTransient(USER_PROVIDED_LABELS);
}

public static Map<String, Object> getRuleBasedLabels(ThreadPool threadPool) {
return threadPool.getThreadContext().getTransient(RULE_BASED_LABELS);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/

package org.opensearch.search.labels;

import org.opensearch.action.search.SearchRequestContext;
import org.opensearch.action.search.SearchRequestOperationsListener;

/**
* SearchRequestOperationsListener subscriber for labeling search requests
*
* @opensearch.internal
*/
public final class SearchRequestLabelingListener extends SearchRequestOperationsListener {
final private RequestLabelingService requestLabelingService;

public SearchRequestLabelingListener(final RequestLabelingService requestLabelingService) {
this.requestLabelingService = requestLabelingService;
}

@Override
public void onRequestStart(SearchRequestContext searchRequestContext) {
// add tags to search request
requestLabelingService.applyAllRules(searchRequestContext.getRequest());
requestLabelingService.parseUserLabels();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/

/** Search labeling service. */
package org.opensearch.search.labels;
Loading
Loading