Skip to content

Commit

Permalink
Pyarrow IO property for configuring large v small types on read (#986)
Browse files Browse the repository at this point in the history
* upyarrow IO property for configuring large v small types on read

* tests

* adopt feedback

* use property_as_bool

* fix

* docs

* nits

* respect flag on promotion

* lint

---------

Co-authored-by: Sung Yun <107272191+syun64@users.noreply.github.com>
  • Loading branch information
sungwy and sungwy committed Aug 7, 2024
1 parent ba85dd1 commit 8aeab49
Show file tree
Hide file tree
Showing 5 changed files with 174 additions and 11 deletions.
10 changes: 10 additions & 0 deletions mkdocs/docs/configuration.md
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,16 @@ For the FileIO there are several configuration options available:

<!-- markdown-link-check-enable-->

### PyArrow

<!-- markdown-link-check-disable -->

| Key | Example | Description |
| ------------------------------- | ------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| pyarrow.use-large-types-on-read | True | Use large PyArrow types i.e. [large_string](https://arrow.apache.org/docs/python/generated/pyarrow.large_string.html), [large_binary](https://arrow.apache.org/docs/python/generated/pyarrow.large_binary.html) and [large_list](https://arrow.apache.org/docs/python/generated/pyarrow.large_list.html) field types on table scans. The default value is True. |

<!-- markdown-link-check-enable-->

## Catalogs

PyIceberg currently has native catalog type support for REST, SQL, Hive, Glue and DynamoDB.
Expand Down
1 change: 1 addition & 0 deletions pyiceberg/io/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@
GCS_ENDPOINT = "gcs.endpoint"
GCS_DEFAULT_LOCATION = "gcs.default-bucket-location"
GCS_VERSION_AWARE = "gcs.version-aware"
PYARROW_USE_LARGE_TYPES_ON_READ = "pyarrow.use-large-types-on-read"


@runtime_checkable
Expand Down
81 changes: 71 additions & 10 deletions pyiceberg/io/pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@
HDFS_KERB_TICKET,
HDFS_PORT,
HDFS_USER,
PYARROW_USE_LARGE_TYPES_ON_READ,
S3_ACCESS_KEY_ID,
S3_CONNECT_TIMEOUT,
S3_ENDPOINT,
Expand Down Expand Up @@ -158,7 +159,7 @@
from pyiceberg.utils.config import Config
from pyiceberg.utils.datetime import millis_to_datetime
from pyiceberg.utils.deprecated import deprecated
from pyiceberg.utils.properties import get_first_property_value, property_as_int
from pyiceberg.utils.properties import get_first_property_value, property_as_bool, property_as_int
from pyiceberg.utils.singleton import Singleton
from pyiceberg.utils.truncate import truncate_upper_bound_binary_string, truncate_upper_bound_text_string

Expand Down Expand Up @@ -835,6 +836,10 @@ def _pyarrow_schema_ensure_large_types(schema: pa.Schema) -> pa.Schema:
return visit_pyarrow(schema, _ConvertToLargeTypes())


def _pyarrow_schema_ensure_small_types(schema: pa.Schema) -> pa.Schema:
return visit_pyarrow(schema, _ConvertToSmallTypes())


@singledispatch
def visit_pyarrow(obj: Union[pa.DataType, pa.Schema], visitor: PyArrowSchemaVisitor[T]) -> T:
"""Apply a pyarrow schema visitor to any point within a schema.
Expand Down Expand Up @@ -876,7 +881,6 @@ def _(obj: Union[pa.ListType, pa.LargeListType, pa.FixedSizeListType], visitor:
visitor.before_list_element(obj.value_field)
result = visit_pyarrow(obj.value_type, visitor)
visitor.after_list_element(obj.value_field)

return visitor.list(obj, result)


Expand Down Expand Up @@ -1145,6 +1149,30 @@ def primitive(self, primitive: pa.DataType) -> pa.DataType:
return primitive


class _ConvertToSmallTypes(PyArrowSchemaVisitor[Union[pa.DataType, pa.Schema]]):
def schema(self, schema: pa.Schema, struct_result: pa.StructType) -> pa.Schema:
return pa.schema(struct_result)

def struct(self, struct: pa.StructType, field_results: List[pa.Field]) -> pa.StructType:
return pa.struct(field_results)

def field(self, field: pa.Field, field_result: pa.DataType) -> pa.Field:
return field.with_type(field_result)

def list(self, list_type: pa.ListType, element_result: pa.DataType) -> pa.DataType:
return pa.list_(element_result)

def map(self, map_type: pa.MapType, key_result: pa.DataType, value_result: pa.DataType) -> pa.DataType:
return pa.map_(key_result, value_result)

def primitive(self, primitive: pa.DataType) -> pa.DataType:
if primitive == pa.large_string():
return pa.string()
elif primitive == pa.large_binary():
return pa.binary()
return primitive


class _ConvertToIcebergWithoutIDs(_ConvertToIceberg):
"""
Converts PyArrowSchema to Iceberg Schema with all -1 ids.
Expand All @@ -1169,6 +1197,7 @@ def _task_to_record_batches(
positional_deletes: Optional[List[ChunkedArray]],
case_sensitive: bool,
name_mapping: Optional[NameMapping] = None,
use_large_types: bool = True,
) -> Iterator[pa.RecordBatch]:
_, _, path = PyArrowFileIO.parse_location(task.file.file_path)
arrow_format = ds.ParquetFileFormat(pre_buffer=True, buffer_size=(ONE_MEGABYTE * 8))
Expand Down Expand Up @@ -1197,7 +1226,9 @@ def _task_to_record_batches(
# https://github.com/apache/arrow/issues/41884
# https://github.com/apache/arrow/issues/43183
# Would be good to remove this later on
schema=_pyarrow_schema_ensure_large_types(physical_schema),
schema=_pyarrow_schema_ensure_large_types(physical_schema)
if use_large_types
else (_pyarrow_schema_ensure_small_types(physical_schema)),
# This will push down the query to Arrow.
# But in case there are positional deletes, we have to apply them first
filter=pyarrow_filter if not positional_deletes else None,
Expand All @@ -1219,7 +1250,9 @@ def _task_to_record_batches(
arrow_table = pa.Table.from_batches([batch])
arrow_table = arrow_table.filter(pyarrow_filter)
batch = arrow_table.to_batches()[0]
yield _to_requested_schema(projected_schema, file_project_schema, batch, downcast_ns_timestamp_to_us=True)
yield _to_requested_schema(
projected_schema, file_project_schema, batch, downcast_ns_timestamp_to_us=True, use_large_types=use_large_types
)
current_index += len(batch)


Expand All @@ -1232,10 +1265,19 @@ def _task_to_table(
positional_deletes: Optional[List[ChunkedArray]],
case_sensitive: bool,
name_mapping: Optional[NameMapping] = None,
use_large_types: bool = True,
) -> Optional[pa.Table]:
batches = list(
_task_to_record_batches(
fs, task, bound_row_filter, projected_schema, projected_field_ids, positional_deletes, case_sensitive, name_mapping
fs,
task,
bound_row_filter,
projected_schema,
projected_field_ids,
positional_deletes,
case_sensitive,
name_mapping,
use_large_types,
)
)

Expand Down Expand Up @@ -1303,6 +1345,8 @@ def project_table(
# When FsSpec is not installed
raise ValueError(f"Expected PyArrowFileIO or FsspecFileIO, got: {io}") from e

use_large_types = property_as_bool(io.properties, PYARROW_USE_LARGE_TYPES_ON_READ, True)

bound_row_filter = bind(table_metadata.schema(), row_filter, case_sensitive=case_sensitive)

projected_field_ids = {
Expand All @@ -1322,6 +1366,7 @@ def project_table(
deletes_per_file.get(task.file.file_path),
case_sensitive,
table_metadata.name_mapping(),
use_large_types,
)
for task in tasks
]
Expand Down Expand Up @@ -1394,6 +1439,8 @@ def project_batches(
# When FsSpec is not installed
raise ValueError(f"Expected PyArrowFileIO or FsspecFileIO, got: {io}") from e

use_large_types = property_as_bool(io.properties, PYARROW_USE_LARGE_TYPES_ON_READ, True)

bound_row_filter = bind(table_metadata.schema(), row_filter, case_sensitive=case_sensitive)

projected_field_ids = {
Expand All @@ -1414,6 +1461,7 @@ def project_batches(
deletes_per_file.get(task.file.file_path),
case_sensitive,
table_metadata.name_mapping(),
use_large_types,
)
for batch in batches:
if limit is not None:
Expand Down Expand Up @@ -1447,12 +1495,13 @@ def _to_requested_schema(
batch: pa.RecordBatch,
downcast_ns_timestamp_to_us: bool = False,
include_field_ids: bool = False,
use_large_types: bool = True,
) -> pa.RecordBatch:
# We could re-use some of these visitors
struct_array = visit_with_partner(
requested_schema,
batch,
ArrowProjectionVisitor(file_schema, downcast_ns_timestamp_to_us, include_field_ids),
ArrowProjectionVisitor(file_schema, downcast_ns_timestamp_to_us, include_field_ids, use_large_types),
ArrowAccessor(file_schema),
)
return pa.RecordBatch.from_struct_array(struct_array)
Expand All @@ -1462,20 +1511,31 @@ class ArrowProjectionVisitor(SchemaWithPartnerVisitor[pa.Array, Optional[pa.Arra
_file_schema: Schema
_include_field_ids: bool
_downcast_ns_timestamp_to_us: bool
_use_large_types: bool

def __init__(self, file_schema: Schema, downcast_ns_timestamp_to_us: bool = False, include_field_ids: bool = False) -> None:
def __init__(
self,
file_schema: Schema,
downcast_ns_timestamp_to_us: bool = False,
include_field_ids: bool = False,
use_large_types: bool = True,
) -> None:
self._file_schema = file_schema
self._include_field_ids = include_field_ids
self._downcast_ns_timestamp_to_us = downcast_ns_timestamp_to_us
self._use_large_types = use_large_types

def _cast_if_needed(self, field: NestedField, values: pa.Array) -> pa.Array:
file_field = self._file_schema.find_field(field.field_id)

if field.field_type.is_primitive:
if field.field_type != file_field.field_type:
return values.cast(
schema_to_pyarrow(promote(file_field.field_type, field.field_type), include_field_ids=self._include_field_ids)
target_schema = schema_to_pyarrow(
promote(file_field.field_type, field.field_type), include_field_ids=self._include_field_ids
)
if not self._use_large_types:
target_schema = _pyarrow_schema_ensure_small_types(target_schema)
return values.cast(target_schema)
elif (target_type := schema_to_pyarrow(field.field_type, include_field_ids=self._include_field_ids)) != values.type:
if field.field_type == TimestampType():
# Downcasting of nanoseconds to microseconds
Expand Down Expand Up @@ -1547,12 +1607,13 @@ def field(self, field: NestedField, _: Optional[pa.Array], field_array: Optional

def list(self, list_type: ListType, list_array: Optional[pa.Array], value_array: Optional[pa.Array]) -> Optional[pa.Array]:
if isinstance(list_array, (pa.ListArray, pa.LargeListArray, pa.FixedSizeListArray)) and value_array is not None:
list_initializer = pa.large_list if isinstance(list_array, pa.LargeListArray) else pa.list_
if isinstance(value_array, pa.StructArray):
# This can be removed once this has been fixed:
# https://github.com/apache/arrow/issues/38809
list_array = pa.LargeListArray.from_arrays(list_array.offsets, value_array)
value_array = self._cast_if_needed(list_type.element_field, value_array)
arrow_field = pa.large_list(self._construct_field(list_type.element_field, value_array.type))
arrow_field = list_initializer(self._construct_field(list_type.element_field, value_array.type))
return list_array.cast(arrow_field)
else:
return None
Expand Down
87 changes: 86 additions & 1 deletion tests/integration/test_reads.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,14 @@
NotEqualTo,
NotNaN,
)
from pyiceberg.io.pyarrow import pyarrow_to_schema
from pyiceberg.io import PYARROW_USE_LARGE_TYPES_ON_READ
from pyiceberg.io.pyarrow import (
pyarrow_to_schema,
)
from pyiceberg.schema import Schema
from pyiceberg.table import Table
from pyiceberg.types import (
BinaryType,
BooleanType,
IntegerType,
NestedField,
Expand Down Expand Up @@ -665,6 +669,87 @@ def another_task() -> None:
assert table.properties.get("lock") == "xxx"


@pytest.mark.integration
@pytest.mark.parametrize("catalog", [pytest.lazy_fixture("session_catalog_hive"), pytest.lazy_fixture("session_catalog")])
def test_table_scan_default_to_large_types(catalog: Catalog) -> None:
identifier = "default.test_table_scan_default_to_large_types"
arrow_table = pa.Table.from_arrays(
[
pa.array(["a", "b", "c"]),
pa.array(["a", "b", "c"]),
pa.array([b"a", b"b", b"c"]),
pa.array([["a", "b"], ["c", "d"], ["e", "f"]]),
],
names=["string", "string-to-binary", "binary", "list"],
)

try:
catalog.drop_table(identifier)
except NoSuchTableError:
pass

tbl = catalog.create_table(
identifier,
schema=arrow_table.schema,
)

tbl.append(arrow_table)

with tbl.update_schema() as update_schema:
update_schema.update_column("string-to-binary", BinaryType())

result_table = tbl.scan().to_arrow()

expected_schema = pa.schema([
pa.field("string", pa.large_string()),
pa.field("string-to-binary", pa.large_binary()),
pa.field("binary", pa.large_binary()),
pa.field("list", pa.large_list(pa.large_string())),
])
assert result_table.schema.equals(expected_schema)


@pytest.mark.integration
@pytest.mark.parametrize("catalog", [pytest.lazy_fixture("session_catalog_hive"), pytest.lazy_fixture("session_catalog")])
def test_table_scan_override_with_small_types(catalog: Catalog) -> None:
identifier = "default.test_table_scan_override_with_small_types"
arrow_table = pa.Table.from_arrays(
[
pa.array(["a", "b", "c"]),
pa.array(["a", "b", "c"]),
pa.array([b"a", b"b", b"c"]),
pa.array([["a", "b"], ["c", "d"], ["e", "f"]]),
],
names=["string", "string-to-binary", "binary", "list"],
)

try:
catalog.drop_table(identifier)
except NoSuchTableError:
pass

tbl = catalog.create_table(
identifier,
schema=arrow_table.schema,
)

tbl.append(arrow_table)

with tbl.update_schema() as update_schema:
update_schema.update_column("string-to-binary", BinaryType())

tbl.io.properties[PYARROW_USE_LARGE_TYPES_ON_READ] = "False"
result_table = tbl.scan().to_arrow()

expected_schema = pa.schema([
pa.field("string", pa.string()),
pa.field("string-to-binary", pa.binary()),
pa.field("binary", pa.binary()),
pa.field("list", pa.list_(pa.string())),
])
assert result_table.schema.equals(expected_schema)


@pytest.mark.integration
@pytest.mark.parametrize("catalog", [pytest.lazy_fixture("session_catalog_hive"), pytest.lazy_fixture("session_catalog")])
def test_empty_scan_ordered_str(catalog: Catalog) -> None:
Expand Down
6 changes: 6 additions & 0 deletions tests/io/test_pyarrow_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
_HasIds,
_NullNaNUnmentionedTermsCollector,
_pyarrow_schema_ensure_large_types,
_pyarrow_schema_ensure_small_types,
pyarrow_to_schema,
schema_to_pyarrow,
visit_pyarrow,
Expand Down Expand Up @@ -596,6 +597,11 @@ def test_pyarrow_schema_ensure_large_types(pyarrow_schema_nested_without_ids: pa
assert _pyarrow_schema_ensure_large_types(pyarrow_schema_nested_without_ids) == expected_schema


def test_pyarrow_schema_round_trip_ensure_large_types_and_then_small_types(pyarrow_schema_nested_without_ids: pa.Schema) -> None:
schema_with_large_types = _pyarrow_schema_ensure_large_types(pyarrow_schema_nested_without_ids)
assert _pyarrow_schema_ensure_small_types(schema_with_large_types) == pyarrow_schema_nested_without_ids


@pytest.fixture
def bound_reference_str() -> BoundReference[Any]:
return BoundReference(
Expand Down

0 comments on commit 8aeab49

Please sign in to comment.