diff --git a/Cargo.lock b/Cargo.lock index 883067b..455c621 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -428,7 +428,7 @@ checksum = "a507befe795404456341dfab10cef66ead4c041f62b8b11bbb92bffe5d0953e0" [[package]] name = "sqlite-loadable" -version = "0.0.5" +version = "0.0.6" dependencies = [ "bitflags 1.3.2", "libsqlite3-sys", diff --git a/Cargo.toml b/Cargo.toml index cb60c9a..e8c1a61 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -39,6 +39,10 @@ crate-type = ["cdylib"] name = "characters" crate-type = ["cdylib"] +[[example]] +name = "load_permanent" +crate-type = ["cdylib"] + [[example]] name = "in" crate-type = ["cdylib"] diff --git a/examples/load_permanent.rs b/examples/load_permanent.rs new file mode 100644 index 0000000..59dc6c3 --- /dev/null +++ b/examples/load_permanent.rs @@ -0,0 +1,28 @@ +//! cargo build --example load_permanent +//! sqlite3 :memory: '.read examples/test.sql' + +use sqlite_loadable::prelude::*; +use sqlite_loadable::{api, define_scalar_function, Result}; + +// This function will be registered as a scalar function named "hello", and will be called on +// every invocation. It's goal is to return a string of "hello, NAME!" where NAME is the +// text value of the 1st argument. +pub fn hello(context: *mut sqlite3_context, values: &[*mut sqlite3_value]) -> Result<()> { + let name = api::value_text(values.get(0).expect("1st argument as name"))?; + + api::result_text(context, format!("hello, {}!", name))?; + Ok(()) +} + +// Exposes a extern C function named "sqlite3_hello_init" in the compiled dynamic library, +// the "entrypoint" that SQLite will use to load the extension. +// Notice the naming sequence - "sqlite3_" followed by "hello" then "_init". Since the +// compiled file is named "libhello.dylib" (or .so/.dll depending on your operating system), +// SQLite by default will look for an entrypoint called "sqlite3_hello_init". +// See "Loading an Extension" for more details +#[sqlite_entrypoint_permanent] +pub fn sqlite3_hello_init(db: *mut sqlite3) -> Result<()> { + let flags = FunctionFlags::UTF8 | FunctionFlags::DETERMINISTIC; + define_scalar_function(db, "hello", 1, hello, flags)?; + Ok(()) +} diff --git a/sqlite-loadable-macros/src/lib.rs b/sqlite-loadable-macros/src/lib.rs index b09bbd2..b2c02a4 100644 --- a/sqlite-loadable-macros/src/lib.rs +++ b/sqlite-loadable-macros/src/lib.rs @@ -5,7 +5,7 @@ use syn::{parse_macro_input, spanned::Spanned, Item}; use proc_macro::TokenStream; use quote::quote_spanned; -/// Wraps an entrypoint function to expose an unsafe extern "C" function of the same name. +/// Wraps an entrypoint function to expose an unsafe extern "C" function of the same name. #[proc_macro_attribute] pub fn sqlite_entrypoint(_attr: TokenStream, item: TokenStream) -> TokenStream { let ast = parse_macro_input!(item as syn::Item); @@ -26,7 +26,7 @@ pub fn sqlite_entrypoint(_attr: TokenStream, item: TokenStream) -> TokenStream { /// # Safety /// - /// Should only be called by underlying SQLite C APIs, + /// Should only be called by underlying SQLite C APIs, /// like sqlite3_auto_extension and sqlite3_cancel_auto_extension. #[no_mangle] pub unsafe extern "C" fn #c_entrypoint( @@ -44,3 +44,43 @@ pub fn sqlite_entrypoint(_attr: TokenStream, item: TokenStream) -> TokenStream { _ => panic!("Only function items are allowed on sqlite_entrypoint"), } } + +/// Wraps an entrypoint function to expose an unsafe extern "C" function of the same name. +#[proc_macro_attribute] +pub fn sqlite_entrypoint_permanent(_attr: TokenStream, item: TokenStream) -> TokenStream { + let ast = parse_macro_input!(item as syn::Item); + match ast { + Item::Fn(mut func) => { + let c_entrypoint = func.sig.ident.clone(); + + let original_funcname = func.sig.ident.to_string(); + func.sig.ident = Ident::new( + format!("_{}", original_funcname).as_str(), + func.sig.ident.span(), + ); + + let prefixed_original_function = func.sig.ident.clone(); + + quote_spanned! {func.span()=> + #func + + /// # Safety + /// + /// Should only be called by underlying SQLite C APIs, + /// like sqlite3_auto_extension and sqlite3_cancel_auto_extension. + #[no_mangle] + pub unsafe extern "C" fn #c_entrypoint( + db: *mut sqlite3, + pz_err_msg: *mut *mut c_char, + p_api: *mut sqlite3_api_routines, + ) -> c_uint { + register_entrypoint_load_permanently(db, pz_err_msg, p_api, #prefixed_original_function) + } + + + } + .into() + } + _ => panic!("Only function items are allowed on sqlite_entrypoint"), + } +} diff --git a/src/entrypoints.rs b/src/entrypoints.rs index b662350..d34c22c 100644 --- a/src/entrypoints.rs +++ b/src/entrypoints.rs @@ -26,3 +26,24 @@ where Err(err) => err.code_extended(), } } + +/// Low-level wrapper around an entrypoint to a SQLite extension that loads permanently. You +/// shouldn't have to use this directly - the sqlite_entrypoint_permanent macro will do this +/// for you. +pub fn register_entrypoint_load_permanently( + db: *mut sqlite3, + _pz_err_msg: *mut *mut c_char, + p_api: *mut sqlite3_api_routines, + callback: F, +) -> c_uint +where + F: Fn(*mut sqlite3) -> Result<()>, +{ + unsafe { + faux_sqlite_extension_init2(p_api); + } + match callback(db) { + Ok(()) => 256, // https://www.sqlite.org/rescode.html#ok_load_permanently + Err(err) => err.code_extended(), + } +} diff --git a/src/prelude.rs b/src/prelude.rs index be63df0..f1e5e6f 100644 --- a/src/prelude.rs +++ b/src/prelude.rs @@ -3,11 +3,14 @@ #[doc(inline)] pub use crate::entrypoints::register_entrypoint; #[doc(inline)] +pub use crate::entrypoints::register_entrypoint_load_permanently; +#[doc(inline)] pub use sqlite3ext_sys::{ sqlite3, sqlite3_api_routines, sqlite3_context, sqlite3_value, sqlite3_vtab, sqlite3_vtab_cursor, }; pub use sqlite_loadable_macros::sqlite_entrypoint; +pub use sqlite_loadable_macros::sqlite_entrypoint_permanent; pub use std::os::raw::{c_char, c_uint};