Skip to content

Commit

Permalink
prevent adding duplicate files (#1036)
Browse files Browse the repository at this point in the history
* prevent add_files from adding a file that's already referenced by the iceberg table

* fix method that searches files that are already referenced + docs

* move function to locate duplicate files into add_files

* add check_duplicate_files flag to add_files api to make the behaviour according to java api

* add check_duplicate_files flag to table level api and add tests

* add check_duplicate_files flag to table level api and add tests

* fix tests to check new new added flag check_duplicate_files and fix checks

* fix linting
  • Loading branch information
amitgilad3 committed Aug 26, 2024
1 parent 0b487f2 commit 53a0b73
Show file tree
Hide file tree
Showing 2 changed files with 120 additions and 4 deletions.
29 changes: 25 additions & 4 deletions pyiceberg/table/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -621,7 +621,9 @@ def delete(self, delete_filter: Union[str, BooleanExpression], snapshot_properti
if not delete_snapshot.files_affected and not delete_snapshot.rewrites_needed:
warnings.warn("Delete operation did not match any records")

def add_files(self, file_paths: List[str], snapshot_properties: Dict[str, str] = EMPTY_DICT) -> None:
def add_files(
self, file_paths: List[str], snapshot_properties: Dict[str, str] = EMPTY_DICT, check_duplicate_files: bool = True
) -> None:
"""
Shorthand API for adding files as data files to the table transaction.
Expand All @@ -630,7 +632,21 @@ def add_files(self, file_paths: List[str], snapshot_properties: Dict[str, str] =
Raises:
FileNotFoundError: If the file does not exist.
ValueError: Raises a ValueError given file_paths contains duplicate files
ValueError: Raises a ValueError given file_paths already referenced by table
"""
if len(file_paths) != len(set(file_paths)):
raise ValueError("File paths must be unique")

if check_duplicate_files:
import pyarrow.compute as pc

expr = pc.field("file_path").isin(file_paths)
referenced_files = [file["file_path"] for file in self._table.inspect.files().filter(expr).to_pylist()]

if referenced_files:
raise ValueError(f"Cannot add files that are already referenced by table, files: {', '.join(referenced_files)}")

if self.table_metadata.name_mapping() is None:
self.set_properties(**{
TableProperties.DEFAULT_NAME_MAPPING: self.table_metadata.schema().name_mapping.model_dump_json()
Expand Down Expand Up @@ -1632,7 +1648,9 @@ def delete(
with self.transaction() as tx:
tx.delete(delete_filter=delete_filter, snapshot_properties=snapshot_properties)

def add_files(self, file_paths: List[str], snapshot_properties: Dict[str, str] = EMPTY_DICT) -> None:
def add_files(
self, file_paths: List[str], snapshot_properties: Dict[str, str] = EMPTY_DICT, check_duplicate_files: bool = True
) -> None:
"""
Shorthand API for adding files as data files to the table.
Expand All @@ -1643,7 +1661,9 @@ def add_files(self, file_paths: List[str], snapshot_properties: Dict[str, str] =
FileNotFoundError: If the file does not exist.
"""
with self.transaction() as tx:
tx.add_files(file_paths=file_paths, snapshot_properties=snapshot_properties)
tx.add_files(
file_paths=file_paths, snapshot_properties=snapshot_properties, check_duplicate_files=check_duplicate_files
)

def update_spec(self, case_sensitive: bool = True) -> UpdateSpec:
return UpdateSpec(Transaction(self, autocommit=True), case_sensitive=case_sensitive)
Expand Down Expand Up @@ -2260,7 +2280,8 @@ def union_by_name(self, new_schema: Union[Schema, "pa.Schema"]) -> UpdateSchema:
visit_with_partner(
Catalog._convert_schema_if_needed(new_schema),
-1,
UnionByNameVisitor(update_schema=self, existing_schema=self._schema, case_sensitive=self._case_sensitive), # type: ignore
UnionByNameVisitor(update_schema=self, existing_schema=self._schema, case_sensitive=self._case_sensitive),
# type: ignore
PartnerIdByNameAccessor(partner_schema=self._schema, case_sensitive=self._case_sensitive),
)
return self
Expand Down
95 changes: 95 additions & 0 deletions tests/integration/test_add_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -732,3 +732,98 @@ def test_add_files_subset_of_schema(spark: SparkSession, session_catalog: Catalo
for column in written_arrow_table.column_names:
for left, right in zip(lhs[column].to_list(), rhs[column].to_list()):
assert left == right


@pytest.mark.integration
@pytest.mark.parametrize("format_version", [1, 2])
def test_add_files_with_duplicate_files_in_file_paths(spark: SparkSession, session_catalog: Catalog, format_version: int) -> None:
identifier = f"default.test_table_duplicate_add_files_v{format_version}"
tbl = _create_table(session_catalog, identifier, format_version)
file_path = "s3://warehouse/default/unpartitioned/v{format_version}/test-1.parquet"
file_paths = [file_path, file_path]

# add the parquet files as data files
with pytest.raises(ValueError) as exc_info:
tbl.add_files(file_paths=file_paths)
assert "File paths must be unique" in str(exc_info.value)


@pytest.mark.integration
@pytest.mark.parametrize("format_version", [1, 2])
def test_add_files_that_referenced_by_current_snapshot(
spark: SparkSession, session_catalog: Catalog, format_version: int
) -> None:
identifier = f"default.test_table_add_referenced_file_v{format_version}"
tbl = _create_table(session_catalog, identifier, format_version)

file_paths = [f"s3://warehouse/default/unpartitioned/v{format_version}/test-{i}.parquet" for i in range(5)]

# write parquet files
for file_path in file_paths:
fo = tbl.io.new_output(file_path)
with fo.create(overwrite=True) as fos:
with pq.ParquetWriter(fos, schema=ARROW_SCHEMA) as writer:
writer.write_table(ARROW_TABLE)

# add the parquet files as data files
tbl.add_files(file_paths=file_paths)
existing_files_in_table = tbl.inspect.files().to_pylist().pop()["file_path"]

with pytest.raises(ValueError) as exc_info:
tbl.add_files(file_paths=[existing_files_in_table])
assert f"Cannot add files that are already referenced by table, files: {existing_files_in_table}" in str(exc_info.value)


@pytest.mark.integration
@pytest.mark.parametrize("format_version", [1, 2])
def test_add_files_that_referenced_by_current_snapshot_with_check_duplicate_files_false(
spark: SparkSession, session_catalog: Catalog, format_version: int
) -> None:
identifier = f"default.test_table_add_referenced_file_v{format_version}"
tbl = _create_table(session_catalog, identifier, format_version)

file_paths = [f"s3://warehouse/default/unpartitioned/v{format_version}/test-{i}.parquet" for i in range(5)]
# write parquet files
for file_path in file_paths:
fo = tbl.io.new_output(file_path)
with fo.create(overwrite=True) as fos:
with pq.ParquetWriter(fos, schema=ARROW_SCHEMA) as writer:
writer.write_table(ARROW_TABLE)

# add the parquet files as data files
tbl.add_files(file_paths=file_paths)
existing_files_in_table = tbl.inspect.files().to_pylist().pop()["file_path"]
tbl.add_files(file_paths=[existing_files_in_table], check_duplicate_files=False)
rows = spark.sql(
f"""
SELECT added_data_files_count, existing_data_files_count, deleted_data_files_count
FROM {identifier}.all_manifests
"""
).collect()
assert [row.added_data_files_count for row in rows] == [5, 1, 5]
assert [row.existing_data_files_count for row in rows] == [0, 0, 0]
assert [row.deleted_data_files_count for row in rows] == [0, 0, 0]


@pytest.mark.integration
@pytest.mark.parametrize("format_version", [1, 2])
def test_add_files_that_referenced_by_current_snapshot_with_check_duplicate_files_true(
spark: SparkSession, session_catalog: Catalog, format_version: int
) -> None:
identifier = f"default.test_table_add_referenced_file_v{format_version}"
tbl = _create_table(session_catalog, identifier, format_version)

file_paths = [f"s3://warehouse/default/unpartitioned/v{format_version}/test-{i}.parquet" for i in range(5)]
# write parquet files
for file_path in file_paths:
fo = tbl.io.new_output(file_path)
with fo.create(overwrite=True) as fos:
with pq.ParquetWriter(fos, schema=ARROW_SCHEMA) as writer:
writer.write_table(ARROW_TABLE)

# add the parquet files as data files
tbl.add_files(file_paths=file_paths)
existing_files_in_table = tbl.inspect.files().to_pylist().pop()["file_path"]
with pytest.raises(ValueError) as exc_info:
tbl.add_files(file_paths=[existing_files_in_table], check_duplicate_files=True)
assert f"Cannot add files that are already referenced by table, files: {existing_files_in_table}" in str(exc_info.value)

0 comments on commit 53a0b73

Please sign in to comment.