Skip to content

Commit

Permalink
address comments
Browse files Browse the repository at this point in the history
Signed-off-by: Ruirui Zhang <mariazrr@amazon.com>
  • Loading branch information
ruai0511 committed Aug 15, 2024
1 parent afb596e commit c0b3e82
Show file tree
Hide file tree
Showing 15 changed files with 129 additions and 109 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,23 +18,23 @@
import org.opensearch.core.xcontent.XContentBuilder;

import java.io.IOException;
import java.util.List;
import java.util.Collection;

/**
* Response for the get API for QueryGroup
*
* @opensearch.experimental
*/
public class GetQueryGroupResponse extends ActionResponse implements ToXContent, ToXContentObject {
private final List<QueryGroup> queryGroups;
private final Collection<QueryGroup> queryGroups;
private final RestStatus restStatus;

/**
* Constructor for GetQueryGroupResponse
* @param queryGroups - The QueryGroup list to be fetched
* @param restStatus - The rest status of the request
*/
public GetQueryGroupResponse(final List<QueryGroup> queryGroups, RestStatus restStatus) {
public GetQueryGroupResponse(final Collection<QueryGroup> queryGroups, RestStatus restStatus) {
this.queryGroups = queryGroups;
this.restStatus = restStatus;
}
Expand All @@ -50,7 +50,7 @@ public GetQueryGroupResponse(StreamInput in) throws IOException {

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeList(queryGroups);
out.writeCollection(queryGroups);
RestStatus.writeTo(out, restStatus);
}

Expand All @@ -69,7 +69,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
/**
* queryGroups getter
*/
public List<QueryGroup> getQueryGroups() {
public Collection<QueryGroup> getQueryGroups() {
return queryGroups;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.ResourceNotFoundException;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.support.clustermanager.TransportClusterManagerNodeReadAction;
import org.opensearch.cluster.ClusterState;
Expand All @@ -28,7 +29,7 @@
import org.opensearch.transport.TransportService;

import java.io.IOException;
import java.util.List;
import java.util.Collection;

/**
* Transport action to get QueryGroup
Expand Down Expand Up @@ -85,12 +86,12 @@ protected ClusterBlockException checkBlock(GetQueryGroupRequest request, Cluster
@Override
protected void clusterManagerOperation(GetQueryGroupRequest request, ClusterState state, ActionListener<GetQueryGroupResponse> listener)
throws Exception {
String name = request.getName();
List<QueryGroup> resultGroups = QueryGroupPersistenceService.getFromClusterStateMetadata(name, state);
final String name = request.getName();
final Collection<QueryGroup> resultGroups = QueryGroupPersistenceService.getFromClusterStateMetadata(name, state);

if (resultGroups.isEmpty() && name != null && !name.isEmpty()) {
logger.warn("No QueryGroup exists with the provided name: {}", name);
throw new IllegalArgumentException("No QueryGroup exists with the provided name: " + name);
throw new ResourceNotFoundException("No QueryGroup exists with the provided name: " + name);
}
listener.onResponse(new GetQueryGroupResponse(resultGroups, RestStatus.OK));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ public List<Route> routes() {

@Override
protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException {
GetQueryGroupRequest getQueryGroupRequest = new GetQueryGroupRequest(request.param("name"));
final GetQueryGroupRequest getQueryGroupRequest = new GetQueryGroupRequest(request.param("name"));
return channel -> client.execute(GetQueryGroupAction.INSTANCE, getQueryGroupRequest, getQueryGroupResponse(channel));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,11 @@
import org.opensearch.plugin.wlm.action.CreateQueryGroupResponse;
import org.opensearch.search.ResourceType;

import java.util.ArrayList;
import java.util.Collection;
import java.util.EnumMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Collectors;

/**
* This class defines the functions for QueryGroup persistence
Expand Down Expand Up @@ -199,14 +199,17 @@ private Map<ResourceType, Double> calculateTotalUsage(Map<String, QueryGroup> ex
* @param name - the QueryGroup name we are getting
* @param currentState - current cluster state
*/
public static List<QueryGroup> getFromClusterStateMetadata(String name, ClusterState currentState) {
Map<String, QueryGroup> currentGroups = currentState.getMetadata().queryGroups();
public static Collection<QueryGroup> getFromClusterStateMetadata(String name, ClusterState currentState) {
final Map<String, QueryGroup> currentGroups = currentState.getMetadata().queryGroups();
if (name == null || name.isEmpty()) {
return new ArrayList<>(currentGroups.values());
return currentGroups.values();
}
List<QueryGroup> resultGroups = new ArrayList<>();
currentGroups.values().stream().filter(group -> group.getName().equals(name)).findFirst().ifPresent(resultGroups::add);
return resultGroups;
return currentGroups.values()
.stream()
.filter(group -> group.getName().equals(name))
.findAny()
.stream()
.collect(Collectors.toList());
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import org.opensearch.threadpool.ThreadPool;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Comparator;
import java.util.HashSet;
import java.util.List;
Expand Down Expand Up @@ -130,8 +131,10 @@ public static Tuple<QueryGroupPersistenceService, ClusterState> preparePersisten
return new Tuple<QueryGroupPersistenceService, ClusterState>(queryGroupPersistenceService, clusterState);
}

public static void assertEqualQueryGroups(List<QueryGroup> listOne, List<QueryGroup> listTwo) {
assertEquals(listOne.size(), listTwo.size());
public static void assertEqualQueryGroups(Collection<QueryGroup> collectionOne, Collection<QueryGroup> collectionTwo) {
assertEquals(collectionOne.size(), collectionTwo.size());
List<QueryGroup> listOne = new ArrayList<>(collectionOne);
List<QueryGroup> listTwo = new ArrayList<>(collectionTwo);
listOne.sort(Comparator.comparing(QueryGroup::getName));
listTwo.sort(Comparator.comparing(QueryGroup::getName));
for (int i = 0; i < listOne.size(); i++) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@

public class GetQueryGroupRequestTests extends OpenSearchTestCase {

/**
* Test case to verify the serialization and deserialization of GetQueryGroupRequest.
*/
public void testSerialization() throws IOException {
GetQueryGroupRequest request = new GetQueryGroupRequest(QueryGroupTestUtils.NAME_ONE);
assertEquals(QueryGroupTestUtils.NAME_ONE, request.getName());
Expand All @@ -27,6 +30,9 @@ public void testSerialization() throws IOException {
assertEquals(request.getName(), otherRequest.getName());
}

/**
* Test case to verify the serialization and deserialization of GetQueryGroupRequest when name is null.
*/
public void testSerializationWithNull() throws IOException {
GetQueryGroupRequest request = new GetQueryGroupRequest((String) null);
assertNull(request.getName());
Expand All @@ -37,6 +43,9 @@ public void testSerializationWithNull() throws IOException {
assertEquals(request.getName(), otherRequest.getName());
}

/**
* Test case the validation function of GetQueryGroupRequest
*/
public void testValidation() {
GetQueryGroupRequest request = new GetQueryGroupRequest("a".repeat(51));
assertThrows(IllegalArgumentException.class, request::validate);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@

public class GetQueryGroupResponseTests extends OpenSearchTestCase {

/**
* Test case to verify the serialization and deserialization of GetQueryGroupResponse.
*/
public void testSerializationSingleQueryGroup() throws IOException {
List<QueryGroup> list = new ArrayList<>();
list.add(QueryGroupTestUtils.queryGroupOne);
Expand All @@ -41,6 +44,9 @@ public void testSerializationSingleQueryGroup() throws IOException {
QueryGroupTestUtils.assertEqualQueryGroups(response.getQueryGroups(), otherResponse.getQueryGroups());
}

/**
* Test case to verify the serialization and deserialization of GetQueryGroupResponse when the result contains multiple QueryGroups.
*/
public void testSerializationMultipleQueryGroup() throws IOException {
GetQueryGroupResponse response = new GetQueryGroupResponse(QueryGroupTestUtils.queryGroupList(), RestStatus.OK);
assertEquals(response.getQueryGroups(), QueryGroupTestUtils.queryGroupList());
Expand All @@ -55,6 +61,9 @@ public void testSerializationMultipleQueryGroup() throws IOException {
QueryGroupTestUtils.assertEqualQueryGroups(response.getQueryGroups(), otherResponse.getQueryGroups());
}

/**
* Test case to verify the serialization and deserialization of GetQueryGroupResponse when the result is empty.
*/
public void testSerializationNull() throws IOException {
List<QueryGroup> list = new ArrayList<>();
GetQueryGroupResponse response = new GetQueryGroupResponse(list, RestStatus.OK);
Expand All @@ -69,6 +78,9 @@ public void testSerializationNull() throws IOException {
assertEquals(0, otherResponse.getQueryGroups().size());
}

/**
* Test case to verify the toXContent of GetQueryGroupResponse.
*/
public void testToXContentGetSingleQueryGroup() throws IOException {
List<QueryGroup> queryGroupList = new ArrayList<>();
queryGroupList.add(QueryGroupTestUtils.queryGroupOne);
Expand All @@ -91,6 +103,9 @@ public void testToXContentGetSingleQueryGroup() throws IOException {
assertEquals(expected, actual);
}

/**
* Test case to verify the toXContent of GetQueryGroupResponse when the result contains multiple QueryGroups.
*/
public void testToXContentGetMultipleQueryGroup() throws IOException {
List<QueryGroup> queryGroupList = new ArrayList<>();
queryGroupList.add(QueryGroupTestUtils.queryGroupOne);
Expand Down Expand Up @@ -123,6 +138,9 @@ public void testToXContentGetMultipleQueryGroup() throws IOException {
assertEquals(expected, actual);
}

/**
* Test case to verify toXContent of GetQueryGroupResponse when the result contains zero QueryGroup.
*/
public void testToXContentGetZeroQueryGroup() throws IOException {
XContentBuilder builder = JsonXContent.contentBuilder().prettyPrint();
GetQueryGroupResponse otherResponse = new GetQueryGroupResponse(new ArrayList<>(), RestStatus.OK);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

package org.opensearch.plugin.wlm.action;

import org.opensearch.ResourceNotFoundException;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.cluster.metadata.IndexNameExpressionResolver;
import org.opensearch.cluster.service.ClusterService;
Expand All @@ -23,6 +24,9 @@

public class TransportGetQueryGroupActionTests extends OpenSearchTestCase {

/**
* Test case for ClusterManagerOperation function
*/
@SuppressWarnings("unchecked")
public void testClusterManagerOperation() throws Exception {
GetQueryGroupRequest getQueryGroupRequest1 = new GetQueryGroupRequest(NAME_NONE_EXISTED);
Expand All @@ -35,7 +39,7 @@ public void testClusterManagerOperation() throws Exception {
mock(IndexNameExpressionResolver.class)
);
assertThrows(
IllegalArgumentException.class,
ResourceNotFoundException.class,
() -> transportGetQueryGroupAction.clusterManagerOperation(getQueryGroupRequest1, clusterState(), mock(ActionListener.class))
);
transportGetQueryGroupAction.clusterManagerOperation(getQueryGroupRequest2, clusterState(), mock(ActionListener.class));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
import static org.opensearch.plugin.wlm.QueryGroupTestUtils.assertEqualQueryGroups;
import static org.opensearch.plugin.wlm.QueryGroupTestUtils.clusterSettings;
import static org.opensearch.plugin.wlm.QueryGroupTestUtils.clusterSettingsSet;
import static org.opensearch.plugin.wlm.QueryGroupTestUtils.clusterState;
import static org.opensearch.plugin.wlm.QueryGroupTestUtils.preparePersistenceServiceSetup;
import static org.opensearch.plugin.wlm.QueryGroupTestUtils.queryGroupList;
import static org.opensearch.plugin.wlm.QueryGroupTestUtils.queryGroupOne;
Expand Down Expand Up @@ -255,11 +256,12 @@ public void testPersistInClusterStateMetadataFailure() {
verify(listener).onFailure(any(RuntimeException.class));
}

/**
* Tests getting a single QueryGroup
*/
public void testGetSingleQueryGroup() {
List<QueryGroup> groups = QueryGroupPersistenceService.getFromClusterStateMetadata(
QueryGroupTestUtils.NAME_ONE,
QueryGroupTestUtils.clusterState()
);
Collection<QueryGroup> groupsCollections = QueryGroupPersistenceService.getFromClusterStateMetadata(NAME_ONE, clusterState());
List<QueryGroup> groups = new ArrayList<>(groupsCollections);
assertEquals(1, groups.size());
QueryGroup queryGroup = groups.get(0);
List<QueryGroup> listOne = new ArrayList<>();
Expand All @@ -269,32 +271,35 @@ public void testGetSingleQueryGroup() {
QueryGroupTestUtils.assertEqualQueryGroups(listOne, listTwo);
}

/**
* Tests getting all QueryGroups
*/
public void testGetAllQueryGroups() {
assertEquals(2, QueryGroupTestUtils.clusterState().metadata().queryGroups().size());
List<QueryGroup> res = QueryGroupPersistenceService.getFromClusterStateMetadata(null, QueryGroupTestUtils.clusterState());
Collection<QueryGroup> groupsCollections = QueryGroupPersistenceService.getFromClusterStateMetadata(null, clusterState());
List<QueryGroup> res = new ArrayList<>(groupsCollections);
assertEquals(2, res.size());
Set<String> currentNAME = res.stream().map(QueryGroup::getName).collect(Collectors.toSet());
assertTrue(currentNAME.contains(QueryGroupTestUtils.NAME_ONE));
assertTrue(currentNAME.contains(QueryGroupTestUtils.NAME_TWO));
QueryGroupTestUtils.assertEqualQueryGroups(QueryGroupTestUtils.queryGroupList(), res);
}

public void testGetZeroQueryGroups() {
List<QueryGroup> res = QueryGroupPersistenceService.getFromClusterStateMetadata(
QueryGroupTestUtils.NAME_NONE_EXISTED,
QueryGroupTestUtils.clusterState()
);
assertEquals(0, res.size());
}

/**
* Tests getting a QueryGroup with invalid name
*/
public void testGetNonExistedQueryGroups() {
List<QueryGroup> groups = QueryGroupPersistenceService.getFromClusterStateMetadata(
QueryGroupTestUtils.NAME_NONE_EXISTED,
QueryGroupTestUtils.clusterState()
Collection<QueryGroup> groupsCollections = QueryGroupPersistenceService.getFromClusterStateMetadata(
NAME_NONE_EXISTED,
clusterState()
);
List<QueryGroup> groups = new ArrayList<>(groupsCollections);
assertEquals(0, groups.size());
}

/**
* Tests setting maxQueryGroupCount
*/
public void testMaxQueryGroupCount() {
assertThrows(IllegalArgumentException.class, () -> QueryGroupTestUtils.queryGroupPersistenceService().setMaxQueryGroupCount(-1));
QueryGroupPersistenceService queryGroupPersistenceService = QueryGroupTestUtils.queryGroupPersistenceService();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
{
"create_query_group_context": {
"stability": "experimental",
"url": {
"paths": [
{
"path": "/_wlm/query_group",
"methods": ["PUT", "POST"],
"parts": {}
}
]
},
"params":{},
"body":{
"description":"The QueryGroup schema"
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,12 @@
"paths": [
{
"path": "/_wlm/query_group",
"methods": ["PUT", "POST", "GET"],
"methods": ["GET"],
"parts": {}
},
{
"path": "/_wlm/query_group/{name}",
"methods": [
"GET"
],
"methods": ["GET"],
"parts": {
"name": {
"type": "string",
Expand Down
Loading

0 comments on commit c0b3e82

Please sign in to comment.