Last active
July 5, 2023 17:40
-
-
Save gwenn/dbd03af8c4189cf7acc765348aa17893 to your computer and use it in GitHub Desktop.
Generate bindgen_bundled_version_ext.rs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
[package] | |
name = "sqlite-ext" | |
version = "0.1.0" | |
edition = "2021" | |
[dependencies] | |
bindgen = { version = "0.66", default-features = false, features = ["runtime"] } | |
quote = { version = "1", default-features = false } | |
syn = { version = "2.0", features = ["full", "extra-traits", "visit-mut"] } | |
regex = { version = "1.8", default-features = false, features = ["std"] } | |
prettyplease = "0.2" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use std::collections::HashMap; | |
use std::fs; | |
use bindgen::callbacks::{IntKind, ParseCallbacks}; | |
use bindgen::MacroTypeVariation; | |
#[derive(Debug)] | |
struct SqliteTypeChooser; | |
impl ParseCallbacks for SqliteTypeChooser { | |
fn int_macro(&self, name: &str, _value: i64) -> Option<IntKind> { | |
if name == "SQLITE_SERIALIZE_NOCOPY" | |
|| name.starts_with("SQLITE_DESERIALIZE_") | |
|| name.starts_with("SQLITE_PREPARE_") | |
{ | |
Some(IntKind::UInt) | |
} else { | |
None | |
} | |
} | |
} | |
// Load all `#define sqlite3_xyz sqlite3_api->abc` in sqlite3ext.h | |
// as a map `{abc => sqlite3_xyz}` | |
// See https://github.com/rust-lang/rust-bindgen/issues/2544 | |
fn parse_macros() -> HashMap<String, String> { | |
use regex::Regex; | |
use std::fs::File; | |
use std::io::{BufRead, BufReader}; | |
let re = Regex::new(r"^#define\s+(sqlite3_\w+)\s+sqlite3_api->(\w+)").unwrap(); | |
let f = File::open("sqlite3ext.h").expect("could not read sqlite3ext.h"); | |
let f = BufReader::new(f); | |
let mut mappings = HashMap::new(); | |
for line in f.lines() { | |
let line = line.expect("could not read line"); | |
if let Some(caps) = re.captures(&line) { | |
mappings.insert( | |
caps.get(2).unwrap().as_str().to_owned(), | |
caps.get(1).unwrap().as_str().to_owned(), | |
); | |
} | |
} | |
mappings | |
} | |
fn extract_method(ty: &syn::Type) -> Option<&syn::TypeBareFn> { | |
match ty { | |
syn::Type::Path(tp) => tp.path.segments.last(), | |
_ => None, | |
} | |
.map(|seg| match &seg.arguments { | |
syn::PathArguments::AngleBracketed(args) => args.args.first(), | |
_ => None, | |
})? | |
.map(|arg| match arg { | |
syn::GenericArgument::Type(t) => Some(t), | |
_ => None, | |
})? | |
.map(|ty| match ty { | |
syn::Type::BareFn(r) => Some(r), | |
_ => None, | |
})? | |
} | |
fn main() { | |
let header = "sqlite3ext.h"; | |
let mappings = parse_macros(); | |
let out_path = "bindgen_bundled_version_ext.rs"; | |
let mut output = Vec::new(); | |
let mut bindings = bindgen::builder() | |
.trust_clang_mangling(false) | |
.header(header.clone()) | |
.parse_callbacks(Box::new(SqliteTypeChooser)); | |
bindings = bindings | |
.opaque_type("(?i).*va_list.*") | |
.blocklist_type("va_list") | |
.blocklist_item("__.*"); | |
bindings | |
.default_macro_constant_type(MacroTypeVariation::Signed) | |
.disable_nested_struct_naming() | |
.ignore_functions() // TODO only for ext | |
.layout_tests(false) | |
.generate() | |
.unwrap_or_else(|_| panic!("could not run bindgen on header {}", header)) | |
.write(Box::new(&mut output)) | |
.expect("could not write output of bindgen"); | |
let mut output = String::from_utf8(output).expect("bindgen output was not UTF-8?!"); | |
let ast: syn::File = syn::parse_str(&output).expect("could not parse bindgen output"); | |
let sqlite3_api_routines: syn::ItemStruct = ast | |
.items | |
.into_iter() | |
.find_map(|i| { | |
if let syn::Item::Struct(s) = i { | |
if s.ident == "sqlite3_api_routines" { | |
Some(s) | |
} else { | |
None | |
} | |
} else { | |
None | |
} | |
}) | |
.expect("could not find sqlite3_api_routines"); | |
let sqlite3_api_routines_ident = sqlite3_api_routines.ident; | |
let p_api = quote::format_ident!("p_api"); | |
let mut stores = Vec::new(); | |
for field in sqlite3_api_routines.fields { | |
let ident = field.ident.expect("unamed field"); | |
let span = ident.span(); | |
let name = ident.to_string(); | |
if name == "vmprintf" || name == "xvsnprintf" || name == "str_vappendf" { | |
// FIXME va_list | |
continue; | |
} | |
let sqlite3_name = mappings | |
.get(&name) | |
.unwrap_or_else(|| panic!("no mapping for {name}")); | |
let ptr_name = syn::Ident::new(format!("__{}", sqlite3_name.to_uppercase()).as_ref(), span); | |
let sqlite3_fn_name = syn::Ident::new(sqlite3_name, span); | |
let method = | |
extract_method(&field.ty).unwrap_or_else(|| panic!("unexpected type for {name}")); | |
let arg_names: syn::punctuated::Punctuated<&syn::Ident, syn::token::Comma> = method | |
.inputs | |
.iter() | |
.map(|i| &i.name.as_ref().unwrap().0) | |
.collect(); | |
let args = &method.inputs; | |
// mprintf/sqlite3_mprintf, xsnprintf/sqlite3_snprintf, | |
// test_control/sqlite3_test_control, str_appendf/sqlite3_str_appendf: unused | |
// vtab_config/sqlite3_vtab_config: ok | |
let varargs = &method.variadic; | |
let ty = &method.output; | |
let tokens = if "db_config" == name { | |
quote::quote! { | |
static #ptr_name: ::std::sync::atomic::AtomicPtr<unsafe extern "C" fn(#args #varargs) #ty> = ::std::sync::atomic::AtomicPtr::new(::std::ptr::null_mut()); | |
pub unsafe fn #sqlite3_fn_name(#args arg3: ::std::os::raw::c_int, arg4: *mut ::std::os::raw::c_int) #ty { | |
let fun = #ptr_name.load(::std::sync::atomic::Ordering::Acquire); | |
assert!(!fun.is_null(), "SQLite API not initialized"); | |
(*fun)(#arg_names, arg3, arg4) | |
} | |
} | |
} else if "log" == name { | |
quote::quote! { | |
static #ptr_name: ::std::sync::atomic::AtomicPtr<unsafe extern "C" fn(#args #varargs) #ty> = ::std::sync::atomic::AtomicPtr::new(::std::ptr::null_mut()); | |
pub unsafe fn #sqlite3_fn_name(#args arg3: *const ::std::os::raw::c_char) #ty { | |
let fun = #ptr_name.load(::std::sync::atomic::Ordering::Acquire); | |
assert!(!fun.is_null(), "SQLite API not initialized"); | |
(*fun)(#arg_names, arg3) | |
} | |
} | |
} else { | |
quote::quote! { | |
static #ptr_name: ::std::sync::atomic::AtomicPtr<unsafe extern "C" fn(#args #varargs) #ty> = ::std::sync::atomic::AtomicPtr::new(::std::ptr::null_mut()); | |
pub unsafe fn #sqlite3_fn_name(#args) #ty { | |
let fun = #ptr_name.load(::std::sync::atomic::Ordering::Acquire); | |
assert!(!fun.is_null(), "SQLite API not initialized"); | |
(*fun)(#arg_names) | |
} | |
} | |
}; | |
output.push_str(&prettyplease::unparse( | |
&syn::parse2(tokens).expect("could not parse quote output"), | |
)); | |
output.push_str("\n"); | |
stores.push(quote::quote! { | |
#ptr_name.store( | |
&mut (*#p_api).#ident.ok_or(InitError::NullFunctionPointer)?, | |
::std::sync::atomic::Ordering::Release, | |
); | |
}); | |
} | |
let tokens = quote::quote! { | |
#[derive(Clone, Copy, Debug, PartialEq, Eq)] | |
#[non_exhaustive] | |
pub enum InitError { | |
NullApiPointer, | |
VersionMismatch{compile_time: i32, runtime: i32}, | |
NullFunctionPointer, | |
} | |
impl ::std::fmt::Display for InitError { | |
fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result { | |
match *self { | |
InitError::NullApiPointer => write!(f, "Invalid sqlite3_api_routines pointer"), | |
InitError::VersionMismatch{compile_time, runtime} => write!(f, "SQLite version mismatch: {runtime} < {compile_time}"), | |
InitError::NullFunctionPointer => write!(f, "Some sqlite3_api_routines fields are null"), | |
} | |
} | |
} | |
/// Like SQLITE_EXTENSION_INIT2 macro | |
pub unsafe fn rusqlite_extension_init2(#p_api: *mut #sqlite3_api_routines_ident) -> ::std::result::Result<(),InitError> { | |
if #p_api.is_null() { | |
return Err(InitError::NullApiPointer); | |
} | |
if let Some(fun) = (*#p_api).libversion_number { | |
let version = fun(); | |
if SQLITE_VERSION_NUMBER > version { | |
return Err(InitError::VersionMismatch{compile_time: SQLITE_VERSION_NUMBER, runtime: version}); | |
} | |
} else { | |
return Err(InitError::NullFunctionPointer); | |
} | |
#(#stores)* | |
Ok(()) | |
} | |
}; | |
output.push_str(&prettyplease::unparse( | |
&syn::parse2(tokens).expect("could not parse quote output"), | |
)); | |
output.push_str("\n"); | |
fs::write(out_path, output.as_bytes()) | |
.unwrap_or_else(|_| panic!("Could not write to {:?}", out_path)); | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
diff --git a/Cargo.toml b/Cargo.toml | |
index e8a668c..df1caae 100644 | |
--- a/Cargo.toml | |
+++ b/Cargo.toml | |
@@ -50,6 +50,7 @@ bundled-sqlcipher = ["libsqlite3-sys/bundled-sqlcipher", "bundled"] | |
bundled-sqlcipher-vendored-openssl = ["libsqlite3-sys/bundled-sqlcipher-vendored-openssl", "bundled-sqlcipher"] | |
buildtime_bindgen = ["libsqlite3-sys/buildtime_bindgen"] | |
limits = [] | |
+loadable_extension = [] | |
hooks = [] | |
i128_blob = [] | |
sqlcipher = ["libsqlite3-sys/sqlcipher"] | |
@@ -109,6 +110,7 @@ modern-full = [ | |
] | |
bundled-full = ["modern-full", "bundled"] | |
+default = ["modern_sqlite", "loadable_extension"] | |
[dependencies] | |
time = { version = "0.3.0", features = ["formatting", "macros", "parsing"], optional = true } | |
diff --git a/src/inner_connection.rs b/src/inner_connection.rs | |
index a7487a8..6a5c7ba 100644 | |
--- a/src/inner_connection.rs | |
+++ b/src/inner_connection.rs | |
@@ -4,7 +4,7 @@ use std::os::raw::{c_char, c_int}; | |
use std::path::Path; | |
use std::ptr; | |
use std::str; | |
-use std::sync::atomic::{AtomicBool, Ordering}; | |
+use std::sync::atomic::AtomicBool; | |
use std::sync::{Arc, Mutex}; | |
use super::ffi; | |
@@ -390,7 +390,7 @@ impl Drop for InnerConnection { | |
} | |
} | |
-#[cfg(not(any(target_arch = "wasm32")))] | |
+#[cfg(not(any(target_arch = "wasm32", feature = "loadable_extension")))] | |
static SQLITE_INIT: std::sync::Once = std::sync::Once::new(); | |
pub static BYPASS_SQLITE_INIT: AtomicBool = AtomicBool::new(false); | |
@@ -440,7 +440,9 @@ fn ensure_safe_sqlite_threading_mode() -> Result<()> { | |
Ok(()) | |
} | |
} else { | |
+ #[cfg(not(feature = "loadable_extension"))] | |
SQLITE_INIT.call_once(|| { | |
+ use std::sync::atomic::Ordering; | |
if BYPASS_SQLITE_INIT.load(Ordering::Relaxed) { | |
return; | |
} | |
diff --git a/src/trace.rs b/src/trace.rs | |
index ce4c80b..7317a0c 100644 | |
--- a/src/trace.rs | |
+++ b/src/trace.rs | |
@@ -8,8 +8,7 @@ use std::ptr; | |
use std::time::Duration; | |
use super::ffi; | |
-use crate::error::error_from_sqlite_code; | |
-use crate::{Connection, Result}; | |
+use crate::Connection; | |
/// Set up the process-wide SQLite error logging callback. | |
/// | |
@@ -25,7 +24,8 @@ use crate::{Connection, Result}; | |
/// * It must be threadsafe if SQLite is used in a multithreaded way. | |
/// | |
/// cf [The Error And Warning Log](http://sqlite.org/errlog.html). | |
-pub unsafe fn config_log(callback: Option<fn(c_int, &str)>) -> Result<()> { | |
+#[cfg(not(feature = "loadable_extension"))] | |
+pub unsafe fn config_log(callback: Option<fn(c_int, &str)>) -> crate::Result<()> { | |
extern "C" fn log_callback(p_arg: *mut c_void, err: c_int, msg: *const c_char) { | |
let c_slice = unsafe { CStr::from_ptr(msg).to_bytes() }; | |
let callback: fn(c_int, &str) = unsafe { mem::transmute(p_arg) }; | |
@@ -48,7 +48,7 @@ pub unsafe fn config_log(callback: Option<fn(c_int, &str)>) -> Result<()> { | |
if rc == ffi::SQLITE_OK { | |
Ok(()) | |
} else { | |
- Err(error_from_sqlite_code(rc, None)) | |
+ Err(crate::error::error_from_sqlite_code(rc, None)) | |
} | |
} | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment