Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GH-44066: [Python] Add Python wrapper for JsonExtensionType #44070

Open
wants to merge 23 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 19 additions & 12 deletions cpp/src/arrow/extension/json.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,17 +28,13 @@
namespace arrow::extension {

bool JsonExtensionType::ExtensionEquals(const ExtensionType& other) const {
return other.extension_name() == this->extension_name();
return other.extension_name() == this->extension_name() &&
other.storage_type()->Equals(storage_type_);
}

Result<std::shared_ptr<DataType>> JsonExtensionType::Deserialize(
std::shared_ptr<DataType> storage_type, const std::string& serialized) const {
if (storage_type->id() != Type::STRING && storage_type->id() != Type::STRING_VIEW &&
storage_type->id() != Type::LARGE_STRING) {
return Status::Invalid("Invalid storage type for JsonExtensionType: ",
storage_type->ToString());
}
return std::make_shared<JsonExtensionType>(storage_type);
return JsonExtensionType::Make(std::move(storage_type));
}

std::string JsonExtensionType::Serialize() const { return ""; }
Expand All @@ -51,11 +47,22 @@ std::shared_ptr<Array> JsonExtensionType::MakeArray(
return std::make_shared<ExtensionArray>(data);
}

std::shared_ptr<DataType> json(const std::shared_ptr<DataType> storage_type) {
ARROW_CHECK(storage_type->id() != Type::STRING ||
storage_type->id() != Type::STRING_VIEW ||
storage_type->id() != Type::LARGE_STRING);
return std::make_shared<JsonExtensionType>(storage_type);
bool JsonExtensionType::IsSupportedStorageType(Type::type type_id) {
return type_id == Type::STRING || type_id == Type::STRING_VIEW ||
type_id == Type::LARGE_STRING;
}

Result<std::shared_ptr<DataType>> JsonExtensionType::Make(
std::shared_ptr<DataType> storage_type) {
if (!IsSupportedStorageType(storage_type->id())) {
return Status::Invalid("Invalid storage type for JsonExtensionType: ",
storage_type->ToString());
}
return std::make_shared<JsonExtensionType>(std::move(storage_type));
}

std::shared_ptr<DataType> json(std::shared_ptr<DataType> storage_type) {
return JsonExtensionType::Make(std::move(storage_type)).ValueOrDie();
}

} // namespace arrow::extension
4 changes: 4 additions & 0 deletions cpp/src/arrow/extension/json.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@ class ARROW_EXPORT JsonExtensionType : public ExtensionType {

std::shared_ptr<Array> MakeArray(std::shared_ptr<ArrayData> data) const override;

static Result<std::shared_ptr<DataType>> Make(std::shared_ptr<DataType> storage_type);

static bool IsSupportedStorageType(Type::type type_id);

private:
std::shared_ptr<DataType> storage_type_;
};
Expand Down
14 changes: 14 additions & 0 deletions cpp/src/arrow/extension/json_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -80,4 +80,18 @@ TEST_F(TestJsonExtensionType, InvalidUTF8) {
}
}

TEST_F(TestJsonExtensionType, StorageTypeValidation) {
ASSERT_TRUE(json(utf8())->Equals(json(utf8())));
ASSERT_FALSE(json(large_utf8())->Equals(json(utf8())));
ASSERT_FALSE(json(utf8_view())->Equals(json(utf8())));
ASSERT_FALSE(json(utf8_view())->Equals(json(large_utf8())));

for (const auto& storage_type : {int16(), binary(), float64(), null()}) {
ASSERT_RAISES_WITH_MESSAGE(Invalid,
"Invalid: Invalid storage type for JsonExtensionType: " +
storage_type->ToString(),
extension::JsonExtensionType::Make(storage_type));
}
}

} // namespace arrow
26 changes: 19 additions & 7 deletions cpp/src/parquet/arrow/arrow_schema_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -757,23 +757,35 @@ TEST_F(TestConvertParquetSchema, ParquetSchemaArrowExtensions) {

{
// Parquet file does not contain Arrow schema.
// If Arrow extensions are enabled, both fields should be treated as json() extension
// fields.
// If Arrow extensions are enabled, fields will be interpreted as json(utf8())
// extension fields.
ArrowReaderProperties props;
props.set_arrow_extensions_enabled(true);
auto arrow_schema = ::arrow::schema(
{::arrow::field("json_1", ::arrow::extension::json(), true),
::arrow::field("json_2", ::arrow::extension::json(::arrow::large_utf8()),
true)});
::arrow::field("json_2", ::arrow::extension::json(::arrow::utf8()), true)});
std::shared_ptr<KeyValueMetadata> metadata{};
ASSERT_OK(ConvertSchema(parquet_fields, metadata, props));
CheckFlatSchema(arrow_schema);

