Skip to content

Commit

Permalink
Sanitize special character column names when writing (#590)
Browse files Browse the repository at this point in the history
* write with sanitized column names

* push down to when parquet writes

* add test for writing special character column name

* parameterize format_version

* use to_requested_schema

* refactor to_requested_schema

* more refactor

* test nested schema

* special character inside nested field

* comment on why arrow is enabled

* use existing variable

* move spark config to conftest

* pyspark arrow turns pandas df from tuple to dict

* Revert refactor to_requested_schema

* reorder args

* refactor

* pushdown schema

* only tranform when necessary
  • Loading branch information
kevinjqliu committed Apr 17, 2024
1 parent 2ee2d19 commit 62b527e
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 13 deletions.
24 changes: 15 additions & 9 deletions pyiceberg/io/pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@
pre_order_visit,
promote,
prune_columns,
sanitize_column_names,
visit,
visit_with_partner,
)
Expand Down Expand Up @@ -1016,7 +1017,6 @@ def _task_to_table(

if len(arrow_table) < 1:
return None

return to_requested_schema(projected_schema, file_project_schema, arrow_table)


Expand Down Expand Up @@ -1769,27 +1769,33 @@ def data_file_statistics_from_parquet_metadata(


def write_file(io: FileIO, table_metadata: TableMetadata, tasks: Iterator[WriteTask]) -> Iterator[DataFile]:
schema = table_metadata.schema()
arrow_file_schema = schema.as_arrow()
parquet_writer_kwargs = _get_parquet_writer_kwargs(table_metadata.properties)

row_group_size = PropertyUtil.property_as_int(
properties=table_metadata.properties,
property_name=TableProperties.PARQUET_ROW_GROUP_SIZE_BYTES,
default=TableProperties.PARQUET_ROW_GROUP_SIZE_BYTES_DEFAULT,
)

def write_parquet(task: WriteTask) -> DataFile:
table_schema = task.schema
arrow_table = pa.Table.from_batches(task.record_batches)
# if schema needs to be transformed, use the transformed schema and adjust the arrow table accordingly
# otherwise use the original schema
if (sanitized_schema := sanitize_column_names(table_schema)) != table_schema:
file_schema = sanitized_schema
arrow_table = to_requested_schema(requested_schema=file_schema, file_schema=table_schema, table=arrow_table)
else:
file_schema = table_schema

file_path = f'{table_metadata.location}/data/{task.generate_data_file_path("parquet")}'
fo = io.new_output(file_path)
with fo.create(overwrite=True) as fos:
with pq.ParquetWriter(fos, schema=arrow_file_schema, **parquet_writer_kwargs) as writer:
writer.write(pa.Table.from_batches(task.record_batches), row_group_size=row_group_size)

with pq.ParquetWriter(fos, schema=file_schema.as_arrow(), **parquet_writer_kwargs) as writer:
writer.write(arrow_table, row_group_size=row_group_size)
statistics = data_file_statistics_from_parquet_metadata(
parquet_metadata=writer.writer.metadata,
stats_columns=compute_statistics_plan(schema, table_metadata.properties),
parquet_column_mapping=parquet_path_to_id_mapping(schema),
stats_columns=compute_statistics_plan(file_schema, table_metadata.properties),
parquet_column_mapping=parquet_path_to_id_mapping(file_schema),
)
data_file = DataFile(
content=DataFileContent.DATA,
Expand Down
1 change: 1 addition & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2060,6 +2060,7 @@ def spark() -> "SparkSession":
.config("spark.sql.catalog.hive.warehouse", "s3://warehouse/hive/")
.config("spark.sql.catalog.hive.s3.endpoint", "http://localhost:9000")
.config("spark.sql.catalog.hive.s3.path-style-access", "true")
.config("spark.sql.execution.arrow.pyspark.enabled", "true")
.getOrCreate()
)

Expand Down
3 changes: 0 additions & 3 deletions tests/integration/test_inspect_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,6 @@ def check_pyiceberg_df_equals_spark_df(df: pa.Table, spark_df: DataFrame) -> Non
for column in df.column_names:
for left, right in zip(lhs[column].to_list(), rhs[column].to_list()):
if column == 'data_file':
right = right.asDict(recursive=True)
for df_column in left.keys():
if df_column == 'partition':
# Spark leaves out the partition if the table is unpartitioned
Expand All @@ -185,8 +184,6 @@ def check_pyiceberg_df_equals_spark_df(df: pa.Table, spark_df: DataFrame) -> Non

assert df_lhs == df_rhs, f"Difference in data_file column {df_column}: {df_lhs} != {df_rhs}"
elif column == 'readable_metrics':
right = right.asDict(recursive=True)

assert list(left.keys()) == [
'bool',
'string',
Expand Down
20 changes: 19 additions & 1 deletion tests/integration/test_writes/test_writes.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,9 +280,27 @@ def test_python_writes_special_character_column_with_spark_reads(
column_name_with_special_character = "letter/abc"
TEST_DATA_WITH_SPECIAL_CHARACTER_COLUMN = {
column_name_with_special_character: ['a', None, 'z'],
'id': [1, 2, 3],
'name': ['AB', 'CD', 'EF'],
'address': [
{'street': '123', 'city': 'SFO', 'zip': 12345, column_name_with_special_character: 'a'},
{'street': '456', 'city': 'SW', 'zip': 67890, column_name_with_special_character: 'b'},
{'street': '789', 'city': 'Random', 'zip': 10112, column_name_with_special_character: 'c'},
],
}
pa_schema = pa.schema([
(column_name_with_special_character, pa.string()),
pa.field(column_name_with_special_character, pa.string()),
pa.field('id', pa.int32()),
pa.field('name', pa.string()),
pa.field(
'address',
pa.struct([
pa.field('street', pa.string()),
pa.field('city', pa.string()),
pa.field('zip', pa.int32()),
pa.field(column_name_with_special_character, pa.string()),
]),
),
])
arrow_table_with_special_character_column = pa.Table.from_pydict(TEST_DATA_WITH_SPECIAL_CHARACTER_COLUMN, schema=pa_schema)
tbl = _create_table(session_catalog, identifier, {"format-version": format_version}, schema=pa_schema)
Expand Down

0 comments on commit 62b527e

Please sign in to comment.