diff --git a/pyiceberg/io/pyarrow.py b/pyiceberg/io/pyarrow.py index cd6736fbb..1e583f513 100644 --- a/pyiceberg/io/pyarrow.py +++ b/pyiceberg/io/pyarrow.py @@ -1050,11 +1050,6 @@ def _task_to_record_batches( fragment_scanner = ds.Scanner.from_fragment( fragment=fragment, - # With PyArrow 16.0.0 there is an issue with casting record-batches: - # https://github.com/apache/arrow/issues/41884 - # https://github.com/apache/arrow/issues/43183 - # Would be good to remove this later on - schema=_pyarrow_schema_ensure_large_types(physical_schema), # This will push down the query to Arrow. # But in case there are positional deletes, we have to apply them first filter=pyarrow_filter if not positional_deletes else None, @@ -1070,12 +1065,7 @@ def _task_to_record_batches( batch = batch.take(indices) # Apply the user filter if pyarrow_filter is not None: - # we need to switch back and forth between RecordBatch and Table - # as Expression filter isn't yet supported in RecordBatch - # https://github.com/apache/arrow/issues/39220 - arrow_table = pa.Table.from_batches([batch]) - arrow_table = arrow_table.filter(pyarrow_filter) - batch = arrow_table.to_batches()[0] + batch = batch.filter(pyarrow_filter) yield _to_requested_schema(projected_schema, file_project_schema, batch, downcast_ns_timestamp_to_us=True) current_index += len(batch) diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index 0b211e673..2843dc1d1 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -2020,7 +2020,7 @@ def to_arrow_batch_reader(self) -> pa.RecordBatchReader: case_sensitive=self.case_sensitive, limit=self.limit, ), - ) + ).cast(target_schema=target_schema) def to_pandas(self, **kwargs: Any) -> pd.DataFrame: return self.to_arrow().to_pandas(**kwargs) diff --git a/pyproject.toml b/pyproject.toml index a86617192..87449c281 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -58,7 +58,7 @@ fsspec = ">=2023.1.0,<2025.1.0" pyparsing = ">=3.1.0,<4.0.0" zstandard = ">=0.13.0,<1.0.0" tenacity = ">=8.2.3,<9.0.0" -pyarrow = { version = ">=9.0.0,<18.0.0", optional = true } +pyarrow = { version = ">=17.0.0,<18.0.0", optional = true } pandas = { version = ">=1.0.0,<3.0.0", optional = true } duckdb = { version = ">=0.5.0,<2.0.0", optional = true } ray = { version = ">=2.0.0,<2.10.0", optional = true } diff --git a/tests/integration/test_add_files.py b/tests/integration/test_add_files.py index 3703a9e0b..30167c6b6 100644 --- a/tests/integration/test_add_files.py +++ b/tests/integration/test_add_files.py @@ -549,7 +549,7 @@ def test_add_files_with_large_and_regular_schema(spark: SparkSession, session_ca tbl.add_files([file_path]) table_schema = tbl.scan().to_arrow().schema - assert table_schema == arrow_schema_large + assert table_schema == arrow_schema file_path_large = f"s3://warehouse/default/unpartitioned_with_large_types/v{format_version}/test-1.parquet" _write_parquet( diff --git a/tests/integration/test_writes/test_writes.py b/tests/integration/test_writes/test_writes.py index 09fe654d2..69bfb3068 100644 --- a/tests/integration/test_writes/test_writes.py +++ b/tests/integration/test_writes/test_writes.py @@ -357,7 +357,8 @@ def test_python_writes_dictionary_encoded_column_with_spark_reads( tbl.overwrite(arrow_table) spark_df = spark.sql(f"SELECT * FROM {identifier}").toPandas() pyiceberg_df = tbl.scan().to_pandas() - assert spark_df.equals(pyiceberg_df) + assert spark_df['id'].equals(pyiceberg_df['id']) + assert all(spark_df['name'].values == pyiceberg_df['name'].values) @pytest.mark.integration @@ -401,12 +402,12 @@ def test_python_writes_with_small_and_large_types_spark_reads( assert arrow_table_on_read.schema == pa.schema([ pa.field("foo", pa.large_string()), pa.field("id", pa.int32()), - pa.field("name", pa.large_string()), + pa.field("name", pa.string()), pa.field( "address", pa.struct([ - pa.field("street", pa.large_string()), - pa.field("city", pa.large_string()), + pa.field("street", pa.string()), + pa.field("city", pa.string()), pa.field("zip", pa.int32()), pa.field("bar", pa.large_string()), ]),