Skip to content

Instantly share code, notes, and snippets.

@sgrif
Last active November 24, 2015 17:57
Show Gist options
  • Star 4 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save sgrif/5f44e911eb9be515bdfb to your computer and use it in GitHub Desktop.
Save sgrif/5f44e911eb9be515bdfb to your computer and use it in GitHub Desktop.

I think that an ORM that is truly in the spirit of Rust should endeavor to statically check as much as humanly possible at compile time. I am writing this with PG specifically in mind, and I think coupling to PG for the first iteration is a good way to go about it. In my experience, supporting MySQL and SQLite is trivial from working PG support. Other back ends like SQL Server have other interesting constraints that could be difficult and would also be great to check at compile time.

The first thing that I'd like to look at a good API for is reads. This also only focuses on the simplest case (SELECT * FROM queries), but I believe this structure can be used to create an API that allows for where to be implemented in such a way that it is also statically checked at compile time.

In my mind, reading from the database will consist of two steps. The transition should be from (raw query results) -> (tuple of Rust "primitives") -> (user defined structure).

Type safe queries

I believe we can structure the types in such a way that we reject any query at compile time where we do not know for certain that we can handle every native database type that the query will return. This is the general structure that I'm imagining, with the setup to handle id and string columns.

trait NativeSqlType {}

mod types {
    use super::NativeSqlType;

    pub struct Serial;
    impl NativeSqlType for Serial {}
    pub struct VarChar;
    impl NativeSqlType for VarChar {}
}

// Real impl would replace with a macro to define for all tuple sizes
impl<A: NativeSqlType, B: NativeSqlType> NativeSqlType for (A, B) {}

trait FromSql<Source: NativeSqlType> {
}

// Real impl would replace with a macro to define for all tuple sizes
impl<A, B, SA, SB> FromSql<(SA, SB)> for (A, B) where
    A: FromSql<SA>,
    B: FromSql<SB>,
    SA: NativeSqlType,
    SB: NativeSqlType,
{}

impl FromSql<types::Serial> for i32 {}
impl FromSql<types::VarChar> for String {}

trait QuerySource {
    type SqlType: NativeSqlType;
}


trait Table: QuerySource {
    fn name() -> &'static str;
}

We can also likely implement a generic join type of method at this level with a signature such as this (ignoring boxing, there'd be a concrete known type):

trait QuerySource {
  ...

  fn join<A: QuerySource>(other: A) ->
    QuerySource<SqlType=(Self::SqlType, A::SqlType)> {
    ...
  }
}

It is unclear to me at this time how we actually want to represent joined records for a has many type of association, I don't think Vec is the right type to use here, since we don't want to automatically force heap allocation, but it might be required. Perhaps just returning an Iterator is the right call here. Either way these can be figured out during implementation.

At this stage, I do think I would like to have the interface that we use to represent associations be at the query level, and not actually invade the model. This also makes it very clear what data is actually being accessed, and makes it much more difficult to accidentally introduce an N+1 bug. So the type of a user which has many posts would be (User, Vec<Post>)

User consumption of QuerySource

This can be pretty trivially written into a macro, which with a compiler plugin we could automatically generate calls to this macro at compile time.

table! {
    UserTable {
        users
        id -> types::Serial,
        name -> types::VarChar,
    }
}

That would expand into roughly:

struct UserTable;

impl QuerySource for UserTable {
    type SqlType = (types::Serial, types::VarChar);
}

impl Table for UserTable {
    fn name() -> &'static str {
        "users"
    }
}

Mapping to models

Now we want to populate the following struct from a SQL query:

struct User {
    id: i32,
    name: String,
}

Again, using a compiler plugin, we can automatically derive this:

trait Queriable<A: QuerySource> {
    type Row: FromSql<A::SqlType>;

    fn build(row: Self::Row) -> Self;
}

impl<QS> Queriable<QS> for User where
    (i32, String): FromSql<QS::SqlType>,
{
    fn build(row: (i32, String)) -> Self {
        User {
            id: row.0,
            name: row.1,
        }
    }
}

Useful usage

This then allows us to write a generic function to actually query the values, statically verifying that we're building a query that can actually be used properly (note: this is an over-simplified case ignoring any query other than SELECT * FROM ...)

impl ConnectionAdapter {
  fn query<M, QS>(&self, query_source: &QS) -> Result<Cursor<M, QS>>
  where M: Queriable<QS>, QS: QuerySource {
      unimplemented!()
  }
}

struct Cursor<M, QS> where M: Queriable<QS>, QS: QuerySource {
    internal_cursor: Rows,
    _model_marker: PhantomData<M>,
    _query_marker: PhantomData<QS>,
}

impl<M, QS> Iterator for Cursor<M, QS> where M: Queriable<QS>, QS: QuerySource {
    type Item = M;

