diff --git a/.github/workflows/cargo-publish.yml b/.github/workflows/cargo-publish.yml index 2eda1d9242..4275f554c8 100644 --- a/.github/workflows/cargo-publish.yml +++ b/.github/workflows/cargo-publish.yml @@ -5,6 +5,12 @@ on: # Use released instead of published, since we don't publish preview/beta # versions. Users instead install them from the git repo. types: [released] + workflow_dispatch: + inputs: + tag: + description: 'Tag to publish (e.g., v1.0.0)' + required: true + type: string env: # This env var is used by Swatinem/rust-cache@v2 for the cache @@ -14,7 +20,7 @@ env: jobs: build: runs-on: ubuntu-latest - timeout-minutes: 30 + timeout-minutes: 60 env: # Need up-to-date compilers for kernels CC: gcc-12 @@ -27,6 +33,19 @@ jobs: - uses: Swatinem/rust-cache@v2 with: workspaces: rust + - name: Verify and checkout specified tag + if: github.event_name == 'workflow_dispatch' + run: | + git fetch --all --tags + if git rev-parse ${{ github.event.inputs.tag }} >/dev/null 2>&1; then + git checkout ${{ github.event.inputs.tag }} + echo "Successfully checked out tag ${{ github.event.inputs.tag }}" + else + echo "Error: Tag ${{ github.event.inputs.tag }} does not exist" + echo "Available tags:" + git tag -l + exit 1 + fi - name: Install dependencies run: | sudo apt update diff --git a/.github/workflows/pr-title.yml b/.github/workflows/pr-title.yml index eb8a4785ed..80d75b0c30 100644 --- a/.github/workflows/pr-title.yml +++ b/.github/workflows/pr-title.yml @@ -41,12 +41,11 @@ jobs: - uses: actions/setup-python@v5 with: python-version: "3.12" + - run: pip install PyGithub - env: PR_NUMBER: ${{ github.event.pull_request.number }} working-directory: pr - run: | - pip install PyGithub - python ../base/ci/check_versions.py + run: python ../base/ci/check_versions.py commitlint: permissions: pull-requests: write diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index 094fa028cb..7b6718c783 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -99,7 +99,7 @@ jobs: fail_ci_if_error: false linux-arm: runs-on: warp-ubuntu-latest-arm64-4x - timeout-minutes: 30 + timeout-minutes: 45 steps: - uses: actions/checkout@v4 - uses: Swatinem/rust-cache@v2 @@ -160,7 +160,7 @@ jobs: workspaces: rust - name: Select new xcode # Default XCode right now is 15.0.1, which contains a bug that causes - # backtraces to not show properly. See: + # backtraces to not show properly. See: # https://github.com/rust-lang/rust/issues/113783 run: sudo xcode-select -s /Applications/Xcode_15.4.app - name: Install dependencies @@ -171,7 +171,7 @@ jobs: rustup component add rustfmt - name: Run tests # Check all benches, even though we aren't going to run them. - run: | + run: | cargo build --tests --benches --all-features --workspace cargo test --all-features windows-build: diff --git a/Cargo.toml b/Cargo.toml index 7aa46d5a92..efdce06ec4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -21,7 +21,7 @@ exclude = ["python"] resolver = "2" [workspace.package] -version = "0.17.1" +version = "0.18.1" edition = "2021" authors = ["Lance Devs "] license = "Apache-2.0" @@ -44,20 +44,21 @@ categories = [ rust-version = "1.78" [workspace.dependencies] -lance = { version = "=0.17.1", path = "./rust/lance" } -lance-arrow = { version = "=0.17.1", path = "./rust/lance-arrow" } -lance-core = { version = "=0.17.1", path = "./rust/lance-core" } -lance-datafusion = { version = "=0.17.1", path = "./rust/lance-datafusion" } -lance-datagen = { version = "=0.17.1", path = "./rust/lance-datagen" } -lance-encoding = { version = "=0.17.1", path = "./rust/lance-encoding" } -lance-encoding-datafusion = { version = "=0.17.1", path = "./rust/lance-encoding-datafusion" } -lance-file = { version = "=0.17.1", path = "./rust/lance-file" } -lance-index = { version = "=0.17.1", path = "./rust/lance-index" } -lance-io = { version = "=0.17.1", path = "./rust/lance-io" } -lance-linalg = { version = "=0.17.1", path = "./rust/lance-linalg" } -lance-table = { version = "=0.17.1", path = "./rust/lance-table" } -lance-test-macros = { version = "=0.17.1", path = "./rust/lance-test-macros" } -lance-testing = { version = "=0.17.1", path = "./rust/lance-testing" } +lance = { version = "=0.18.1", path = "./rust/lance" } +lance-arrow = { version = "=0.18.1", path = "./rust/lance-arrow" } +lance-core = { version = "=0.18.1", path = "./rust/lance-core" } +lance-datafusion = { version = "=0.18.1", path = "./rust/lance-datafusion" } +lance-datagen = { version = "=0.18.1", path = "./rust/lance-datagen" } +lance-encoding = { version = "=0.18.1", path = "./rust/lance-encoding" } +lance-encoding-datafusion = { version = "=0.18.1", path = "./rust/lance-encoding-datafusion" } +lance-file = { version = "=0.18.1", path = "./rust/lance-file" } +lance-index = { version = "=0.18.1", path = "./rust/lance-index" } +lance-io = { version = "=0.18.1", path = "./rust/lance-io" } +lance-jni = { version = "=0.18.1", path = "./java/core/lance-jni" } +lance-linalg = { version = "=0.18.1", path = "./rust/lance-linalg" } +lance-table = { version = "=0.18.1", path = "./rust/lance-table" } +lance-test-macros = { version = "=0.18.1", path = "./rust/lance-test-macros" } +lance-testing = { version = "=0.18.1", path = "./rust/lance-testing" } approx = "0.5.1" # Note that this one does not include pyarrow arrow = { version = "52.2", optional = false, features = ["prettyprint"] } @@ -110,7 +111,7 @@ datafusion-physical-expr = { version = "40.0", features = [ ] } deepsize = "0.2.0" either = "1.0" -fsst = { version = "=0.17.1", path = "./rust/lance-encoding/compression-algo/fsst" } +fsst = { version = "=0.18.1", path = "./rust/lance-encoding/compression-algo/fsst" } futures = "0.3" http = "0.2.9" hyperloglogplus = { version = "0.4.1", features = ["const-loop"] } diff --git a/ci/check_versions.py b/ci/check_versions.py index a16246d3b8..d42062a255 100644 --- a/ci/check_versions.py +++ b/ci/check_versions.py @@ -47,8 +47,16 @@ def parse_version(version: str) -> tuple[int, int, int]: # Check for a breaking-change label in the PRs between the last release and the current commit. commits = repo.compare(latest_release.tag_name, os.environ["GITHUB_SHA"]).commits prs = (pr for commit in commits for pr in commit.get_pulls()) - pr_labels = (label.name for pr in prs for label in pr.labels) - has_breaking_changes = any(label == "breaking-change" for label in pr_labels) + has_breaking_changes = False + for pr in prs: + pr_labels = (label.name for label in pr.labels) + if any(label == "breaking-change" for label in pr_labels): + has_breaking_changes = True + print(f"Found breaking change in PR #{pr.number}: {pr.title}") + print(f" {pr.html_url}") + break + else: + print("No breaking changes found.") if os.environ.get("PR_NUMBER"): # If we're running on a PR, we should validate that the version has been diff --git a/docs/read_and_write.rst b/docs/read_and_write.rst index 9ced48b83d..85eca5668b 100644 --- a/docs/read_and_write.rst +++ b/docs/read_and_write.rst @@ -700,6 +700,10 @@ These options apply to all object stores. - Description * - ``allow_http`` - Allow non-TLS, i.e. non-HTTPS connections. Default, ``False``. + * - ``download_retry_count`` + - Number of times to retry a download. Default, ``3``. This limit is applied when + the HTTP request succeeds but the response is not fully downloaded, typically due + to a violation of ``request_timeout``. * - ``allow_invalid_certificates`` - Skip certificate validation on https connections. Default, ``False``. Warning: This is insecure and should only be used for testing. diff --git a/java/core/lance-jni/Cargo.toml b/java/core/lance-jni/Cargo.toml index d3266084a1..6016fdae8e 100644 --- a/java/core/lance-jni/Cargo.toml +++ b/java/core/lance-jni/Cargo.toml @@ -14,9 +14,9 @@ crate-type = ["cdylib"] [dependencies] lance = { workspace = true, features = ["substrait"] } -lance-encoding = { path = "../../../rust/lance-encoding" } -lance-linalg = { path = "../../../rust/lance-linalg" } -lance-index = { path = "../../../rust/lance-index" } +lance-encoding = { workspace = true } +lance-linalg = { workspace = true } +lance-index = { workspace = true } lance-io.workspace = true arrow = { workspace = true, features = ["ffi"] } arrow-schema.workspace = true diff --git a/java/core/pom.xml b/java/core/pom.xml index 4ff96010f6..2a5bae08d2 100644 --- a/java/core/pom.xml +++ b/java/core/pom.xml @@ -8,7 +8,7 @@ com.lancedb lance-parent - 0.0.4 + 0.18.1 ../pom.xml diff --git a/java/pom.xml b/java/pom.xml index 648f7f3dab..78729c8e11 100644 --- a/java/pom.xml +++ b/java/pom.xml @@ -6,7 +6,7 @@ com.lancedb lance-parent - 0.0.4 + 0.18.1 pom Lance Parent diff --git a/java/spark/README.md b/java/spark/README.md index 95b6e94681..47edf02099 100644 --- a/java/spark/README.md +++ b/java/spark/README.md @@ -1,6 +1,6 @@ # Spark-Lance Connector -The Spark-Lance Connector allows Apache Spark to efficiently read tables stored in Lance format. +The Spark-Lance Connector allows Apache Spark to efficiently read datasets stored in Lance format. Lance is a modern columnar data format optimized for machine learning workflows and datasets, supporting distributed, parallel scans, and optimizations such as column and filter pushdown to improve performance. Additionally, Lance provides high-performance random access that is 100 times faster than Parquet without sacrificing scan performance. @@ -8,8 +8,8 @@ By using the Spark-Lance Connector, you can leverage Spark's powerful data proce ## Features -* Query Lance Tables: Seamlessly query tables stored in the Lance format using Spark. -* Distributed, Parallel Scans: Leverage Spark's distributed computing capabilities to perform parallel scans on Lance tables. +* Query Lance Datasets: Seamlessly query datasets stored in the Lance format using Spark. +* Distributed, Parallel Scans: Leverage Spark's distributed computing capabilities to perform parallel scans on Lance datasets. * Column and Filter Pushdown: Optimize query performance by pushing down column selections and filters to the data source. ## Installation @@ -49,7 +49,7 @@ SparkSession spark = SparkSession.builder() Dataset data = spark.read().format("lance") .option("db", "/path/to/example_db") - .option("table", "lance_example_table") + .option("dataset", "lance_example_dataset") .load(); data.show(100) diff --git a/java/spark/pom.xml b/java/spark/pom.xml index b1f4c98d80..9a0c16360a 100644 --- a/java/spark/pom.xml +++ b/java/spark/pom.xml @@ -8,7 +8,7 @@ com.lancedb lance-parent - 0.0.4 + 0.18.1 ../pom.xml @@ -40,7 +40,7 @@ com.lancedb lance-core - 0.0.4 + 0.18.1 org.apache.spark diff --git a/java/spark/src/main/java/com/lancedb/lance/spark/LanceCatalog.java b/java/spark/src/main/java/com/lancedb/lance/spark/LanceCatalog.java new file mode 100644 index 0000000000..0b2693e1f1 --- /dev/null +++ b/java/spark/src/main/java/com/lancedb/lance/spark/LanceCatalog.java @@ -0,0 +1,85 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.lancedb.lance.spark; + +import com.lancedb.lance.spark.internal.LanceDatasetAdapter; +import com.lancedb.lance.spark.utils.Optional; +import org.apache.spark.sql.catalyst.analysis.NoSuchNamespaceException; +import org.apache.spark.sql.catalyst.analysis.NoSuchTableException; +import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException; +import org.apache.spark.sql.connector.catalog.Identifier; +import org.apache.spark.sql.connector.catalog.Table; +import org.apache.spark.sql.connector.catalog.TableCatalog; +import org.apache.spark.sql.connector.catalog.TableChange; +import org.apache.spark.sql.connector.expressions.Transform; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.util.CaseInsensitiveStringMap; +import scala.Some; + +import java.util.Map; + +public class LanceCatalog implements TableCatalog { + @Override + public Identifier[] listTables(String[] namespace) throws NoSuchNamespaceException { + throw new UnsupportedOperationException("Please use lancedb catalog for dataset listing"); + } + + @Override + public Table loadTable(Identifier ident) throws NoSuchTableException { + LanceConfig config = LanceConfig.from(ident.name()); + Optional schema = LanceDatasetAdapter.getSchema(ident.name()); + if (schema.isEmpty()) { + throw new NoSuchTableException(config.getDbPath(), config.getDatasetName()); + } + return new LanceDataset(LanceConfig.from(ident.name()), schema.get()); + } + + @Override + public Table createTable(Identifier ident, StructType schema, Transform[] partitions, + Map properties) throws TableAlreadyExistsException, NoSuchNamespaceException { + try { + LanceDatasetAdapter.createDataset(ident.name(), schema); + } catch (IllegalArgumentException e) { + throw new TableAlreadyExistsException(ident.name(), new Some<>(e)); + } + return new LanceDataset(LanceConfig.from(properties, ident.name()), schema); + } + + @Override + public Table alterTable(Identifier ident, TableChange... changes) throws NoSuchTableException { + throw new UnsupportedOperationException(); + } + + @Override + public boolean dropTable(Identifier ident) { + throw new UnsupportedOperationException(); + } + + @Override + public void renameTable(Identifier oldIdent, Identifier newIdent) + throws NoSuchTableException, TableAlreadyExistsException { + throw new UnsupportedOperationException(); + } + + @Override + public void initialize(String name, CaseInsensitiveStringMap options) { + // Do nothing here + } + + @Override + public String name() { + return "lance"; + } +} diff --git a/java/spark/src/main/java/com/lancedb/lance/spark/LanceConfig.java b/java/spark/src/main/java/com/lancedb/lance/spark/LanceConfig.java new file mode 100644 index 0000000000..a0ea017b71 --- /dev/null +++ b/java/spark/src/main/java/com/lancedb/lance/spark/LanceConfig.java @@ -0,0 +1,110 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.lancedb.lance.spark; + +import java.io.Serializable; +import java.util.Map; +import org.apache.spark.sql.util.CaseInsensitiveStringMap; + +/** + * Lance Configuration. + */ +public class LanceConfig implements Serializable { + private static final long serialVersionUID = 827364827364823764L; + public static final String CONFIG_DATASET_URI = "path"; // Path is default spark option key + public static final String CONFIG_PUSH_DOWN_FILTERS = "pushDownFilters"; + public static final String LANCE_FILE_SUFFIX = ".lance"; + + private static final boolean DEFAULT_PUSH_DOWN_FILTERS = true; + + private final String dbPath; + private final String datasetName; + private final String datasetUri; + private final boolean pushDownFilters; + + private LanceConfig(String dbPath, String datasetName, + String datasetUri, boolean pushDownFilters) { + this.dbPath = dbPath; + this.datasetName = datasetName; + this.datasetUri = datasetUri; + this.pushDownFilters = pushDownFilters; + } + + public static LanceConfig from(Map properties) { + return from(new CaseInsensitiveStringMap(properties)); + } + + public static LanceConfig from(CaseInsensitiveStringMap options) { + if (!options.containsKey(CONFIG_DATASET_URI)) { + throw new IllegalArgumentException("Missing required option " + CONFIG_DATASET_URI); + } + return from(options, options.get(CONFIG_DATASET_URI)); + } + + public static LanceConfig from(Map properties, String datasetUri) { + return from(new CaseInsensitiveStringMap(properties), datasetUri); + } + + public static LanceConfig from(String datasetUri) { + return from(CaseInsensitiveStringMap.empty(), datasetUri); + } + + public static LanceConfig from(CaseInsensitiveStringMap options, String datasetUri) { + boolean pushDownFilters = options.getBoolean(CONFIG_PUSH_DOWN_FILTERS, + DEFAULT_PUSH_DOWN_FILTERS); + String[] paths = extractDbPathAndDatasetName(datasetUri); + return new LanceConfig(paths[0], paths[1], datasetUri, pushDownFilters); + } + + public static String getDatasetUri(String dbPath, String datasetUri) { + StringBuilder sb = new StringBuilder().append(dbPath); + if (!dbPath.endsWith("/")) { + sb.append("/"); + } + return sb.append(datasetUri).append(LANCE_FILE_SUFFIX).toString(); + } + + private static String[] extractDbPathAndDatasetName(String datasetUri) { + if (datasetUri == null || !datasetUri.endsWith(LANCE_FILE_SUFFIX)) { + throw new IllegalArgumentException("Invalid dataset uri: " + datasetUri); + } + + int lastSlashIndex = datasetUri.lastIndexOf('/'); + if (lastSlashIndex == -1) { + throw new IllegalArgumentException("Invalid dataset uri: " + datasetUri); + } + + String datasetNameWithSuffix = datasetUri.substring(lastSlashIndex + 1); + return new String[]{datasetUri.substring(0, lastSlashIndex + 1), + datasetNameWithSuffix.substring(0, + datasetNameWithSuffix.length() - LANCE_FILE_SUFFIX.length())}; + } + + public String getDbPath() { + return dbPath; + } + + public String getDatasetName() { + return datasetName; + } + + public String getDatasetUri() { + return datasetUri; + } + + public boolean isPushDownFilters() { + return pushDownFilters; + } +} diff --git a/java/spark/src/main/java/com/lancedb/lance/spark/LanceDataSource.java b/java/spark/src/main/java/com/lancedb/lance/spark/LanceDataSource.java index e00bbfd26d..0bc5fcbbdd 100644 --- a/java/spark/src/main/java/com/lancedb/lance/spark/LanceDataSource.java +++ b/java/spark/src/main/java/com/lancedb/lance/spark/LanceDataSource.java @@ -14,10 +14,11 @@ package com.lancedb.lance.spark; -import com.lancedb.lance.spark.internal.LanceConfig; -import com.lancedb.lance.spark.internal.LanceReader; +import com.lancedb.lance.spark.internal.LanceDatasetAdapter; +import com.lancedb.lance.spark.utils.Optional; +import org.apache.spark.sql.connector.catalog.Identifier; +import org.apache.spark.sql.connector.catalog.SupportsCatalogOptions; import org.apache.spark.sql.connector.catalog.Table; -import org.apache.spark.sql.connector.catalog.TableProvider; import org.apache.spark.sql.connector.expressions.Transform; import org.apache.spark.sql.sources.DataSourceRegister; import org.apache.spark.sql.types.StructType; @@ -25,29 +26,33 @@ import java.util.Map; -public class LanceDataSource implements TableProvider, DataSourceRegister { - private static final String name = "lance"; +public class LanceDataSource implements SupportsCatalogOptions, DataSourceRegister { + public static final String name = "lance"; @Override public StructType inferSchema(CaseInsensitiveStringMap options) { - // Given options help identify a table, no schema filter is passed in - return LanceReader.getSchema(LanceConfig.from(options)); + Optional schema = LanceDatasetAdapter.getSchema(LanceConfig.from(options)); + return schema.isPresent() ? schema.get() : null; } @Override public Table getTable(StructType schema, Transform[] partitioning, Map properties) { - LanceConfig config = LanceConfig.from(properties); - return new LanceTable(config, LanceReader.getSchema(config)); + return new LanceDataset(LanceConfig.from(properties), schema); } @Override - public boolean supportsExternalMetadata() { - return TableProvider.super.supportsExternalMetadata(); + public String shortName() { + return name; } @Override - public String shortName() { - return name; + public Identifier extractIdentifier(CaseInsensitiveStringMap options) { + return new LanceIdentifier(LanceConfig.from(options).getDatasetUri()); + } + + @Override + public String extractCatalog(CaseInsensitiveStringMap options) { + return "lance"; } } diff --git a/java/spark/src/main/java/com/lancedb/lance/spark/LanceDataset.java b/java/spark/src/main/java/com/lancedb/lance/spark/LanceDataset.java new file mode 100644 index 0000000000..702b3bdf42 --- /dev/null +++ b/java/spark/src/main/java/com/lancedb/lance/spark/LanceDataset.java @@ -0,0 +1,75 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ + +package com.lancedb.lance.spark; + +import com.google.common.collect.ImmutableSet; + +import java.util.Set; + +import com.lancedb.lance.spark.read.LanceScanBuilder; +import com.lancedb.lance.spark.write.SparkWrite; +import org.apache.spark.sql.connector.catalog.SupportsRead; +import org.apache.spark.sql.connector.catalog.SupportsWrite; +import org.apache.spark.sql.connector.catalog.TableCapability; +import org.apache.spark.sql.connector.read.ScanBuilder; +import org.apache.spark.sql.connector.write.LogicalWriteInfo; +import org.apache.spark.sql.connector.write.WriteBuilder; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.util.CaseInsensitiveStringMap; + +/** + * Lance Spark Dataset. + */ +public class LanceDataset implements SupportsRead, SupportsWrite { + private static final Set CAPABILITIES = + ImmutableSet.of(TableCapability.BATCH_READ, TableCapability.BATCH_WRITE); + + LanceConfig options; + private final StructType sparkSchema; + + /** + * Creates a Lance dataset. + * + * @param config read config + * @param sparkSchema spark struct type + */ + public LanceDataset(LanceConfig config, StructType sparkSchema) { + this.options = config; + this.sparkSchema = sparkSchema; + } + + @Override + public ScanBuilder newScanBuilder(CaseInsensitiveStringMap caseInsensitiveStringMap) { + return new LanceScanBuilder(sparkSchema, options); + } + + @Override + public String name() { + return this.options.getDatasetName(); + } + + @Override + public StructType schema() { + return sparkSchema; + } + + @Override + public Set capabilities() { + return CAPABILITIES; + } + + @Override + public WriteBuilder newWriteBuilder(LogicalWriteInfo logicalWriteInfo) { + return new SparkWrite.SparkWriteBuilder(sparkSchema, options); + } +} diff --git a/java/spark/src/main/java/com/lancedb/lance/spark/LanceIdentifier.java b/java/spark/src/main/java/com/lancedb/lance/spark/LanceIdentifier.java new file mode 100644 index 0000000000..4c872721ee --- /dev/null +++ b/java/spark/src/main/java/com/lancedb/lance/spark/LanceIdentifier.java @@ -0,0 +1,36 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.lancedb.lance.spark; + +import org.apache.spark.sql.connector.catalog.Identifier; + +public class LanceIdentifier implements Identifier { + private final String[] namespace = new String[]{"default"}; + private final String datasetUri; + + public LanceIdentifier(String datasetUri) { + this.datasetUri = datasetUri; + } + + @Override + public String[] namespace() { + return this.namespace; + } + + @Override + public String name() { + return datasetUri; + } +} \ No newline at end of file diff --git a/java/spark/src/main/java/com/lancedb/lance/spark/LanceTable.java b/java/spark/src/main/java/com/lancedb/lance/spark/LanceTable.java deleted file mode 100644 index e461226fb2..0000000000 --- a/java/spark/src/main/java/com/lancedb/lance/spark/LanceTable.java +++ /dev/null @@ -1,67 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.lancedb.lance.spark; - -import com.google.common.collect.ImmutableSet; -import com.lancedb.lance.spark.internal.LanceConfig; -import java.util.Set; -import org.apache.spark.sql.connector.catalog.SupportsRead; -import org.apache.spark.sql.connector.catalog.TableCapability; -import org.apache.spark.sql.connector.read.ScanBuilder; -import org.apache.spark.sql.types.StructType; -import org.apache.spark.sql.util.CaseInsensitiveStringMap; - -/** - * Lance Spark Table. - */ -public class LanceTable implements SupportsRead { - private static final Set CAPABILITIES = - ImmutableSet.of( - TableCapability.BATCH_READ); - - LanceConfig options; - private final StructType sparkSchema; - - /** - * Creates a spark table. - * - * @param config read config - * @param sparkSchema spark struct type - */ - public LanceTable(LanceConfig config, StructType sparkSchema) { - this.options = config; - this.sparkSchema = sparkSchema; - } - - @Override - public ScanBuilder newScanBuilder(CaseInsensitiveStringMap caseInsensitiveStringMap) { - return new LanceScanBuilder(sparkSchema, options); - } - - @Override - public String name() { - return this.options.getTableName(); - } - - @Override - public StructType schema() { - return this.sparkSchema; - } - - @Override - public Set capabilities() { - return CAPABILITIES; - } -} diff --git a/java/spark/src/main/java/com/lancedb/lance/spark/internal/LanceConfig.java b/java/spark/src/main/java/com/lancedb/lance/spark/internal/LanceConfig.java deleted file mode 100644 index 7d72b6b0cc..0000000000 --- a/java/spark/src/main/java/com/lancedb/lance/spark/internal/LanceConfig.java +++ /dev/null @@ -1,87 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.lancedb.lance.spark.internal; - -import java.io.Serializable; -import java.util.Map; -import org.apache.spark.sql.util.CaseInsensitiveStringMap; - -/** - * Lance Configuration. - */ -public class LanceConfig implements Serializable { - private static final long serialVersionUID = 827364827364823764L; - public static final String CONFIG_DB_PATH = "db"; - public static final String CONFIG_TABLE_NAME = "table"; - public static final String CONFIG_PUSH_DOWN_FILTERS = "pushDownFilters"; - private static final String LANCE_FILE_SUFFIX = ".lance"; - - private static final boolean DEFAULT_PUSH_DOWN_FILTERS = true; - - private final String dbPath; - private final String tableName; - private final String tablePath; - private final boolean pushDownFilters; - - private LanceConfig(String dbPath, String tableName, String tablePath, boolean pushDownFilters) { - this.dbPath = dbPath; - this.tableName = tableName; - this.tablePath = tablePath; - this.pushDownFilters = pushDownFilters; - } - - public static LanceConfig from(Map properties) { - return from(new CaseInsensitiveStringMap(properties)); - } - - public static LanceConfig from(CaseInsensitiveStringMap options) { - if (!options.containsKey(CONFIG_DB_PATH) || !options.containsKey(CONFIG_TABLE_NAME)) { - throw new IllegalArgumentException("Missing required options"); - } - - String dbPath = options.get(CONFIG_DB_PATH); - String tableName = options.get(CONFIG_TABLE_NAME); - boolean pushDownFilters = options.getBoolean(CONFIG_PUSH_DOWN_FILTERS, - DEFAULT_PUSH_DOWN_FILTERS); - - String tablePath = calculateTablePath(dbPath, tableName); - - return new LanceConfig(dbPath, tableName, tablePath, pushDownFilters); - } - - private static String calculateTablePath(String dbPath, String tableName) { - StringBuilder sb = new StringBuilder().append(dbPath); - if (!dbPath.endsWith("/")) { - sb.append("/"); - } - return sb.append(tableName).append(LANCE_FILE_SUFFIX).toString(); - } - - public String getDbPath() { - return dbPath; - } - - public String getTableName() { - return tableName; - } - - public String getTablePath() { - return tablePath; - } - - public boolean isPushDownFilters() { - return pushDownFilters; - } -} diff --git a/java/spark/src/main/java/com/lancedb/lance/spark/internal/LanceDatasetAdapter.java b/java/spark/src/main/java/com/lancedb/lance/spark/internal/LanceDatasetAdapter.java new file mode 100644 index 0000000000..16412f1026 --- /dev/null +++ b/java/spark/src/main/java/com/lancedb/lance/spark/internal/LanceDatasetAdapter.java @@ -0,0 +1,95 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.lancedb.lance.spark.internal; + +import com.lancedb.lance.Dataset; +import com.lancedb.lance.DatasetFragment; +import com.lancedb.lance.Fragment; +import com.lancedb.lance.FragmentMetadata; +import com.lancedb.lance.FragmentOperation; +import com.lancedb.lance.WriteParams; +import com.lancedb.lance.spark.LanceConfig; +import com.lancedb.lance.spark.read.LanceInputPartition; +import com.lancedb.lance.spark.utils.Optional; +import com.lancedb.lance.spark.write.LanceArrowWriter; +import org.apache.arrow.c.ArrowArrayStream; +import org.apache.arrow.c.Data; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.ipc.ArrowReader; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.util.ArrowUtils; + +import java.time.ZoneId; +import java.util.List; +import java.util.stream.Collectors; + +public class LanceDatasetAdapter { + private static final BufferAllocator allocator = new RootAllocator( + RootAllocator.configBuilder().from(RootAllocator.defaultConfig()) + .maxAllocation(4 * 1024 * 1024).build()); + + public static Optional getSchema(LanceConfig config) { + return getSchema(config.getDatasetUri()); + } + + public static Optional getSchema(String datasetUri) { + try (Dataset dataset = Dataset.open(datasetUri, allocator)) { + return Optional.of(ArrowUtils.fromArrowSchema(dataset.getSchema())); + } catch (IllegalArgumentException e) { + // dataset not found + return Optional.empty(); + } + } + + public static List getFragmentIds(LanceConfig config) { + try (Dataset dataset = Dataset.open(config.getDatasetUri(), allocator)) { + return dataset.getFragments().stream() + .map(DatasetFragment::getId).collect(Collectors.toList()); + } + } + + public static LanceFragmentScanner getFragmentScanner(int fragmentId, + LanceInputPartition inputPartition) { + return LanceFragmentScanner.create(fragmentId, inputPartition, allocator); + } + + public static void appendFragments(LanceConfig config, List fragments) { + FragmentOperation.Append appendOp = new FragmentOperation.Append(fragments); + try (Dataset datasetRead = Dataset.open(config.getDatasetUri(), allocator)) { + Dataset.commit(allocator, config.getDatasetUri(), + appendOp, java.util.Optional.of(datasetRead.version())).close(); + } + } + + public static LanceArrowWriter getArrowWriter(StructType sparkSchema, int batchSize) { + return new LanceArrowWriter(allocator, + ArrowUtils.toArrowSchema(sparkSchema, "UTC", false, false), batchSize); + } + + public static FragmentMetadata createFragment(String datasetUri, ArrowReader reader) { + try (ArrowArrayStream arrowStream = ArrowArrayStream.allocateNew(allocator)) { + Data.exportArrayStream(allocator, reader, arrowStream); + return Fragment.create(datasetUri, arrowStream, + java.util.Optional.empty(), new WriteParams.Builder().build()); + } + } + + public static void createDataset(String datasetUri, StructType sparkSchema) { + Dataset.create(allocator, datasetUri, + ArrowUtils.toArrowSchema(sparkSchema, ZoneId.systemDefault().getId(), true, false), + new WriteParams.Builder().build()).close(); + } +} diff --git a/java/spark/src/main/java/com/lancedb/lance/spark/internal/LanceFragmentColumnarBatchScanner.java b/java/spark/src/main/java/com/lancedb/lance/spark/internal/LanceFragmentColumnarBatchScanner.java index fa0b165047..660ec55770 100644 --- a/java/spark/src/main/java/com/lancedb/lance/spark/internal/LanceFragmentColumnarBatchScanner.java +++ b/java/spark/src/main/java/com/lancedb/lance/spark/internal/LanceFragmentColumnarBatchScanner.java @@ -14,7 +14,7 @@ package com.lancedb.lance.spark.internal; -import com.lancedb.lance.spark.LanceInputPartition; +import com.lancedb.lance.spark.read.LanceInputPartition; import org.apache.arrow.vector.VectorSchemaRoot; import org.apache.arrow.vector.ipc.ArrowReader; import org.apache.spark.sql.vectorized.ArrowColumnVector; @@ -35,7 +35,7 @@ public LanceFragmentColumnarBatchScanner(LanceFragmentScanner fragmentScanner, public static LanceFragmentColumnarBatchScanner create( int fragmentId, LanceInputPartition inputPartition) { - LanceFragmentScanner fragmentScanner = LanceReader + LanceFragmentScanner fragmentScanner = LanceDatasetAdapter .getFragmentScanner(fragmentId, inputPartition); return new LanceFragmentColumnarBatchScanner(fragmentScanner, fragmentScanner.getArrowReader()); } diff --git a/java/spark/src/main/java/com/lancedb/lance/spark/internal/LanceFragmentScanner.java b/java/spark/src/main/java/com/lancedb/lance/spark/internal/LanceFragmentScanner.java index f8b5ce7432..55ba61e797 100644 --- a/java/spark/src/main/java/com/lancedb/lance/spark/internal/LanceFragmentScanner.java +++ b/java/spark/src/main/java/com/lancedb/lance/spark/internal/LanceFragmentScanner.java @@ -18,7 +18,7 @@ import com.lancedb.lance.DatasetFragment; import com.lancedb.lance.ipc.LanceScanner; import com.lancedb.lance.ipc.ScanOptions; -import com.lancedb.lance.spark.LanceInputPartition; +import com.lancedb.lance.spark.read.LanceInputPartition; import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.vector.ipc.ArrowReader; import org.apache.spark.sql.types.StructField; @@ -46,7 +46,7 @@ public static LanceFragmentScanner create(int fragmentId, DatasetFragment fragment = null; LanceScanner scanner = null; try { - dataset = Dataset.open(inputPartition.getConfig().getTablePath(), allocator); + dataset = Dataset.open(inputPartition.getConfig().getDatasetUri(), allocator); fragment = dataset.getFragments().get(fragmentId); ScanOptions.Builder scanOptions = new ScanOptions.Builder(); scanOptions.columns(getColumnNames(inputPartition.getSchema())); diff --git a/java/spark/src/main/java/com/lancedb/lance/spark/internal/LanceReader.java b/java/spark/src/main/java/com/lancedb/lance/spark/internal/LanceReader.java deleted file mode 100644 index 1f7b9f7926..0000000000 --- a/java/spark/src/main/java/com/lancedb/lance/spark/internal/LanceReader.java +++ /dev/null @@ -1,52 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.lancedb.lance.spark.internal; - -import com.lancedb.lance.Dataset; -import com.lancedb.lance.DatasetFragment; -import com.lancedb.lance.spark.LanceInputPartition; -import org.apache.arrow.memory.BufferAllocator; -import org.apache.arrow.memory.RootAllocator; -import org.apache.spark.sql.types.StructType; -import org.apache.spark.sql.util.ArrowUtils; - -import java.util.List; -import java.util.stream.Collectors; - -public class LanceReader { - private static final BufferAllocator allocator = new RootAllocator( - RootAllocator.configBuilder().from(RootAllocator.defaultConfig()) - .maxAllocation(4 * 1024 * 1024).build()); - - public static StructType getSchema(LanceConfig options) - { - try (Dataset dataset = Dataset.open(options.getTablePath(), allocator)) { - return ArrowUtils.fromArrowSchema(dataset.getSchema()); - } - } - - public static List getFragmentIds(LanceConfig config) { - try (Dataset dataset = Dataset.open(config.getTablePath(), allocator)) { - return dataset.getFragments().stream() - .map(DatasetFragment::getId).collect(Collectors.toList()); - } - } - - public static LanceFragmentScanner getFragmentScanner(int fragmentId, - LanceInputPartition inputPartition) { - return LanceFragmentScanner.create(fragmentId, inputPartition, allocator); - } -} - diff --git a/java/spark/src/main/java/com/lancedb/lance/spark/query/FilterPushDown.java b/java/spark/src/main/java/com/lancedb/lance/spark/read/FilterPushDown.java similarity index 99% rename from java/spark/src/main/java/com/lancedb/lance/spark/query/FilterPushDown.java rename to java/spark/src/main/java/com/lancedb/lance/spark/read/FilterPushDown.java index 59403f2dc7..7cc30dc74a 100644 --- a/java/spark/src/main/java/com/lancedb/lance/spark/query/FilterPushDown.java +++ b/java/spark/src/main/java/com/lancedb/lance/spark/read/FilterPushDown.java @@ -12,7 +12,7 @@ * limitations under the License. */ -package com.lancedb.lance.spark.query; +package com.lancedb.lance.spark.read; import com.lancedb.lance.spark.utils.Optional; import org.apache.spark.sql.sources.And; diff --git a/java/spark/src/main/java/com/lancedb/lance/spark/LanceColumnarPartitionReader.java b/java/spark/src/main/java/com/lancedb/lance/spark/read/LanceColumnarPartitionReader.java similarity index 98% rename from java/spark/src/main/java/com/lancedb/lance/spark/LanceColumnarPartitionReader.java rename to java/spark/src/main/java/com/lancedb/lance/spark/read/LanceColumnarPartitionReader.java index 973cc6763c..5745709823 100644 --- a/java/spark/src/main/java/com/lancedb/lance/spark/LanceColumnarPartitionReader.java +++ b/java/spark/src/main/java/com/lancedb/lance/spark/read/LanceColumnarPartitionReader.java @@ -12,7 +12,7 @@ * limitations under the License. */ -package com.lancedb.lance.spark; +package com.lancedb.lance.spark.read; import com.lancedb.lance.spark.internal.LanceFragmentColumnarBatchScanner; import org.apache.spark.sql.connector.read.PartitionReader; diff --git a/java/spark/src/main/java/com/lancedb/lance/spark/LanceInputPartition.java b/java/spark/src/main/java/com/lancedb/lance/spark/read/LanceInputPartition.java similarity index 95% rename from java/spark/src/main/java/com/lancedb/lance/spark/LanceInputPartition.java rename to java/spark/src/main/java/com/lancedb/lance/spark/read/LanceInputPartition.java index a4ffbe91a8..3525502a63 100644 --- a/java/spark/src/main/java/com/lancedb/lance/spark/LanceInputPartition.java +++ b/java/spark/src/main/java/com/lancedb/lance/spark/read/LanceInputPartition.java @@ -12,9 +12,9 @@ * limitations under the License. */ -package com.lancedb.lance.spark; +package com.lancedb.lance.spark.read; -import com.lancedb.lance.spark.internal.LanceConfig; +import com.lancedb.lance.spark.LanceConfig; import com.lancedb.lance.spark.utils.Optional; import org.apache.spark.sql.connector.read.InputPartition; import org.apache.spark.sql.types.StructType; diff --git a/java/spark/src/main/java/com/lancedb/lance/spark/LanceRowPartitionReader.java b/java/spark/src/main/java/com/lancedb/lance/spark/read/LanceRowPartitionReader.java similarity index 98% rename from java/spark/src/main/java/com/lancedb/lance/spark/LanceRowPartitionReader.java rename to java/spark/src/main/java/com/lancedb/lance/spark/read/LanceRowPartitionReader.java index 8f8e66f87d..88c105e7c7 100644 --- a/java/spark/src/main/java/com/lancedb/lance/spark/LanceRowPartitionReader.java +++ b/java/spark/src/main/java/com/lancedb/lance/spark/read/LanceRowPartitionReader.java @@ -12,7 +12,7 @@ * limitations under the License. */ -package com.lancedb.lance.spark; +package com.lancedb.lance.spark.read; import org.apache.spark.sql.connector.read.PartitionReader; import org.apache.spark.sql.catalyst.InternalRow; diff --git a/java/spark/src/main/java/com/lancedb/lance/spark/LanceScan.java b/java/spark/src/main/java/com/lancedb/lance/spark/read/LanceScan.java similarity index 97% rename from java/spark/src/main/java/com/lancedb/lance/spark/LanceScan.java rename to java/spark/src/main/java/com/lancedb/lance/spark/read/LanceScan.java index 2df28ebd89..382cae20d3 100644 --- a/java/spark/src/main/java/com/lancedb/lance/spark/LanceScan.java +++ b/java/spark/src/main/java/com/lancedb/lance/spark/read/LanceScan.java @@ -12,9 +12,9 @@ * limitations under the License. */ -package com.lancedb.lance.spark; +package com.lancedb.lance.spark.read; -import com.lancedb.lance.spark.internal.LanceConfig; +import com.lancedb.lance.spark.LanceConfig; import com.lancedb.lance.spark.utils.Optional; import org.apache.arrow.util.Preconditions; import org.apache.spark.sql.catalyst.InternalRow; diff --git a/java/spark/src/main/java/com/lancedb/lance/spark/LanceScanBuilder.java b/java/spark/src/main/java/com/lancedb/lance/spark/read/LanceScanBuilder.java similarity index 93% rename from java/spark/src/main/java/com/lancedb/lance/spark/LanceScanBuilder.java rename to java/spark/src/main/java/com/lancedb/lance/spark/read/LanceScanBuilder.java index 69167b44e5..9fba4601c3 100644 --- a/java/spark/src/main/java/com/lancedb/lance/spark/LanceScanBuilder.java +++ b/java/spark/src/main/java/com/lancedb/lance/spark/read/LanceScanBuilder.java @@ -12,10 +12,9 @@ * limitations under the License. */ -package com.lancedb.lance.spark; +package com.lancedb.lance.spark.read; -import com.lancedb.lance.spark.internal.LanceConfig; -import com.lancedb.lance.spark.query.FilterPushDown; +import com.lancedb.lance.spark.LanceConfig; import com.lancedb.lance.spark.utils.Optional; import org.apache.spark.sql.connector.read.Scan; import org.apache.spark.sql.connector.read.SupportsPushDownFilters; diff --git a/java/spark/src/main/java/com/lancedb/lance/spark/LanceSplit.java b/java/spark/src/main/java/com/lancedb/lance/spark/read/LanceSplit.java similarity index 85% rename from java/spark/src/main/java/com/lancedb/lance/spark/LanceSplit.java rename to java/spark/src/main/java/com/lancedb/lance/spark/read/LanceSplit.java index 8081101ade..4e46b464df 100644 --- a/java/spark/src/main/java/com/lancedb/lance/spark/LanceSplit.java +++ b/java/spark/src/main/java/com/lancedb/lance/spark/read/LanceSplit.java @@ -12,10 +12,10 @@ * limitations under the License. */ -package com.lancedb.lance.spark; +package com.lancedb.lance.spark.read; -import com.lancedb.lance.spark.internal.LanceConfig; -import com.lancedb.lance.spark.internal.LanceReader; +import com.lancedb.lance.spark.LanceConfig; +import com.lancedb.lance.spark.internal.LanceDatasetAdapter; import java.io.Serializable; import java.util.Collections; @@ -36,7 +36,7 @@ public List getFragments() { } public static List generateLanceSplits(LanceConfig config) { - return LanceReader.getFragmentIds(config).stream() + return LanceDatasetAdapter.getFragmentIds(config).stream() .map(id -> new LanceSplit(Collections.singletonList(id))) .collect(Collectors.toList()); } diff --git a/java/spark/src/main/java/com/lancedb/lance/spark/write/BatchAppend.java b/java/spark/src/main/java/com/lancedb/lance/spark/write/BatchAppend.java new file mode 100644 index 0000000000..bf41dcefe8 --- /dev/null +++ b/java/spark/src/main/java/com/lancedb/lance/spark/write/BatchAppend.java @@ -0,0 +1,80 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.lancedb.lance.spark.write; + +import com.lancedb.lance.FragmentMetadata; +import com.lancedb.lance.spark.LanceConfig; +import com.lancedb.lance.spark.internal.LanceDatasetAdapter; +import org.apache.spark.sql.connector.write.BatchWrite; +import org.apache.spark.sql.connector.write.DataWriterFactory; +import org.apache.spark.sql.connector.write.PhysicalWriteInfo; +import org.apache.spark.sql.connector.write.WriterCommitMessage; +import org.apache.spark.sql.types.StructType; + +import java.util.Arrays; +import java.util.List; +import java.util.stream.Collectors; + +public class BatchAppend implements BatchWrite { + private final StructType schema; + private final LanceConfig config; + + public BatchAppend(StructType schema, LanceConfig config) { + this.schema = schema; + this.config = config; + } + + @Override + public DataWriterFactory createBatchWriterFactory(PhysicalWriteInfo info) { + return new LanceDataWriter.WriterFactory(schema, config); + } + + @Override + public boolean useCommitCoordinator() { + return false; + } + + @Override + public void commit(WriterCommitMessage[] messages) { + List fragments = Arrays.stream(messages) + .map(m -> (TaskCommit) m) + .map(TaskCommit::getFragments) + .flatMap(List::stream) + .collect(Collectors.toList()); + LanceDatasetAdapter.appendFragments(config, fragments); + } + + @Override + public void abort(WriterCommitMessage[] messages) { + throw new UnsupportedOperationException(); + } + + @Override + public String toString() { + return String.format("LanceBatchWrite(datasetUri=%s)", config.getDatasetUri()); + } + + public static class TaskCommit implements WriterCommitMessage { + private final List fragments; + + TaskCommit(List fragments) { + this.fragments = fragments; + } + + List getFragments() { + return fragments; + } + } +} diff --git a/java/spark/src/main/java/com/lancedb/lance/spark/write/LanceArrowWriter.java b/java/spark/src/main/java/com/lancedb/lance/spark/write/LanceArrowWriter.java new file mode 100644 index 0000000000..341aa11ab3 --- /dev/null +++ b/java/spark/src/main/java/com/lancedb/lance/spark/write/LanceArrowWriter.java @@ -0,0 +1,122 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.lancedb.lance.spark.write; + +import com.google.common.base.Preconditions; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.vector.ipc.ArrowReader; +import org.apache.arrow.vector.types.pojo.Schema; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.execution.arrow.ArrowWriter; + +import javax.annotation.concurrent.GuardedBy; +import java.io.IOException; +import java.util.Queue; +import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.atomic.AtomicLong; + +/** + * A custom arrow reader that supports writes Spark internal rows while reading data in batches. + */ +public class LanceArrowWriter extends ArrowReader { + private final Schema schema; + private final int batchSize; + private final Object monitor = new Object(); + @GuardedBy("monitor") + private final Queue rowQueue = new ConcurrentLinkedQueue<>(); + @GuardedBy("monitor") + private volatile boolean finished; + + private final AtomicLong totalBytesRead = new AtomicLong(); + private ArrowWriter arrowWriter = null; + + public LanceArrowWriter(BufferAllocator allocator, Schema schema, int batchSize) { + super(allocator); + Preconditions.checkNotNull(schema); + Preconditions.checkArgument(batchSize > 0); + this.schema = schema; + // TODO(lu) batch size as config? + this.batchSize = batchSize; + } + + void write(InternalRow row) { + Preconditions.checkNotNull(row); + synchronized (monitor) { + // TODO(lu) wait if too much elements in rowQueue + rowQueue.offer(row); + monitor.notify(); + } + } + + void setFinished() { + synchronized (monitor) { + finished = true; + monitor.notify(); + } + } + + @Override + protected void prepareLoadNextBatch() throws IOException { + super.prepareLoadNextBatch(); + // Do not use ArrowWriter.reset since it does not work well with Arrow JNI + arrowWriter = ArrowWriter.create(this.getVectorSchemaRoot()); + } + + @Override + public boolean loadNextBatch() throws IOException { + prepareLoadNextBatch(); + int rowCount = 0; + synchronized (monitor) { + while (rowCount < batchSize) { + while (rowQueue.isEmpty() && !finished) { + try { + monitor.wait(); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new IOException("Interrupted while waiting for data", e); + } + } + if (rowQueue.isEmpty() && finished) { + break; + } + InternalRow row = rowQueue.poll(); + if (row != null) { + arrowWriter.write(row); + rowCount++; + } + } + } + if (rowCount == 0) { + return false; + } + arrowWriter.finish(); + return true; + } + + @Override + public long bytesRead() { + throw new UnsupportedOperationException(); + } + + @Override + protected synchronized void closeReadSource() throws IOException { + // Implement if needed + } + + @Override + protected Schema readSchema() { + return this.schema; + } +} diff --git a/java/spark/src/main/java/com/lancedb/lance/spark/write/LanceDataWriter.java b/java/spark/src/main/java/com/lancedb/lance/spark/write/LanceDataWriter.java new file mode 100644 index 0000000000..044c5d2e24 --- /dev/null +++ b/java/spark/src/main/java/com/lancedb/lance/spark/write/LanceDataWriter.java @@ -0,0 +1,102 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.lancedb.lance.spark.write; + +import com.lancedb.lance.FragmentMetadata; +import com.lancedb.lance.spark.LanceConfig; +import com.lancedb.lance.spark.internal.LanceDatasetAdapter; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.connector.write.DataWriter; +import org.apache.spark.sql.connector.write.DataWriterFactory; +import org.apache.spark.sql.connector.write.WriterCommitMessage; +import org.apache.spark.sql.types.StructType; + +import java.io.IOException; +import java.util.Arrays; +import java.util.concurrent.Callable; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.FutureTask; + +public class LanceDataWriter implements DataWriter { + private LanceArrowWriter arrowWriter; + private FutureTask fragmentCreationTask; + private Thread fragmentCreationThread; + + private LanceDataWriter(LanceArrowWriter arrowWriter, + FutureTask fragmentCreationTask, Thread fragmentCreationThread) { + // TODO support write to multiple fragments + this.arrowWriter = arrowWriter; + this.fragmentCreationThread = fragmentCreationThread; + this.fragmentCreationTask = fragmentCreationTask; + } + + @Override + public void write(InternalRow record) throws IOException { + arrowWriter.write(record); + } + + @Override + public WriterCommitMessage commit() throws IOException { + arrowWriter.setFinished(); + try { + FragmentMetadata fragmentMetadata = fragmentCreationTask.get(); + return new BatchAppend.TaskCommit(Arrays.asList(fragmentMetadata)); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new IOException("Interrupted while waiting for reader thread to finish", e); + } catch (ExecutionException e) { + throw new IOException("Exception in reader thread", e); + } + } + + @Override + public void abort() throws IOException { + fragmentCreationThread.interrupt(); + try { + fragmentCreationTask.get(); + } catch (InterruptedException | ExecutionException e) { + throw new IOException("Failed to abort the reader thread", e); + } + close(); + } + + @Override + public void close() throws IOException { + arrowWriter.close(); + } + + public static class WriterFactory implements DataWriterFactory { + private final LanceConfig config; + private final StructType schema; + + protected WriterFactory(StructType schema, LanceConfig config) { + // Everything passed to writer factory should be serializable + this.schema = schema; + this.config = config; + } + + @Override + public DataWriter createWriter(int partitionId, long taskId) { + LanceArrowWriter arrowWriter = LanceDatasetAdapter.getArrowWriter(schema, 1024); + Callable fragmentCreator + = () -> LanceDatasetAdapter.createFragment(config.getDatasetUri(), arrowWriter); + FutureTask fragmentCreationTask = new FutureTask<>(fragmentCreator); + Thread fragmentCreationThread = new Thread(fragmentCreationTask); + fragmentCreationThread.start(); + + return new LanceDataWriter(arrowWriter, fragmentCreationTask, fragmentCreationThread); + } + } +} \ No newline at end of file diff --git a/java/spark/src/main/java/com/lancedb/lance/spark/write/SparkWrite.java b/java/spark/src/main/java/com/lancedb/lance/spark/write/SparkWrite.java new file mode 100644 index 0000000000..857387d018 --- /dev/null +++ b/java/spark/src/main/java/com/lancedb/lance/spark/write/SparkWrite.java @@ -0,0 +1,62 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.lancedb.lance.spark.write; + +import com.lancedb.lance.spark.LanceConfig; +import org.apache.spark.sql.connector.write.BatchWrite; +import org.apache.spark.sql.connector.write.Write; +import org.apache.spark.sql.connector.write.WriteBuilder; +import org.apache.spark.sql.connector.write.streaming.StreamingWrite; +import org.apache.spark.sql.types.StructType; + +/** + * Spark write builder. + */ +public class SparkWrite implements Write { + private final LanceConfig config; + private final StructType schema; + + SparkWrite(StructType schema, LanceConfig config) { + this.schema = schema; + this.config = config; + } + + @Override + public BatchWrite toBatch() { + return new BatchAppend(schema, config); + } + + @Override + public StreamingWrite toStreaming() { + throw new UnsupportedOperationException(); + } + + /** Task commit. */ + + public static class SparkWriteBuilder implements WriteBuilder { + private final LanceConfig options; + private final StructType schema; + + public SparkWriteBuilder(StructType schema, LanceConfig options) { + this.schema = schema; + this.options = options; + } + + @Override + public Write build() { + return new SparkWrite(schema, options); + } + } +} \ No newline at end of file diff --git a/java/spark/src/test/java/com/lancedb/lance/spark/LanceConfigTest.java b/java/spark/src/test/java/com/lancedb/lance/spark/LanceConfigTest.java index 068caf1e05..56713f5c39 100644 --- a/java/spark/src/test/java/com/lancedb/lance/spark/LanceConfigTest.java +++ b/java/spark/src/test/java/com/lancedb/lance/spark/LanceConfigTest.java @@ -14,7 +14,6 @@ package com.lancedb.lance.spark; -import com.lancedb.lance.spark.internal.LanceConfig; import org.apache.spark.sql.util.CaseInsensitiveStringMap; import org.junit.jupiter.api.Test; @@ -26,63 +25,63 @@ public class LanceConfigTest { @Test public void testLanceConfigFromCaseInsensitiveStringMap() { - String dbPath = "file://path/to/db"; - String tableName = "testTableName"; + String dbPath = "file://path/to/db/"; + String datasetName = "testDatasetName"; + String datasetUri = LanceConfig.getDatasetUri(dbPath, datasetName); CaseInsensitiveStringMap options = new CaseInsensitiveStringMap(new HashMap() {{ - put(LanceConfig.CONFIG_DB_PATH, dbPath); - put(LanceConfig.CONFIG_TABLE_NAME, tableName); + put(LanceConfig.CONFIG_DATASET_URI, datasetUri); }}); LanceConfig config = LanceConfig.from(options); assertEquals(dbPath, config.getDbPath()); - assertEquals(tableName, config.getTableName()); - assertEquals(dbPath + "/" + tableName + ".lance", config.getTablePath()); + assertEquals(datasetName, config.getDatasetName()); + assertEquals(datasetUri, config.getDatasetUri()); } @Test public void testLanceConfigFromCaseInsensitiveStringMap2() { String dbPath = "s3://bucket/folder/"; - String tableName = "testTableName"; + String datasetName = "testDatasetName"; + String datasetUri = LanceConfig.getDatasetUri(dbPath, datasetName); CaseInsensitiveStringMap options = new CaseInsensitiveStringMap(new HashMap() {{ - put(LanceConfig.CONFIG_DB_PATH, dbPath); - put(LanceConfig.CONFIG_TABLE_NAME, tableName); + put(LanceConfig.CONFIG_DATASET_URI, datasetUri); }}); LanceConfig config = LanceConfig.from(options); assertEquals(dbPath, config.getDbPath()); - assertEquals(tableName, config.getTableName()); - assertEquals(dbPath + tableName + ".lance", config.getTablePath()); + assertEquals(datasetName, config.getDatasetName()); + assertEquals(datasetUri, config.getDatasetUri()); } @Test public void testLanceConfigFromMap() { - String dbPath = "file://path/to/db"; - String tableName = "testTableName"; + String dbPath = "file://path/to/db/"; + String datasetName = "testDatasetName"; + String datasetUri = LanceConfig.getDatasetUri(dbPath, datasetName); Map properties = new HashMap<>(); - properties.put(LanceConfig.CONFIG_DB_PATH, dbPath); - properties.put(LanceConfig.CONFIG_TABLE_NAME, tableName); + properties.put(LanceConfig.CONFIG_DATASET_URI, datasetUri); LanceConfig config = LanceConfig.from(properties); assertEquals(dbPath, config.getDbPath()); - assertEquals(tableName, config.getTableName()); - assertEquals(dbPath + "/" + tableName + ".lance", config.getTablePath()); + assertEquals(datasetName, config.getDatasetName()); + assertEquals(datasetUri, config.getDatasetUri()); } @Test public void testLanceConfigFromMap2() { String dbPath = "s3://bucket/folder/"; - String tableName = "testTableName"; + String datasetName = "testDatasetName"; + String datasetUri = LanceConfig.getDatasetUri(dbPath, datasetName); Map properties = new HashMap<>(); - properties.put(LanceConfig.CONFIG_DB_PATH, dbPath); - properties.put(LanceConfig.CONFIG_TABLE_NAME, tableName); + properties.put(LanceConfig.CONFIG_DATASET_URI, datasetUri); LanceConfig config = LanceConfig.from(properties); assertEquals(dbPath, config.getDbPath()); - assertEquals(tableName, config.getTableName()); - assertEquals(dbPath + tableName + ".lance", config.getTablePath()); + assertEquals(datasetName, config.getDatasetName()); + assertEquals(datasetUri, config.getDatasetUri()); } } diff --git a/java/spark/src/test/java/com/lancedb/lance/spark/TestUtils.java b/java/spark/src/test/java/com/lancedb/lance/spark/TestUtils.java index e2002d1207..fb89f16656 100644 --- a/java/spark/src/test/java/com/lancedb/lance/spark/TestUtils.java +++ b/java/spark/src/test/java/com/lancedb/lance/spark/TestUtils.java @@ -14,7 +14,8 @@ package com.lancedb.lance.spark; -import com.lancedb.lance.spark.internal.LanceConfig; +import com.lancedb.lance.spark.read.LanceInputPartition; +import com.lancedb.lance.spark.read.LanceSplit; import com.lancedb.lance.spark.utils.Optional; import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.types.StructField; @@ -22,29 +23,13 @@ import java.net.URL; import java.util.Arrays; -import java.util.HashMap; import java.util.List; -import java.util.Map; public class TestUtils { - /** - * Converts dbPath and tableName to a LanceConfig instance using a Map. - * - * @param dbPath The database path. - * @param tableName The table name. - * @return A LanceConfig instance. - */ - public static LanceConfig createLanceConfig(String dbPath, String tableName) { - Map properties = new HashMap<>(); - properties.put(LanceConfig.CONFIG_DB_PATH, dbPath); - properties.put(LanceConfig.CONFIG_TABLE_NAME, tableName); - return LanceConfig.from(properties); - } - public static class TestTable1Config { public static final String dbPath; - public static final String tableName = "test_table1"; - public static final String tablePath; + public static final String datasetName = "test_dataset1"; + public static final String datasetUri; public static final List> expectedValues = Arrays.asList( Arrays.asList(0L, 0L, 0L, 0L), Arrays.asList(1L, 2L, 3L, -1L), @@ -69,8 +54,8 @@ public static class TestTable1Config { } else { throw new IllegalArgumentException("example_db not found in resources directory"); } - tablePath = String.format("%s/%s.lance", dbPath, tableName); - lanceConfig = createLanceConfig(dbPath, tableName); + datasetUri = LanceConfig.getDatasetUri(dbPath, datasetName); + lanceConfig = LanceConfig.from(datasetUri); inputPartition = new LanceInputPartition(schema, 0, new LanceSplit(Arrays.asList(0, 1)), lanceConfig, Optional.empty()); } } diff --git a/java/spark/src/test/java/com/lancedb/lance/spark/FilterPushDownTest.java b/java/spark/src/test/java/com/lancedb/lance/spark/read/FilterPushDownTest.java similarity index 97% rename from java/spark/src/test/java/com/lancedb/lance/spark/FilterPushDownTest.java rename to java/spark/src/test/java/com/lancedb/lance/spark/read/FilterPushDownTest.java index 786403fc95..5376ba0b7c 100644 --- a/java/spark/src/test/java/com/lancedb/lance/spark/FilterPushDownTest.java +++ b/java/spark/src/test/java/com/lancedb/lance/spark/read/FilterPushDownTest.java @@ -12,9 +12,8 @@ * limitations under the License. */ -package com.lancedb.lance.spark; +package com.lancedb.lance.spark.read; -import com.lancedb.lance.spark.query.FilterPushDown; import com.lancedb.lance.spark.utils.Optional; import org.apache.spark.sql.sources.*; import org.junit.jupiter.api.Test; @@ -22,7 +21,6 @@ import static org.junit.jupiter.api.Assertions.*; public class FilterPushDownTest { - @Test public void testCompileFiltersToSqlWhereClause() { // Test case 1: GreaterThan, LessThanOrEqual, IsNotNull diff --git a/java/spark/src/test/java/com/lancedb/lance/spark/LanceColumnarPartitionReaderTest.java b/java/spark/src/test/java/com/lancedb/lance/spark/read/LanceColumnarPartitionReaderTest.java similarity index 96% rename from java/spark/src/test/java/com/lancedb/lance/spark/LanceColumnarPartitionReaderTest.java rename to java/spark/src/test/java/com/lancedb/lance/spark/read/LanceColumnarPartitionReaderTest.java index b60d861a3d..23bfd233fc 100644 --- a/java/spark/src/test/java/com/lancedb/lance/spark/LanceColumnarPartitionReaderTest.java +++ b/java/spark/src/test/java/com/lancedb/lance/spark/read/LanceColumnarPartitionReaderTest.java @@ -12,8 +12,9 @@ * limitations under the License. */ -package com.lancedb.lance.spark; +package com.lancedb.lance.spark.read; +import com.lancedb.lance.spark.TestUtils; import com.lancedb.lance.spark.utils.Optional; import org.apache.spark.sql.vectorized.ColumnarBatch; import org.junit.jupiter.api.Test; diff --git a/java/spark/src/test/java/com/lancedb/lance/spark/LanceReaderTest.java b/java/spark/src/test/java/com/lancedb/lance/spark/read/LanceDatasetReadTest.java similarity index 82% rename from java/spark/src/test/java/com/lancedb/lance/spark/LanceReaderTest.java rename to java/spark/src/test/java/com/lancedb/lance/spark/read/LanceDatasetReadTest.java index 80379f8173..6423a13ce0 100644 --- a/java/spark/src/test/java/com/lancedb/lance/spark/LanceReaderTest.java +++ b/java/spark/src/test/java/com/lancedb/lance/spark/read/LanceDatasetReadTest.java @@ -12,10 +12,11 @@ * limitations under the License. */ -package com.lancedb.lance.spark; +package com.lancedb.lance.spark.read; +import com.lancedb.lance.spark.TestUtils; import com.lancedb.lance.spark.internal.LanceFragmentScanner; -import com.lancedb.lance.spark.internal.LanceReader; +import com.lancedb.lance.spark.internal.LanceDatasetAdapter; import com.lancedb.lance.spark.utils.Optional; import org.apache.arrow.vector.VectorSchemaRoot; import org.apache.arrow.vector.ipc.ArrowReader; @@ -30,25 +31,25 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; -public class LanceReaderTest { - +public class LanceDatasetReadTest { @Test public void testSchema() { StructType expectedSchema = TestUtils.TestTable1Config.schema; - StructType schema = LanceReader.getSchema(TestUtils.TestTable1Config.lanceConfig); - assertNotNull(schema); - assertEquals(expectedSchema, schema); + Optional schema = LanceDatasetAdapter.getSchema(TestUtils.TestTable1Config.lanceConfig); + assertTrue(schema.isPresent()); + assertEquals(expectedSchema, schema.get()); } - + @Test public void testFragmentIds() { - List fragments = LanceReader.getFragmentIds(TestUtils.TestTable1Config.lanceConfig); + List fragments = LanceDatasetAdapter.getFragmentIds(TestUtils.TestTable1Config.lanceConfig); assertEquals(2, fragments.size()); assertEquals(0, fragments.get(0)); assertEquals(1, fragments.get(1)); } - + @Test public void getFragmentScanner() throws IOException { List> expectedValues = Arrays.asList( @@ -80,7 +81,7 @@ public void getFragmentScanner() throws IOException { } public void validateFragment(List> expectedValues, int fragment, StructType schema) throws IOException { - try (LanceFragmentScanner scanner = LanceReader.getFragmentScanner(fragment, + try (LanceFragmentScanner scanner = LanceDatasetAdapter.getFragmentScanner(fragment, new LanceInputPartition(schema, 0, new LanceSplit(Arrays.asList(fragment)), TestUtils.TestTable1Config.lanceConfig, Optional.empty()))) { try (ArrowReader reader = scanner.getArrowReader()) { @@ -98,5 +99,5 @@ public void validateFragment(List> expectedValues, int fragment, St } } - // TODO test_table4 [UNSUPPORTED_ARROWTYPE] Unsupported arrow type FixedSizeList(128). + // TODO test_dataset4 [UNSUPPORTED_ARROWTYPE] Unsupported arrow type FixedSizeList(128). } diff --git a/java/spark/src/test/java/com/lancedb/lance/spark/LanceFragmentColumnarBatchScannerTest.java b/java/spark/src/test/java/com/lancedb/lance/spark/read/LanceFragmentColumnarBatchScannerTest.java similarity index 96% rename from java/spark/src/test/java/com/lancedb/lance/spark/LanceFragmentColumnarBatchScannerTest.java rename to java/spark/src/test/java/com/lancedb/lance/spark/read/LanceFragmentColumnarBatchScannerTest.java index aae11b2926..cda163db71 100644 --- a/java/spark/src/test/java/com/lancedb/lance/spark/LanceFragmentColumnarBatchScannerTest.java +++ b/java/spark/src/test/java/com/lancedb/lance/spark/read/LanceFragmentColumnarBatchScannerTest.java @@ -12,8 +12,9 @@ * limitations under the License. */ -package com.lancedb.lance.spark; +package com.lancedb.lance.spark.read; +import com.lancedb.lance.spark.TestUtils; import com.lancedb.lance.spark.internal.LanceFragmentColumnarBatchScanner; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.vectorized.ColumnarBatch; diff --git a/java/spark/src/test/java/com/lancedb/lance/spark/LargeLanceDatasetSparkConnectorTest.java b/java/spark/src/test/java/com/lancedb/lance/spark/read/SparkConnectorLineItemTest.java similarity index 93% rename from java/spark/src/test/java/com/lancedb/lance/spark/LargeLanceDatasetSparkConnectorTest.java rename to java/spark/src/test/java/com/lancedb/lance/spark/read/SparkConnectorLineItemTest.java index 8335de4745..2aa779ae75 100644 --- a/java/spark/src/test/java/com/lancedb/lance/spark/LargeLanceDatasetSparkConnectorTest.java +++ b/java/spark/src/test/java/com/lancedb/lance/spark/read/SparkConnectorLineItemTest.java @@ -12,8 +12,10 @@ * limitations under the License. */ -package com.lancedb.lance.spark; +package com.lancedb.lance.spark.read; +import com.lancedb.lance.spark.LanceConfig; +import com.lancedb.lance.spark.LanceDataSource; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.SparkSession; @@ -27,7 +29,7 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assumptions.assumeTrue; -public class LargeLanceDatasetSparkConnectorTest { +public class SparkConnectorLineItemTest { private static SparkSession spark; private static String dbPath; private static String parquetPath; @@ -44,10 +46,10 @@ static void setup() { spark = SparkSession.builder() .appName("spark-lance-connector-test") .master("local") + .config("spark.sql.catalog.lance", "com.lancedb.lance.spark.LanceCatalog") .getOrCreate(); - lanceData = spark.read().format("lance") - .option("db", dbPath) - .option("table", "lineitem_10") + lanceData = spark.read().format(LanceDataSource.name) + .option(LanceConfig.CONFIG_DATASET_URI, LanceConfig.getDatasetUri(dbPath, "lineitem_10")) .load(); lanceData.createOrReplaceTempView("lance_dataset"); parquetData = spark.read().parquet(parquetPath); diff --git a/java/spark/src/test/java/com/lancedb/lance/spark/SparkLanceConnectorTest.java b/java/spark/src/test/java/com/lancedb/lance/spark/read/SparkConnectorReadTest.java similarity index 88% rename from java/spark/src/test/java/com/lancedb/lance/spark/SparkLanceConnectorTest.java rename to java/spark/src/test/java/com/lancedb/lance/spark/read/SparkConnectorReadTest.java index 04914db4c8..fe5a82a642 100644 --- a/java/spark/src/test/java/com/lancedb/lance/spark/SparkLanceConnectorTest.java +++ b/java/spark/src/test/java/com/lancedb/lance/spark/read/SparkConnectorReadTest.java @@ -12,8 +12,11 @@ * limitations under the License. */ -package com.lancedb.lance.spark; +package com.lancedb.lance.spark.read; +import com.lancedb.lance.spark.LanceConfig; +import com.lancedb.lance.spark.LanceDataSource; +import com.lancedb.lance.spark.TestUtils; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.SparkSession; @@ -27,7 +30,7 @@ import static org.junit.jupiter.api.Assertions.assertEquals; -public class SparkLanceConnectorTest { +public class SparkConnectorReadTest { private static SparkSession spark; private static String dbPath; private static Dataset data; @@ -37,11 +40,11 @@ static void setup() { spark = SparkSession.builder() .appName("spark-lance-connector-test") .master("local") + .config("spark.sql.catalog.lance", "com.lancedb.lance.spark.LanceCatalog") .getOrCreate(); dbPath = TestUtils.TestTable1Config.dbPath; - data = spark.read().format("lance") - .option("db", dbPath) - .option("table", TestUtils.TestTable1Config.tableName) + data = spark.read().format(LanceDataSource.name) + .option(LanceConfig.CONFIG_DATASET_URI, LanceConfig.getDatasetUri(dbPath, TestUtils.TestTable1Config.datasetName)) .load(); } @@ -125,4 +128,7 @@ public void filterSelect() { .filter(row -> row.get(0) > 3) .collect(Collectors.toList())); } + + // TODO(lu) support spark.read().format("lance") + // .load(dbPath.resolve(datasetName).toString()); } diff --git a/java/spark/src/test/java/com/lancedb/lance/spark/write/BatchAppendTest.java b/java/spark/src/test/java/com/lancedb/lance/spark/write/BatchAppendTest.java new file mode 100644 index 0000000000..45c04de4a3 --- /dev/null +++ b/java/spark/src/test/java/com/lancedb/lance/spark/write/BatchAppendTest.java @@ -0,0 +1,98 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.lancedb.lance.spark.write; + +import com.lancedb.lance.Dataset; +import com.lancedb.lance.WriteParams; +import com.lancedb.lance.spark.LanceConfig; +import org.apache.arrow.dataset.scanner.Scanner; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.ipc.ArrowReader; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.FieldType; +import org.apache.arrow.vector.types.pojo.Schema; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; +import org.apache.spark.sql.connector.write.DataWriter; +import org.apache.spark.sql.connector.write.DataWriterFactory; +import org.apache.spark.sql.connector.write.PhysicalWriteInfo; +import org.apache.spark.sql.connector.write.WriterCommitMessage; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.util.ArrowUtils; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.TestInfo; +import org.junit.jupiter.api.io.TempDir; + +import java.io.IOException; +import java.nio.file.Path; +import java.util.Collections; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +public class BatchAppendTest { + @TempDir + static Path tempDir; + + @Test + public void testLanceDataWriter(TestInfo testInfo) throws Exception { + String datasetName = testInfo.getTestMethod().get().getName(); + String datasetUri = LanceConfig.getDatasetUri(tempDir.toString(), datasetName); + try (BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE)) { + // Create lance dataset + Field field = new Field("column1", FieldType.nullable(new ArrowType.Int(32, true)), null); + Schema schema = new Schema(Collections.singletonList(field)); + Dataset.create(allocator, datasetUri, schema, new WriteParams.Builder().build()).close(); + + // Append data to lance dataset + LanceConfig config = LanceConfig.from(datasetUri); + StructType sparkSchema = ArrowUtils.fromArrowSchema(schema); + BatchAppend batchAppend = new BatchAppend(sparkSchema, config); + DataWriterFactory factor = batchAppend.createBatchWriterFactory(() -> 1); + + int rows = 132; + WriterCommitMessage message; + try (DataWriter writer = factor.createWriter(0, 0)) { + for (int i = 0; i < rows; i++) { + InternalRow row = new GenericInternalRow(new Object[]{i}); + writer.write(row); + } + message = writer.commit(); + } + batchAppend.commit(new WriterCommitMessage[]{message}); + + // Validate lance dataset data + try (Dataset dataset = Dataset.open(datasetUri, allocator)) { + try (Scanner scanner = dataset.newScan()) { + try (ArrowReader reader = scanner.scanBatches()) { + VectorSchemaRoot readerRoot = reader.getVectorSchemaRoot(); + int totalRowsRead = 0; + while (reader.loadNextBatch()) { + int batchRows = readerRoot.getRowCount(); + for (int i = 0; i < batchRows; i++) { + int value = (int) readerRoot.getVector("column1").getObject(i); + assertEquals(totalRowsRead + i, value); + } + totalRowsRead += batchRows; + } + assertEquals(rows, totalRowsRead); + } + } + } + } + } +} diff --git a/java/spark/src/test/java/com/lancedb/lance/spark/write/LanceArrowWriterTest.java b/java/spark/src/test/java/com/lancedb/lance/spark/write/LanceArrowWriterTest.java new file mode 100644 index 0000000000..1dbc63ca60 --- /dev/null +++ b/java/spark/src/test/java/com/lancedb/lance/spark/write/LanceArrowWriterTest.java @@ -0,0 +1,92 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.lancedb.lance.spark.write; + +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.VectorUnloader; +import org.apache.arrow.vector.ipc.message.ArrowRecordBatch; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.FieldType; +import org.apache.arrow.vector.types.pojo.Schema; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; +import org.junit.jupiter.api.Test; + +import java.util.Collections; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicLong; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +public class LanceArrowWriterTest { + @Test + public void test() throws Exception { + try (BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE)) { + Field field = new Field("column1", FieldType.nullable(org.apache.arrow.vector.types.Types.MinorType.INT.getType()), null); + Schema schema = new Schema(Collections.singletonList(field)); + + final int totalRows = 125; + final int batchSize = 34; + final LanceArrowWriter arrowWriter = new LanceArrowWriter(allocator, schema, batchSize); + + AtomicInteger rowsWritten = new AtomicInteger(0); + AtomicInteger rowsRead = new AtomicInteger(0); + AtomicLong expectedBytesRead = new AtomicLong(0); + + Thread writerThread = new Thread(() -> { + try { + for (int i = 0; i < totalRows; i++) { + InternalRow row = new GenericInternalRow(new Object[]{rowsWritten.incrementAndGet()}); + arrowWriter.write(row); + } + arrowWriter.setFinished(); + } catch (Exception e) { + e.printStackTrace(); + throw e; + } + }); + + Thread readerThread = new Thread(() -> { + try { + while (arrowWriter.loadNextBatch()) { + VectorSchemaRoot root = arrowWriter.getVectorSchemaRoot(); + int rowCount = root.getRowCount(); + rowsRead.addAndGet(rowCount); + try (ArrowRecordBatch recordBatch = new VectorUnloader(root).getRecordBatch()) { + expectedBytesRead.addAndGet(recordBatch.computeBodyLength()); + } + for (int i = 0; i < rowCount; i++) { + int value = (int) root.getVector("column1").getObject(i); + assertEquals(value, rowsRead.get() - rowCount + i + 1); + } + } + } catch (Exception e) { + e.printStackTrace(); + } + }); + + writerThread.start(); + readerThread.start(); + + writerThread.join(); + readerThread.join(); + assertEquals(totalRows, rowsWritten.get()); + assertEquals(totalRows, rowsRead.get()); + arrowWriter.close(); + } + } +} diff --git a/java/spark/src/test/java/com/lancedb/lance/spark/write/LanceDataWriterTest.java b/java/spark/src/test/java/com/lancedb/lance/spark/write/LanceDataWriterTest.java new file mode 100644 index 0000000000..8ea2c47cd6 --- /dev/null +++ b/java/spark/src/test/java/com/lancedb/lance/spark/write/LanceDataWriterTest.java @@ -0,0 +1,68 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.lancedb.lance.spark.write; + +import com.lancedb.lance.FragmentMetadata; +import com.lancedb.lance.spark.LanceConfig; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.FieldType; +import org.apache.arrow.vector.types.pojo.Schema; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.util.ArrowUtils; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.TestInfo; +import org.junit.jupiter.api.io.TempDir; + +import java.io.IOException; +import java.nio.file.Path; +import java.util.Collections; +import java.util.List; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +public class LanceDataWriterTest { + @TempDir + static Path tempDir; + + @Test + public void testLanceDataWriter(TestInfo testInfo) throws IOException { + String datasetName = testInfo.getTestMethod().get().getName(); + try (BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE)) { + Field field = new Field("column1", FieldType.nullable(new ArrowType.Int(32, true)), null); + Schema schema = new Schema(Collections.singletonList(field)); + LanceConfig config = LanceConfig.from(tempDir.resolve(datasetName + LanceConfig.LANCE_FILE_SUFFIX).toString()); + StructType sparkSchema = ArrowUtils.fromArrowSchema(schema); + LanceDataWriter.WriterFactory writerFactory = new LanceDataWriter.WriterFactory(sparkSchema, config); + LanceDataWriter dataWriter = (LanceDataWriter) writerFactory.createWriter(0, 0); + + int rows = 132; + for (int i = 0; i < rows; i++) { + InternalRow row = new GenericInternalRow(new Object[]{i}); + dataWriter.write(row); + } + + BatchAppend.TaskCommit commitMessage = (BatchAppend.TaskCommit) dataWriter.commit(); + dataWriter.close(); + List fragments = commitMessage.getFragments(); + assertEquals(1, fragments.size()); + assertEquals(rows, fragments.get(0).getPhysicalRows()); + } + } +} diff --git a/java/spark/src/test/java/com/lancedb/lance/spark/write/SparkWriteTest.java b/java/spark/src/test/java/com/lancedb/lance/spark/write/SparkWriteTest.java new file mode 100644 index 0000000000..78c5f9cb12 --- /dev/null +++ b/java/spark/src/test/java/com/lancedb/lance/spark/write/SparkWriteTest.java @@ -0,0 +1,167 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.lancedb.lance.spark.write; + +import com.lancedb.lance.spark.LanceConfig; +import com.lancedb.lance.spark.LanceDataSource; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.catalyst.analysis.NoSuchTableException; +import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.TestInfo; +import org.junit.jupiter.api.io.TempDir; + +import java.nio.file.Path; +import java.util.Arrays; +import java.util.List; + +import static org.apache.spark.sql.functions.col; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; + +public class SparkWriteTest { + private static SparkSession spark; + private static Dataset testData; + @TempDir + static Path dbPath; + + @BeforeAll + static void setup() { + spark = SparkSession.builder() + .appName("spark-lance-connector-test") + .master("local") + .config("spark.sql.catalog.lance", "com.lancedb.lance.spark.LanceCatalog") + .getOrCreate(); + StructType schema = new StructType(new StructField[]{ + DataTypes.createStructField("id", DataTypes.IntegerType, false), + DataTypes.createStructField("name", DataTypes.StringType, false) + }); + + Row row1 = RowFactory.create(1, "Alice"); + Row row2 = RowFactory.create(2, "Bob"); + List data = Arrays.asList(row1, row2); + + testData = spark.createDataFrame(data, schema); + } + + @AfterAll + static void tearDown() { + if (spark != null) { + spark.stop(); + } + } + + @Test + public void defaultWrite(TestInfo testInfo) { + String datasetName = testInfo.getTestMethod().get().getName(); + testData.write().format(LanceDataSource.name) + .option(LanceConfig.CONFIG_DATASET_URI, LanceConfig.getDatasetUri(dbPath.toString(), datasetName)) + .save(); + + validateData(datasetName, 1); + } + + @Test + public void errorIfExists(TestInfo testInfo) { + String datasetName = testInfo.getTestMethod().get().getName(); + testData.write().format(LanceDataSource.name) + .option(LanceConfig.CONFIG_DATASET_URI, LanceConfig.getDatasetUri(dbPath.toString(), datasetName)) + .save(); + + assertThrows(TableAlreadyExistsException.class, () -> { + testData.write().format(LanceDataSource.name) + .option(LanceConfig.CONFIG_DATASET_URI, LanceConfig.getDatasetUri(dbPath.toString(), datasetName)) + .save(); + }); + } + + @Test + public void append(TestInfo testInfo) { + String datasetName = testInfo.getTestMethod().get().getName(); + testData.write().format(LanceDataSource.name) + .option(LanceConfig.CONFIG_DATASET_URI, LanceConfig.getDatasetUri(dbPath.toString(), datasetName)) + .save(); + testData.write().format(LanceDataSource.name) + .option(LanceConfig.CONFIG_DATASET_URI, LanceConfig.getDatasetUri(dbPath.toString(), datasetName)) + .mode("append") + .save(); + validateData(datasetName, 2); + } + + @Test + public void appendErrorIfNotExist(TestInfo testInfo) { + String datasetName = testInfo.getTestMethod().get().getName(); + assertThrows(NoSuchTableException.class, () -> { + testData.write().format(LanceDataSource.name) + .option(LanceConfig.CONFIG_DATASET_URI, LanceConfig.getDatasetUri(dbPath.toString(), datasetName)) + .mode("append") + .save(); + }); + } + + @Test + public void saveToPath(TestInfo testInfo) { + String datasetName = testInfo.getTestMethod().get().getName(); + testData.write().format(LanceDataSource.name) + .save(LanceConfig.getDatasetUri(dbPath.toString(), datasetName)); + + validateData(datasetName, 1); + } + + @Disabled("Do not support overwrite") + @Test + public void overwrite(TestInfo testInfo) { + String datasetName = testInfo.getTestMethod().get().getName(); + testData.write().format(LanceDataSource.name) + .option(LanceConfig.CONFIG_DATASET_URI, LanceConfig.getDatasetUri(dbPath.toString(), datasetName)) + .save(); + testData.write().format(LanceDataSource.name) + .option(LanceConfig.CONFIG_DATASET_URI, LanceConfig.getDatasetUri(dbPath.toString(), datasetName)) + .mode("overwrite") + .save(); + + validateData(datasetName, 1); + } + + private void validateData(String datasetName, int iteration) { + Dataset data = spark.read().format("lance") + .option(LanceConfig.CONFIG_DATASET_URI, LanceConfig.getDatasetUri(dbPath.toString(), datasetName)) + .load(); + + assertEquals(2 * iteration, data.count()); + assertEquals(iteration, data.filter(col("id").equalTo(1)).count()); + assertEquals(iteration, data.filter(col("id").equalTo(2)).count()); + + Dataset data1 = data.filter(col("id").equalTo(1)).select("name"); + Dataset data2 = data.filter(col("id").equalTo(2)).select("name"); + + for (Row row : data1.collectAsList()) { + assertEquals("Alice", row.getString(0)); + } + + for (Row row : data2.collectAsList()) { + assertEquals("Bob", row.getString(0)); + } + } +} \ No newline at end of file diff --git a/java/spark/src/test/resources/example_db/test_table1.lance/_latest.manifest b/java/spark/src/test/resources/example_db/test_dataset1.lance/_latest.manifest similarity index 100% rename from java/spark/src/test/resources/example_db/test_table1.lance/_latest.manifest rename to java/spark/src/test/resources/example_db/test_dataset1.lance/_latest.manifest diff --git a/java/spark/src/test/resources/example_db/test_table1.lance/_transactions/0-4daea2b4-b38b-4542-af0c-5a839ceab54a.txn b/java/spark/src/test/resources/example_db/test_dataset1.lance/_transactions/0-4daea2b4-b38b-4542-af0c-5a839ceab54a.txn similarity index 100% rename from java/spark/src/test/resources/example_db/test_table1.lance/_transactions/0-4daea2b4-b38b-4542-af0c-5a839ceab54a.txn rename to java/spark/src/test/resources/example_db/test_dataset1.lance/_transactions/0-4daea2b4-b38b-4542-af0c-5a839ceab54a.txn diff --git a/java/spark/src/test/resources/example_db/test_table1.lance/_transactions/1-99519b7f-c80f-4961-bacc-d556df5ae798.txn b/java/spark/src/test/resources/example_db/test_dataset1.lance/_transactions/1-99519b7f-c80f-4961-bacc-d556df5ae798.txn similarity index 100% rename from java/spark/src/test/resources/example_db/test_table1.lance/_transactions/1-99519b7f-c80f-4961-bacc-d556df5ae798.txn rename to java/spark/src/test/resources/example_db/test_dataset1.lance/_transactions/1-99519b7f-c80f-4961-bacc-d556df5ae798.txn diff --git a/java/spark/src/test/resources/example_db/test_table1.lance/_transactions/2-b9f7655d-01e1-4fa7-8ca2-ddc646564fb8.txn b/java/spark/src/test/resources/example_db/test_dataset1.lance/_transactions/2-b9f7655d-01e1-4fa7-8ca2-ddc646564fb8.txn similarity index 100% rename from java/spark/src/test/resources/example_db/test_table1.lance/_transactions/2-b9f7655d-01e1-4fa7-8ca2-ddc646564fb8.txn rename to java/spark/src/test/resources/example_db/test_dataset1.lance/_transactions/2-b9f7655d-01e1-4fa7-8ca2-ddc646564fb8.txn diff --git a/java/spark/src/test/resources/example_db/test_table1.lance/_transactions/3-90bc5dd5-204d-42ba-b39a-65f2abce1602.txn b/java/spark/src/test/resources/example_db/test_dataset1.lance/_transactions/3-90bc5dd5-204d-42ba-b39a-65f2abce1602.txn similarity index 100% rename from java/spark/src/test/resources/example_db/test_table1.lance/_transactions/3-90bc5dd5-204d-42ba-b39a-65f2abce1602.txn rename to java/spark/src/test/resources/example_db/test_dataset1.lance/_transactions/3-90bc5dd5-204d-42ba-b39a-65f2abce1602.txn diff --git a/java/spark/src/test/resources/example_db/test_table1.lance/_transactions/4-dffa23f0-c357-4935-a9c1-e286099b5533.txn b/java/spark/src/test/resources/example_db/test_dataset1.lance/_transactions/4-dffa23f0-c357-4935-a9c1-e286099b5533.txn similarity index 100% rename from java/spark/src/test/resources/example_db/test_table1.lance/_transactions/4-dffa23f0-c357-4935-a9c1-e286099b5533.txn rename to java/spark/src/test/resources/example_db/test_dataset1.lance/_transactions/4-dffa23f0-c357-4935-a9c1-e286099b5533.txn diff --git a/java/spark/src/test/resources/example_db/test_table1.lance/_transactions/5-8bfb238d-4a29-4582-ab7d-8c53e2253e47.txn b/java/spark/src/test/resources/example_db/test_dataset1.lance/_transactions/5-8bfb238d-4a29-4582-ab7d-8c53e2253e47.txn similarity index 100% rename from java/spark/src/test/resources/example_db/test_table1.lance/_transactions/5-8bfb238d-4a29-4582-ab7d-8c53e2253e47.txn rename to java/spark/src/test/resources/example_db/test_dataset1.lance/_transactions/5-8bfb238d-4a29-4582-ab7d-8c53e2253e47.txn diff --git a/java/spark/src/test/resources/example_db/test_table1.lance/_versions/1.manifest b/java/spark/src/test/resources/example_db/test_dataset1.lance/_versions/1.manifest similarity index 100% rename from java/spark/src/test/resources/example_db/test_table1.lance/_versions/1.manifest rename to java/spark/src/test/resources/example_db/test_dataset1.lance/_versions/1.manifest diff --git a/java/spark/src/test/resources/example_db/test_table1.lance/_versions/2.manifest b/java/spark/src/test/resources/example_db/test_dataset1.lance/_versions/2.manifest similarity index 100% rename from java/spark/src/test/resources/example_db/test_table1.lance/_versions/2.manifest rename to java/spark/src/test/resources/example_db/test_dataset1.lance/_versions/2.manifest diff --git a/java/spark/src/test/resources/example_db/test_table1.lance/_versions/3.manifest b/java/spark/src/test/resources/example_db/test_dataset1.lance/_versions/3.manifest similarity index 100% rename from java/spark/src/test/resources/example_db/test_table1.lance/_versions/3.manifest rename to java/spark/src/test/resources/example_db/test_dataset1.lance/_versions/3.manifest diff --git a/java/spark/src/test/resources/example_db/test_table1.lance/_versions/4.manifest b/java/spark/src/test/resources/example_db/test_dataset1.lance/_versions/4.manifest similarity index 100% rename from java/spark/src/test/resources/example_db/test_table1.lance/_versions/4.manifest rename to java/spark/src/test/resources/example_db/test_dataset1.lance/_versions/4.manifest diff --git a/java/spark/src/test/resources/example_db/test_table1.lance/_versions/5.manifest b/java/spark/src/test/resources/example_db/test_dataset1.lance/_versions/5.manifest similarity index 100% rename from java/spark/src/test/resources/example_db/test_table1.lance/_versions/5.manifest rename to java/spark/src/test/resources/example_db/test_dataset1.lance/_versions/5.manifest diff --git a/java/spark/src/test/resources/example_db/test_table1.lance/_versions/6.manifest b/java/spark/src/test/resources/example_db/test_dataset1.lance/_versions/6.manifest similarity index 100% rename from java/spark/src/test/resources/example_db/test_table1.lance/_versions/6.manifest rename to java/spark/src/test/resources/example_db/test_dataset1.lance/_versions/6.manifest diff --git a/java/spark/src/test/resources/example_db/test_table1.lance/data/083d1c7c-b0d2-4ff3-b7ff-4237ea586491.lance b/java/spark/src/test/resources/example_db/test_dataset1.lance/data/083d1c7c-b0d2-4ff3-b7ff-4237ea586491.lance similarity index 100% rename from java/spark/src/test/resources/example_db/test_table1.lance/data/083d1c7c-b0d2-4ff3-b7ff-4237ea586491.lance rename to java/spark/src/test/resources/example_db/test_dataset1.lance/data/083d1c7c-b0d2-4ff3-b7ff-4237ea586491.lance diff --git a/java/spark/src/test/resources/example_db/test_table1.lance/data/25c37abd-c753-419b-b420-4847ce2de5a1.lance b/java/spark/src/test/resources/example_db/test_dataset1.lance/data/25c37abd-c753-419b-b420-4847ce2de5a1.lance similarity index 100% rename from java/spark/src/test/resources/example_db/test_table1.lance/data/25c37abd-c753-419b-b420-4847ce2de5a1.lance rename to java/spark/src/test/resources/example_db/test_dataset1.lance/data/25c37abd-c753-419b-b420-4847ce2de5a1.lance diff --git a/java/spark/src/test/resources/example_db/test_table1.lance/data/2c8a0da6-1ace-4b1c-baf0-ed48b04996dc.lance b/java/spark/src/test/resources/example_db/test_dataset1.lance/data/2c8a0da6-1ace-4b1c-baf0-ed48b04996dc.lance similarity index 100% rename from java/spark/src/test/resources/example_db/test_table1.lance/data/2c8a0da6-1ace-4b1c-baf0-ed48b04996dc.lance rename to java/spark/src/test/resources/example_db/test_dataset1.lance/data/2c8a0da6-1ace-4b1c-baf0-ed48b04996dc.lance diff --git a/java/spark/src/test/resources/example_db/test_table1.lance/data/ac0bf34e-0e0d-4e3b-ae7e-ab247cae5f77.lance b/java/spark/src/test/resources/example_db/test_dataset1.lance/data/ac0bf34e-0e0d-4e3b-ae7e-ab247cae5f77.lance similarity index 100% rename from java/spark/src/test/resources/example_db/test_table1.lance/data/ac0bf34e-0e0d-4e3b-ae7e-ab247cae5f77.lance rename to java/spark/src/test/resources/example_db/test_dataset1.lance/data/ac0bf34e-0e0d-4e3b-ae7e-ab247cae5f77.lance diff --git a/java/spark/src/test/resources/example_db/test_table1.lance/data/c888f970-b7b3-4efb-9293-d7c6dc4996d2.lance b/java/spark/src/test/resources/example_db/test_dataset1.lance/data/c888f970-b7b3-4efb-9293-d7c6dc4996d2.lance similarity index 100% rename from java/spark/src/test/resources/example_db/test_table1.lance/data/c888f970-b7b3-4efb-9293-d7c6dc4996d2.lance rename to java/spark/src/test/resources/example_db/test_dataset1.lance/data/c888f970-b7b3-4efb-9293-d7c6dc4996d2.lance diff --git a/java/spark/src/test/resources/example_db/test_table1.lance/data/cbe16da7-b812-43a1-87f1-521470dfed32.lance b/java/spark/src/test/resources/example_db/test_dataset1.lance/data/cbe16da7-b812-43a1-87f1-521470dfed32.lance similarity index 100% rename from java/spark/src/test/resources/example_db/test_table1.lance/data/cbe16da7-b812-43a1-87f1-521470dfed32.lance rename to java/spark/src/test/resources/example_db/test_dataset1.lance/data/cbe16da7-b812-43a1-87f1-521470dfed32.lance diff --git a/java/spark/src/test/resources/example_db/test_table2.lance/_deletions/0-1-8958018423523767581.arrow b/java/spark/src/test/resources/example_db/test_dataset2.lance/_deletions/0-1-8958018423523767581.arrow similarity index 100% rename from java/spark/src/test/resources/example_db/test_table2.lance/_deletions/0-1-8958018423523767581.arrow rename to java/spark/src/test/resources/example_db/test_dataset2.lance/_deletions/0-1-8958018423523767581.arrow diff --git a/java/spark/src/test/resources/example_db/test_table2.lance/_latest.manifest b/java/spark/src/test/resources/example_db/test_dataset2.lance/_latest.manifest similarity index 100% rename from java/spark/src/test/resources/example_db/test_table2.lance/_latest.manifest rename to java/spark/src/test/resources/example_db/test_dataset2.lance/_latest.manifest diff --git a/java/spark/src/test/resources/example_db/test_table2.lance/_transactions/0-304ab2ef-f7bc-47b8-aeb6-9110ec67bf98.txn b/java/spark/src/test/resources/example_db/test_dataset2.lance/_transactions/0-304ab2ef-f7bc-47b8-aeb6-9110ec67bf98.txn similarity index 100% rename from java/spark/src/test/resources/example_db/test_table2.lance/_transactions/0-304ab2ef-f7bc-47b8-aeb6-9110ec67bf98.txn rename to java/spark/src/test/resources/example_db/test_dataset2.lance/_transactions/0-304ab2ef-f7bc-47b8-aeb6-9110ec67bf98.txn diff --git a/java/spark/src/test/resources/example_db/test_table2.lance/_transactions/1-1baf3405-66ab-4668-9578-5c333acd0440.txn b/java/spark/src/test/resources/example_db/test_dataset2.lance/_transactions/1-1baf3405-66ab-4668-9578-5c333acd0440.txn similarity index 100% rename from java/spark/src/test/resources/example_db/test_table2.lance/_transactions/1-1baf3405-66ab-4668-9578-5c333acd0440.txn rename to java/spark/src/test/resources/example_db/test_dataset2.lance/_transactions/1-1baf3405-66ab-4668-9578-5c333acd0440.txn diff --git a/java/spark/src/test/resources/example_db/test_table2.lance/_versions/1.manifest b/java/spark/src/test/resources/example_db/test_dataset2.lance/_versions/1.manifest similarity index 100% rename from java/spark/src/test/resources/example_db/test_table2.lance/_versions/1.manifest rename to java/spark/src/test/resources/example_db/test_dataset2.lance/_versions/1.manifest diff --git a/java/spark/src/test/resources/example_db/test_table2.lance/_versions/2.manifest b/java/spark/src/test/resources/example_db/test_dataset2.lance/_versions/2.manifest similarity index 100% rename from java/spark/src/test/resources/example_db/test_table2.lance/_versions/2.manifest rename to java/spark/src/test/resources/example_db/test_dataset2.lance/_versions/2.manifest diff --git a/java/spark/src/test/resources/example_db/test_table2.lance/data/016c15dc-2c94-4382-b7a4-2c7def9c3897.lance b/java/spark/src/test/resources/example_db/test_dataset2.lance/data/016c15dc-2c94-4382-b7a4-2c7def9c3897.lance similarity index 100% rename from java/spark/src/test/resources/example_db/test_table2.lance/data/016c15dc-2c94-4382-b7a4-2c7def9c3897.lance rename to java/spark/src/test/resources/example_db/test_dataset2.lance/data/016c15dc-2c94-4382-b7a4-2c7def9c3897.lance diff --git a/java/spark/src/test/resources/example_db/test_table3.lance/_deletions/0-1-8958018423523767581.arrow b/java/spark/src/test/resources/example_db/test_dataset3.lance/_deletions/0-1-8958018423523767581.arrow similarity index 100% rename from java/spark/src/test/resources/example_db/test_table3.lance/_deletions/0-1-8958018423523767581.arrow rename to java/spark/src/test/resources/example_db/test_dataset3.lance/_deletions/0-1-8958018423523767581.arrow diff --git a/java/spark/src/test/resources/example_db/test_table3.lance/_latest.manifest b/java/spark/src/test/resources/example_db/test_dataset3.lance/_latest.manifest similarity index 100% rename from java/spark/src/test/resources/example_db/test_table3.lance/_latest.manifest rename to java/spark/src/test/resources/example_db/test_dataset3.lance/_latest.manifest diff --git a/java/spark/src/test/resources/example_db/test_table3.lance/_transactions/0-304ab2ef-f7bc-47b8-aeb6-9110ec67bf98.txn b/java/spark/src/test/resources/example_db/test_dataset3.lance/_transactions/0-304ab2ef-f7bc-47b8-aeb6-9110ec67bf98.txn similarity index 100% rename from java/spark/src/test/resources/example_db/test_table3.lance/_transactions/0-304ab2ef-f7bc-47b8-aeb6-9110ec67bf98.txn rename to java/spark/src/test/resources/example_db/test_dataset3.lance/_transactions/0-304ab2ef-f7bc-47b8-aeb6-9110ec67bf98.txn diff --git a/java/spark/src/test/resources/example_db/test_table3.lance/_transactions/1-1baf3405-66ab-4668-9578-5c333acd0440.txn b/java/spark/src/test/resources/example_db/test_dataset3.lance/_transactions/1-1baf3405-66ab-4668-9578-5c333acd0440.txn similarity index 100% rename from java/spark/src/test/resources/example_db/test_table3.lance/_transactions/1-1baf3405-66ab-4668-9578-5c333acd0440.txn rename to java/spark/src/test/resources/example_db/test_dataset3.lance/_transactions/1-1baf3405-66ab-4668-9578-5c333acd0440.txn diff --git a/java/spark/src/test/resources/example_db/test_table3.lance/_transactions/2-8e340735-2a60-438b-9cf0-ec662fb25f1a.txn b/java/spark/src/test/resources/example_db/test_dataset3.lance/_transactions/2-8e340735-2a60-438b-9cf0-ec662fb25f1a.txn similarity index 100% rename from java/spark/src/test/resources/example_db/test_table3.lance/_transactions/2-8e340735-2a60-438b-9cf0-ec662fb25f1a.txn rename to java/spark/src/test/resources/example_db/test_dataset3.lance/_transactions/2-8e340735-2a60-438b-9cf0-ec662fb25f1a.txn diff --git a/java/spark/src/test/resources/example_db/test_table3.lance/_versions/1.manifest b/java/spark/src/test/resources/example_db/test_dataset3.lance/_versions/1.manifest similarity index 100% rename from java/spark/src/test/resources/example_db/test_table3.lance/_versions/1.manifest rename to java/spark/src/test/resources/example_db/test_dataset3.lance/_versions/1.manifest diff --git a/java/spark/src/test/resources/example_db/test_table3.lance/_versions/2.manifest b/java/spark/src/test/resources/example_db/test_dataset3.lance/_versions/2.manifest similarity index 100% rename from java/spark/src/test/resources/example_db/test_table3.lance/_versions/2.manifest rename to java/spark/src/test/resources/example_db/test_dataset3.lance/_versions/2.manifest diff --git a/java/spark/src/test/resources/example_db/test_table3.lance/_versions/3.manifest b/java/spark/src/test/resources/example_db/test_dataset3.lance/_versions/3.manifest similarity index 100% rename from java/spark/src/test/resources/example_db/test_table3.lance/_versions/3.manifest rename to java/spark/src/test/resources/example_db/test_dataset3.lance/_versions/3.manifest diff --git a/java/spark/src/test/resources/example_db/test_table3.lance/data/016c15dc-2c94-4382-b7a4-2c7def9c3897.lance b/java/spark/src/test/resources/example_db/test_dataset3.lance/data/016c15dc-2c94-4382-b7a4-2c7def9c3897.lance similarity index 100% rename from java/spark/src/test/resources/example_db/test_table3.lance/data/016c15dc-2c94-4382-b7a4-2c7def9c3897.lance rename to java/spark/src/test/resources/example_db/test_dataset3.lance/data/016c15dc-2c94-4382-b7a4-2c7def9c3897.lance diff --git a/java/spark/src/test/resources/example_db/test_table3.lance/data/e6574672-b3cb-4bc7-92a8-db8754dac368.lance b/java/spark/src/test/resources/example_db/test_dataset3.lance/data/e6574672-b3cb-4bc7-92a8-db8754dac368.lance similarity index 100% rename from java/spark/src/test/resources/example_db/test_table3.lance/data/e6574672-b3cb-4bc7-92a8-db8754dac368.lance rename to java/spark/src/test/resources/example_db/test_dataset3.lance/data/e6574672-b3cb-4bc7-92a8-db8754dac368.lance diff --git a/java/spark/src/test/resources/example_db/test_table4.lance/_indices/d32dac97-985b-4628-b1b4-e4b64947e115/index.idx b/java/spark/src/test/resources/example_db/test_dataset4.lance/_indices/d32dac97-985b-4628-b1b4-e4b64947e115/index.idx similarity index 100% rename from java/spark/src/test/resources/example_db/test_table4.lance/_indices/d32dac97-985b-4628-b1b4-e4b64947e115/index.idx rename to java/spark/src/test/resources/example_db/test_dataset4.lance/_indices/d32dac97-985b-4628-b1b4-e4b64947e115/index.idx diff --git a/java/spark/src/test/resources/example_db/test_table4.lance/_indices/f358a219-95e8-4956-be35-0835f2bed10f/index.idx b/java/spark/src/test/resources/example_db/test_dataset4.lance/_indices/f358a219-95e8-4956-be35-0835f2bed10f/index.idx similarity index 100% rename from java/spark/src/test/resources/example_db/test_table4.lance/_indices/f358a219-95e8-4956-be35-0835f2bed10f/index.idx rename to java/spark/src/test/resources/example_db/test_dataset4.lance/_indices/f358a219-95e8-4956-be35-0835f2bed10f/index.idx diff --git a/java/spark/src/test/resources/example_db/test_table4.lance/_latest.manifest b/java/spark/src/test/resources/example_db/test_dataset4.lance/_latest.manifest similarity index 100% rename from java/spark/src/test/resources/example_db/test_table4.lance/_latest.manifest rename to java/spark/src/test/resources/example_db/test_dataset4.lance/_latest.manifest diff --git a/java/spark/src/test/resources/example_db/test_table4.lance/_transactions/0-c4ece134-3d52-41a8-b2ec-0fb9fff76c35.txn b/java/spark/src/test/resources/example_db/test_dataset4.lance/_transactions/0-c4ece134-3d52-41a8-b2ec-0fb9fff76c35.txn similarity index 100% rename from java/spark/src/test/resources/example_db/test_table4.lance/_transactions/0-c4ece134-3d52-41a8-b2ec-0fb9fff76c35.txn rename to java/spark/src/test/resources/example_db/test_dataset4.lance/_transactions/0-c4ece134-3d52-41a8-b2ec-0fb9fff76c35.txn diff --git a/java/spark/src/test/resources/example_db/test_table4.lance/_transactions/1-cac38053-d1b8-4ff5-b34c-0e47b41c1b56.txn b/java/spark/src/test/resources/example_db/test_dataset4.lance/_transactions/1-cac38053-d1b8-4ff5-b34c-0e47b41c1b56.txn similarity index 100% rename from java/spark/src/test/resources/example_db/test_table4.lance/_transactions/1-cac38053-d1b8-4ff5-b34c-0e47b41c1b56.txn rename to java/spark/src/test/resources/example_db/test_dataset4.lance/_transactions/1-cac38053-d1b8-4ff5-b34c-0e47b41c1b56.txn diff --git a/java/spark/src/test/resources/example_db/test_table4.lance/_transactions/2-f3ac6254-2471-4c8a-8183-e529af6d2603.txn b/java/spark/src/test/resources/example_db/test_dataset4.lance/_transactions/2-f3ac6254-2471-4c8a-8183-e529af6d2603.txn similarity index 100% rename from java/spark/src/test/resources/example_db/test_table4.lance/_transactions/2-f3ac6254-2471-4c8a-8183-e529af6d2603.txn rename to java/spark/src/test/resources/example_db/test_dataset4.lance/_transactions/2-f3ac6254-2471-4c8a-8183-e529af6d2603.txn diff --git a/java/spark/src/test/resources/example_db/test_table4.lance/_transactions/3-e08f185e-5734-4533-bee5-325567f2221a.txn b/java/spark/src/test/resources/example_db/test_dataset4.lance/_transactions/3-e08f185e-5734-4533-bee5-325567f2221a.txn similarity index 100% rename from java/spark/src/test/resources/example_db/test_table4.lance/_transactions/3-e08f185e-5734-4533-bee5-325567f2221a.txn rename to java/spark/src/test/resources/example_db/test_dataset4.lance/_transactions/3-e08f185e-5734-4533-bee5-325567f2221a.txn diff --git a/java/spark/src/test/resources/example_db/test_table4.lance/_transactions/4-2536db77-3757-414b-a525-f8f3288e9d80.txn b/java/spark/src/test/resources/example_db/test_dataset4.lance/_transactions/4-2536db77-3757-414b-a525-f8f3288e9d80.txn similarity index 100% rename from java/spark/src/test/resources/example_db/test_table4.lance/_transactions/4-2536db77-3757-414b-a525-f8f3288e9d80.txn rename to java/spark/src/test/resources/example_db/test_dataset4.lance/_transactions/4-2536db77-3757-414b-a525-f8f3288e9d80.txn diff --git a/java/spark/src/test/resources/example_db/test_table4.lance/_versions/1.manifest b/java/spark/src/test/resources/example_db/test_dataset4.lance/_versions/1.manifest similarity index 100% rename from java/spark/src/test/resources/example_db/test_table4.lance/_versions/1.manifest rename to java/spark/src/test/resources/example_db/test_dataset4.lance/_versions/1.manifest diff --git a/java/spark/src/test/resources/example_db/test_table4.lance/_versions/2.manifest b/java/spark/src/test/resources/example_db/test_dataset4.lance/_versions/2.manifest similarity index 100% rename from java/spark/src/test/resources/example_db/test_table4.lance/_versions/2.manifest rename to java/spark/src/test/resources/example_db/test_dataset4.lance/_versions/2.manifest diff --git a/java/spark/src/test/resources/example_db/test_table4.lance/_versions/3.manifest b/java/spark/src/test/resources/example_db/test_dataset4.lance/_versions/3.manifest similarity index 100% rename from java/spark/src/test/resources/example_db/test_table4.lance/_versions/3.manifest rename to java/spark/src/test/resources/example_db/test_dataset4.lance/_versions/3.manifest diff --git a/java/spark/src/test/resources/example_db/test_table4.lance/_versions/4.manifest b/java/spark/src/test/resources/example_db/test_dataset4.lance/_versions/4.manifest similarity index 100% rename from java/spark/src/test/resources/example_db/test_table4.lance/_versions/4.manifest rename to java/spark/src/test/resources/example_db/test_dataset4.lance/_versions/4.manifest diff --git a/java/spark/src/test/resources/example_db/test_table4.lance/_versions/5.manifest b/java/spark/src/test/resources/example_db/test_dataset4.lance/_versions/5.manifest similarity index 100% rename from java/spark/src/test/resources/example_db/test_table4.lance/_versions/5.manifest rename to java/spark/src/test/resources/example_db/test_dataset4.lance/_versions/5.manifest diff --git a/java/spark/src/test/resources/example_db/test_table4.lance/data/03c1a82b-a745-4bfe-8413-9441e4ed216e.lance b/java/spark/src/test/resources/example_db/test_dataset4.lance/data/03c1a82b-a745-4bfe-8413-9441e4ed216e.lance similarity index 100% rename from java/spark/src/test/resources/example_db/test_table4.lance/data/03c1a82b-a745-4bfe-8413-9441e4ed216e.lance rename to java/spark/src/test/resources/example_db/test_dataset4.lance/data/03c1a82b-a745-4bfe-8413-9441e4ed216e.lance diff --git a/java/spark/src/test/resources/example_db/test_table4.lance/data/2f786e97-1d4c-43e5-bc32-6f7a444396f1.lance b/java/spark/src/test/resources/example_db/test_dataset4.lance/data/2f786e97-1d4c-43e5-bc32-6f7a444396f1.lance similarity index 100% rename from java/spark/src/test/resources/example_db/test_table4.lance/data/2f786e97-1d4c-43e5-bc32-6f7a444396f1.lance rename to java/spark/src/test/resources/example_db/test_dataset4.lance/data/2f786e97-1d4c-43e5-bc32-6f7a444396f1.lance diff --git a/java/spark/src/test/resources/example_db/test_table4.lance/data/34199dea-ca38-460b-af71-a816b0f093a1.lance b/java/spark/src/test/resources/example_db/test_dataset4.lance/data/34199dea-ca38-460b-af71-a816b0f093a1.lance similarity index 100% rename from java/spark/src/test/resources/example_db/test_table4.lance/data/34199dea-ca38-460b-af71-a816b0f093a1.lance rename to java/spark/src/test/resources/example_db/test_dataset4.lance/data/34199dea-ca38-460b-af71-a816b0f093a1.lance diff --git a/java/spark/src/test/resources/example_db/test_table4.lance/data/37ff0067-df64-4ba7-8c50-2086eb2b8127.lance b/java/spark/src/test/resources/example_db/test_dataset4.lance/data/37ff0067-df64-4ba7-8c50-2086eb2b8127.lance similarity index 100% rename from java/spark/src/test/resources/example_db/test_table4.lance/data/37ff0067-df64-4ba7-8c50-2086eb2b8127.lance rename to java/spark/src/test/resources/example_db/test_dataset4.lance/data/37ff0067-df64-4ba7-8c50-2086eb2b8127.lance diff --git a/java/spark/src/test/resources/example_db/test_table4.lance/data/4062824b-36bd-42e6-9283-22e9f29dc5ed.lance b/java/spark/src/test/resources/example_db/test_dataset4.lance/data/4062824b-36bd-42e6-9283-22e9f29dc5ed.lance similarity index 100% rename from java/spark/src/test/resources/example_db/test_table4.lance/data/4062824b-36bd-42e6-9283-22e9f29dc5ed.lance rename to java/spark/src/test/resources/example_db/test_dataset4.lance/data/4062824b-36bd-42e6-9283-22e9f29dc5ed.lance diff --git a/java/spark/src/test/resources/example_db/test_table4.lance/data/4d41cc61-800b-46b0-a548-893a35201cf1.lance b/java/spark/src/test/resources/example_db/test_dataset4.lance/data/4d41cc61-800b-46b0-a548-893a35201cf1.lance similarity index 100% rename from java/spark/src/test/resources/example_db/test_table4.lance/data/4d41cc61-800b-46b0-a548-893a35201cf1.lance rename to java/spark/src/test/resources/example_db/test_dataset4.lance/data/4d41cc61-800b-46b0-a548-893a35201cf1.lance diff --git a/java/spark/src/test/resources/example_db/test_table4.lance/data/66c4453b-7e80-411d-8508-e7f6dfeb693e.lance b/java/spark/src/test/resources/example_db/test_dataset4.lance/data/66c4453b-7e80-411d-8508-e7f6dfeb693e.lance similarity index 100% rename from java/spark/src/test/resources/example_db/test_table4.lance/data/66c4453b-7e80-411d-8508-e7f6dfeb693e.lance rename to java/spark/src/test/resources/example_db/test_dataset4.lance/data/66c4453b-7e80-411d-8508-e7f6dfeb693e.lance diff --git a/java/spark/src/test/resources/example_db/test_table4.lance/data/7ac6d965-4d35-4e2b-825b-f4a4a8be9024.lance b/java/spark/src/test/resources/example_db/test_dataset4.lance/data/7ac6d965-4d35-4e2b-825b-f4a4a8be9024.lance similarity index 100% rename from java/spark/src/test/resources/example_db/test_table4.lance/data/7ac6d965-4d35-4e2b-825b-f4a4a8be9024.lance rename to java/spark/src/test/resources/example_db/test_dataset4.lance/data/7ac6d965-4d35-4e2b-825b-f4a4a8be9024.lance diff --git a/java/spark/src/test/resources/example_db/test_table4.lance/data/86d11ae4-4a8f-48bc-b1b8-3c850a67c871.lance b/java/spark/src/test/resources/example_db/test_dataset4.lance/data/86d11ae4-4a8f-48bc-b1b8-3c850a67c871.lance similarity index 100% rename from java/spark/src/test/resources/example_db/test_table4.lance/data/86d11ae4-4a8f-48bc-b1b8-3c850a67c871.lance rename to java/spark/src/test/resources/example_db/test_dataset4.lance/data/86d11ae4-4a8f-48bc-b1b8-3c850a67c871.lance diff --git a/java/spark/src/test/resources/example_db/test_table4.lance/data/cd32d611-941e-4aa9-88c4-72193c618255.lance b/java/spark/src/test/resources/example_db/test_dataset4.lance/data/cd32d611-941e-4aa9-88c4-72193c618255.lance similarity index 100% rename from java/spark/src/test/resources/example_db/test_table4.lance/data/cd32d611-941e-4aa9-88c4-72193c618255.lance rename to java/spark/src/test/resources/example_db/test_dataset4.lance/data/cd32d611-941e-4aa9-88c4-72193c618255.lance diff --git a/java/spark/src/test/resources/example_db/test_table4.lance/data/ec05a2ea-2387-45a0-a146-1208997c4f12.lance b/java/spark/src/test/resources/example_db/test_dataset4.lance/data/ec05a2ea-2387-45a0-a146-1208997c4f12.lance similarity index 100% rename from java/spark/src/test/resources/example_db/test_table4.lance/data/ec05a2ea-2387-45a0-a146-1208997c4f12.lance rename to java/spark/src/test/resources/example_db/test_dataset4.lance/data/ec05a2ea-2387-45a0-a146-1208997c4f12.lance diff --git a/python/Cargo.toml b/python/Cargo.toml index 2ffdb459bf..b81415810f 100644 --- a/python/Cargo.toml +++ b/python/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "pylance" -version = "0.17.1" +version = "0.18.1" edition = "2021" authors = ["Lance Devs "] rust-version = "1.65" diff --git a/python/python/lance/dataset.py b/python/python/lance/dataset.py index 5e42628590..bafda27057 100644 --- a/python/python/lance/dataset.py +++ b/python/python/lance/dataset.py @@ -24,6 +24,7 @@ Literal, NamedTuple, Optional, + Set, TypedDict, Union, ) @@ -1061,7 +1062,7 @@ def update( self, updates: Dict[str, str], where: Optional[str] = None, - ): + ) -> Dict[str, Any]: """ Update column values for rows matching where. @@ -1072,13 +1073,19 @@ def update( where : str, optional A SQL predicate indicating which rows should be updated. + Returns + ------- + updates : dict + A dictionary containing the number of rows updated. + Examples -------- >>> import lance >>> import pyarrow as pa >>> table = pa.table({"a": [1, 2, 3], "b": ["a", "b", "c"]}) >>> dataset = lance.write_dataset(table, "example") - >>> dataset.update(dict(a = 'a + 2'), where="b != 'a'") + >>> update_stats = dataset.update(dict(a = 'a + 2'), where="b != 'a'") + >>> update_stats["num_updated_rows"] = 2 >>> dataset.to_table().to_pandas() a b 0 1 a @@ -1087,7 +1094,7 @@ def update( """ if isinstance(where, pa.compute.Expression): where = str(where) - self._ds.update(updates, where) + return self._ds.update(updates, where) def versions(self): """ @@ -2227,6 +2234,27 @@ def _to_inner(self): rewritten_indices = [index._to_inner() for index in self.rewritten_indices] return _Operation.rewrite(groups, rewritten_indices) + @dataclass + class CreateIndex(BaseOperation): + """ + Operation that creates an index on the dataset. + """ + + uuid: str + name: str + fields: List[int] + dataset_version: int + fragment_ids: Set[int] + + def _to_inner(self): + return _Operation.create_index( + self.uuid, + self.name, + self.fields, + self.dataset_version, + self.fragment_ids, + ) + class ScannerBuilder: def __init__(self, ds: LanceDataset): diff --git a/python/python/lance/fragment.py b/python/python/lance/fragment.py index 61685863e0..5659786fb8 100644 --- a/python/python/lance/fragment.py +++ b/python/python/lance/fragment.py @@ -364,6 +364,7 @@ def merge_columns( self, value_func: Callable[[pa.RecordBatch], pa.RecordBatch], columns: Optional[list[str]] = None, + batch_size: Optional[int] = None, ) -> Tuple[FragmentMetadata, LanceSchema]: """Add columns to this Fragment. @@ -390,7 +391,7 @@ def merge_columns( Tuple[FragmentMetadata, LanceSchema] A new fragment with the added column(s) and the final schema. """ - updater = self._fragment.updater(columns) + updater = self._fragment.updater(columns, batch_size) while True: batch = updater.next() diff --git a/python/python/tests/test_commit_index.py b/python/python/tests/test_commit_index.py new file mode 100644 index 0000000000..fa2eeacabc --- /dev/null +++ b/python/python/tests/test_commit_index.py @@ -0,0 +1,86 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright The Lance Authors + +import random +import shutil +import string +from pathlib import Path + +import lance +import numpy as np +import pyarrow as pa +import pytest + + +@pytest.fixture() +def test_table(): + num_rows = 1000 + price = np.random.rand(num_rows) * 100 + + def gen_str(n, split="", char_set=string.ascii_letters + string.digits): + return "".join(random.choices(char_set, k=n)) + + meta = np.array([gen_str(100) for _ in range(num_rows)]) + doc = [gen_str(10, " ", string.ascii_letters) for _ in range(num_rows)] + tbl = pa.Table.from_arrays( + [ + pa.array(price), + pa.array(meta), + pa.array(doc, pa.large_string()), + pa.array(range(num_rows)), + ], + names=["price", "meta", "doc", "id"], + ) + return tbl + + +@pytest.fixture() +def dataset_with_index(test_table, tmp_path): + dataset = lance.write_dataset(test_table, tmp_path) + dataset.create_scalar_index("meta", index_type="BTREE") + return dataset + + +def test_commit_index(dataset_with_index, test_table, tmp_path): + index_id = dataset_with_index.list_indices()[0]["uuid"] + + # Create a new dataset without index + dataset_without_index = lance.write_dataset( + test_table, tmp_path / "dataset_without_index" + ) + + # Copy the index from dataset_with_index to dataset_without_index + src_index_dir = Path(dataset_with_index.uri) / "_indices" / index_id + dest_index_dir = Path(dataset_without_index.uri) / "_indices" / index_id + shutil.copytree(src_index_dir, dest_index_dir) + + # Commit the index to dataset_without_index + field_idx = dataset_without_index.schema.get_field_index("meta") + create_index_op = lance.LanceOperation.CreateIndex( + index_id, + "meta_idx", + [field_idx], + dataset_without_index.version, + set([f.fragment_id for f in dataset_without_index.get_fragments()]), + ) + dataset_without_index = lance.LanceDataset.commit( + dataset_without_index.uri, + create_index_op, + read_version=dataset_without_index.version, + ) + + # Verify that both datasets have the index + assert len(dataset_with_index.list_indices()) == 1 + assert len(dataset_without_index.list_indices()) == 1 + + assert ( + dataset_without_index.list_indices()[0] == dataset_with_index.list_indices()[0] + ) + + # Check if the index is used in scans + for dataset in [dataset_with_index, dataset_without_index]: + scanner = dataset.scanner( + fast_search=True, prefilter=True, filter="meta = 'hello'" + ) + plan = scanner.explain_plan() + assert "MaterializeIndex" in plan diff --git a/python/python/tests/test_dataset.py b/python/python/tests/test_dataset.py index f18d4114da..5c865195ce 100644 --- a/python/python/tests/test_dataset.py +++ b/python/python/tests/test_dataset.py @@ -449,6 +449,22 @@ def test_limit_offset(tmp_path: Path, data_storage_version: str): with pytest.raises(ValueError, match="Limit must be non-negative"): assert dataset.to_table(offset=10, limit=-1) == table.slice(50, 50) + full_ds_version = dataset.version + dataset.delete("a % 2 = 0") + filt_table = table.filter((pa.compute.bit_wise_and(pa.compute.field("a"), 1)) != 0) + assert ( + dataset.to_table(offset=10).combine_chunks() + == filt_table.slice(10).combine_chunks() + ) + + dataset = dataset.checkout_version(full_ds_version) + dataset.restore() + dataset.delete("a > 2 AND a < 7") + print(dataset.to_table(offset=3, limit=1)) + filt_table = table.slice(7, 1) + + assert dataset.to_table(offset=3, limit=1) == filt_table + def test_relative_paths(tmp_path: Path): # relative paths get coerced to the full absolute path @@ -896,6 +912,33 @@ def test_merge_with_commit(tmp_path: Path): assert tbl == expected +def test_merge_batch_size(tmp_path: Path): + # Create dataset with 10 fragments with 100 rows each + table = pa.table({"a": range(1000)}) + for batch_size in [1, 10, 100, 1000]: + ds_path = str(tmp_path / str(batch_size)) + dataset = lance.write_dataset(table, ds_path, max_rows_per_file=100) + fragments = [] + + def mutate(batch): + assert batch.num_rows <= batch_size + return pa.RecordBatch.from_pydict({"b": batch.column("a")}) + + for frag in dataset.get_fragments(): + merged, schema = frag.merge_columns(mutate, batch_size=batch_size) + fragments.append(merged) + + merge = lance.LanceOperation.Merge(fragments, schema) + dataset = lance.LanceDataset.commit( + ds_path, merge, read_version=dataset.version + ) + + dataset.validate() + tbl = dataset.to_table() + expected = pa.table({"a": range(1000), "b": range(1000)}) + assert tbl == expected + + def test_merge_with_schema_holes(tmp_path: Path): # Create table with 3 cols table = pa.table({"a": range(10)}) @@ -1501,6 +1544,10 @@ def test_merge_insert_vector_column(tmp_path: Path): check_merge_stats(merge_dict, (1, 1, 0)) +def check_update_stats(update_dict, expected): + assert (update_dict["num_rows_updated"],) == expected + + def test_update_dataset(tmp_path: Path): nrows = 100 vecs = pa.FixedSizeListArray.from_arrays( @@ -1511,11 +1558,12 @@ def test_update_dataset(tmp_path: Path): dataset = lance.dataset(tmp_path / "dataset") - dataset.update(dict(b="b + 1")) + update_dict = dataset.update(dict(b="b + 1")) expected = pa.table({"a": range(100), "b": range(1, 101)}) assert dataset.to_table(columns=["a", "b"]) == expected + check_update_stats(update_dict, (100,)) - dataset.update(dict(a="a * 2"), where="a < 50") + update_dict = dataset.update(dict(a="a * 2"), where="a < 50") expected = pa.table( { "a": [x * 2 if x < 50 else x for x in range(100)], @@ -1523,8 +1571,9 @@ def test_update_dataset(tmp_path: Path): } ) assert dataset.to_table(columns=["a", "b"]).sort_by("b") == expected + check_update_stats(update_dict, (50,)) - dataset.update(dict(vec="[42.0, 43.0]")) + update_dict = dataset.update(dict(vec="[42.0, 43.0]")) expected = pa.table( { "b": range(1, 101), @@ -1534,6 +1583,7 @@ def test_update_dataset(tmp_path: Path): } ) assert dataset.to_table(columns=["b", "vec"]).sort_by("b") == expected + check_update_stats(update_dict, (100,)) def test_update_dataset_all_types(tmp_path: Path): @@ -1558,7 +1608,7 @@ def test_update_dataset_all_types(tmp_path: Path): dataset = lance.write_dataset(table, tmp_path) # One update with all matching types - dataset.update( + update_dict = dataset.update( dict( int32="2", int64="2", @@ -1593,6 +1643,7 @@ def test_update_dataset_all_types(tmp_path: Path): } ) assert dataset.to_table() == expected + check_update_stats(update_dict, (1,)) def test_update_with_binary_field(tmp_path: Path): @@ -1607,12 +1658,13 @@ def test_update_with_binary_field(tmp_path: Path): dataset = lance.write_dataset(table, tmp_path) # Update binary field - dataset.update({"b": "X'616263'"}, where="c < 2") + update_dict = dataset.update({"b": "X'616263'"}, where="c < 2") ds = lance.dataset(tmp_path) assert ds.scanner(filter="c < 2").to_table().column( "b" ).combine_chunks() == pa.array([b"abc", b"abc"]) + check_update_stats(update_dict, (2,)) def test_create_update_empty_dataset(tmp_path: Path, provide_pandas: bool): diff --git a/python/python/tests/test_scalar_index.py b/python/python/tests/test_scalar_index.py index c3096c166e..06385cb6c3 100644 --- a/python/python/tests/test_scalar_index.py +++ b/python/python/tests/test_scalar_index.py @@ -223,11 +223,27 @@ def test_full_text_search(dataset, with_position): columns=["doc"], full_text_query=query, ).to_table() + assert results.num_rows > 0 results = results.column(0) for row in results: assert query in row.as_py() +def test_filter_with_fts_index(dataset): + dataset.create_scalar_index("doc", index_type="INVERTED", with_position=False) + row = dataset.take(indices=[0], columns=["doc"]) + query = row.column(0)[0].as_py() + query = query.split(" ")[0] + results = dataset.scanner( + filter=f"doc = '{query}'", + prefilter=True, + ).to_table() + assert results.num_rows > 0 + results = results["doc"] + for row in results: + assert query == row.as_py() + + def test_bitmap_index(tmp_path: Path): """Test create bitmap index""" tbl = pa.Table.from_arrays( diff --git a/python/src/dataset.rs b/python/src/dataset.rs index 8e843b318d..baef5ee732 100644 --- a/python/src/dataset.rs +++ b/python/src/dataset.rs @@ -56,6 +56,7 @@ use lance_index::{ use lance_io::object_store::ObjectStoreParams; use lance_linalg::distance::MetricType; use lance_table::format::Fragment; +use lance_table::format::Index; use lance_table::io::commit::CommitHandler; use object_store::path::Path; use pyo3::exceptions::{PyStopIteration, PyTypeError}; @@ -320,6 +321,32 @@ impl Operation { }; Ok(Self(op)) } + + #[staticmethod] + fn create_index( + uuid: String, + name: String, + fields: Vec, + dataset_version: u64, + fragment_ids: &PySet, + ) -> PyResult { + let fragment_ids: Vec = fragment_ids + .iter() + .map(|item| item.extract::()) + .collect::>>()?; + let new_indices = vec![Index { + uuid: Uuid::parse_str(&uuid).map_err(|e| PyValueError::new_err(e.to_string()))?, + name, + fields, + dataset_version, + fragment_bitmap: Some(fragment_ids.into_iter().collect()), + }]; + let op = LanceOperation::CreateIndex { + new_indices, + removed_indices: vec![], + }; + Ok(Self(op)) + } } /// Lance Dataset that will be wrapped by another class in Python @@ -916,7 +943,7 @@ impl Dataset { Ok(()) } - fn update(&mut self, updates: &PyDict, predicate: Option<&str>) -> PyResult<()> { + fn update(&mut self, updates: &PyDict, predicate: Option<&str>) -> PyResult { let mut builder = UpdateBuilder::new(self.ds.clone()); if let Some(predicate) = predicate { builder = builder @@ -941,9 +968,11 @@ impl Dataset { .block_on(None, operation.execute())? .map_err(|err| PyIOError::new_err(err.to_string()))?; - self.ds = new_self; - - Ok(()) + self.ds = new_self.new_dataset; + let update_dict = PyDict::new(updates.py()); + let num_rows_updated = new_self.rows_updated; + update_dict.set_item("num_rows_updated", num_rows_updated)?; + Ok(update_dict.into()) } fn count_deleted_rows(&self) -> PyResult { diff --git a/python/src/fragment.rs b/python/src/fragment.rs index e384630476..1e69d3b354 100644 --- a/python/src/fragment.rs +++ b/python/src/fragment.rs @@ -210,10 +210,12 @@ impl FileFragment { Ok(Scanner::new(scn)) } - fn updater(&self, columns: Option>) -> PyResult { + fn updater(&self, columns: Option>, batch_size: Option) -> PyResult { let cols = columns.as_deref(); let inner = RT - .block_on(None, async { self.fragment.updater(cols, None).await })? + .block_on(None, async { + self.fragment.updater(cols, None, batch_size).await + })? .map_err(|err| PyIOError::new_err(err.to_string()))?; Ok(Updater::new(inner)) } diff --git a/rust/lance-core/src/cache.rs b/rust/lance-core/src/cache.rs index 0ebbacfd5a..289b48d358 100644 --- a/rust/lance-core/src/cache.rs +++ b/rust/lance-core/src/cache.rs @@ -67,6 +67,10 @@ impl FileMetadataCache { } } + pub fn size(&self) -> usize { + self.cache.entry_count() as usize + } + pub fn get(&self, path: &Path) -> Option> { self.cache .get(&(path.to_owned(), TypeId::of::())) diff --git a/rust/lance-core/src/utils/deletion.rs b/rust/lance-core/src/utils/deletion.rs index b847df0079..1735f90b8c 100644 --- a/rust/lance-core/src/utils/deletion.rs +++ b/rust/lance-core/src/utils/deletion.rs @@ -83,6 +83,17 @@ impl DeletionVector { } } + /// Create an iterator that iterates over the values in the deletion vector in sorted order. + pub fn to_sorted_iter<'a>(&'a self) -> Box + Send + 'a> { + match self { + Self::NoDeletions => Box::new(std::iter::empty()), + // We have to make a clone when we're using a set + // but sets should be relatively small. + Self::Set(_) => self.clone().into_sorted_iter(), + Self::Bitmap(bitmap) => Box::new(bitmap.iter()), + } + } + // Note: deletion vectors are based on 32-bit offsets. However, this function works // even when given 64-bit row addresses. That is because `id as u32` returns the lower // 32 bits (the row offset) and the upper 32 bits are ignored. diff --git a/rust/lance-datafusion/src/exec.rs b/rust/lance-datafusion/src/exec.rs index b94b98ab77..0ec78dfc29 100644 --- a/rust/lance-datafusion/src/exec.rs +++ b/rust/lance-datafusion/src/exec.rs @@ -18,8 +18,9 @@ use datafusion::{ TaskContext, }, physical_plan::{ - stream::RecordBatchStreamAdapter, streaming::PartitionStream, DisplayAs, DisplayFormatType, - ExecutionPlan, PlanProperties, SendableRecordBatchStream, + display::DisplayableExecutionPlan, stream::RecordBatchStreamAdapter, + streaming::PartitionStream, DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties, + SendableRecordBatchStream, }, }; use datafusion_common::{DataFusionError, Statistics}; @@ -29,7 +30,7 @@ use lazy_static::lazy_static; use futures::stream; use lance_arrow::SchemaExt; use lance_core::Result; -use log::{info, warn}; +use log::{debug, info, warn}; /// An source execution node created from an existing stream /// @@ -241,6 +242,11 @@ pub fn execute_plan( plan: Arc, options: LanceExecutionOptions, ) -> Result { + debug!( + "Executing plan:\n{}", + DisplayableExecutionPlan::new(plan.as_ref()).indent(true) + ); + let session_ctx = get_session_context(options); // NOTE: we are only executing the first partition here. Therefore, if diff --git a/rust/lance-datafusion/src/expr.rs b/rust/lance-datafusion/src/expr.rs index 00ebf1e86b..dbc450b654 100644 --- a/rust/lance-datafusion/src/expr.rs +++ b/rust/lance-datafusion/src/expr.rs @@ -219,6 +219,11 @@ pub fn safe_coerce_scalar(value: &ScalarValue, ty: &DataType) -> Option Some(ScalarValue::LargeUtf8(val.clone())), _ => None, }, + ScalarValue::LargeUtf8(val) => match ty { + DataType::Utf8 => Some(ScalarValue::Utf8(val.clone())), + DataType::LargeUtf8 => Some(value.clone()), + _ => None, + }, ScalarValue::Boolean(_) => match ty { DataType::Boolean => Some(value.clone()), _ => None, diff --git a/rust/lance-encoding/src/encodings/physical/value.rs b/rust/lance-encoding/src/encodings/physical/value.rs index 32ba61bf88..8767008c01 100644 --- a/rust/lance-encoding/src/encodings/physical/value.rs +++ b/rust/lance-encoding/src/encodings/physical/value.rs @@ -174,7 +174,7 @@ impl ValuePageDecoder { (None, 0) => { // The entire request is contained in one buffer so we can maybe zero-copy // if the slice is aligned properly - return LanceBuffer::from_bytes(slice, dbg!(self.bytes_per_value)); + return LanceBuffer::from_bytes(slice, self.bytes_per_value); } (None, _) => { dest.replace(Vec::with_capacity(bytes_to_take as usize)); diff --git a/rust/lance-index/src/scalar/inverted/builder.rs b/rust/lance-index/src/scalar/inverted/builder.rs index 459c8e73a2..782752e155 100644 --- a/rust/lance-index/src/scalar/inverted/builder.rs +++ b/rust/lance-index/src/scalar/inverted/builder.rs @@ -496,12 +496,17 @@ mod tests { use crate::scalar::lance_format::LanceIndexStore; use crate::scalar::{FullTextSearchQuery, SargableQuery, ScalarIndex}; - async fn test_inverted_index() { + use super::InvertedIndex; + + async fn create_index( + with_position: bool, + ) -> Arc { let tempdir = tempfile::tempdir().unwrap(); let index_dir = Path::from_filesystem_path(tempdir.path()).unwrap(); let store = LanceIndexStore::new(ObjectStore::local(), index_dir, None); - let mut invert_index = super::InvertedIndexBuilder::default(); + let params = super::InvertedIndexParams::default().with_position(with_position); + let mut invert_index = super::InvertedIndexBuilder::new(params); let doc_col = GenericStringArray::::from(vec![ "lance database the search", "lance database", @@ -531,7 +536,11 @@ mod tests { .await .expect("failed to update invert index"); - let invert_index = super::InvertedIndex::load(Arc::new(store)).await.unwrap(); + super::InvertedIndex::load(Arc::new(store)).await.unwrap() + } + + async fn test_inverted_index() { + let invert_index = create_index::(false).await; let row_ids = invert_index .search(&SargableQuery::FullTextSearch( FullTextSearchQuery::new("lance".to_owned()).limit(Some(3)), @@ -554,9 +563,34 @@ mod tests { assert!(row_ids.contains(1)); assert!(row_ids.contains(3)); + let row_ids = invert_index + .search(&SargableQuery::FullTextSearch( + FullTextSearchQuery::new("unknown null".to_owned()).limit(Some(3)), + )) + .await + .unwrap(); + assert_eq!(row_ids.len(), Some(0)); + // test phrase query // for non-phrasal query, the order of the tokens doesn't matter // so there should be 4 documents that contain "database" or "lance" + + // we built the index without position, so the phrase query will not work + let results = invert_index + .search(&SargableQuery::FullTextSearch( + FullTextSearchQuery::new("\"unknown null\"".to_owned()).limit(Some(3)), + )) + .await; + assert!(results.unwrap_err().to_string().contains("position is not found but required for phrase queries, try recreating the index with position")); + let results = invert_index + .search(&SargableQuery::FullTextSearch( + FullTextSearchQuery::new("\"lance database\"".to_owned()).limit(Some(10)), + )) + .await; + assert!(results.unwrap_err().to_string().contains("position is not found but required for phrase queries, try recreating the index with position")); + + // recreate the index with position + let invert_index = create_index::(true).await; let row_ids = invert_index .search(&SargableQuery::FullTextSearch( FullTextSearchQuery::new("lance database".to_owned()).limit(Some(10)), @@ -594,6 +628,14 @@ mod tests { .await .unwrap(); assert_eq!(row_ids.len(), Some(0)); + + let row_ids = invert_index + .search(&SargableQuery::FullTextSearch( + FullTextSearchQuery::new("\"unknown null\"".to_owned()).limit(Some(3)), + )) + .await + .unwrap(); + assert_eq!(row_ids.len(), Some(0)); } #[tokio::test] diff --git a/rust/lance-index/src/scalar/inverted/index.rs b/rust/lance-index/src/scalar/inverted/index.rs index 8aa1349ca9..a59b83f01e 100644 --- a/rust/lance-index/src/scalar/inverted/index.rs +++ b/rust/lance-index/src/scalar/inverted/index.rs @@ -108,6 +108,9 @@ impl InvertedIndex { let token_ids = if !is_phrase_query(&query.query) { token_ids.sorted_unstable().dedup().collect() } else { + if !self.inverted_list.has_positions() { + return Err(Error::Index { message: "position is not found but required for phrase queries, try recreating the index with position".to_owned(), location: location!() }); + } let token_ids = token_ids.collect::>(); // for phrase query, all tokens must be present if token_ids.len() != tokens.len() { @@ -377,6 +380,8 @@ struct InvertedListReader { offsets: Vec, max_scores: Option>, + has_position: bool, + // cache posting_cache: Cache, position_cache: Cache, @@ -413,6 +418,8 @@ impl InvertedListReader { None => None, }; + let has_position = reader.schema().field(POSITION_COL).is_some(); + let posting_cache = Cache::builder() .max_capacity(*CACHE_SIZE as u64) .weigher(|_, posting: &PostingList| posting.deep_size_of() as u32) @@ -425,11 +432,16 @@ impl InvertedListReader { reader, offsets, max_scores, + has_position, posting_cache, position_cache, }) } + pub(crate) fn has_positions(&self) -> bool { + self.has_position + } + pub(crate) fn posting_len(&self, token_id: u32) -> usize { let token_id = token_id as usize; let next_offset = self @@ -489,7 +501,7 @@ impl InvertedListReader { .await?; Result::Ok(batch .column_by_name(POSITION_COL) - .ok_or(Error::Index { message: "the index was built with old version which doesn't support phrase query, please re-create the index".to_owned(), location: location!() })? + .ok_or(Error::Index { message: "position is not found but required for phrase queries, try recreating the index with position".to_owned(), location: location!() })? .as_list::() .clone()) }).await.map_err(|e| Error::io(e.to_string(), location!())) diff --git a/rust/lance-index/src/scalar/label_list.rs b/rust/lance-index/src/scalar/label_list.rs index 2b7a33b5fe..5f02b749c3 100644 --- a/rust/lance-index/src/scalar/label_list.rs +++ b/rust/lance-index/src/scalar/label_list.rs @@ -180,7 +180,10 @@ fn extract_flatten_indices(list_arr: &dyn Array) -> UInt64Array { } UInt64Array::from(indices) } else { - unreachable!("Should verify that the first column is a list earlier") + unreachable!( + "Should verify that the first column is a list earlier. Got array of type: {}", + list_arr.data_type() + ) } } @@ -189,14 +192,19 @@ fn unnest_schema(schema: &Schema) -> SchemaRef { let key_field = fields_iter.next().unwrap(); let remaining_fields = fields_iter.collect::>(); - let new_key_field = if let DataType::List(item_field) = key_field.data_type() { - Field::new( + let new_key_field = match key_field.data_type() { + DataType::List(item_field) | DataType::LargeList(item_field) => Field::new( key_field.name(), item_field.data_type().clone(), item_field.is_nullable() || key_field.is_nullable(), - ) - } else { - unreachable!("Should verify that the first column is a list earlier") + ), + other_type => { + unreachable!( + "The first field in the schema must be a List or LargeList type. \ + Found: {}. This should have been verified earlier in the code.", + other_type + ) + } }; let all_fields = vec![Arc::new(new_key_field)] diff --git a/rust/lance-io/src/object_reader.rs b/rust/lance-io/src/object_reader.rs index e943d791a6..58a6385096 100644 --- a/rust/lance-io/src/object_reader.rs +++ b/rust/lance-io/src/object_reader.rs @@ -9,7 +9,7 @@ use bytes::Bytes; use deepsize::DeepSizeOf; use futures::future::BoxFuture; use lance_core::Result; -use object_store::{path::Path, ObjectStore}; +use object_store::{path::Path, GetOptions, ObjectStore}; use tokio::sync::OnceCell; use tracing::instrument; @@ -28,6 +28,7 @@ pub struct CloudObjectReader { size: OnceCell, block_size: usize, + download_retry_count: usize, } impl DeepSizeOf for CloudObjectReader { @@ -44,12 +45,14 @@ impl CloudObjectReader { path: Path, block_size: usize, known_size: Option, + download_retry_count: usize, ) -> Result { Ok(Self { object_store, path, size: OnceCell::new_with(known_size), block_size, + download_retry_count, }) } @@ -104,7 +107,40 @@ impl Reader for CloudObjectReader { #[instrument(level = "debug", skip(self))] async fn get_range(&self, range: Range) -> object_store::Result { - self.do_with_retry(|| self.object_store.get_range(&self.path, range.clone())) - .await + // We have a separate retry loop here. This is because object_store does not + // attempt retries on downloads that fail during streaming of the response body. + // + // However, this failure is pretty common (e.g. timeout) and we want to retry in these + // situations. In addition, we provide additional logging information in these + // failures cases. + let mut retries = self.download_retry_count; + loop { + let get_result = self + .do_with_retry(|| { + let options = GetOptions { + range: Some(range.clone().into()), + ..Default::default() + }; + self.object_store.get_opts(&self.path, options) + }) + .await?; + match get_result.bytes().await { + Ok(bytes) => return Ok(bytes), + Err(err) => { + if retries == 0 { + log::warn!("Failed to download range {:?} from {} after {} attempts. This may indicate that cloud storage is overloaded or your timeout settings are too restrictive. Error details: {:?}", range, self.path, self.download_retry_count, err); + return Err(err); + } + log::debug!( + "Retrying range {:?} from {} (remaining retries: {}). Error details: {:?}", + range, + self.path, + retries, + err + ); + retries -= 1; + } + } + } } } diff --git a/rust/lance-io/src/object_store.rs b/rust/lance-io/src/object_store.rs index 67e82df8ba..dc50e2dc12 100644 --- a/rust/lance-io/src/object_store.rs +++ b/rust/lance-io/src/object_store.rs @@ -46,6 +46,8 @@ pub const DEFAULT_LOCAL_IO_PARALLELISM: usize = 8; // Cloud disks often need many many threads to saturate the network pub const DEFAULT_CLOUD_IO_PARALLELISM: usize = 64; +pub const DEFAULT_DOWNLOAD_RETRY_COUNT: usize = 3; + #[async_trait] pub trait ObjectStoreExt { /// Returns true if the file exists. @@ -100,6 +102,8 @@ pub struct ObjectStore { /// is true for object stores, but not for local filesystems. pub list_is_lexically_ordered: bool, io_parallelism: usize, + /// Number of times to retry a failed download + download_retry_count: usize, } impl DeepSizeOf for ObjectStore { @@ -440,6 +444,7 @@ impl ObjectStore { use_constant_size_upload_parts: false, list_is_lexically_ordered: false, io_parallelism: DEFAULT_LOCAL_IO_PARALLELISM, + download_retry_count: DEFAULT_DOWNLOAD_RETRY_COUNT, }, Path::from_absolute_path(expanded_path.as_path())?, )) @@ -466,6 +471,7 @@ impl ObjectStore { use_constant_size_upload_parts: false, list_is_lexically_ordered: false, io_parallelism: DEFAULT_LOCAL_IO_PARALLELISM, + download_retry_count: DEFAULT_DOWNLOAD_RETRY_COUNT, } } @@ -478,6 +484,7 @@ impl ObjectStore { use_constant_size_upload_parts: false, list_is_lexically_ordered: true, io_parallelism: get_num_compute_intensive_cpus(), + download_retry_count: DEFAULT_DOWNLOAD_RETRY_COUNT, } } @@ -516,6 +523,7 @@ impl ObjectStore { path.clone(), self.block_size, None, + self.download_retry_count, )?)), } } @@ -533,6 +541,7 @@ impl ObjectStore { path.clone(), self.block_size, Some(known_size), + self.download_retry_count, )?)), } } @@ -641,6 +650,28 @@ impl ObjectStore { Ok(self.inner.head(path).await?.size) } } + +/// Options that can be set for multiple object stores +#[derive(PartialEq, Eq, Hash, Clone, Debug, Copy)] +pub enum LanceConfigKey { + /// Number of times to retry a download that fails + DownloadRetryCount, +} + +impl FromStr for LanceConfigKey { + type Err = Error; + + fn from_str(s: &str) -> std::result::Result { + match s.to_ascii_lowercase().as_str() { + "download_retry_count" => Ok(Self::DownloadRetryCount), + _ => Err(Error::InvalidInput { + source: format!("Invalid LanceConfigKey: {}", s).into(), + location: location!(), + }), + } + } +} + #[derive(Clone, Debug, Default)] pub struct StorageOptions(pub HashMap); @@ -709,6 +740,15 @@ impl StorageOptions { }) } + /// Number of times to retry a download that fails + pub fn download_retry_count(&self) -> usize { + self.0 + .iter() + .find(|(key, _)| key.to_ascii_lowercase() == "download_retry_count") + .map(|(_, value)| value.parse::().unwrap_or(3)) + .unwrap_or(3) + } + /// Subset of options relevant for azure storage pub fn as_azure_options(&self) -> HashMap { self.0 @@ -755,6 +795,7 @@ async fn configure_store( options: ObjectStoreParams, ) -> Result { let mut storage_options = StorageOptions(options.storage_options.clone().unwrap_or_default()); + let download_retry_count = storage_options.download_retry_count(); let mut url = ensure_table_uri(url)?; // Block size: On local file systems, we use 4KB block size. On cloud // object stores, we use 64KB block size. This is generally the largest @@ -813,6 +854,7 @@ async fn configure_store( use_constant_size_upload_parts, list_is_lexically_ordered: true, io_parallelism: DEFAULT_CLOUD_IO_PARALLELISM, + download_retry_count, }) } "gs" => { @@ -831,6 +873,7 @@ async fn configure_store( use_constant_size_upload_parts: false, list_is_lexically_ordered: true, io_parallelism: DEFAULT_CLOUD_IO_PARALLELISM, + download_retry_count, }) } "az" => { @@ -845,6 +888,7 @@ async fn configure_store( use_constant_size_upload_parts: false, list_is_lexically_ordered: true, io_parallelism: DEFAULT_CLOUD_IO_PARALLELISM, + download_retry_count, }) } // we have a bypass logic to use `tokio::fs` directly to lower overhead @@ -862,6 +906,7 @@ async fn configure_store( use_constant_size_upload_parts: false, list_is_lexically_ordered: true, io_parallelism: get_num_compute_intensive_cpus(), + download_retry_count, }), unknown_scheme => { if let Some(provider) = registry.providers.get(unknown_scheme) { @@ -878,6 +923,7 @@ async fn configure_store( } impl ObjectStore { + #[allow(clippy::too_many_arguments)] pub fn new( store: Arc, location: Url, @@ -886,6 +932,7 @@ impl ObjectStore { use_constant_size_upload_parts: bool, list_is_lexically_ordered: bool, io_parallelism: usize, + download_retry_count: usize, ) -> Self { let scheme = location.scheme(); let block_size = block_size.unwrap_or_else(|| infer_block_size(scheme)); @@ -902,6 +949,7 @@ impl ObjectStore { use_constant_size_upload_parts, list_is_lexically_ordered, io_parallelism, + download_retry_count, } } } diff --git a/rust/lance-io/src/scheduler.rs b/rust/lance-io/src/scheduler.rs index 747e684e54..206383fd2a 100644 --- a/rust/lance-io/src/scheduler.rs +++ b/rust/lance-io/src/scheduler.rs @@ -660,7 +660,7 @@ mod tests { use tokio::{runtime::Handle, time::timeout}; use url::Url; - use crate::testing::MockObjectStore; + use crate::{object_store::DEFAULT_DOWNLOAD_RETRY_COUNT, testing::MockObjectStore}; use super::*; @@ -743,6 +743,7 @@ mod tests { false, false, 1, + DEFAULT_DOWNLOAD_RETRY_COUNT, )); let config = SchedulerConfig { @@ -831,6 +832,7 @@ mod tests { false, false, 1, + DEFAULT_DOWNLOAD_RETRY_COUNT, )); let config = SchedulerConfig { diff --git a/rust/lance-io/src/utils.rs b/rust/lance-io/src/utils.rs index 2c4cb0900a..37253339a5 100644 --- a/rust/lance-io/src/utils.rs +++ b/rust/lance-io/src/utils.rs @@ -183,7 +183,7 @@ mod tests { use crate::{ object_reader::CloudObjectReader, - object_store::ObjectStore, + object_store::{ObjectStore, DEFAULT_DOWNLOAD_RETRY_COUNT}, object_writer::ObjectWriter, traits::{ProtoStruct, WriteExt, Writer}, utils::read_struct, @@ -226,7 +226,9 @@ mod tests { assert_eq!(pos, 0); object_writer.shutdown().await.unwrap(); - let object_reader = CloudObjectReader::new(store.inner, path, 1024, None).unwrap(); + let object_reader = + CloudObjectReader::new(store.inner, path, 1024, None, DEFAULT_DOWNLOAD_RETRY_COUNT) + .unwrap(); let actual: BytesWrapper = read_struct(&object_reader, pos).await.unwrap(); assert_eq!(some_message, actual); } diff --git a/rust/lance-linalg/build.rs b/rust/lance-linalg/build.rs index 88c7223271..2eadc0a65d 100644 --- a/rust/lance-linalg/build.rs +++ b/rust/lance-linalg/build.rs @@ -89,7 +89,7 @@ fn build_f16_with_flags(suffix: &str, flags: &[&str]) -> Result<(), cc::Error> { // Pedantic will complain about _Float16 in some versions of GCC // .flag("-Wpedantic") // We pass in the suffix to make sure the symbol names are unique - .flag(format!("-DSUFFIX=_{}", suffix)); + .flag(format!("-DSUFFIX=_{}", suffix).as_str()); for flag in flags { builder.flag(flag); diff --git a/rust/lance/src/arrow/json.rs b/rust/lance/src/arrow/json.rs index 75ad082157..5e11ae412f 100644 --- a/rust/lance/src/arrow/json.rs +++ b/rust/lance/src/arrow/json.rs @@ -48,6 +48,8 @@ impl TryFrom<&DataType> for JsonDataType { | DataType::Float16 | DataType::Float32 | DataType::Float64 + | DataType::Decimal128(_, _) + | DataType::Decimal256(_, _) | DataType::Utf8 | DataType::Binary | DataType::LargeUtf8 @@ -63,7 +65,6 @@ impl TryFrom<&DataType> for JsonDataType { let logical_type: LogicalType = dt.try_into()?; (logical_type.to_string(), None) } - DataType::List(f) => { let fields = vec![JsonField::try_from(f.as_ref())?]; ("list".to_string(), Some(fields)) @@ -119,7 +120,8 @@ impl TryFrom<&JsonDataType> for DataType { || dt.starts_with("time64:") || dt.starts_with("timestamp:") || dt.starts_with("duration:") - || dt.starts_with("dict:") => + || dt.starts_with("dict:") + || dt.starts_with("decimal:") => { let logical_type: LogicalType = dt.into(); (&logical_type).try_into() @@ -312,6 +314,10 @@ mod test { assert_primitive_types(DataType::Date32, "date32:day"); assert_primitive_types(DataType::Date64, "date64:ms"); assert_primitive_types(DataType::Time32(TimeUnit::Second), "time32:s"); + assert_primitive_types(DataType::Decimal128(38, 10), "decimal:128:38:10"); + assert_primitive_types(DataType::Decimal256(76, 20), "decimal:256:76:20"); + assert_primitive_types(DataType::Decimal128(18, 6), "decimal:128:18:6"); + assert_primitive_types(DataType::Decimal256(50, 15), "decimal:256:50:15"); } #[test] diff --git a/rust/lance/src/dataset/builder.rs b/rust/lance/src/dataset/builder.rs index 5d7037b601..005ea89372 100644 --- a/rust/lance/src/dataset/builder.rs +++ b/rust/lance/src/dataset/builder.rs @@ -4,7 +4,8 @@ use std::{collections::HashMap, sync::Arc, time::Duration}; use lance_file::datatypes::populate_schema_dictionary; use lance_io::object_store::{ - ObjectStore, ObjectStoreParams, ObjectStoreRegistry, DEFAULT_CLOUD_IO_PARALLELISM, + ObjectStore, ObjectStoreParams, ObjectStoreRegistry, StorageOptions, + DEFAULT_CLOUD_IO_PARALLELISM, }; use lance_table::{ format::Manifest, @@ -220,6 +221,14 @@ impl DatasetBuilder { None => commit_handler_from_url(&self.table_uri, &Some(self.options.clone())).await, }?; + let storage_options = self + .options + .storage_options + .clone() + .map(StorageOptions::new) + .unwrap_or_default(); + let download_retry_count = storage_options.download_retry_count(); + match &self.options.object_store { Some(store) => Ok(( ObjectStore::new( @@ -232,6 +241,7 @@ impl DatasetBuilder { // If user supplied an object store then we just assume it's probably // cloud-like DEFAULT_CLOUD_IO_PARALLELISM, + download_retry_count, ), Path::from(store.1.path()), commit_handler, diff --git a/rust/lance/src/dataset/fragment.rs b/rust/lance/src/dataset/fragment.rs index 731365a8a6..22acebdd67 100644 --- a/rust/lance/src/dataset/fragment.rs +++ b/rust/lance/src/dataset/fragment.rs @@ -1132,6 +1132,7 @@ impl FileFragment { &self, columns: Option<&[T]>, schemas: Option<(Schema, Schema)>, + batch_size: Option, ) -> Result { let mut schema = self.dataset.schema().clone(); @@ -1160,11 +1161,11 @@ impl FileFragment { let reader = reader?; let deletion_vector = deletion_vector?.unwrap_or_default(); - Updater::try_new(self.clone(), reader, deletion_vector, schemas) + Updater::try_new(self.clone(), reader, deletion_vector, schemas, batch_size) } pub(crate) async fn merge(mut self, join_column: &str, joiner: &HashJoiner) -> Result { - let mut updater = self.updater(Some(&[join_column]), None).await?; + let mut updater = self.updater(Some(&[join_column]), None, None).await?; while let Some(batch) = updater.next().await? { let batch = joiner.collect(batch[join_column].clone()).await?; @@ -1740,7 +1741,33 @@ impl FragmentReader { ) } - pub fn read_range(&self, range: Range, batch_size: u32) -> Result { + fn patch_range_for_deletions(&self, range: Range, dv: &DeletionVector) -> Range { + let mut start = range.start; + let mut end = range.end; + for val in dv.to_sorted_iter() { + if val <= start { + start += 1; + end += 1; + } else if val < end { + end += 1; + } else { + break; + } + } + start..end + } + + fn do_read_range( + &self, + mut range: Range, + batch_size: u32, + skip_deleted_rows: bool, + ) -> Result { + if skip_deleted_rows { + if let Some(deletion_vector) = self.deletion_vec.as_ref() { + range = self.patch_range_for_deletions(range, deletion_vector.as_ref()); + } + } self.new_read_impl( ReadBatchParams::Range(range.start as usize..range.end as usize), batch_size, @@ -1754,6 +1781,22 @@ impl FragmentReader { ) } + /// Reads a range of rows from the fragment + /// + /// This function interprets the request as the Xth to the Nth row of the fragment (after deletions) + /// and will always return range.len().min(self.num_rows()) rows. + pub fn read_range(&self, range: Range, batch_size: u32) -> Result { + self.do_read_range(range, batch_size, true) + } + + /// Takes a range of rows from the fragment + /// + /// Unlike [`Self::read_range`], this function will NOT skip deleted rows. If rows are deleted they will + /// be filtered or set to null. This function may return less than range.len() rows as a result. + pub fn take_range(&self, range: Range, batch_size: u32) -> Result { + self.do_read_range(range, batch_size, false) + } + pub fn read_all(&self, batch_size: u32) -> Result { self.new_read_impl(ReadBatchParams::RangeFull, batch_size, move |reader| { reader.read_all_tasks(batch_size, reader.projection().clone()) @@ -1766,7 +1809,7 @@ impl FragmentReader { // TODO: Move away from this by changing callers to support consuming a stream pub async fn legacy_read_range_as_batch(&self, range: Range) -> Result { let batches = self - .read_range( + .take_range( range.start as u32..range.end as u32, DEFAULT_BATCH_READ_SIZE, )? @@ -1988,6 +2031,7 @@ mod tests { let fragment = &dataset.get_fragments()[0]; assert_eq!(fragment.metadata.num_rows().unwrap(), 20); + // Test with take_range (all rows addressible) for with_row_id in [false, true] { let reader = fragment .open(fragment.schema(), with_row_id, false, None) @@ -1995,7 +2039,7 @@ mod tests { .unwrap(); for valid_range in [0..40, 20..40] { reader - .read_range(valid_range, 100) + .take_range(valid_range, 100) .unwrap() .buffered(1) .try_collect::>() @@ -2003,6 +2047,26 @@ mod tests { .unwrap(); } for invalid_range in [0..41, 41..42] { + assert!(reader.take_range(invalid_range, 100).is_err()); + } + } + + // Test with read_range (only non-deleted rows addressible) + for with_row_id in [false, true] { + let reader = fragment + .open(fragment.schema(), with_row_id, false, None) + .await + .unwrap(); + for valid_range in [0..20, 0..10, 10..20] { + reader + .read_range(valid_range, 100) + .unwrap() + .buffered(1) + .try_collect::>() + .await + .unwrap(); + } + for invalid_range in [0..21, 21..22] { assert!(reader.read_range(invalid_range, 100).is_err()); } } @@ -2010,7 +2074,7 @@ mod tests { #[rstest] #[tokio::test] - async fn test_fragment_scan_deletions( + async fn test_fragment_take_range_deletions( #[values(LanceFileVersion::Legacy, LanceFileVersion::Stable)] data_storage_version: LanceFileVersion, ) { @@ -2058,7 +2122,7 @@ mod tests { let to_batches = |range: Range| { let batch_size = range.len() as u32; reader - .read_range(range, batch_size) + .take_range(range, batch_size) .unwrap() .buffered(1) .try_collect::>() @@ -2090,6 +2154,66 @@ mod tests { } } + #[rstest] + #[tokio::test] + async fn test_range_scan_deletions( + #[values(LanceFileVersion::Legacy, LanceFileVersion::Stable)] + data_storage_version: LanceFileVersion, + ) { + let test_dir = tempdir().unwrap(); + let test_uri = test_dir.path().to_str().unwrap(); + let dataset = create_dataset(test_uri, data_storage_version).await; + + let version = dataset.version().version; + + let check = |cond: &'static str, range: Range, expected: Vec| async { + let mut dataset = dataset.checkout_version(version).await.unwrap(); + dataset.restore().await.unwrap(); + dataset.delete(cond).await.unwrap(); + + let fragment = &dataset.get_fragments()[0]; + let reader = fragment + .open(dataset.schema(), true, false, None) + .await + .unwrap(); + + // Using batch_size=20 here. If we use batch_size=range.len() we get + // multiple batches because we might have to read from a larger range + // to satisfy the request + let mut stream = reader.read_range(range, 20).unwrap(); + let mut batches = Vec::new(); + while let Some(next) = stream.next().await { + batches.push(next.await.unwrap()); + } + let schema = Arc::new(dataset.schema().into()); + let batch = arrow_select::concat::concat_batches(&schema, batches.iter()).unwrap(); + + assert_eq!(batch.num_rows(), expected.len()); + assert_eq!( + batch.column_by_name("i").unwrap().as_ref(), + &Int32Array::from(expected) + ); + }; + // Deleting from the start + check("i < 5", 0..2, vec![5, 6]).await; + check("i < 5", 0..15, (5..20).collect()).await; + // Deleting from the middle + check("i >= 5 and i < 15", 7..9, vec![17, 18]).await; + check("i >= 5 and i < 15", 3..5, vec![3, 4]).await; + check("i >= 5 and i < 15", 3..6, vec![3, 4, 15]).await; + check("i >= 5 and i < 15", 5..6, vec![15]).await; + check("i >= 5 and i < 15", 5..10, vec![15, 16, 17, 18, 19]).await; + check( + "i >= 5 and i < 15", + 0..10, + vec![0, 1, 2, 3, 4, 15, 16, 17, 18, 19], + ) + .await; + // Deleting from the end + check("i >= 15", 10..15, vec![10, 11, 12, 13, 14]).await; + check("i >= 15", 0..15, (0..15).collect()).await; + } + #[rstest] #[tokio::test] async fn test_fragment_take_indices( @@ -2310,7 +2434,7 @@ mod tests { } let fragment = &mut dataset.get_fragment(0).unwrap(); - let mut updater = fragment.updater(Some(&["i"]), None).await.unwrap(); + let mut updater = fragment.updater(Some(&["i"]), None, None).await.unwrap(); let new_schema = Arc::new(ArrowSchema::new(vec![ArrowField::new( "double_i", DataType::Int32, @@ -2554,7 +2678,7 @@ mod tests { let fragment = dataset.get_fragments().pop().unwrap(); // Write batch_s using add_columns - let mut updater = fragment.updater(Some(&["i"]), None).await?; + let mut updater = fragment.updater(Some(&["i"]), None, None).await?; updater.next().await?; updater.update(batch_s.clone()).await?; let frag = updater.finish().await?; diff --git a/rust/lance/src/dataset/rowids.rs b/rust/lance/src/dataset/rowids.rs index 808bdae377..3f93ca5a74 100644 --- a/rust/lance/src/dataset/rowids.rs +++ b/rust/lance/src/dataset/rowids.rs @@ -410,7 +410,7 @@ mod test { assert_eq!(dataset.manifest().next_row_id, num_rows); - let dataset = UpdateBuilder::new(Arc::new(dataset)) + let update_result = UpdateBuilder::new(Arc::new(dataset)) .update_where("id = 3") .unwrap() .set("id", "100") @@ -421,6 +421,7 @@ mod test { .await .unwrap(); + let dataset = update_result.new_dataset; let index = get_row_id_index(&dataset).await.unwrap().unwrap(); assert!(index.get(0).is_some()); // Old address is still there. diff --git a/rust/lance/src/dataset/scanner.rs b/rust/lance/src/dataset/scanner.rs index 5eef527144..e64b319096 100644 --- a/rust/lance/src/dataset/scanner.rs +++ b/rust/lance/src/dataset/scanner.rs @@ -38,7 +38,6 @@ use lance_core::utils::tokio::get_num_compute_intensive_cpus; use lance_core::{ROW_ADDR, ROW_ADDR_FIELD, ROW_ID, ROW_ID_FIELD}; use lance_datafusion::exec::{execute_plan, LanceExecutionOptions}; use lance_datafusion::projection::ProjectionPlan; -use lance_index::scalar::expression::IndexInformationProvider; use lance_index::scalar::expression::PlannerIndexExt; use lance_index::scalar::inverted::SCORE_COL; use lance_index::scalar::FullTextSearchQuery; @@ -47,7 +46,6 @@ use lance_index::{scalar::expression::ScalarIndexExpr, DatasetIndexExt}; use lance_io::stream::RecordBatchStream; use lance_linalg::distance::MetricType; use lance_table::format::{Fragment, Index}; -use log::debug; use roaring::RoaringBitmap; use tracing::{info_span, instrument, Span}; @@ -67,7 +65,14 @@ use snafu::{location, Location}; #[cfg(feature = "substrait")] use lance_datafusion::substrait::parse_substrait; -pub const DEFAULT_BATCH_SIZE: usize = 8192; +const BATCH_SIZE_FALLBACK: usize = 8192; +// For backwards compatibility / historical reasons we re-calculate the default batch size +// on each call +pub fn get_default_batch_size() -> Option { + std::env::var("LANCE_DEFAULT_BATCH_SIZE") + .map(|val| Some(val.parse().unwrap())) + .unwrap_or(None) +} pub const LEGACY_DEFAULT_FRAGMENT_READAHEAD: usize = 4; lazy_static::lazy_static! { @@ -263,23 +268,14 @@ impl Scanner { // 64KB, this is 16K rows. For local file systems, the default block size // is just 4K, which would mean only 1K rows, which might be a little small. // So we use a default minimum of 8K rows. - std::env::var("LANCE_DEFAULT_BATCH_SIZE") - .map(|bs| { - bs.parse().unwrap_or_else(|_| { - panic!( - "The value of LANCE_DEFAULT_BATCH_SIZE ({}) is not a valid batch size", - bs - ) - }) - }) - .unwrap_or_else(|_| { - self.batch_size.unwrap_or_else(|| { - std::cmp::max( - self.dataset.object_store().block_size() / 4, - DEFAULT_BATCH_SIZE, - ) - }) + get_default_batch_size().unwrap_or_else(|| { + self.batch_size.unwrap_or_else(|| { + std::cmp::max( + self.dataset.object_store().block_size() / 4, + BATCH_SIZE_FALLBACK, + ) }) + }) } fn ensure_not_fragment_scan(&self) -> Result<()> { @@ -1115,8 +1111,6 @@ impl Scanner { plan = rule.optimize(plan, &options)?; } - debug!("Execution plan:\n{:?}", plan); - Ok(plan) } @@ -1127,19 +1121,23 @@ impl Scanner { query: &FullTextSearchQuery, ) -> Result> { let columns = if query.columns.is_empty() { - let index_info = self.dataset.scalar_index_info().await?; - self.dataset - .schema() - .fields - .iter() - .filter_map(|f| { - if f.data_type() == DataType::Utf8 || f.data_type() == DataType::LargeUtf8 { - index_info.get_index(&f.name).map(|_| f.name.clone()) - } else { - None - } - }) - .collect() + let string_columns = self.dataset.schema().fields.iter().filter_map(|f| { + if f.data_type() == DataType::Utf8 || f.data_type() == DataType::LargeUtf8 { + Some(&f.name) + } else { + None + } + }); + + let mut indexed_columns = Vec::new(); + for column in string_columns { + let index = self.dataset.load_scalar_index_for_column(column).await?; + if index.is_some() { + indexed_columns.push(column.clone()); + } + } + + indexed_columns } else { query.columns.clone() }; diff --git a/rust/lance/src/dataset/schema_evolution.rs b/rust/lance/src/dataset/schema_evolution.rs index 9a5865159e..521a22290e 100644 --- a/rust/lance/src/dataset/schema_evolution.rs +++ b/rust/lance/src/dataset/schema_evolution.rs @@ -261,7 +261,7 @@ async fn add_columns_impl( } let mut updater = fragment - .updater(read_columns_ref, schemas_ref.clone()) + .updater(read_columns_ref, schemas_ref.clone(), None) .await?; let mut batch_index = 0; diff --git a/rust/lance/src/dataset/updater.rs b/rust/lance/src/dataset/updater.rs index 1ba9ce0565..d8a6ae96bb 100644 --- a/rust/lance/src/dataset/updater.rs +++ b/rust/lance/src/dataset/updater.rs @@ -10,6 +10,7 @@ use lance_table::utils::stream::ReadBatchFutStream; use snafu::{location, Location}; use super::fragment::FragmentReader; +use super::scanner::get_default_batch_size; use super::write::{open_writer, GenericWriter}; use super::Dataset; use crate::dataset::FileFragment; @@ -59,6 +60,7 @@ impl Updater { reader: FragmentReader, deletion_vector: DeletionVector, schemas: Option<(Schema, Schema)>, + batch_size: Option, ) -> Result { let (write_schema, final_schema) = if let Some((write_schema, final_schema)) = schemas { (Some(write_schema), Some(final_schema)) @@ -66,9 +68,18 @@ impl Updater { (None, None) }; - let batch_size = reader.legacy_num_rows_in_batch(0); + let legacy_batch_size = reader.legacy_num_rows_in_batch(0); - let input_stream = reader.read_all(1024)?; + let batch_size = match (&legacy_batch_size, batch_size) { + // If this is a v1 dataset we must use the row group size of the file + (Some(num_rows), _) => *num_rows, + // If this is a v2 dataset, let the user pick the batch size + (None, Some(legacy_batch_size)) => legacy_batch_size, + // Otherwise, default to 1024 if the user didn't specify anything + (None, None) => get_default_batch_size().unwrap_or(1024) as u32, + }; + + let input_stream = reader.read_all(batch_size)?; Ok(Self { fragment, @@ -78,7 +89,7 @@ impl Updater { write_schema, final_schema, finished: false, - deletion_restorer: DeletionRestorer::new(deletion_vector, batch_size), + deletion_restorer: DeletionRestorer::new(deletion_vector, legacy_batch_size), }) } @@ -226,7 +237,7 @@ struct DeletionRestorer { current_row_id: u32, /// Number of rows in each batch, only used in legacy files for validation - batch_size: Option, + legacy_batch_size: Option, deletion_vector_iter: Option + Send>>, @@ -234,10 +245,10 @@ struct DeletionRestorer { } impl DeletionRestorer { - fn new(deletion_vector: DeletionVector, batch_size: Option) -> Self { + fn new(deletion_vector: DeletionVector, legacy_batch_size: Option) -> Self { Self { current_row_id: 0, - batch_size, + legacy_batch_size, deletion_vector_iter: Some(deletion_vector.into_sorted_iter()), last_deleted_row_id: None, } @@ -248,12 +259,12 @@ impl DeletionRestorer { } fn is_full(batch_size: Option, num_rows: u32) -> bool { - if let Some(batch_size) = batch_size { + if let Some(legacy_batch_size) = batch_size { // We should never encounter the case that `batch_size < num_rows` because // that would mean we have a v1 writer and it generated a batch with more rows // than expected - debug_assert!(batch_size >= num_rows); - batch_size == num_rows + debug_assert!(legacy_batch_size >= num_rows); + legacy_batch_size == num_rows } else { false } @@ -295,7 +306,8 @@ impl DeletionRestorer { loop { if let Some(next_deleted_id) = next_deleted_id { if next_deleted_id > last_row_id - || (next_deleted_id == last_row_id && Self::is_full(self.batch_size, num_rows)) + || (next_deleted_id == last_row_id + && Self::is_full(self.legacy_batch_size, num_rows)) { // Either the next deleted id is out of range or it is the next row but // we are full. Either way, stash it and return @@ -322,7 +334,7 @@ impl DeletionRestorer { let deleted_batch_offsets = self.deleted_batch_offsets_in_range(batch.num_rows() as u32); let batch = add_blanks(batch, &deleted_batch_offsets)?; - if let Some(batch_size) = self.batch_size { + if let Some(batch_size) = self.legacy_batch_size { // validation just in case, when the input has a fixed batch size then the // output should have the same fixed batch size (except the last batch) let is_last = self.is_exhausted(); diff --git a/rust/lance/src/dataset/write/merge_insert.rs b/rust/lance/src/dataset/write/merge_insert.rs index 9f38611171..01d5d66138 100644 --- a/rust/lance/src/dataset/write/merge_insert.rs +++ b/rust/lance/src/dataset/write/merge_insert.rs @@ -701,6 +701,7 @@ impl MergeInsertJob { .updater( Some(&read_columns), Some((write_schema, dataset.schema().clone())), + None, ) .await?; diff --git a/rust/lance/src/dataset/write/update.rs b/rust/lance/src/dataset/write/update.rs index 12aea2a4c3..6d413081bb 100644 --- a/rust/lance/src/dataset/write/update.rs +++ b/rust/lance/src/dataset/write/update.rs @@ -185,6 +185,12 @@ impl UpdateBuilder { // TODO: support distributed operation. +#[derive(Debug, Clone)] +pub struct UpdateResult { + pub new_dataset: Arc, + pub rows_updated: u64, +} + #[derive(Debug, Clone)] pub struct UpdateJob { dataset: Arc, @@ -193,7 +199,7 @@ pub struct UpdateJob { } impl UpdateJob { - pub async fn execute(self) -> Result> { + pub async fn execute(self) -> Result { let mut scanner = self.dataset.scan(); scanner.with_row_id(); @@ -246,7 +252,6 @@ impl UpdateJob { WriteParams::with_storage_version(version), ) .await?; - // Apply deletions let removed_row_ids = Arc::into_inner(removed_row_ids) .unwrap() @@ -254,9 +259,18 @@ impl UpdateJob { .unwrap(); let (old_fragments, removed_fragment_ids) = self.apply_deletions(&removed_row_ids).await?; + let num_updated_rows = new_fragments + .iter() + .map(|f| f.physical_rows.unwrap() as u64) + .sum::(); // Commit updated and new fragments - self.commit(removed_fragment_ids, old_fragments, new_fragments) - .await + let new_dataset = self + .commit(removed_fragment_ids, old_fragments, new_fragments) + .await?; + Ok(UpdateResult { + new_dataset, + rows_updated: num_updated_rows, + }) } fn apply_updates( @@ -446,7 +460,7 @@ mod tests { ) { let (dataset, _test_dir) = make_test_dataset(version).await; - let dataset = UpdateBuilder::new(dataset) + let update_result = UpdateBuilder::new(dataset) .set("name", "'bar' || cast(id as string)") .unwrap() .build() @@ -455,6 +469,7 @@ mod tests { .await .unwrap(); + let dataset = update_result.new_dataset; let actual_batches = dataset .scan() .try_into_stream() @@ -490,7 +505,7 @@ mod tests { let original_fragments = dataset.get_fragments(); - let dataset = UpdateBuilder::new(dataset) + let update_result = UpdateBuilder::new(dataset) .update_where("id >= 15") .unwrap() .set("name", "'bar' || cast(id as string)") @@ -501,6 +516,7 @@ mod tests { .await .unwrap(); + let dataset = update_result.new_dataset; let actual_batches = dataset .scan() .try_into_stream() diff --git a/rust/lance/src/index.rs b/rust/lance/src/index.rs index 969a719376..fd803e0cc5 100644 --- a/rust/lance/src/index.rs +++ b/rust/lance/src/index.rs @@ -19,7 +19,7 @@ use lance_index::scalar::expression::{ IndexInformationProvider, LabelListQueryParser, SargableQueryParser, ScalarQueryParser, }; use lance_index::scalar::lance_format::LanceIndexStore; -use lance_index::scalar::{InvertedIndexParams, ScalarIndex}; +use lance_index::scalar::{InvertedIndexParams, ScalarIndex, ScalarIndexType}; use lance_index::vector::flat::index::{FlatIndex, FlatQuantizer}; use lance_index::vector::hnsw::HNSW; use lance_index::vector::pq::ProductQuantizer; @@ -41,7 +41,7 @@ use lance_table::format::Index as IndexMetadata; use lance_table::format::{Fragment, SelfDescribingFileReader}; use lance_table::io::manifest::read_manifest_indexes; use roaring::RoaringBitmap; -use scalar::build_inverted_index; +use scalar::{build_inverted_index, detect_scalar_index_type}; use serde_json::json; use snafu::{location, Location}; use tracing::instrument; @@ -766,22 +766,34 @@ impl DatasetIndexInternalExt for Dataset { async fn scalar_index_info(&self) -> Result { let indices = self.load_indices().await?; let schema = self.schema(); - let indexed_fields = indices - .iter() - .filter(|idx| idx.fields.len() == 1) - .map(|idx| { - let field = idx.fields[0]; - let field = schema.field_by_id(field).ok_or_else(|| Error::Internal { message: format!("Index referenced a field with id {field} which did not exist in the schema"), location: location!() }); - field.map(|field| { - let query_parser = if let DataType::List(_) = field.data_type() { + let mut indexed_fields = Vec::new(); + for index in indices.iter().filter(|idx| idx.fields.len() == 1) { + let field = index.fields[0]; + let field = schema.field_by_id(field).ok_or_else(|| Error::Internal { + message: format!( + "Index referenced a field with id {field} which did not exist in the schema" + ), + location: location!(), + })?; + + let query_parser = match field.data_type() { + DataType::List(_) => { Box::::default() as Box - } else { + } + DataType::Utf8 | DataType::LargeUtf8 => { + let uuid = index.uuid.to_string(); + let index_type = detect_scalar_index_type(self, &field.name, &uuid).await?; + // Inverted index can't be used for filtering + if matches!(index_type, ScalarIndexType::Inverted) { + continue; + } Box::::default() as Box - }; - (field.name.clone(), (field.data_type(), query_parser)) - }) - }) - .collect::>>()?; + } + _ => Box::::default() as Box, + }; + + indexed_fields.push((field.name.clone(), (field.data_type(), query_parser))); + } let index_info_map = HashMap::from_iter(indexed_fields); Ok(ScalarIndexInfo { indexed_columns: index_info_map, diff --git a/rust/lance/src/index/scalar.rs b/rust/lance/src/index/scalar.rs index 380c95f247..b7d5e6cbb1 100644 --- a/rust/lance/src/index/scalar.rs +++ b/rust/lance/src/index/scalar.rs @@ -95,6 +95,25 @@ pub(super) async fn build_scalar_index( source: format!("No column with name {}", column).into(), location: location!(), })?; + + // Check if LabelList index is being created on a non-List or non-LargeList type + if matches!(params.force_index_type, Some(ScalarIndexType::LabelList)) + && !matches!( + field.data_type(), + DataType::List(_) | DataType::LargeList(_) + ) + { + return Err(Error::InvalidInput { + source: format!( + "LabelList index can only be created on List or LargeList type columns. Column '{}' has type {:?}", + column, + field.data_type() + ) + .into(), + location: location!(), + }); + } + // In theory it should be possible to create a btree/bitmap index on a nested field but // performance would be poor and I'm not sure we want to allow that unless there is a need. if !matches!(params.force_index_type, Some(ScalarIndexType::LabelList)) @@ -151,9 +170,33 @@ pub async fn open_scalar_index( uuid: &str, ) -> Result> { let index_store = Arc::new(LanceIndexStore::from_dataset(dataset, uuid)); - let index_dir = dataset.indices_dir().child(uuid); - // This works at the moment, since we only have a few index types, may need to introduce better - // detection method in the future. + let index_type = detect_scalar_index_type(dataset, column, uuid).await?; + match index_type { + ScalarIndexType::Bitmap => { + let bitmap_index = BitmapIndex::load(index_store).await?; + Ok(bitmap_index as Arc) + } + ScalarIndexType::LabelList => { + let tag_index = LabelListIndex::load(index_store).await?; + Ok(tag_index as Arc) + } + ScalarIndexType::Inverted => { + let inverted_index = InvertedIndex::load(index_store).await?; + Ok(inverted_index as Arc) + } + ScalarIndexType::BTree => { + let btree_index = BTreeIndex::load(index_store).await?; + Ok(btree_index as Arc) + } + } +} + +pub async fn detect_scalar_index_type( + dataset: &Dataset, + column: &str, + index_uuid: &str, +) -> Result { + let index_dir = dataset.indices_dir().child(index_uuid); let col = dataset.schema().field(column).ok_or(Error::Internal { message: format!( "Index refers to column {} which does not exist in dataset schema", @@ -161,19 +204,18 @@ pub async fn open_scalar_index( ), location: location!(), })?; + let bitmap_page_lookup = index_dir.child(BITMAP_LOOKUP_NAME); let inverted_list_lookup = index_dir.child(INVERT_LIST_FILE); - if let DataType::List(_) = col.data_type() { - let tag_index = LabelListIndex::load(index_store).await?; - Ok(tag_index as Arc) + let index_type = if let DataType::List(_) = col.data_type() { + ScalarIndexType::LabelList } else if dataset.object_store.exists(&bitmap_page_lookup).await? { - let bitmap_index = BitmapIndex::load(index_store).await?; - Ok(bitmap_index as Arc) + ScalarIndexType::Bitmap } else if dataset.object_store.exists(&inverted_list_lookup).await? { - let inverted_index = InvertedIndex::load(index_store).await?; - Ok(inverted_index as Arc) + ScalarIndexType::Inverted } else { - let btree_index = BTreeIndex::load(index_store).await?; - Ok(btree_index as Arc) - } + ScalarIndexType::BTree + }; + + Ok(index_type) } diff --git a/rust/lance/src/session.rs b/rust/lance/src/session.rs index d57c018c58..6a978b0cb2 100644 --- a/rust/lance/src/session.rs +++ b/rust/lance/src/session.rs @@ -31,7 +31,28 @@ pub struct Session { impl std::fmt::Debug for Session { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "Session()") + f.debug_struct("Session") + .field( + "index_cache", + &format!( + "IndexCache(items={}, size_bytes={})", + self.index_cache.get_size(), + self.index_cache.deep_size_of() + ), + ) + .field( + "file_metadata_cache", + &format!( + "FileMetadataCache(items={}, size_bytes={})", + self.file_metadata_cache.size(), + self.file_metadata_cache.deep_size_of() + ), + ) + .field( + "index_extensions", + &self.index_extensions.keys().collect::>(), + ) + .finish() } }