Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add optional Arrow deserialization support #2632

Merged
merged 1 commit into from
Aug 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/guide/configuration.asciidoc
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,7 @@ The calculation is equal to `min(dead_node_backoff_factor * (2 ** (consecutive_f
[[serializer]]
=== Serializers

Serializers transform bytes on the wire into native Python objects and vice-versa. By default the client ships with serializers for `application/json`, `application/x-ndjson`, `text/*`, and `application/mapbox-vector-tile`.
Serializers transform bytes on the wire into native Python objects and vice-versa. By default the client ships with serializers for `application/json`, `application/x-ndjson`, `text/*`, `application/vnd.apache.arrow.stream` and `application/mapbox-vector-tile`.

You can define custom serializers via the `serializers` parameter:

Expand Down
34 changes: 34 additions & 0 deletions elasticsearch/serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,14 @@
_OrjsonSerializer = None # type: ignore[assignment,misc]


try:
import pyarrow as pa

__all__.append("PyArrowSerializer")
except ImportError:
pa = None


class JsonSerializer(_JsonSerializer):
mimetype: ClassVar[str] = "application/json"

Expand Down Expand Up @@ -114,6 +122,29 @@ def dumps(self, data: bytes) -> bytes:
raise SerializationError(f"Cannot serialize {data!r} into a MapBox vector tile")


if pa is not None:

class PyArrowSerializer(Serializer):
"""PyArrow serializer for deserializing Arrow Stream data."""

mimetype: ClassVar[str] = "application/vnd.apache.arrow.stream"

def loads(self, data: bytes) -> pa.Table:
try:
with pa.ipc.open_stream(data) as reader:
return reader.read_all()
except pa.ArrowException as e:
raise SerializationError(
message=f"Unable to deserialize as Arrow stream: {data!r}",
errors=(e,),
)

def dumps(self, data: Any) -> bytes:
raise SerializationError(
message="Elasticsearch does not accept Arrow input data"
)


DEFAULT_SERIALIZERS: Dict[str, Serializer] = {
JsonSerializer.mimetype: JsonSerializer(),
MapboxVectorTileSerializer.mimetype: MapboxVectorTileSerializer(),
Expand All @@ -122,6 +153,9 @@ def dumps(self, data: bytes) -> bytes:
CompatibilityModeNdjsonSerializer.mimetype: CompatibilityModeNdjsonSerializer(),
}

if pa is not None:
DEFAULT_SERIALIZERS[PyArrowSerializer.mimetype] = PyArrowSerializer()

# Alias for backwards compatibility
JSONSerializer = JsonSerializer

Expand Down
2 changes: 1 addition & 1 deletion noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def lint(session):
session.run("flake8", *SOURCE_FILES)
session.run("python", "utils/license-headers.py", "check", *SOURCE_FILES)

session.install(".[async,requests,orjson,vectorstore_mmr]", env=INSTALL_ENV)
session.install(".[async,requests,orjson,pyarrow,vectorstore_mmr]", env=INSTALL_ENV)

# Run mypy on the package and then the type examples separately for
# the two different mypy use-cases, ourselves and our users.
Expand Down
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ dependencies = [
async = ["aiohttp>=3,<4"]
requests = ["requests>=2.4.0, !=2.32.2, <3.0.0"]
orjson = ["orjson>=3"]
pyarrow = ["pyarrow>=1"]
# Maximal Marginal Relevance (MMR) for search results
vectorstore_mmr = ["numpy>=1", "simsimd>=3"]
dev = [
Expand All @@ -69,6 +70,7 @@ dev = [
"orjson",
"numpy",
"simsimd",
"pyarrow",
"pandas",
"mapbox-vector-tile",
]
Expand Down
2 changes: 2 additions & 0 deletions test_elasticsearch/test_client/test_deprecated_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ class CustomSerializer(JsonSerializer):
"application/x-ndjson",
"application/json",
"text/*",
"application/vnd.apache.arrow.stream",
"application/vnd.elasticsearch+json",
"application/vnd.elasticsearch+x-ndjson",
}
Expand All @@ -154,6 +155,7 @@ class CustomSerializer(JsonSerializer):
"application/x-ndjson",
"application/json",
"text/*",
"application/vnd.apache.arrow.stream",
"application/vnd.elasticsearch+json",
"application/vnd.elasticsearch+x-ndjson",
"application/cbor",
Expand Down
3 changes: 3 additions & 0 deletions test_elasticsearch/test_client/test_serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ class CustomSerializer:
"application/json",
"text/*",
"application/x-ndjson",
"application/vnd.apache.arrow.stream",
"application/vnd.mapbox-vector-tile",
"application/vnd.elasticsearch+json",
"application/vnd.elasticsearch+x-ndjson",
Expand Down Expand Up @@ -121,6 +122,7 @@ class CustomSerializer:
"application/json",
"text/*",
"application/x-ndjson",
"application/vnd.apache.arrow.stream",
"application/vnd.mapbox-vector-tile",
"application/vnd.elasticsearch+json",
"application/vnd.elasticsearch+x-ndjson",
Expand All @@ -140,6 +142,7 @@ class CustomSerializer:
"application/json",
"text/*",
"application/x-ndjson",
"application/vnd.apache.arrow.stream",
"application/vnd.mapbox-vector-tile",
"application/vnd.elasticsearch+json",
"application/vnd.elasticsearch+x-ndjson",
Expand Down
27 changes: 26 additions & 1 deletion test_elasticsearch/test_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from datetime import datetime
from decimal import Decimal

import pyarrow as pa
import pytest

try:
Expand All @@ -31,7 +32,12 @@

from elasticsearch import Elasticsearch
from elasticsearch.exceptions import SerializationError
from elasticsearch.serializer import JSONSerializer, OrjsonSerializer, TextSerializer
from elasticsearch.serializer import (
JSONSerializer,
OrjsonSerializer,
PyArrowSerializer,
TextSerializer,
)

requires_numpy_and_pandas = pytest.mark.skipif(
np is None or pd is None, reason="Test requires numpy and pandas to be available"
Expand Down Expand Up @@ -157,6 +163,25 @@ def test_serializes_pandas_category(json_serializer):
assert b'{"d":[1,2,3]}' == json_serializer.dumps({"d": cat})


def test_pyarrow_loads():
data = [
pa.array([1, 2, 3, 4]),
pa.array(["foo", "bar", "baz", None]),
pa.array([True, None, False, True]),
]
batch = pa.record_batch(data, names=["f0", "f1", "f2"])
sink = pa.BufferOutputStream()
with pa.ipc.new_stream(sink, batch.schema) as writer:
writer.write_batch(batch)

serializer = PyArrowSerializer()
assert serializer.loads(sink.getvalue()).to_pydict() == {
"f0": [1, 2, 3, 4],
"f1": ["foo", "bar", "baz", None],
"f2": [True, None, False, True],
}


def test_json_raises_serialization_error_on_dump_error(json_serializer):
with pytest.raises(SerializationError):
json_serializer.dumps(object())
Expand Down
Loading