Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SQL: forward warning headers to JDBC driver #84499

Merged
merged 11 commits into from
Mar 15, 2022
5 changes: 5 additions & 0 deletions docs/changelog/84499.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 84499
summary: Forward warning headers to JDBC driver
area: SQL
type: enhancement
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,6 @@ default int columnSize() {
int batchSize();

void close() throws SQLException;

List<String> warnings();
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,24 @@ class DefaultCursor implements Cursor {

private final List<JdbcColumnInfo> columnInfos;
private List<List<Object>> rows;
private final List<String> warnings;
private int row = -1;
private String cursor;

DefaultCursor(JdbcHttpClient client, String cursor, List<JdbcColumnInfo> columnInfos, List<List<Object>> rows, RequestMeta meta) {
DefaultCursor(
JdbcHttpClient client,
String cursor,
List<JdbcColumnInfo> columnInfos,
List<List<Object>> rows,
RequestMeta meta,
List<String> warnings
) {
this.client = client;
this.meta = meta;
this.cursor = cursor;
this.columnInfos = columnInfos;
this.rows = rows;
this.warnings = warnings;
}

@Override
Expand Down Expand Up @@ -67,4 +76,8 @@ public void close() throws SQLException {
client.queryClose(cursor);
}
}

public List<String> warnings() {
return warnings;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import java.sql.RowIdLifetime;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;

import static java.sql.JDBCType.BIGINT;
Expand Down Expand Up @@ -1420,5 +1421,10 @@ public int batchSize() {
public void close() throws SQLException {
// this cursor doesn't hold any resource - no need to clean up
}

@Override
public List<String> warnings() {
return Collections.emptyList();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,15 @@ Cursor query(String sql, List<SqlTypedParamValue> params, RequestMeta meta) thro
conCfg.indexIncludeFrozen(),
conCfg.binaryCommunication()
);
SqlQueryResponse response = httpClient.query(sqlRequest);
return new DefaultCursor(this, response.cursor(), toJdbcColumnInfo(response.columns()), response.rows(), meta);
Tuple<SqlQueryResponse, List<String>> response = httpClient.query(sqlRequest);
return new DefaultCursor(
this,
response.v1().cursor(),
toJdbcColumnInfo(response.v1().columns()),
response.v1().rows(),
meta,
response.v2()
);
}

/**
Expand All @@ -91,7 +98,7 @@ Tuple<String, List<List<Object>>> nextPage(String cursor, RequestMeta meta) thro
new RequestInfo(Mode.JDBC),
conCfg.binaryCommunication()
);
SqlQueryResponse response = httpClient.query(sqlRequest);
SqlQueryResponse response = httpClient.query(sqlRequest).v1();
return new Tuple<>(response.cursor(), response.rows());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -585,7 +585,15 @@ public InputStream getBinaryStream(String columnLabel) throws SQLException {
@Override
public SQLWarning getWarnings() throws SQLException {
checkOpen();
return null;
SQLWarning sqlWarning = null;
for (String warning : cursor.warnings()) {
if (sqlWarning == null) {
sqlWarning = new SQLWarning(warning);
} else {
sqlWarning.setNextWarning(new SQLWarning(warning));
}
}
return sqlWarning;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,73 @@

package org.elasticsearch.xpack.sql.qa.jdbc;

import org.elasticsearch.Version;
import org.junit.Before;

import java.io.IOException;
import java.sql.Connection;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.SQLWarning;
import java.sql.Statement;
import java.util.LinkedList;
import java.util.List;
import java.util.Properties;

import static org.elasticsearch.xpack.sql.qa.jdbc.JdbcTestUtils.JDBC_DRIVER_VERSION;
import static org.hamcrest.Matchers.containsInAnyOrder;
import static org.hamcrest.Matchers.containsString;

public abstract class JdbcWarningsTestCase extends JdbcIntegrationTestCase {

public void testDeprecationWarningsDoNotReachJdbcDriver() throws Exception {
private static Version WARNING_HANDLING_ADDED_VERSION = Version.V_8_2_0;

@Before
public void setupData() throws IOException {
index("test_data", b -> b.field("foo", 1));
}

public void testNoWarnings() throws SQLException {
try (Connection connection = esJdbc(); Statement statement = connection.createStatement()) {
ResultSet rs = statement.executeQuery("SELECT * FROM FROZEN \"test_*\"");
ResultSet rs = statement.executeQuery("SELECT * FROM test_data");
assertNull(rs.getWarnings());
}
}

public void testSingleDeprecationWarning() throws SQLException {
assumeTrue("Driver does not yet handle deprecation warnings", JDBC_DRIVER_VERSION.onOrAfter(WARNING_HANDLING_ADDED_VERSION));

try (Connection connection = esJdbc(); Statement statement = connection.createStatement()) {
ResultSet rs = statement.executeQuery("SELECT * FROM FROZEN test_data");
SQLWarning warning = rs.getWarnings();
assertThat(warning.getMessage(), containsString("[FROZEN] syntax is deprecated because frozen indices have been deprecated."));
assertNull(warning.getNextWarning());
}
}

public void testMultipleDeprecationWarnings() throws SQLException {
assumeTrue("Driver does not yet handle deprecation warnings", JDBC_DRIVER_VERSION.onOrAfter(WARNING_HANDLING_ADDED_VERSION));

Properties props = connectionProperties();
props.setProperty("index.include.frozen", "true");

try (Connection connection = esJdbc(props); Statement statement = connection.createStatement()) {
ResultSet rs = statement.executeQuery("SELECT * FROM FROZEN test_data");
List<String> warnings = new LinkedList<>();
SQLWarning warning = rs.getWarnings();
while (warning != null) {
warnings.add(warning.getMessage());
warning = warning.getNextWarning();
}

assertThat(
warnings,
containsInAnyOrder(
containsString("[FROZEN] syntax is deprecated because frozen indices have been deprecated."),
containsString("[index_include_frozen] parameter is deprecated because frozen indices have been deprecated.")
)
);
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@
import java.io.InputStream;
import java.security.PrivilegedAction;
import java.sql.SQLException;
import java.util.Collections;
import java.util.List;
import java.util.function.Function;

import static java.util.Collections.emptyList;
Expand Down Expand Up @@ -82,10 +84,10 @@ public SqlQueryResponse basicQuery(String query, int fetchSize, boolean fieldMul
false,
cfg.binaryCommunication()
);
return query(sqlRequest);
return query(sqlRequest).v1();
}

public SqlQueryResponse query(SqlQueryRequest sqlRequest) throws SQLException {
public Tuple<SqlQueryResponse, List<String>> query(SqlQueryRequest sqlRequest) throws SQLException {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not too fond of working with tuples everywhere I have to pass the warnings but this was the least intrusive way to get the warnings from the response headers. Since it's an internal client it probably doesn't matter too much but suggestions for solving this differently are welcome.

return post(CoreProtocol.SQL_QUERY_REST_ENDPOINT, sqlRequest, Payloads::parseQueryResponse);
}

Expand All @@ -98,28 +100,28 @@ public SqlQueryResponse nextPage(String cursor) throws SQLException {
new RequestInfo(Mode.CLI),
cfg.binaryCommunication()
);
return post(CoreProtocol.SQL_QUERY_REST_ENDPOINT, sqlRequest, Payloads::parseQueryResponse);
return post(CoreProtocol.SQL_QUERY_REST_ENDPOINT, sqlRequest, Payloads::parseQueryResponse).v1();
}

public boolean queryClose(String cursor, Mode mode) throws SQLException {
SqlClearCursorResponse response = post(
Tuple<SqlClearCursorResponse, List<String>> response = post(
CoreProtocol.CLEAR_CURSOR_REST_ENDPOINT,
new SqlClearCursorRequest(cursor, new RequestInfo(mode)),
Payloads::parseClearCursorResponse
);
return response.isSucceeded();
return response.v1().isSucceeded();
}

@SuppressWarnings({ "removal" })
private <Request extends AbstractSqlRequest, Response> Response post(
private <Request extends AbstractSqlRequest, Response> Tuple<Response, List<String>> post(
String path,
Request request,
CheckedFunction<JsonParser, Response, IOException> responseParser
) throws SQLException {
byte[] requestBytes = toContent(request);
String query = "error_trace";
Tuple<ContentType, byte[]> response = java.security.AccessController.doPrivileged(
(PrivilegedAction<ResponseOrException<Tuple<ContentType, byte[]>>>) () -> JreHttpUrlConnection.http(
Tuple<Function<String, List<String>>, byte[]> response = java.security.AccessController.doPrivileged(
(PrivilegedAction<ResponseOrException<Tuple<Function<String, List<String>>, byte[]>>>) () -> JreHttpUrlConnection.http(
path,
query,
cfg,
Expand All @@ -131,7 +133,11 @@ private <Request extends AbstractSqlRequest, Response> Response post(
)
)
).getResponseOrThrowException();
return fromContent(response.v1(), response.v2(), responseParser);
List<String> warnings = response.v1().apply("Warning");
return new Tuple<>(
fromContent(contentType(response.v1()), response.v2(), responseParser),
warnings == null ? Collections.emptyList() : warnings
);
}

@SuppressWarnings({ "removal" })
Expand Down Expand Up @@ -162,15 +168,15 @@ private boolean head(String path, long timeoutInMs) throws SQLException {

@SuppressWarnings({ "removal" })
private <Response> Response get(String path, CheckedFunction<JsonParser, Response, IOException> responseParser) throws SQLException {
Tuple<ContentType, byte[]> response = java.security.AccessController.doPrivileged(
(PrivilegedAction<ResponseOrException<Tuple<ContentType, byte[]>>>) () -> JreHttpUrlConnection.http(
Tuple<Function<String, List<String>>, byte[]> response = java.security.AccessController.doPrivileged(
(PrivilegedAction<ResponseOrException<Tuple<Function<String, List<String>>, byte[]>>>) () -> JreHttpUrlConnection.http(
path,
"error_trace",
cfg,
con -> con.request(null, this::readFrom, "GET")
)
).getResponseOrThrowException();
return fromContent(response.v1(), response.v2(), responseParser);
return fromContent(contentType(response.v1()), response.v2(), responseParser);
}

private <Request extends AbstractSqlRequest> byte[] toContent(Request request) {
Expand All @@ -184,31 +190,35 @@ private <Request extends AbstractSqlRequest> byte[] toContent(Request request) {
}
}

private Tuple<ContentType, byte[]> readFrom(InputStream inputStream, Function<String, String> headers) {
String contentType = headers.apply("Content-Type");
ContentType type = ContentFactory.parseMediaType(contentType);
if (type == null) {
throw new IllegalStateException("Unsupported Content-Type: " + contentType);
}
private Tuple<Function<String, List<String>>, byte[]> readFrom(InputStream inputStream, Function<String, List<String>> headers) {
ByteArrayOutputStream out = new ByteArrayOutputStream();
try {
Streams.copy(inputStream, out);
} catch (Exception ex) {
throw new ClientException("Cannot deserialize response", ex);
}
return new Tuple<>(type, out.toByteArray());
return new Tuple<>(headers, out.toByteArray());

}

private ContentType contentType(Function<String, List<String>> headers) {
List<String> contentTypeHeaders = headers.apply("Content-Type");

String contentType = contentTypeHeaders == null || contentTypeHeaders.isEmpty() ? null : contentTypeHeaders.get(0);
ContentType type = ContentFactory.parseMediaType(contentType);
if (type == null) {
throw new IllegalStateException("Unsupported Content-Type: " + contentType);
} else {
return type;
}
}

private <Response> Response fromContent(
ContentType contentType,
ContentType type,
byte[] bytesReference,
CheckedFunction<JsonParser, Response, IOException> responseParser
) {
try (
InputStream stream = new ByteArrayInputStream(bytesReference);
JsonParser parser = ContentFactory.parser(contentType, stream)
) {
try (InputStream stream = new ByteArrayInputStream(bytesReference); JsonParser parser = ContentFactory.parser(type, stream)) {
return responseParser.apply(parser);
} catch (Exception ex) {
throw new ClientException("Cannot parse response", ex);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import java.net.URI;
import java.net.URISyntaxException;
import java.net.URL;
import java.net.URLConnection;
import java.security.AccessController;
import java.security.PrivilegedAction;
import java.sql.SQLClientInfoException;
Expand All @@ -31,6 +32,8 @@
import java.sql.SQLSyntaxErrorException;
import java.sql.SQLTimeoutException;
import java.util.Base64;
import java.util.LinkedList;
import java.util.List;
import java.util.function.Function;
import java.util.zip.GZIPInputStream;

Expand Down Expand Up @@ -150,15 +153,15 @@ public boolean head() throws ClientException {

public <R> ResponseOrException<R> request(
CheckedConsumer<OutputStream, IOException> doc,
CheckedBiFunction<InputStream, Function<String, String>, R, IOException> parser,
CheckedBiFunction<InputStream, Function<String, List<String>>, R, IOException> parser,
String requestMethod
) throws ClientException {
return request(doc, parser, requestMethod, "application/json");
}

public <R> ResponseOrException<R> request(
CheckedConsumer<OutputStream, IOException> doc,
CheckedBiFunction<InputStream, Function<String, String>, R, IOException> parser,
CheckedBiFunction<InputStream, Function<String, List<String>>, R, IOException> parser,
String requestMethod,
String contentTypeHeader
) throws ClientException {
Expand All @@ -174,7 +177,7 @@ public <R> ResponseOrException<R> request(
}
if (shouldParseBody(con.getResponseCode())) {
try (InputStream stream = getStream(con, con.getInputStream())) {
return new ResponseOrException<>(parser.apply(new BufferedInputStream(stream), con::getHeaderField));
return new ResponseOrException<>(parser.apply(new BufferedInputStream(stream), getHeaderFields(con)));
}
}
return parserError();
Expand All @@ -183,6 +186,23 @@ public <R> ResponseOrException<R> request(
}
}

private Function<String, List<String>> getHeaderFields(URLConnection con) {
return header -> {
// HTTP headers are case-insensitive but the map returned by `URLConnection.getHeaderFields` is case-sensitive.
// The linear scan below replicates the linear case-insensitive lookup of `URLConnection.getHeaderField(String)`.
List<String> values = new LinkedList<>();
int i = 0;
String value = con.getHeaderField(i);
while (value != null) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just noting that the function would stop iterating on an empty header value? This seems to be legal, though very corner case.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I hope that's never the case but the API does not forbid null headers. Changing the implementation to use URLConnection.getHeaderFields() 👍

if (header.equalsIgnoreCase(con.getHeaderFieldKey(i))) {
values.add(value);
}
value = con.getHeaderField(++i);
}
return values;
};
}

private boolean shouldParseBody(int responseCode) {
return responseCode == 200 || responseCode == 201 || responseCode == 202;
}
Expand Down