// If original data was e.g. json(large_utf8()) it will be interpreted as json(utf8())
// in absence of Arrow schema.
arrow_schema = ::arrow::schema(
{::arrow::field("json_1", ::arrow::extension::json(), true),
::arrow::field("json_2", ::arrow::extension::json(::arrow::large_utf8()),
true)});
metadata = std::shared_ptr<KeyValueMetadata>{};
ASSERT_OK(ConvertSchema(parquet_fields, metadata, props));
EXPECT_TRUE(result_schema_->field(1)->type()->Equals(
::arrow::extension::json(::arrow::utf8())));
EXPECT_FALSE(
result_schema_->field(1)->type()->Equals(arrow_schema->field(1)->type()));
}

{
// Parquet file contains Arrow schema.
// Both json_1 and json_2 should be returned as a json() field
// even though extensions are not enabled.
// json_1 and json_2 will be interpreted as json(utf8()) and json(large_utf8())
// fields even though extensions are not enabled.
ArrowReaderProperties props;
props.set_arrow_extensions_enabled(false);
std::shared_ptr<KeyValueMetadata> field_metadata =
Expand All @@ -791,7 +803,7 @@ TEST_F(TestConvertParquetSchema, ParquetSchemaArrowExtensions) {

{
// Parquet file contains Arrow schema. Extensions are enabled.
// Both json_1 and json_2 should be returned as a json() field
// json_1 and json_2 will be interpreted as json(utf8()) and json(large_utf8()).
ArrowReaderProperties props;
props.set_arrow_extensions_enabled(true);
std::shared_ptr<KeyValueMetadata> field_metadata =
Expand Down
10 changes: 6 additions & 4 deletions cpp/src/parquet/arrow/schema.cc
Original file line number Diff line number Diff line change
Expand Up @@ -997,9 +997,8 @@ Result<bool> ApplyOriginalMetadata(const Field& origin_field, SchemaField* infer
const auto& ex_type = checked_cast<const ::arrow::ExtensionType&>(*origin_type);
if (inferred_type->id() != ::arrow::Type::EXTENSION &&
ex_type.extension_name() == std::string("arrow.json") &&
(inferred_type->id() == ::arrow::Type::STRING ||
inferred_type->id() == ::arrow::Type::LARGE_STRING ||
inferred_type->id() == ::arrow::Type::STRING_VIEW)) {
::arrow::extension::JsonExtensionType::IsSupportedStorageType(
inferred_type->id())) {
// Schema mismatch.
//
// Arrow extensions are DISABLED in Parquet.
Expand All @@ -1017,7 +1016,10 @@ Result<bool> ApplyOriginalMetadata(const Field& origin_field, SchemaField* infer

// Restore extension type, if the storage type is the same as inferred
// from the Parquet type
if (ex_type.storage_type()->Equals(*inferred->field->type())) {
if (ex_type.storage_type()->Equals(*inferred->field->type()) ||
((ex_type.extension_name() == "arrow.json") &&
::arrow::extension::JsonExtensionType::IsSupportedStorageType(
inferred->field->type()->storage_id()))) {
inferred->field = inferred->field->WithType(origin_type);
}
}
Expand Down
8 changes: 4 additions & 4 deletions python/pyarrow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def print_entry(label, value):
union, sparse_union, dense_union,
dictionary,
run_end_encoded,
bool8, fixed_shape_tensor, opaque, uuid,
bool8, fixed_shape_tensor, json_, opaque, uuid,
field,
type_for_alias,
DataType, DictionaryType, StructType,
Expand All @@ -183,7 +183,7 @@ def print_entry(label, value):
FixedSizeBinaryType, Decimal128Type, Decimal256Type,
BaseExtensionType, ExtensionType,
RunEndEncodedType, Bool8Type, FixedShapeTensorType,
OpaqueType, UuidType,
JsonType, OpaqueType, UuidType,
PyExtensionType, UnknownExtensionType,
register_extension_type, unregister_extension_type,
DictionaryMemo,
Expand Down Expand Up @@ -218,7 +218,7 @@ def print_entry(label, value):
MonthDayNanoIntervalArray,
Decimal128Array, Decimal256Array, StructArray, ExtensionArray,
RunEndEncodedArray, Bool8Array, FixedShapeTensorArray,
OpaqueArray, UuidArray,
JsonArray, OpaqueArray, UuidArray,
scalar, NA, _NULL as NULL, Scalar,
NullScalar, BooleanScalar,
Int8Scalar, Int16Scalar, Int32Scalar, Int64Scalar,
Expand All @@ -236,7 +236,7 @@ def print_entry(label, value):
FixedSizeBinaryScalar, DictionaryScalar,
MapScalar, StructScalar, UnionScalar,
RunEndEncodedScalar, Bool8Scalar, ExtensionScalar,
FixedShapeTensorScalar, OpaqueScalar, UuidScalar)
FixedShapeTensorScalar, JsonScalar, OpaqueScalar, UuidScalar)

# Buffers, allocation
from pyarrow.lib import (DeviceAllocationType, Device, MemoryManager,
Expand Down
27 changes: 27 additions & 0 deletions python/pyarrow/array.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -4343,6 +4343,33 @@ cdef class ExtensionArray(Array):
return result


class JsonArray(ExtensionArray):
"""
Concrete class for Arrow arrays of JSON data type.

This does not guarantee that the JSON data actually
is valid JSON.

jorisvandenbossche marked this conversation as resolved.
Show resolved Hide resolved
Examples
--------
Define the extension type for JSON array

>>> import pyarrow as pa
>>> json_type = pa.json_(pa.large_utf8())

Create an extension array

>>> arr = [None, '{ "id":30, "values":["a", "b"] }']
>>> storage = pa.array(arr, pa.large_utf8())
>>> pa.ExtensionArray.from_storage(json_type, storage)
<pyarrow.lib.JsonArray object at ...>
[
null,
"{ "id":30, "values":["a", "b"] }"
]
"""


class UuidArray(ExtensionArray):
"""
Concrete class for Arrow arrays of UUID data type.
Expand Down
7 changes: 7 additions & 0 deletions python/pyarrow/includes/libarrow.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -2867,6 +2867,13 @@ cdef extern from "arrow/extension_type.h" namespace "arrow":
shared_ptr[CArray] storage()


cdef extern from "arrow/extension/json.h" namespace "arrow::extension" nogil:
cdef cppclass CJsonType" arrow::extension::JsonExtensionType"(CExtensionType):

@staticmethod
CResult[shared_ptr[CDataType]] Make(shared_ptr[CDataType]& storage_type)


cdef extern from "arrow/extension/uuid.h" namespace "arrow::extension" nogil:
cdef cppclass CUuidType" arrow::extension::UuidType"(CExtensionType):

Expand Down
5 changes: 5 additions & 0 deletions python/pyarrow/lib.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,11 @@ cdef class UuidType(BaseExtensionType):
cdef:
const CUuidType* uuid_ext_type

cdef class JsonType(BaseExtensionType):
cdef:
const CJsonType* json_ext_type

rok marked this conversation as resolved.
Show resolved Hide resolved

cdef class PyExtensionType(ExtensionType):
pass

Expand Down
2 changes: 2 additions & 0 deletions python/pyarrow/public-api.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,8 @@ cdef api object pyarrow_wrap_data_type(
out = OpaqueType.__new__(OpaqueType)
elif extension_name == b"arrow.uuid":
out = UuidType.__new__(UuidType)
elif extension_name == b"arrow.json":
out = JsonType.__new__(JsonType)
else:
out = BaseExtensionType.__new__(BaseExtensionType)
else:
Expand Down
6 changes: 6 additions & 0 deletions python/pyarrow/scalar.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -1044,6 +1044,12 @@ cdef class ExtensionScalar(Scalar):
return pyarrow_wrap_scalar(<shared_ptr[CScalar]> sp_scalar)


class JsonScalar(ExtensionScalar):
"""
Concrete class for JSON extension scalar.
"""


class UuidScalar(ExtensionScalar):
"""
Concrete class for Uuid extension scalar.
Expand Down
13 changes: 13 additions & 0 deletions python/pyarrow/tests/parquet/test_data_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,3 +510,16 @@ def test_large_binary_overflow():
pa.ArrowInvalid,
match="Parquet cannot store strings with size 2GB or more"):
_write_table(table, writer, use_dictionary=use_dictionary)


@pytest.mark.parametrize("storage_type", (
pa.string(), pa.large_string()))
def test_json_extension_type(storage_type):
data = ['{"a": 1}', '{"b": 2}', None]
storage = pa.array(data, type=storage_type)
json_type = pa.json_(storage_type)

arr = pa.ExtensionArray.from_storage(json_type, storage)
table = pa.table([arr], names=["ext"])

_simple_table_roundtrip(table)
54 changes: 54 additions & 0 deletions python/pyarrow/tests/test_extension_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -1926,3 +1926,57 @@ def test_bool8_scalar():
assert pa.scalar(1, type=pa.bool8()).as_py() is True
assert pa.scalar(2, type=pa.bool8()).as_py() is True
assert pa.scalar(None, type=pa.bool8()).as_py() is None


@pytest.mark.parametrize("storage_type", (
pa.string(), pa.large_string(), pa.string_view()))
def test_json(storage_type, pickle_module):
data = ['{"a": 1}', '{"b": 2}', None]
storage = pa.array(data, type=storage_type)
json_type = pa.json_(storage_type)
json_arr_class = json_type.__arrow_ext_class__()

assert pa.json_() == pa.json_(pa.utf8())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So this means that the storage type is not taken into account for checking equality of the JSON extension type? Should it?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I meant to test that pa.json_() sets storage to pa.utf8() by default and I'd keep this test. However as noted in #13901 (review) disregarding storage type here is not desirable.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I meant to test that pa.json_() sets storage to pa.utf8() by default and I'd keep this test.

Ah, yes of course, that's good to test

assert json_type.extension_name == "arrow.json"
assert json_type.storage_type == storage_type
assert json_type.__class__ is pa.JsonType
jorisvandenbossche marked this conversation as resolved.
Show resolved Hide resolved

assert json_type == pa.json_(storage_type)
assert json_type != storage_type

array = pa.ExtensionArray.from_storage(json_type, storage)
assert isinstance(array, pa.JsonArray)

assert array.to_pylist() == data
assert array[0].as_py() == data[0]
assert array[2].as_py() is None

# Pickle roundtrip
result = pickle_module.loads(pickle_module.dumps(json_type))
assert result == json_type

# IPC roundtrip
buf = ipc_write_batch(pa.RecordBatch.from_arrays([array], ["ext"]))
batch = ipc_read_batch(buf)
reconstructed_array = batch.column(0)
assert reconstructed_array.type == json_type
assert reconstructed_array == array
assert isinstance(array, json_arr_class)

assert json_type.__arrow_ext_scalar_class__() == pa.JsonScalar
assert isinstance(array[0], pa.JsonScalar)

# cast storage -> extension type
result = storage.cast(json_type)
assert result == array

# cast extension type -> storage type
inner = array.cast(storage_type)
assert inner == storage
rok marked this conversation as resolved.
Show resolved Hide resolved

for storage_type in (pa.int32(), pa.large_binary(), pa.float32()):
with pytest.raises(
pa.ArrowInvalid,
match="Invalid storage type for JsonExtensionType: " +
str(storage_type)):
pa.json_(storage_type)
3 changes: 3 additions & 0 deletions python/pyarrow/tests/test_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,9 @@ def test_set_timezone_db_path_non_windows():
pa.Bool8Array,
pa.Bool8Scalar,
pa.Bool8Type,
pa.JsonArray,
pa.JsonScalar,
pa.JsonType,
])
def test_extension_type_constructor_errors(klass):
# ARROW-2638: prevent calling extension class constructors directly
Expand Down
Loading
Loading