Skip to content

Instantly share code, notes, and snippets.

@gwenn
Last active July 5, 2023 17:40
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save gwenn/dbd03af8c4189cf7acc765348aa17893 to your computer and use it in GitHub Desktop.
Save gwenn/dbd03af8c4189cf7acc765348aa17893 to your computer and use it in GitHub Desktop.
Generate bindgen_bundled_version_ext.rs
[package]
name = "sqlite-ext"
version = "0.1.0"
edition = "2021"
[dependencies]
bindgen = { version = "0.66", default-features = false, features = ["runtime"] }
quote = { version = "1", default-features = false }
syn = { version = "2.0", features = ["full", "extra-traits", "visit-mut"] }
regex = { version = "1.8", default-features = false, features = ["std"] }
prettyplease = "0.2"
use std::collections::HashMap;
use std::fs;
use bindgen::callbacks::{IntKind, ParseCallbacks};
use bindgen::MacroTypeVariation;
#[derive(Debug)]
struct SqliteTypeChooser;
impl ParseCallbacks for SqliteTypeChooser {
fn int_macro(&self, name: &str, _value: i64) -> Option<IntKind> {
if name == "SQLITE_SERIALIZE_NOCOPY"
|| name.starts_with("SQLITE_DESERIALIZE_")
|| name.starts_with("SQLITE_PREPARE_")
{
Some(IntKind::UInt)
} else {
None
}
}
}
// Load all `#define sqlite3_xyz sqlite3_api->abc` in sqlite3ext.h
// as a map `{abc => sqlite3_xyz}`
// See https://github.com/rust-lang/rust-bindgen/issues/2544
fn parse_macros() -> HashMap<String, String> {
use regex::Regex;
use std::fs::File;
use std::io::{BufRead, BufReader};
let re = Regex::new(r"^#define\s+(sqlite3_\w+)\s+sqlite3_api->(\w+)").unwrap();
let f = File::open("sqlite3ext.h").expect("could not read sqlite3ext.h");
let f = BufReader::new(f);
let mut mappings = HashMap::new();
for line in f.lines() {
let line = line.expect("could not read line");
if let Some(caps) = re.captures(&line) {
mappings.insert(
caps.get(2).unwrap().as_str().to_owned(),
caps.get(1).unwrap().as_str().to_owned(),
);
}
}
mappings
}
fn extract_method(ty: &syn::Type) -> Option<&syn::TypeBareFn> {
match ty {
syn::Type::Path(tp) => tp.path.segments.last(),
_ => None,
}
.map(|seg| match &seg.arguments {
syn::PathArguments::AngleBracketed(args) => args.args.first(),
_ => None,
})?
.map(|arg| match arg {
syn::GenericArgument::Type(t) => Some(t),
_ => None,
})?
.map(|ty| match ty {
syn::Type::BareFn(r) => Some(r),
_ => None,
})?
}
fn main() {
let header = "sqlite3ext.h";
let mappings = parse_macros();
let out_path = "bindgen_bundled_version_ext.rs";
let mut output = Vec::new();
let mut bindings = bindgen::builder()
.trust_clang_mangling(false)
.header(header.clone())
.parse_callbacks(Box::new(SqliteTypeChooser));
bindings = bindings
.opaque_type("(?i).*va_list.*")
.blocklist_type("va_list")
.blocklist_item("__.*");
bindings
.default_macro_constant_type(MacroTypeVariation::Signed)
.disable_nested_struct_naming()
.ignore_functions() // TODO only for ext
.layout_tests(false)
.generate()
.unwrap_or_else(|_| panic!("could not run bindgen on header {}", header))
.write(Box::new(&mut output))
.expect("could not write output of bindgen");
let mut output = String::from_utf8(output).expect("bindgen output was not UTF-8?!");
let ast: syn::File = syn::parse_str(&output).expect("could not parse bindgen output");
let sqlite3_api_routines: syn::ItemStruct = ast
.items
.into_iter()
.find_map(|i| {
if let syn::Item::Struct(s) = i {
if s.ident == "sqlite3_api_routines" {
Some(s)
} else {
None
}
} else {
None
}
})
.expect("could not find sqlite3_api_routines");
let sqlite3_api_routines_ident = sqlite3_api_routines.ident;
let p_api = quote::format_ident!("p_api");
let mut stores = Vec::new();
for field in sqlite3_api_routines.fields {
let ident = field.ident.expect("unamed field");
let span = ident.span();
let name = ident.to_string();
if name == "vmprintf" || name == "xvsnprintf" || name == "str_vappendf" {
// FIXME va_list
continue;
}
let sqlite3_name = mappings
.get(&name)
.unwrap_or_else(|| panic!("no mapping for {name}"));
let ptr_name = syn::Ident::new(format!("__{}", sqlite3_name.to_uppercase()).as_ref(), span);
let sqlite3_fn_name = syn::Ident::new(sqlite3_name, span);
let method =
extract_method(&field.ty).unwrap_or_else(|| panic!("unexpected type for {name}"));
let arg_names: syn::punctuated::Punctuated<&syn::Ident, syn::token::Comma> = method
.inputs
.iter()
.map(|i| &i.name.as_ref().unwrap().0)
.collect();
let args = &method.inputs;
// mprintf/sqlite3_mprintf, xsnprintf/sqlite3_snprintf,
// test_control/sqlite3_test_control, str_appendf/sqlite3_str_appendf: unused
// vtab_config/sqlite3_vtab_config: ok
let varargs = &method.variadic;
let ty = &method.output;
let tokens = if "db_config" == name {
quote::quote! {
static #ptr_name: ::std::sync::atomic::AtomicPtr<unsafe extern "C" fn(#args #varargs) #ty> = ::std::sync::atomic::AtomicPtr::new(::std::ptr::null_mut());
pub unsafe fn #sqlite3_fn_name(#args arg3: ::std::os::raw::c_int, arg4: *mut ::std::os::raw::c_int) #ty {
let fun = #ptr_name.load(::std::sync::atomic::Ordering::Acquire);
assert!(!fun.is_null(), "SQLite API not initialized");
(*fun)(#arg_names, arg3, arg4)
}
}
} else if "log" == name {
quote::quote! {
static #ptr_name: ::std::sync::atomic::AtomicPtr<unsafe extern "C" fn(#args #varargs) #ty> = ::std::sync::atomic::AtomicPtr::new(::std::ptr::null_mut());
pub unsafe fn #sqlite3_fn_name(#args arg3: *const ::std::os::raw::c_char) #ty {
let fun = #ptr_name.load(::std::sync::atomic::Ordering::Acquire);
assert!(!fun.is_null(), "SQLite API not initialized");
(*fun)(#arg_names, arg3)
}
}
} else {
quote::quote! {
static #ptr_name: ::std::sync::atomic::AtomicPtr<unsafe extern "C" fn(#args #varargs) #ty> = ::std::sync::atomic::AtomicPtr::new(::std::ptr::null_mut());
pub unsafe fn #sqlite3_fn_name(#args) #ty {
let fun = #ptr_name.load(::std::sync::atomic::Ordering::Acquire);
assert!(!fun.is_null(), "SQLite API not initialized");
(*fun)(#arg_names)
}
}
};
output.push_str(&prettyplease::unparse(
&syn::parse2(tokens).expect("could not parse quote output"),
));
output.push_str("\n");
stores.push(quote::quote! {
#ptr_name.store(
&mut (*#p_api).#ident.ok_or(InitError::NullFunctionPointer)?,
::std::sync::atomic::Ordering::Release,
);
});
}
let tokens = quote::quote! {
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
#[non_exhaustive]
pub enum InitError {
NullApiPointer,
VersionMismatch{compile_time: i32, runtime: i32},
NullFunctionPointer,
}
impl ::std::fmt::Display for InitError {
fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result {
match *self {
InitError::NullApiPointer => write!(f, "Invalid sqlite3_api_routines pointer"),
InitError::VersionMismatch{compile_time, runtime} => write!(f, "SQLite version mismatch: {runtime} < {compile_time}"),
InitError::NullFunctionPointer => write!(f, "Some sqlite3_api_routines fields are null"),
}
}
}
/// Like SQLITE_EXTENSION_INIT2 macro
pub unsafe fn rusqlite_extension_init2(#p_api: *mut #sqlite3_api_routines_ident) -> ::std::result::Result<(),InitError> {
if #p_api.is_null() {
return Err(InitError::NullApiPointer);
}
if let Some(fun) = (*#p_api).libversion_number {
let version = fun();
if SQLITE_VERSION_NUMBER > version {
return Err(InitError::VersionMismatch{compile_time: SQLITE_VERSION_NUMBER, runtime: version});
}
} else {
return Err(InitError::NullFunctionPointer);
}
#(#stores)*
Ok(())
}
};
output.push_str(&prettyplease::unparse(
&syn::parse2(tokens).expect("could not parse quote output"),
));
output.push_str("\n");
fs::write(out_path, output.as_bytes())
.unwrap_or_else(|_| panic!("Could not write to {:?}", out_path));
}
diff --git a/Cargo.toml b/Cargo.toml
index e8a668c..df1caae 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -50,6 +50,7 @@ bundled-sqlcipher = ["libsqlite3-sys/bundled-sqlcipher", "bundled"]
bundled-sqlcipher-vendored-openssl = ["libsqlite3-sys/bundled-sqlcipher-vendored-openssl", "bundled-sqlcipher"]
buildtime_bindgen = ["libsqlite3-sys/buildtime_bindgen"]
limits = []
+loadable_extension = []
hooks = []
i128_blob = []
sqlcipher = ["libsqlite3-sys/sqlcipher"]
@@ -109,6 +110,7 @@ modern-full = [
]
bundled-full = ["modern-full", "bundled"]
+default = ["modern_sqlite", "loadable_extension"]
[dependencies]
time = { version = "0.3.0", features = ["formatting", "macros", "parsing"], optional = true }
diff --git a/src/inner_connection.rs b/src/inner_connection.rs
index a7487a8..6a5c7ba 100644
--- a/src/inner_connection.rs
+++ b/src/inner_connection.rs
@@ -4,7 +4,7 @@ use std::os::raw::{c_char, c_int};
use std::path::Path;
use std::ptr;
use std::str;
-use std::sync::atomic::{AtomicBool, Ordering};
+use std::sync::atomic::AtomicBool;
use std::sync::{Arc, Mutex};
use super::ffi;
@@ -390,7 +390,7 @@ impl Drop for InnerConnection {
}
}
-#[cfg(not(any(target_arch = "wasm32")))]
+#[cfg(not(any(target_arch = "wasm32", feature = "loadable_extension")))]
static SQLITE_INIT: std::sync::Once = std::sync::Once::new();
pub static BYPASS_SQLITE_INIT: AtomicBool = AtomicBool::new(false);
@@ -440,7 +440,9 @@ fn ensure_safe_sqlite_threading_mode() -> Result<()> {
Ok(())
}
} else {
+ #[cfg(not(feature = "loadable_extension"))]
SQLITE_INIT.call_once(|| {
+ use std::sync::atomic::Ordering;
if BYPASS_SQLITE_INIT.load(Ordering::Relaxed) {
return;
}
diff --git a/src/trace.rs b/src/trace.rs
index ce4c80b..7317a0c 100644
--- a/src/trace.rs
+++ b/src/trace.rs
@@ -8,8 +8,7 @@ use std::ptr;
use std::time::Duration;
use super::ffi;
-use crate::error::error_from_sqlite_code;
-use crate::{Connection, Result};
+use crate::Connection;
/// Set up the process-wide SQLite error logging callback.
///
@@ -25,7 +24,8 @@ use crate::{Connection, Result};
/// * It must be threadsafe if SQLite is used in a multithreaded way.
///
/// cf [The Error And Warning Log](http://sqlite.org/errlog.html).
-pub unsafe fn config_log(callback: Option<fn(c_int, &str)>) -> Result<()> {
+#[cfg(not(feature = "loadable_extension"))]
+pub unsafe fn config_log(callback: Option<fn(c_int, &str)>) -> crate::Result<()> {
extern "C" fn log_callback(p_arg: *mut c_void, err: c_int, msg: *const c_char) {
let c_slice = unsafe { CStr::from_ptr(msg).to_bytes() };
let callback: fn(c_int, &str) = unsafe { mem::transmute(p_arg) };
@@ -48,7 +48,7 @@ pub unsafe fn config_log(callback: Option<fn(c_int, &str)>) -> Result<()> {
if rc == ffi::SQLITE_OK {
Ok(())
} else {
- Err(error_from_sqlite_code(rc, None))
+ Err(crate::error::error_from_sqlite_code(rc, None))
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment