Skip to content

Instantly share code, notes, and snippets.

@hest
Created February 4, 2014 06:08
Show Gist options
  • Save hest/8798884 to your computer and use it in GitHub Desktop.
Save hest/8798884 to your computer and use it in GitHub Desktop.
Fast SQLAlchemy counting (avoid query.count() subquery)
def get_count(q):
count_q = q.statement.with_only_columns([func.count()]).order_by(None)
count = q.session.execute(count_q).scalar()
return count
q = session.query(TestModel).filter(...).order_by(...)
# Slow: SELECT COUNT(*) FROM (SELECT ... FROM TestModel WHERE ...) ...
print q.count()
# Fast: SELECT COUNT(*) FROM TestModel WHERE ...
print get_count(q)
@zt50tz
Copy link

zt50tz commented Nov 8, 2020

Hi all!

If you using ORM querys with model as first argument — the comments below are good to use.

But, in my project i use many querys like:
db.session.query(User.id, UserInfo.name).outerjoin(UserInfo, UserInfo.user_id == User.id)

And prev examples are not able to use.

So i write class that process the query and removes unusable joins for count query.
This class checks joins of query and where conditions then removes unnecessary joins.

Usage:

q = db.session.query(
    User.id, User.nname, UserInfo.theme_id
).outerjoin(
    UserInfo, UserInfo.user_id == User.id
).filter(
    User.id > 0,
    UserInfo.name.like('123')
)

# where=True — take a look on "where"; no drop joins if it exists in "filter" statements
# fields=False — not take a look on fields; drop joins if it exists in fields
count_q = DBRemoveJoin(q, fields=False, where=True, debug=False).process()
count_q = count_q.statement.with_only_columns([func.count('*')]).order_by(None)

print(count_q)

Output:

SELECT count(:count_2) AS count_1
FROM "User" LEFT OUTER JOIN "UserInfo" ON "UserInfo".user_id = "User".id
WHERE "User".id > :id_1 AND "UserInfo".name LIKE :name_1

If you comment UserInfo where statement:

q = db.session.query(
    User.id, User.nname, UserInfo.theme_id
).outerjoin(
    UserInfo, UserInfo.user_id == User.id
).filter(
    User.id > 0,
    # UserInfo.name.like('123')
)

Result will be:

SELECT count(:count_2) AS count_1
FROM "User"
WHERE "User".id > :id_1

Class code:


import collections
import six
from sqlalchemy.orm.util import _ORMJoin, outerjoin
from sqlalchemy import Table, Column

from collections import OrderedDict


def iterable(arg):
    return (
        isinstance(arg, collections.Iterable)
        and not isinstance(arg, six.string_types)
    )


class DBRemoveJoin:
    q = None

    fields = True
    where = True
    debug = False

    table_main = None
    table_need = []
    table_joins = {}

    column_items = []
    where_items = []

    def __init__(self, q, fields=True, where=True, debug=False):
        self.fields = fields
        self.where = where
        self.debug = debug
        self.q = q.filter()
        self.table_need = []
        self.column_items = []
        self.where_items = []
        self.table_joins = OrderedDict()

    @staticmethod
    def clause_columns_process(clause):
        ret = []
        column_objs = []
        if hasattr(clause, 'left'):
            column_objs.append(clause.left)
        if hasattr(clause, 'right'):
            column_objs.append(clause.right)
        if column_objs:
            for column_obj in column_objs:
                if isinstance(column_obj, Column):
                    where_item = {
                        'name': column_obj.name,
                        'column': column_obj,
                        'table_name': str(column_obj.table.name)
                    }
                    ret.append(where_item)
        else:
            for tbl in clause._from_objects:
                where_item = {
                    'name': None,
                    'column': None,
                    'table_name': str(tbl.name)
                }
                ret.append(where_item)
        return ret

    @staticmethod
    def clause_columns(clauses):
        ret = []
        if not iterable(clauses):
            clauses = [clauses]
        for clause in clauses:
            ret += DBRemoveJoin.clause_columns_process(clause)
        return ret
            
    def table_joins_process_join(self, table):
        if isinstance(table.left, Table):
            self.table_main = table.left
        else:
            self.table_joins_process_join(table.left)

        right_str = str(table.right)
        if not right_str:
            right_str = str(table.right.selectable.name)

        right_el = {
            'name': right_str,
            'table': table.right,
            'onclause': table.onclause.expression,
            'table_need': [],
            'need': False
        }

        clause_columns = self.clause_columns(table.onclause)
        clause_columns_skip_table_name = [right_el['name'], self.table_main]
        for clause_column in clause_columns:
            if clause_column['table_name'] not in clause_columns_skip_table_name:
                right_el['table_need'].append(clause_column['table_name'])

        self.table_joins[right_str] = right_el

    def table_joins_process(self):
        for table in self.q._from_obj:
            if isinstance(table, _ORMJoin):
                self.table_joins_process_join(table)
            else:
                pass

    def table_need_add(self, table_name):
        if table_name in self.table_need:
            return
        self.table_need.append(table_name)
        table_join = self.table_joins.get(table_name)
        if table_join:
            table_join['need'] = True
            for sub_table_name in table_join['table_need']:
                self.table_need_add(sub_table_name)

    def fields_process(self):
        if not self.fields:
            return
        for column in self.q.statement.columns:
            column_obj = list(column.base_columns)[0]
            column_el = {
                'name': column.name,
                'column': column_obj,
                'table_name': str(column_obj.table.name)
            }
            self.column_items.append(column_el)
            self.table_need_add(column_el['table_name'])

    def where_process_item(self, clause):
        if clause is None:
            return
        if iterable(clause):
            for el in clause:
                self.where_process_item(el)
            return
        items = self.clause_columns(clause)
        for item in items:
            self.table_need_add(item['table_name'])

    def where_process(self):
        if not self.where:
            return
        if not iterable(self.q.whereclause):
            clauses = self.q.whereclause
        else:
            clauses = self.q.whereclause.clauses
        self.where_process_item(clauses)

    def query_process(self):
        if not self.table_joins:
            return

        self.q._from_obj = (self.table_main, )
        for table in self.table_joins.values():
            if not table['need']:
                continue
            self.q._from_obj = (outerjoin(self.q._from_obj[0], table['table'], table['onclause']), )

    def process(self):
        self.table_joins_process()
        self.fields_process()
        self.where_process()
        self.query_process()
        return self.q

Keep in mind that is home project code.

Sorry for my english (=

@ivcuello
Copy link

ivcuello commented Jan 5, 2021

If anyone is having issues of the ORM "dropping" the FROM statement all together. In my case it was solved by referencing a column in the func.count() call. Instead of using a literal like func.count(1) just reference a column (func.count(Table.id)) & the ORM won't delete the from clause in the resulting query.

@kigawas
Copy link

kigawas commented Apr 9, 2021

Nice function! FYI it does not consider the query limit.

So i extended the function for my usage.

count = q.session.execute(count_q).scalar()
return min(q._limit, count) if q._limit else count

THIS METHOD DOES NOT APPLY TO QUERIES WITH LIMIT AND OFFSET

What if I want to know the total count? Normally in an API's pagination, (page, page_size, total_count) is needed.

Say

SELECT  count(*) AS total_count
FROM user
WHERE user.is_deleted = false
 LIMIT 20 OFFSET 5;

If you execute this in a DB, it'll return nothing, which is None in Python. In fact, the count_q cannot include limit and offset:

image

Although you can work around like this, it still isn't ideal due to overheads.

Solution

To avoid using limit and offset in the query.

@transfluxus
Copy link

Hi,
somebody an idea how that looks like in sqlalchemy 1.4?
I am getting:

sqlalchemy.exc.ArgumentError: Query has only expression-based entities - can't find property named "template". where template is something from my db I guess

@kigawas
Copy link

kigawas commented Oct 25, 2021

⚠️ YOU DO NOT NEED THIS ⚠️

https://dba.stackexchange.com/questions/168022/performance-of-count-in-subquery

⚠️ IF YOUR QUERY IS SLOW, YOU SHOULD REMOVE ORDER BY FIRST ⚠️

count_stmt = select(func.count()).select_from(statement.order_by(None).subquery())
await session.scalar(count_stmt)

@davidjb99
Copy link

davidjb99 commented Nov 8, 2021

This worked perfectly for me after I upgraded from sa 1.3 to 1.4 and a count query went from 80ms to 800ms.

by switching form the built in .count() to the suggested first gist the query went back to 80ms. I believe the problem was .count() loading all columns into python which is not required and very slow for thousands or rows, using with_only_columns and removing the sub query took it back to 80ms. No idea what broke it in 1.4.

    count_q = q.statement.with_only_columns([func.count()]).order_by(None)
    count = q.session.execute(count_q).scalar()

I'm not sure why people are putting warnings on this gist or writing very long replies expecting help. If you are stuck ask on Stackoverflow!

My thanks to @hest

@kigawas
Copy link

kigawas commented Nov 9, 2021

This worked perfectly for me after I upgraded from sa 1.3 to 1.4 and a count query went from 80ms to 800ms.

by switching form the built in .count() to the suggested first gist the query went back to 80ms. I believe the problem was .count() loading all columns into python which is not required and very slow for thousands or rows, using with_only_columns and removing the sub query took it back to 80ms. No idea what broke it in 1.4.

    count_q = q.statement.with_only_columns([func.count()]).order_by(None)
    count = q.session.execute(count_q).scalar()

I'm not sure why people are putting warnings on this gist or writing very long replies expecting help. If you are stuck ask on Stackoverflow!

My thanks to @hest

This thread is generally a kind of misinformation. If you check the stackoverflow link above, these sqls are exactly the same:

-- on postgres
EXPLAIN ANALYZE SELECT COUNT(*) FROM some_big_table WHERE some_col = 'some_val'
EXPLAIN ANALYZE SELECT COUNT(*) FROM ( SELECT col1, col2, col3, col4 FROM some_big_table WHERE some_col = 'some_val' )

If you find your query executing slow, the first is to try removing order by.

@davidjb99
Copy link

If you find your query executing slow, the first is to try removing order by.

I tried this, and the methods outlined in the stackoverflow link, and it did not work. Please don't say I'm providing misinformation by outlining what worked for me it is rude.

@kigawas
Copy link

kigawas commented Dec 20, 2021

@davidjb99

Sorry for letting you misunderstand. I meant the gist (which was posted in 2015) is likely misinformation now, not your comment.

@kannasuresh99
Copy link

This is the method, I'm using

def get_count(self, model_fields, filter_clause):
        """ Note: filter_clause should not be 'None' or 'Null' for this method to work """
        query = self.session.query().with_entities(*model_fields)
        query = query.filter(filter_clause)
        count_query = query.statement \
                .with_only_columns([func.count()]) \
                .order_by(None)
        result = query.session.execute(count_query).scalar()
        return result

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment