From b9fd4386a2bcfb07e7eb59c7b34ee736844786a3 Mon Sep 17 00:00:00 2001 From: Adam Wenocur Date: Wed, 28 Aug 2024 08:57:18 -0400 Subject: [PATCH] add Python API for sample deletion (#759) --- .../src/tiledbvcf/binding/libtiledbvcf.cc | 4 + apis/python/src/tiledbvcf/binding/writer.cc | 12 +++ apis/python/src/tiledbvcf/binding/writer.h | 2 + apis/python/src/tiledbvcf/dataset.py | 8 ++ apis/python/tests/test_tiledbvcf.py | 78 ++++++++++++++----- libtiledbvcf/src/c_api/tiledbvcf.cc | 15 ++++ libtiledbvcf/src/c_api/tiledbvcf.h | 10 +++ libtiledbvcf/src/dataset/tiledbvcfdataset.cc | 4 +- libtiledbvcf/src/write/writer.cc | 5 ++ libtiledbvcf/src/write/writer.h | 5 ++ 10 files changed, 121 insertions(+), 22 deletions(-) diff --git a/apis/python/src/tiledbvcf/binding/libtiledbvcf.cc b/apis/python/src/tiledbvcf/binding/libtiledbvcf.cc index 7a4a4aae8..59fbda775 100644 --- a/apis/python/src/tiledbvcf/binding/libtiledbvcf.cc +++ b/apis/python/src/tiledbvcf/binding/libtiledbvcf.cc @@ -134,6 +134,10 @@ PYBIND11_MODULE(libtiledbvcf, m) { "ingest_samples", &Writer::ingest_samples, py::call_guard()) + .def( + "delete_samples", + &Writer::delete_samples, + py::call_guard()) .def("get_schema_version", &Writer::get_schema_version) .def("set_tiledb_config", &Writer::set_tiledb_config) .def("set_sample_batch_size", &Writer::set_sample_batch_size) diff --git a/apis/python/src/tiledbvcf/binding/writer.cc b/apis/python/src/tiledbvcf/binding/writer.cc index 01dd077ae..e980c07fa 100644 --- a/apis/python/src/tiledbvcf/binding/writer.cc +++ b/apis/python/src/tiledbvcf/binding/writer.cc @@ -232,6 +232,18 @@ void Writer::ingest_samples() { check_error(writer, tiledb_vcf_writer_store(writer)); } +void Writer::delete_samples(std::vector samples_to_delete) { + std::vector samples; + for (std::string& sample : samples_to_delete) { + samples.emplace_back(sample.c_str()); + } + + auto writer = ptr.get(); + check_error( + writer, + tiledb_vcf_writer_delete_samples(writer, samples.data(), samples.size())); +} + void Writer::deleter(tiledb_vcf_writer_t* w) { tiledb_vcf_writer_free(&w); } diff --git a/apis/python/src/tiledbvcf/binding/writer.h b/apis/python/src/tiledbvcf/binding/writer.h index a604c900e..c78f1af5e 100644 --- a/apis/python/src/tiledbvcf/binding/writer.h +++ b/apis/python/src/tiledbvcf/binding/writer.h @@ -162,6 +162,8 @@ class Writer { void ingest_samples(); + void delete_samples(std::vector samples); + /** Returns schema version number of the TileDB VCF dataset */ int32_t get_schema_version(); diff --git a/apis/python/src/tiledbvcf/dataset.py b/apis/python/src/tiledbvcf/dataset.py index 4243b9732..3312fafe5 100644 --- a/apis/python/src/tiledbvcf/dataset.py +++ b/apis/python/src/tiledbvcf/dataset.py @@ -851,6 +851,14 @@ def ingest_samples( self.writer.register_samples() self.writer.ingest_samples() + def delete_samples( + self, + sample_uris: List[str] = None, + ): + if self.mode != "w": + raise Exception("Dataset not open in write mode") + self.writer.delete_samples(sample_uris) + def tiledb_stats(self) -> str: """ Get TileDB stats as a string. diff --git a/apis/python/tests/test_tiledbvcf.py b/apis/python/tests/test_tiledbvcf.py index 67a5ef6bd..bd3689c65 100755 --- a/apis/python/tests/test_tiledbvcf.py +++ b/apis/python/tests/test_tiledbvcf.py @@ -1197,15 +1197,8 @@ def test_ingest_mode_merged(tmp_path): assert ds.count(regions=["chrX:9032893-9032893"]) == 0 -# Ok to skip is missing bcftools in Windows CI job -@pytest.mark.skipif( - os.environ.get("CI") == "true" - and platform.system() == "Windows" - and shutil.which("bcftools") is None, - reason="no bcftools", -) -def test_ingest_with_stats_v3(tmp_path): - # tiledbvcf.config_logging("debug") +@pytest.fixture +def test_stats_bgzipped_inputs(tmp_path): tmp_path_contents = os.listdir(tmp_path) if "stats" in tmp_path_contents: shutil.rmtree(os.path.join(tmp_path, "stats")) @@ -1221,23 +1214,46 @@ def test_ingest_with_stats_v3(tmp_path): check=True, ) bgzipped_inputs = glob.glob(os.path.join(tmp_path, "stats", "*.gz")) - # print(f"bgzipped inputs: {bgzipped_inputs}") for vcf_file in bgzipped_inputs: assert subprocess.run("bcftools index " + vcf_file, shell=True).returncode == 0 if "outputs" in tmp_path_contents: shutil.rmtree(os.path.join(tmp_path, "outputs")) if "stats_test" in tmp_path_contents: shutil.rmtree(os.path.join(tmp_path, "stats_test")) - # tiledbvcf.config_logging("trace") + return bgzipped_inputs + + +@pytest.fixture +def test_stats_sample_names(test_stats_bgzipped_inputs): + assert len(test_stats_bgzipped_inputs) == 8 + return [os.path.basename(file).split(".")[0] for file in test_stats_bgzipped_inputs] + + +@pytest.fixture +def test_stats_v3_ingestion(tmp_path, test_stats_bgzipped_inputs): + assert len(test_stats_bgzipped_inputs) == 8 + # print(f"bgzipped inputs: {test_stats_bgzipped_inputs}") ds = tiledbvcf.Dataset(uri=os.path.join(tmp_path, "stats_test"), mode="w") ds.create_dataset( enable_variant_stats=True, enable_allele_count=True, variant_stats_version=3 ) - ds.ingest_samples(bgzipped_inputs) + ds.ingest_samples(test_stats_bgzipped_inputs) ds = tiledbvcf.Dataset(uri=os.path.join(tmp_path, "stats_test"), mode="r") - sample_names = [os.path.basename(file).split(".")[0] for file in bgzipped_inputs] - data_frame = ds.read( - samples=sample_names, + return ds + + +# Ok to skip is missing bcftools in Windows CI job +@pytest.mark.skipif( + os.environ.get("CI") == "true" + and platform.system() == "Windows" + and shutil.which("bcftools") is None, + reason="no bcftools", +) +def test_ingest_with_stats_v3( + tmp_path, test_stats_v3_ingestion, test_stats_sample_names +): + data_frame = test_stats_v3_ingestion.read( + samples=test_stats_sample_names, attrs=["contig", "pos_start", "id", "qual", "info_TILEDB_IAF", "sample_name"], set_af_filter="<0.2", ) @@ -1249,8 +1265,8 @@ def test_ingest_with_stats_v3(tmp_path): data_frame[data_frame["sample_name"] == "second"]["info_TILEDB_IAF"].iloc[0][0] == 0.9375 ) - data_frame = ds.read( - samples=sample_names, + data_frame = test_stats_v3_ingestion.read( + samples=test_stats_sample_names, attrs=["contig", "pos_start", "id", "qual", "info_TILEDB_IAF", "sample_name"], scan_all_samples=True, ) @@ -1260,8 +1276,7 @@ def test_ingest_with_stats_v3(tmp_path): ]["info_TILEDB_IAF"].iloc[0][0] == 0.9375 ) - ds = tiledbvcf.Dataset(uri=os.path.join(tmp_path, "stats_test"), mode="r") - df = ds.read_variant_stats("chr1:1-10000") + df = test_stats_v3_ingestion.read_variant_stats("chr1:1-10000") assert df.shape == (13, 5) df = tiledbvcf.allele_frequency.read_allele_frequency( os.path.join(tmp_path, "stats_test"), "chr1:1-10000" @@ -1269,16 +1284,37 @@ def test_ingest_with_stats_v3(tmp_path): assert df.pos.is_monotonic_increasing df["an_check"] = (df.ac / df.af).round(0).astype("int32") assert df.an_check.equals(df.an) - df = ds.read_variant_stats("chr1:1-10000") + df = test_stats_v3_ingestion.read_variant_stats("chr1:1-10000") assert df.shape == (13, 5) df = df.to_pandas() - df = ds.read_allele_count("chr1:1-10000") + df = test_stats_v3_ingestion.read_allele_count("chr1:1-10000") assert df.shape == (7, 6) df = df.to_pandas() assert sum(df["pos"] == (0, 1, 1, 2, 2, 2, 3)) == 7 assert sum(df["count"] == (8, 5, 3, 4, 2, 2, 1)) == 7 +@pytest.mark.skipif( + os.environ.get("CI") == "true" + and platform.system() == "Windows" + and shutil.which("bcftools") is None, + reason="no bcftools", +) +def test_delete_samples(tmp_path, test_stats_v3_ingestion, test_stats_sample_names): + # assert test_stats_v3_ingestion.samples() == test_stats_sample_names + assert "second" in test_stats_sample_names + assert "fifth" in test_stats_sample_names + assert "third" in test_stats_sample_names + ds = tiledbvcf.Dataset(uri=os.path.join(tmp_path, "stats_test"), mode="w") + # tiledbvcf.config_logging("trace") + ds.delete_samples(["second", "fifth"]) + ds = tiledbvcf.Dataset(uri=os.path.join(tmp_path, "stats_test"), mode="r") + sample_names = ds.samples() + assert "second" not in sample_names + assert "fifth" not in sample_names + assert "third" in sample_names + + # Ok to skip is missing bcftools in Windows CI job @pytest.mark.skipif( os.environ.get("CI") == "true" diff --git a/libtiledbvcf/src/c_api/tiledbvcf.cc b/libtiledbvcf/src/c_api/tiledbvcf.cc index 160933bda..83f1881b5 100644 --- a/libtiledbvcf/src/c_api/tiledbvcf.cc +++ b/libtiledbvcf/src/c_api/tiledbvcf.cc @@ -1823,6 +1823,21 @@ int32_t tiledb_vcf_writer_set_variant_stats_version( return TILEDB_VCF_OK; } +int32_t tiledb_vcf_writer_delete_samples( + tiledb_vcf_writer_t* writer, const char** samples, size_t nsamples) { + std::vector encoded_samples; + for (size_t i = 0; i < nsamples; i++) + encoded_samples.emplace_back(samples[i]); + if (sanity_check(writer) == TILEDB_VCF_ERR) + return TILEDB_VCF_ERR; + + if (SAVE_ERROR_CATCH( + writer, writer->writer_->delete_samples(encoded_samples))) + return TILEDB_VCF_ERR; + + return TILEDB_VCF_OK; +} + /* ********************************* */ /* ERROR */ /* ********************************* */ diff --git a/libtiledbvcf/src/c_api/tiledbvcf.h b/libtiledbvcf/src/c_api/tiledbvcf.h index 34742e48d..c61ebce06 100644 --- a/libtiledbvcf/src/c_api/tiledbvcf.h +++ b/libtiledbvcf/src/c_api/tiledbvcf.h @@ -1706,6 +1706,16 @@ tiledb_vcf_writer_set_compression_level(tiledb_vcf_writer_t* writer, int level); TILEDBVCF_EXPORT int32_t tiledb_vcf_writer_set_variant_stats_version( tiledb_vcf_writer_t* writer, uint8_t version); +/** + * Deletes samples from dataset + * @param writer VCF writer object + * @param samples samples to delete + * @param nsamples number of samples to delete + */ +TILEDBVCF_EXPORT int32_t tiledb_vcf_writer_delete_samples( + + tiledb_vcf_writer_t* writer, const char** samples, size_t nsamples); + /* ********************************* */ /* ERROR */ /* ********************************* */ diff --git a/libtiledbvcf/src/dataset/tiledbvcfdataset.cc b/libtiledbvcf/src/dataset/tiledbvcfdataset.cc index 8c102eebc..012db126c 100644 --- a/libtiledbvcf/src/dataset/tiledbvcfdataset.cc +++ b/libtiledbvcf/src/dataset/tiledbvcfdataset.cc @@ -938,7 +938,9 @@ void TileDBVCFDataset::delete_samples( const std::vector& sample_names, const std::vector& tiledb_config) { // Open dataset in read mode, required before calling `sample_exists`. - open(uri); + if (!open_) { + open(uri, tiledb_config); + } // Define a function that deletes a sample from an array auto delete_sample = [&](Array& array, const std::string& sample) { diff --git a/libtiledbvcf/src/write/writer.cc b/libtiledbvcf/src/write/writer.cc index 169995800..2f6c0d127 100644 --- a/libtiledbvcf/src/write/writer.cc +++ b/libtiledbvcf/src/write/writer.cc @@ -1484,5 +1484,10 @@ void Writer::set_variant_stats_array_version(uint8_t version) { creation_params_.variant_stats_array_version = version; } +void Writer::delete_samples(std::vector samples) { + dataset_->delete_samples( + ingestion_params_.uri, samples, ingestion_params_.tiledb_config); +} + } // namespace vcf } // namespace tiledb diff --git a/libtiledbvcf/src/write/writer.h b/libtiledbvcf/src/write/writer.h index 62052cec0..aea576ce3 100644 --- a/libtiledbvcf/src/write/writer.h +++ b/libtiledbvcf/src/write/writer.h @@ -382,6 +382,11 @@ class Writer { /** Set variant stats array version */ void set_variant_stats_array_version(uint8_t version); + /** + * @brief Delete samples from the writer's dataset. + */ + void delete_samples(std::vector samples); + private: /* ********************************* */ /* PRIVATE ATTRIBUTES */