-
Notifications
You must be signed in to change notification settings - Fork 40
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
15 changed files
with
620 additions
and
12 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,4 +3,7 @@ Cargo.lock | |
.idea/ | ||
venv/ | ||
target/ | ||
rust-toolchain.toml | ||
rust-toolchain.toml | ||
*.so | ||
**/*.pyc | ||
__pycache__/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
|
||
SHELL=/bin/bash | ||
|
||
venv: ## Set up virtual environment | ||
python3 -m venv venv | ||
venv/bin/pip install -r requirements.txt | ||
|
||
install: venv | ||
unset CONDA_PREFIX && \ | ||
source venv/bin/activate && maturin develop -m io_plugin/Cargo.toml | ||
|
||
install-release: venv | ||
unset CONDA_PREFIX && \ | ||
source venv/bin/activate && maturin develop --release -m io_plugin/Cargo.toml | ||
|
||
clean: | ||
-@rm -r venv | ||
-@cd extend_polars && cargo clean | ||
|
||
|
||
run: install | ||
source venv/bin/activate && python run.py | ||
|
||
run-release: install-release | ||
source venv/bin/activate && python run.py |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
[package] | ||
name = "io_plugin" | ||
version = "0.1.0" | ||
edition = "2021" | ||
|
||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html | ||
[lib] | ||
name = "io_plugin" | ||
crate-type = ["cdylib"] | ||
|
||
[dependencies] | ||
polars = { workspace = true, features = ["fmt", "dtype-date", "timezones", "lazy"], default-features = false } | ||
pyo3 = { version = "0.22.2", features = ["abi3-py38"] } | ||
pyo3-polars = { version = "*", path = "../../../pyo3-polars", features = ["derive", "lazy"] } | ||
rand = { version = "0.8.5", features = [] } |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
from .io_plugin import new_bernoulli, new_uniform, RandomSource | ||
from typing import Any, Iterator | ||
from polars.io.plugins import register_io_source | ||
import polars as pl | ||
|
||
|
||
def scan_random(samplers: list[Any], size: int = 1000) -> pl.LazyFrame: | ||
def source_generator( | ||
with_columns: list[str] | None, | ||
predicate: pl.Expr | None, | ||
n_rows: int | None, | ||
batch_size: int | None, | ||
) -> Iterator[pl.DataFrame]: | ||
""" | ||
Generator function that creates the source. | ||
This function will be registered as IO source. | ||
""" | ||
|
||
new_size = size | ||
if n_rows is not None and n_rows < size: | ||
new_size = n_rows | ||
|
||
src = RandomSource(samplers, batch_size, new_size) | ||
if with_columns is not None: | ||
src.set_with_columns(with_columns) | ||
|
||
# Set the predicate. | ||
predicate_set = True | ||
if predicate is not None: | ||
try: | ||
src.try_set_predicate(predicate) | ||
except pl.exceptions.ComputeError: | ||
predicate_set = False | ||
|
||
while (out := src.next()) is not None: | ||
# If the source could not apply the predicate | ||
# (because it wasn't able to deserialize it), we do it here. | ||
if not predicate_set and predicate is not None: | ||
out = out.filter(predicate) | ||
|
||
yield out | ||
|
||
# create src again to compute the schema | ||
src = RandomSource(samplers, 0, 0) | ||
return register_io_source(callable=source_generator, schema=src.schema()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,123 @@ | ||
mod samplers; | ||
|
||
use crate::samplers::PySampler; | ||
use polars::prelude::*; | ||
use pyo3::prelude::*; | ||
use pyo3_polars::error::PyPolarsErr; | ||
use pyo3_polars::{PyDataFrame, PyExpr, PySchema}; | ||
|
||
#[pyclass] | ||
pub struct RandomSource { | ||
columns: Vec<PySampler>, | ||
size_hint: usize, | ||
n_rows: usize, | ||
predicate: Option<Expr>, | ||
with_columns: Option<Vec<usize>>, | ||
} | ||
|
||
#[pymethods] | ||
impl RandomSource { | ||
#[new] | ||
#[pyo3(signature = (columns, size_hint, n_rows))] | ||
fn new_source( | ||
columns: Vec<PySampler>, | ||
size_hint: Option<usize>, | ||
n_rows: Option<usize>, | ||
) -> Self { | ||
let n_rows = n_rows.unwrap_or(usize::MAX); | ||
let size_hint = size_hint.unwrap_or(10_000); | ||
|
||
Self { | ||
columns, | ||
size_hint, | ||
n_rows, | ||
predicate: None, | ||
with_columns: None, | ||
} | ||
} | ||
|
||
fn schema(&self) -> PySchema { | ||
let schema = self | ||
.columns | ||
.iter() | ||
.map(|s| { | ||
let s = s.0.lock().unwrap(); | ||
Field::new(s.name(), s.dtype()) | ||
}) | ||
.collect::<Schema>(); | ||
PySchema(Arc::new(schema)) | ||
} | ||
|
||
fn try_set_predicate(&mut self, predicate: PyExpr) { | ||
self.predicate = Some(predicate.0); | ||
} | ||
|
||
fn set_with_columns(&mut self, columns: Vec<String>) { | ||
let schema = self.schema().0; | ||
|
||
let indexes = columns | ||
.iter() | ||
.map(|name| { | ||
schema | ||
.index_of(name.as_ref()) | ||
.expect("schema should be correct") | ||
}) | ||
.collect(); | ||
|
||
self.with_columns = Some(indexes) | ||
} | ||
|
||
fn next(&mut self) -> PyResult<Option<PyDataFrame>> { | ||
if self.n_rows > 0 { | ||
// Apply projection pushdown. | ||
// This prevents unneeded sampling. | ||
let s_iter = if let Some(idx) = &self.with_columns { | ||
Box::new(idx.iter().copied().map(|i| &self.columns[i])) | ||
as Box<dyn Iterator<Item = _>> | ||
} else { | ||
Box::new(self.columns.iter()) | ||
}; | ||
|
||
let columns = s_iter | ||
.map(|s| { | ||
let mut s = s.0.lock().unwrap(); | ||
|
||
// Apply slice pushdown. | ||
// This prevents unneeded sampling. | ||
s.next_n(std::cmp::min(self.size_hint, self.n_rows)) | ||
}) | ||
.collect::<Vec<_>>(); | ||
|
||
let mut df = DataFrame::new(columns).map_err(PyPolarsErr::from)?; | ||
self.n_rows = self.n_rows.saturating_sub(self.size_hint); | ||
|
||
// Apply predicate pushdown. | ||
// This is done after the fact, but there could be sources where this could be applied | ||
// lower. | ||
if let Some(predicate) = &self.predicate { | ||
df = df | ||
.lazy() | ||
.filter(predicate.clone()) | ||
._with_eager(true) | ||
.collect() | ||
.map_err(PyPolarsErr::from)?; | ||
} | ||
|
||
Ok(Some(PyDataFrame(df))) | ||
} else { | ||
Ok(None) | ||
} | ||
} | ||
} | ||
|
||
#[pymodule] | ||
fn io_plugin(m: &Bound<PyModule>) -> PyResult<()> { | ||
m.add_class::<RandomSource>().unwrap(); | ||
m.add_class::<PySampler>().unwrap(); | ||
m.add_wrapped(wrap_pyfunction!(samplers::new_bernoulli)) | ||
.unwrap(); | ||
m.add_wrapped(wrap_pyfunction!(samplers::new_uniform)) | ||
.unwrap(); | ||
|
||
Ok(()) | ||
} |
Oops, something went wrong.