Skip to content

Commit

Permalink
Disable Spark Catalog caching for integration tests (#501)
Browse files Browse the repository at this point in the history
  • Loading branch information
kevinjqliu committed Mar 7, 2024
1 parent 29fd42c commit e56326d
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 0 deletions.
1 change: 1 addition & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -1965,6 +1965,7 @@ def spark() -> SparkSession:
.config("spark.sql.extensions", "org.apache.iceberg.spark.extensions.IcebergSparkSessionExtensions")
.config("spark.sql.catalog.integration", "org.apache.iceberg.spark.SparkCatalog")
.config("spark.sql.catalog.integration.catalog-impl", "org.apache.iceberg.rest.RESTCatalog")
.config("spark.sql.catalog.integration.cache-enabled", "false")
.config("spark.sql.catalog.integration.uri", "http://localhost:8181")
.config("spark.sql.catalog.integration.io-impl", "org.apache.iceberg.aws.s3.S3FileIO")
.config("spark.sql.catalog.integration.warehouse", "s3://warehouse/wh/")
Expand Down
22 changes: 22 additions & 0 deletions tests/integration/test_writes.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,28 @@ def test_data_files(spark: SparkSession, session_catalog: Catalog, arrow_table_w
assert [row.deleted_data_files_count for row in rows] == [0, 0, 1, 0, 0]


@pytest.mark.integration
def test_python_writes_with_spark_snapshot_reads(
spark: SparkSession, session_catalog: Catalog, arrow_table_with_null: pa.Table
) -> None:
identifier = "default.python_writes_with_spark_snapshot_reads"
tbl = _create_table(session_catalog, identifier, {"format-version": "1"}, [])

def get_current_snapshot_id(identifier: str) -> int:
return (
spark.sql(f"SELECT snapshot_id FROM {identifier}.snapshots order by committed_at desc limit 1")
.collect()[0]
.snapshot_id
)

tbl.overwrite(arrow_table_with_null)
assert tbl.current_snapshot().snapshot_id == get_current_snapshot_id(identifier) # type: ignore
tbl.overwrite(arrow_table_with_null)
assert tbl.current_snapshot().snapshot_id == get_current_snapshot_id(identifier) # type: ignore
tbl.append(arrow_table_with_null)
assert tbl.current_snapshot().snapshot_id == get_current_snapshot_id(identifier) # type: ignore


@pytest.mark.integration
@pytest.mark.parametrize("format_version", [1, 2])
@pytest.mark.parametrize(
Expand Down

0 comments on commit e56326d

Please sign in to comment.