    fn next(&mut self) -> Option<Self::Item> {
        let maybe_row: Option<M::Row> = unimplemented!();
        maybe_row.map(|r| Self::Item::build(r))
    }
}

The assumption here is that QuerySource would likely implement some form of a to_sql() function, and adding things like where clauses would create a new QuerySource, having the values being covered statically checked. There would be an unsafe fn that performs the same thing, but taking a raw string of SQL.

Usage at this stage would now look like this:

for user in connection.query::<User, _>(&UserTable).unwrap() {
    println!("id: {}, name: {}", user.id, user.name);
}

If rust were to add implicits at some point in the future, we would probably want to move the connection to an implicit argument.

We could likely add some sugar on top of this to make it prettier as the common patterns of duplication become apparent. For example, once specialization lands, we might be able to add a more specific impl for Queriable

impl Queriable<UserTable> for User {
    fn all(connection: ConnectionAdapter) -> Result<Cursor<Self, UserTable>> {
        connection.query(&UserTable)
    }
}

fn main() {
    let connection: ConnectionAdapter = unimplemented!();
    for user in User::all(connection).unwrap() {
        ...
    }
}

Additionally, if we were to have a second table, for example posts, attempting to query from it for the wrong type would fail to compile:

table! {
    PostTable {
        posts
        id -> types::Serial,
        user_id -> types::Integer,
        title -> types::VarChar,
    }
}

fn main() {
    let connection: ConnectionAdapter = unimplemented!();
    connection.query::<User, _>(&PostTable).unwrap(); // This would fail to compile
}

We can also type check an arbitrary SQL query at compile time, if we're willing to have a database connection. Hypothetical code might look like:

// library code
struct SqlQuery<A: NativeSqlType> {
    query: &'static str,
    _marker: PhantomData<A>,
}

impl<A: NativeSqlType> QuerySource for SqlQuery<A> {
    type SqlType = A;
}

fn main() {
    let source = query!("SELECT * FROM users");
    let connection: ConnectionAdapter = unimplemented!();
    for user in connection.query(&source) {
        do_stuff_with_user(user);
    }
}

That macro itself would end up expanding to:

let source;
unsafe {
    source = SqlQuery::<(types::Serial, types::VarChar)>::new("SELECT * FROM users");
}

While I do think a DSL for where, having, and joins all make sense, we can probably avoid for things like select clauses, which tend to be more static.

One case that I do want to consider more carefully what the API should look like is:

let source = query!("SELECT users.*, COUNT(posts) FROM users INNER JOIN posts ON
                     users.id == posts.user_id GROUP BY users.id");

Presumably something like this could look like:

// I have no clue where the group by would live at this point
let source = select!("users.*, COUNT(posts)", UsersTable.join(PostsTable));

However, actually verifying that query at compile time requires some knowledge that UsersTable and PostsTable are compile time constants. We also should think about how we actually derive that join. Maybe it would make sense to actually generate a known type with the join between the table.

Concerns

Two tables with the same types of columns will successfully compile, if we add a non-specific impl of Queryable. One option is to only generate specific sources, however I don't intend for all members of QuerySource to be as specific as UserTable, for example adding a where clause. However, we can likely separate refining a query from the original source, and have the only cases that would actually change the type of the QuerySource be things that would change the select clause or the from clause. In both cases, I believe you would end up changing the return value as well (e.g. joining would change to a tuple).

Another solution to that would be to find a way to separate "users"."id" from "posts"."id" at the type level, but I feel like that would end up causing more problems than it would solve.

Ultimately while this shows all of the internal setup, I think we can push almost all of it into the compiler via macros (possibly automatically invoked), and auto-derivation of traits.

While this is the bare bones for the simplest possible case, I think this is a good place to start, and would give us a good baseline to build off of. The code below is a skeleton of the type structure of everything mentioned in this document, and compiles successfully.

use std::marker::PhantomData;
trait NativeSqlType {}
trait FromSql<Source: NativeSqlType> {
}
mod types {
use super::NativeSqlType;
pub struct Serial;
impl NativeSqlType for Serial {}
pub struct VarChar;
impl NativeSqlType for VarChar {}
}
macro_rules! tuple_impls {
($(
$Tuple:ident {
$(($idx:tt) -> $T:ident, $ST:ident,)+
}
)+) => {
$(
impl<$($T:NativeSqlType),+> NativeSqlType for ($($T,)+) {}
impl<$($T),+,$($ST),+> FromSql<($($ST),+)> for ($($T),+) where
$($T: FromSql<$ST>),+,
$($ST: NativeSqlType),+
{}
)+
}
}
tuple_impls! {
T2 {
(0) -> A, SA,
(1) -> B, SB,
}
T3 {
(0) -> A, SA,
(1) -> B, SB,
(2) -> C, SC,
}
T4 {
(0) -> A, SA,
(1) -> B, SB,
(2) -> C, SC,
(3) -> D, SD,
}
T5 {
(0) -> A, SA,
(1) -> B, SB,
(2) -> C, SC,
(3) -> D, SD,
(4) -> E, SE,
}
T6 {
(0) -> A, SA,
(1) -> B, SB,
(2) -> C, SC,
(3) -> D, SD,
(4) -> E, SE,
(5) -> F, SF,
}
T7 {
(0) -> A, SA,
(1) -> B, SB,
(2) -> C, SC,
(3) -> D, SD,
(4) -> E, SE,
(5) -> F, SF,
(6) -> G, SG,
}
T8 {
(0) -> A, SA,
(1) -> B, SB,
(2) -> C, SC,
(3) -> D, SD,
(4) -> E, SE,
(5) -> F, SF,
(6) -> G, SG,
(7) -> H, SH,
}
T9 {
(0) -> A, SA,
(1) -> B, SB,
(2) -> C, SC,
(3) -> D, SD,
(4) -> E, SE,
(5) -> F, SF,
(6) -> G, SG,
(7) -> H, SH,
(8) -> I, SI,
}
T10 {
(0) -> A, SA,
(1) -> B, SB,
(2) -> C, SC,
(3) -> D, SD,
(4) -> E, SE,
(5) -> F, SF,
(6) -> G, SG,
(7) -> H, SH,
(8) -> I, SI,
(9) -> J, SJ,
}
T11 {
(0) -> A, SA,
(1) -> B, SB,
(2) -> C, SC,
(3) -> D, SD,
(4) -> E, SE,
(5) -> F, SF,
(6) -> G, SG,
(7) -> H, SH,
(8) -> I, SI,
(9) -> J, SJ,
(10) -> K, SK,
}
T12 {
(0) -> A, SA,
(1) -> B, SB,
(2) -> C, SC,
(3) -> D, SD,
(4) -> E, SE,
(5) -> F, SF,
(6) -> G, SG,
(7) -> H, SH,
(8) -> I, SI,
(9) -> J, SJ,
(10) -> K, SK,
(11) -> L, SL,
}
}
impl FromSql<types::Serial> for i32 {}
impl FromSql<types::VarChar> for String {}
trait QuerySource {
type SqlType: NativeSqlType;
}
trait Table: QuerySource {
fn name() -> &'static str;
}
// table! {
// UserTable {
// users
// id -> types::Serial,
// name -> types::VarChar,
// }
// }
//
// table! {
// PostTable {
// posts
// id -> types::Serial,
// user_id -> types::Integer,
// title -> types::VarChar,
// }
// }
// Expanded code from macro
struct UserTable;
impl QuerySource for UserTable {
type SqlType = (types::Serial, types::VarChar);
}
impl Table for UserTable {
fn name() -> &'static str {
"users"
}
}
struct PostTable;
impl QuerySource for PostTable {
// id, user_id, title
type SqlType = (types::Serial, types::Serial, types::VarChar);
}
trait Queriable<A: QuerySource> {
type Row: FromSql<A::SqlType>;
fn build(row: Self::Row) -> Self;
}
struct User {
id: i32,
name: String,
}
impl<QS: QuerySource> Queriable<QS> for User where
(i32, String): FromSql<QS::SqlType>,
{
type Row = (i32, String);
fn build(row: (i32, String)) -> Self {
User {
id: row.0,
name: row.1,
}
}
}
struct ConnectionAdapter;
impl ConnectionAdapter {
fn query<M, QS>(&self, query_source: &QS) -> Result<Cursor<M, QS>>
where M: Queriable<QS>, QS: QuerySource {
unimplemented!()
}
}
struct Cursor<M, QS> where M: Queriable<QS>, QS: QuerySource {
_model_marker: PhantomData<M>,
_query_marker: PhantomData<QS>,
}
#[derive(Debug)]
enum OrmError {}
type Result<A> = ::std::result::Result<A, OrmError>;
impl<M, QS> Iterator for Cursor<M, QS> where M: Queriable<QS>, QS: QuerySource {
type Item = M;
fn next(&mut self) -> Option<Self::Item> {
let maybe_row: Option<M::Row> = unimplemented!();
maybe_row.map(|r| Self::Item::build(r))
}
}
fn main() {
let connection = ConnectionAdapter;
for user in connection.query(&UserTable).unwrap() {
print_a_user(user)
}
// The following line would fail to compile if uncommented.
// connection.query::<User, _>(&PostTable).unwrap();
}
fn print_a_user(user: User) {
println!("id: {}, name: {}", user.id, user.name);
}
@emptyflash
Copy link

I don't know Rust, but this looks like a really cool idea!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment