Skip to content

Commit

Permalink
RestController should not consume request content (#66043)
Browse files Browse the repository at this point in the history
The change #37504 modifies the BaseRestHandler to make it reject all requests
that have an unconsumed body. The notion of consumed or unconsumed body
 is carried by the RestRequest object and its contentConsumed attribute, which
 is set to true when the content() or content(true) methods are used.

In our REST layer, we usually expect the RestHandlers to consume the request
content when needed, but it appears that the RestController always consumes
 the content upfront.

This commit changes the content() method used by the RestController so that it
does not mark the content as consumed.

Backport of #44902
Closes #65242

Co-authored-by: Tanguy Leroux <tlrx.dev@gmail.com>
  • Loading branch information
jaymode and tlrx authored Dec 14, 2020
1 parent 0c9662c commit 7096856
Show file tree
Hide file tree
Showing 12 changed files with 94 additions and 131 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -90,9 +90,7 @@ static Request deleteJob(final DeleteRollupJobRequest deleteRollupJobRequest) th
.addPathPartAsIs("_rollup", "job")
.addPathPart(deleteRollupJobRequest.getId())
.build();
Request request = new Request(HttpDelete.METHOD_NAME, endpoint);
request.setEntity(createEntity(deleteRollupJobRequest, REQUEST_BODY_CONTENT_TYPE));
return request;
return new Request(HttpDelete.METHOD_NAME, endpoint);
}

static Request search(final SearchRequest request) throws IOException {
Expand All @@ -114,18 +112,14 @@ static Request getRollupCaps(final GetRollupCapsRequest getRollupCapsRequest) th
.addPathPartAsIs("_rollup", "data")
.addPathPart(getRollupCapsRequest.getIndexPattern())
.build();
Request request = new Request(HttpGet.METHOD_NAME, endpoint);
request.setEntity(createEntity(getRollupCapsRequest, REQUEST_BODY_CONTENT_TYPE));
return request;
return new Request(HttpGet.METHOD_NAME, endpoint);
}

static Request getRollupIndexCaps(final GetRollupIndexCapsRequest getRollupIndexCapsRequest) throws IOException {
String endpoint = new RequestConverters.EndpointBuilder()
.addCommaSeparatedPathParts(getRollupIndexCapsRequest.indices())
.addPathPartAsIs("_rollup", "data")
.build();
Request request = new Request(HttpGet.METHOD_NAME, endpoint);
request.setEntity(createEntity(getRollupIndexCapsRequest, REQUEST_BODY_CONTENT_TYPE));
return request;
return new Request(HttpGet.METHOD_NAME, endpoint);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,22 +19,14 @@
package org.elasticsearch.client.rollup;

import org.elasticsearch.client.Validatable;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;

import java.io.IOException;
import java.util.Objects;


public class DeleteRollupJobRequest implements Validatable, ToXContentObject {
public class DeleteRollupJobRequest implements Validatable {

private static final ParseField ID_FIELD = new ParseField("id");
private final String id;


public DeleteRollupJobRequest(String id) {
this.id = Objects.requireNonNull(id, "id parameter must not be null");
}
Expand All @@ -43,27 +35,6 @@ public String getId() {
return id;
}

private static final ConstructingObjectParser<DeleteRollupJobRequest, Void> PARSER =
new ConstructingObjectParser<>("request", a -> {
return new DeleteRollupJobRequest((String) a[0]);
});

static {
PARSER.declareString(ConstructingObjectParser.constructorArg(), ID_FIELD);
}

public static DeleteRollupJobRequest fromXContent(XContentParser parser) {
return PARSER.apply(parser, null);
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(ID_FIELD.getPreferredName(), this.id);
builder.endObject();
return builder;
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,11 @@
import org.elasticsearch.client.Validatable;
import org.elasticsearch.cluster.metadata.Metadata;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder;

import java.io.IOException;
import java.util.Objects;

public class GetRollupCapsRequest implements Validatable, ToXContentObject {
private static final String ID = "id";
public class GetRollupCapsRequest implements Validatable {

private final String indexPattern;

public GetRollupCapsRequest(final String indexPattern) {
Expand All @@ -43,14 +40,6 @@ public String getIndexPattern() {
return indexPattern;
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(ID, indexPattern);
builder.endObject();
return builder;
}

@Override
public int hashCode() {
return Objects.hash(indexPattern);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,11 @@
import org.elasticsearch.action.support.IndicesOptions;
import org.elasticsearch.client.Validatable;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder;

import java.io.IOException;
import java.util.Arrays;
import java.util.Objects;

public class GetRollupIndexCapsRequest implements Validatable, ToXContentObject {
private static final String INDICES = "indices";
private static final String INDICES_OPTIONS = "indices_options";
public class GetRollupIndexCapsRequest implements Validatable {

private String[] indices;
private IndicesOptions options;
Expand Down Expand Up @@ -60,21 +55,6 @@ public String[] indices() {
return indices;
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
{
builder.array(INDICES, indices);
builder.startObject(INDICES_OPTIONS);
{
options.toXContent(builder, params);
}
builder.endObject();
}
builder.endObject();
return builder;
}

@Override
public int hashCode() {
return Objects.hash(Arrays.hashCode(indices), options);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,39 +18,12 @@
*/
package org.elasticsearch.client.rollup;

import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.test.AbstractXContentTestCase;
import org.junit.Before;
import org.elasticsearch.test.ESTestCase;

import java.io.IOException;

public class DeleteRollupJobRequestTests extends AbstractXContentTestCase<DeleteRollupJobRequest> {

private String jobId;

@Before
public void setUpOptionalId() {
jobId = randomAlphaOfLengthBetween(1, 10);
}

@Override
protected DeleteRollupJobRequest createTestInstance() {
return new DeleteRollupJobRequest(jobId);
}

@Override
protected DeleteRollupJobRequest doParseInstance(final XContentParser parser) throws IOException {
return DeleteRollupJobRequest.fromXContent(parser);
}

@Override
protected boolean supportsUnknownFields() {
return false;
}
public class DeleteRollupJobRequestTests extends ESTestCase {

public void testRequireConfiguration() {
final NullPointerException e = expectThrows(NullPointerException.class, ()-> new DeleteRollupJobRequest(null));
assertEquals("id parameter must not be null", e.getMessage());
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ public void dispatchBadRequest(final RestChannel channel, final ThreadContext th
}

private void dispatchRequest(RestRequest request, RestChannel channel, RestHandler handler) throws Exception {
final int contentLength = request.content().length();
final int contentLength = request.contentLength();
if (contentLength > 0) {
final XContentType xContentType = request.getXContentType();
if (xContentType == null) {
Expand Down
10 changes: 5 additions & 5 deletions server/src/main/java/org/elasticsearch/rest/RestRequest.java
Original file line number Diff line number Diff line change
Expand Up @@ -209,15 +209,15 @@ public final String path() {
}

public boolean hasContent() {
return content(false).length() > 0;
return contentLength() > 0;
}

public BytesReference content() {
return content(true);
public int contentLength() {
return httpRequest.content().length();
}

protected BytesReference content(final boolean contentConsumed) {
this.contentConsumed = this.contentConsumed | contentConsumed;
public BytesReference content() {
this.contentConsumed = true;
return httpRequest.content();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,23 +19,29 @@

package org.elasticsearch.action.admin.indices.forcemerge;

import org.elasticsearch.client.node.NodeClient;
import org.apache.lucene.util.SetOnce;
import org.elasticsearch.common.bytes.BytesArray;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.common.xcontent.XContentType;
import org.elasticsearch.common.xcontent.json.JsonXContent;
import org.elasticsearch.rest.AbstractRestChannel;
import org.elasticsearch.rest.RestChannel;
import org.elasticsearch.rest.RestRequest;
import org.elasticsearch.rest.RestResponse;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.rest.action.admin.indices.RestForceMergeAction;
import org.elasticsearch.test.rest.FakeRestChannel;
import org.elasticsearch.test.rest.FakeRestRequest;
import org.elasticsearch.test.rest.RestActionTestCase;
import org.junit.Before;

import java.util.HashMap;
import java.util.Map;

import static org.hamcrest.Matchers.equalTo;
import static org.mockito.Mockito.mock;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.notNullValue;

public class RestForceMergeActionTests extends RestActionTestCase {

Expand All @@ -44,16 +50,27 @@ public void setUpAction() {
controller().registerHandler(new RestForceMergeAction());
}


public void testBodyRejection() throws Exception {
final RestForceMergeAction handler = new RestForceMergeAction();
String json = JsonXContent.contentBuilder().startObject().field("max_num_segments", 1).endObject().toString();
final FakeRestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY)
.withContent(new BytesArray(json), XContentType.JSON)
.withMethod(RestRequest.Method.POST)
.withPath("/_forcemerge")
.build();
IllegalArgumentException e = expectThrows(IllegalArgumentException.class,
() -> handler.handleRequest(request, new FakeRestChannel(request, randomBoolean(), 1), mock(NodeClient.class)));
assertThat(e.getMessage(), equalTo("request [GET /_forcemerge] does not support having a body"));

final SetOnce<RestResponse> responseSetOnce = new SetOnce<>();
dispatchRequest(request, new AbstractRestChannel(request, true) {
@Override
public void sendResponse(RestResponse response) {
responseSetOnce.set(response);
}
});

final RestResponse response = responseSetOnce.get();
assertThat(response, notNullValue());
assertThat(response.status(), is(RestStatus.BAD_REQUEST));
assertThat(response.content().utf8ToString(), containsString("request [POST /_forcemerge] does not support having a body"));
}

public void testDeprecationMessage() {
Expand All @@ -75,4 +92,9 @@ public void testDeprecationMessage() {
assertWarnings("setting only_expunge_deletes and max_num_segments at the same time is deprecated " +
"and will be rejected in a future version");
}

protected void dispatchRequest(final RestRequest request, final RestChannel channel) {
ThreadContext threadContext = new ThreadContext(Settings.EMPTY);
controller().dispatchRequest(request, channel, threadContext);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentType;
import org.elasticsearch.common.xcontent.yaml.YamlXContent;
import org.elasticsearch.core.internal.io.IOUtils;
Expand Down Expand Up @@ -482,6 +483,38 @@ public void testDispatchBadRequest() {
assertThat(channel.getRestResponse().content().utf8ToString(), containsString("bad request"));
}

public void testDoesNotConsumeContent() throws Exception {
final RestRequest.Method method = randomFrom(RestRequest.Method.values());
restController.registerHandler(method, "/notconsumed", new RestHandler() {
@Override
public void handleRequest(RestRequest request, RestChannel channel, NodeClient client) throws Exception {
channel.sendResponse(new BytesRestResponse(RestStatus.OK, BytesRestResponse.TEXT_CONTENT_TYPE, BytesArray.EMPTY));
}

@Override
public boolean canTripCircuitBreaker() {
return false;
}
});

final XContentBuilder content = XContentBuilder.builder(randomFrom(XContentType.values()).xContent())
.startObject().field("field", "value").endObject();
final FakeRestRequest restRequest = new FakeRestRequest.Builder(xContentRegistry())
.withPath("/notconsumed")
.withMethod(method)
.withContent(BytesReference.bytes(content), content.contentType())
.build();

final AssertingChannel channel = new AssertingChannel(restRequest, true, RestStatus.OK);
assertFalse(channel.getSendResponseCalled());
assertFalse(restRequest.isContentConsumed());

restController.dispatchRequest(restRequest, channel, new ThreadContext(Settings.EMPTY));

assertTrue(channel.getSendResponseCalled());
assertFalse("RestController must not consume request content", restRequest.isContentConsumed());
}

public void testDispatchBadRequestUnknownCause() {
final FakeRestRequest fakeRestRequest = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).build();
final AssertingChannel channel = new AssertingChannel(fakeRestRequest, true, RestStatus.BAD_REQUEST);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,10 @@ public void testHasContentDoesNotConsumesContent() {
runConsumesContentTest(RestRequest::hasContent, false);
}

public void testContentLengthDoesNotConsumesContent() {
runConsumesContentTest(RestRequest::contentLength, false);
}

private <T extends Exception> void runConsumesContentTest(
final CheckedConsumer<RestRequest, T> consumer, final boolean expected) {
final HttpRequest httpRequest = mock(HttpRequest.class);
Expand Down
Loading

0 comments on commit 7096856

Please sign in to comment.