Skip to content

Instantly share code, notes, and snippets.

@hest
Created February 4, 2014 06:08
Show Gist options
  • Save hest/8798884 to your computer and use it in GitHub Desktop.
Save hest/8798884 to your computer and use it in GitHub Desktop.
Fast SQLAlchemy counting (avoid query.count() subquery)
def get_count(q):
count_q = q.statement.with_only_columns([func.count()]).order_by(None)
count = q.session.execute(count_q).scalar()
return count
q = session.query(TestModel).filter(...).order_by(...)
# Slow: SELECT COUNT(*) FROM (SELECT ... FROM TestModel WHERE ...) ...
print q.count()
# Fast: SELECT COUNT(*) FROM TestModel WHERE ...
print get_count(q)
@Bryant-Yang
Copy link

Bryant-Yang commented Aug 17, 2020

@mononobi

I've implemented this in a subclass of Session with count() overridden, so now I could use query.count() and it does not make a sub-query ...

why not just override default query's count?

original 'count' in default query:

    def count(self):
        r"""Return a count of rows this the SQL formed by this :class:`Query`
        would return.

        This generates the SQL for this Query as follows::

            SELECT count(1) AS count_1 FROM (
                SELECT <rest of query follows...>
            ) AS anon_1

        The above SQL returns a single row, which is the aggregate value
        of the count function; the :meth:`_query.Query.count`
        method then returns
        that single integer value.

        .. warning::

            It is important to note that the value returned by
            count() is **not the same as the number of ORM objects that this
            Query would return from a method such as the .all() method**.
            The :class:`_query.Query` object,
            when asked to return full entities,
            will **deduplicate entries based on primary key**, meaning if the
            same primary key value would appear in the results more than once,
            only one object of that primary key would be present.  This does
            not apply to a query that is against individual columns.

            .. seealso::

                :ref:`faq_query_deduplicating`

                :ref:`orm_tutorial_query_returning`

        For fine grained control over specific columns to count, to skip the
        usage of a subquery or otherwise control of the FROM clause, or to use
        other aggregate functions, use :attr:`~sqlalchemy.sql.expression.func`
        expressions in conjunction with :meth:`~.Session.query`, i.e.::

            from sqlalchemy import func

            # count User records, without
            # using a subquery.
            session.query(func.count(User.id))

            # return count of user "id" grouped
            # by "name"
            session.query(func.count(User.id)).\
                    group_by(User.name)

            from sqlalchemy import distinct

            # count distinct "name" values
            session.query(func.count(distinct(User.name)))

        """
        col = sql.func.count(sql.literal_column("*"))
        return self.from_self(col).scalar()

overridden 'count' in custom query which is subclass of BaseQuery

    def count(self):
		disable_group_by = False
		if len(self._entities) > 1:
			# currently support only one entity
			raise Exception('only one entity is supported for get_count, got: %s' % q)
		entity = self._entities[0]
		if hasattr(entity, 'column'):
			# _ColumnEntity has column attr - on case: query(Model.column)...
			col = entity.column
			if self._group_by and self._distinct:
				# which query can have both?
				raise NotImplementedError
			if self._group_by or self._distinct:
				col = distinct(col)
			if self._group_by:
				# need to disable group_by and enable distinct - we can do this because we have only 1 entity
				disable_group_by = True
			count_func = func.count(col)
		else:
			# _MapperEntity doesn't have column attr - on case: query(Model)...
			count_func = func.count()
		if self._group_by and not disable_group_by:
			count_func = count_func.over(None)
		count_q = self.options(lazyload('*')).statement.with_only_columns([count_func]).order_by(None)
		if disable_group_by:
			count_q = count_q.group_by(None)
		return self.session.execute(count_q).scalar()

had to say, original count with self.from_self(col).scalar() is clear enough, and no troubles caused by group by, distinct, or may be even more advanced query, what ever.

@zt50tz
Copy link

zt50tz commented Nov 8, 2020

Hi all!

If you using ORM querys with model as first argument — the comments below are good to use.

But, in my project i use many querys like:
db.session.query(User.id, UserInfo.name).outerjoin(UserInfo, UserInfo.user_id == User.id)

And prev examples are not able to use.

So i write class that process the query and removes unusable joins for count query.
This class checks joins of query and where conditions then removes unnecessary joins.

Usage:

q = db.session.query(
    User.id, User.nname, UserInfo.theme_id
).outerjoin(
    UserInfo, UserInfo.user_id == User.id
).filter(
    User.id > 0,
    UserInfo.name.like('123')
)

# where=True — take a look on "where"; no drop joins if it exists in "filter" statements
# fields=False — not take a look on fields; drop joins if it exists in fields
count_q = DBRemoveJoin(q, fields=False, where=True, debug=False).process()
count_q = count_q.statement.with_only_columns([func.count('*')]).order_by(None)

print(count_q)

Output:

SELECT count(:count_2) AS count_1
FROM "User" LEFT OUTER JOIN "UserInfo" ON "UserInfo".user_id = "User".id
WHERE "User".id > :id_1 AND "UserInfo".name LIKE :name_1

If you comment UserInfo where statement:

q = db.session.query(
    User.id, User.nname, UserInfo.theme_id
).outerjoin(
    UserInfo, UserInfo.user_id == User.id
).filter(
    User.id > 0,
    # UserInfo.name.like('123')
)

Result will be:

SELECT count(:count_2) AS count_1
FROM "User"
WHERE "User".id > :id_1

Class code:


import collections
import six
from sqlalchemy.orm.util import _ORMJoin, outerjoin
from sqlalchemy import Table, Column

from collections import OrderedDict


def iterable(arg):
    return (
        isinstance(arg, collections.Iterable)
        and not isinstance(arg, six.string_types)
    )


class DBRemoveJoin:
    q = None

    fields = True
    where = True
    debug = False

    table_main = None
    table_need = []
    table_joins = {}

    column_items = []
    where_items = []

    def __init__(self, q, fields=True, where=True, debug=False):
        self.fields = fields
        self.where = where
        self.debug = debug
        self.q = q.filter()
        self.table_need = []
        self.column_items = []
        self.where_items = []
        self.table_joins = OrderedDict()

    @staticmethod
    def clause_columns_process(clause):
        ret = []
        column_objs = []
        if hasattr(clause, 'left'):
            column_objs.append(clause.left)
        if hasattr(clause, 'right'):
            column_objs.append(clause.right)
        if column_objs:
            for column_obj in column_objs:
                if isinstance(column_obj, Column):
                    where_item = {
                        'name': column_obj.name,
                        'column': column_obj,
                        'table_name': str(column_obj.table.name)
                    }
                    ret.append(where_item)
        else:
            for tbl in clause._from_objects:
                where_item = {
                    'name': None,
                    'column': None,
                    'table_name': str(tbl.name)
                }
                ret.append(where_item)
        return ret

    @staticmethod
    def clause_columns(clauses):
        ret = []
        if not iterable(clauses):
            clauses = [clauses]
        for clause in clauses:
            ret += DBRemoveJoin.clause_columns_process(clause)
        return ret
            
    def table_joins_process_join(self, table):
        if isinstance(table.left, Table):
            self.table_main = table.left
        else:
            self.table_joins_process_join(table.left)

        right_str = str(table.right)
        if not right_str:
            right_str = str(table.right.selectable.name)

        right_el = {
            'name': right_str,
            'table': table.right,
            'onclause': table.onclause.expression,
            'table_need': [],
            'need': False
        }

        clause_columns = self.clause_columns(table.onclause)
        clause_columns_skip_table_name = [right_el['name'], self.table_main]
        for clause_column in clause_columns:
            if clause_column['table_name'] not in clause_columns_skip_table_name:
                right_el['table_need'].append(clause_column['table_name'])

        self.table_joins[right_str] = right_el

    def table_joins_process(self):
        for table in self.q._from_obj:
            if isinstance(table, _ORMJoin):
                self.table_joins_process_join(table)
            else:
                pass

    def table_need_add(self, table_name):
        if table_name in self.table_need:
            return
        self.table_need.append(table_name)
        table_join = self.table_joins.get(table_name)
        if table_join:
            table_join['need'] = True
            for sub_table_name in table_join['table_need']:
                self.table_need_add(sub_table_name)

    def fields_process(self):
        if not self.fields:
            return
        for column in self.q.statement.columns:
            column_obj = list(column.base_columns)[0]
            column_el = {
                'name': column.name,
                'column': column_obj,
                'table_name': str(column_obj.table.name)
            }
            self.column_items.append(column_el)
            self.table_need_add(column_el['table_name'])

    def where_process_item(self, clause):
        if clause is None:
            return
        if iterable(clause):
            for el in clause:
                self.where_process_item(el)
            return
        items = self.clause_columns(clause)
        for item in items:
            self.table_need_add(item['table_name'])

    def where_process(self):
        if not self.where:
            return
        if not iterable(self.q.whereclause):
            clauses = self.q.whereclause
        else:
            clauses = self.q.whereclause.clauses
        self.where_process_item(clauses)

    def query_process(self):
        if not self.table_joins:
            return

        self.q._from_obj = (self.table_main, )
        for table in self.table_joins.values():
            if not table['need']:
                continue
            self.q._from_obj = (outerjoin(self.q._from_obj[0], table['table'], table['onclause']), )

    def process(self):
        self.table_joins_process()
        self.fields_process()
        self.where_process()
        self.query_process()
        return self.q

Keep in mind that is home project code.

Sorry for my english (=

@ivcuello
Copy link

ivcuello commented Jan 5, 2021

If anyone is having issues of the ORM "dropping" the FROM statement all together. In my case it was solved by referencing a column in the func.count() call. Instead of using a literal like func.count(1) just reference a column (func.count(Table.id)) & the ORM won't delete the from clause in the resulting query.

@kigawas
Copy link

kigawas commented Apr 9, 2021

Nice function! FYI it does not consider the query limit.

So i extended the function for my usage.

count = q.session.execute(count_q).scalar()
return min(q._limit, count) if q._limit else count

THIS METHOD DOES NOT APPLY TO QUERIES WITH LIMIT AND OFFSET

What if I want to know the total count? Normally in an API's pagination, (page, page_size, total_count) is needed.

Say

SELECT  count(*) AS total_count
FROM user
WHERE user.is_deleted = false
 LIMIT 20 OFFSET 5;

If you execute this in a DB, it'll return nothing, which is None in Python. In fact, the count_q cannot include limit and offset:

image

Although you can work around like this, it still isn't ideal due to overheads.

Solution

To avoid using limit and offset in the query.

@transfluxus
Copy link

Hi,
somebody an idea how that looks like in sqlalchemy 1.4?
I am getting:

sqlalchemy.exc.ArgumentError: Query has only expression-based entities - can't find property named "template". where template is something from my db I guess

@kigawas
Copy link

kigawas commented Oct 25, 2021

⚠️ YOU DO NOT NEED THIS ⚠️

https://dba.stackexchange.com/questions/168022/performance-of-count-in-subquery

⚠️ IF YOUR QUERY IS SLOW, YOU SHOULD REMOVE ORDER BY FIRST ⚠️

count_stmt = select(func.count()).select_from(statement.order_by(None).subquery())
await session.scalar(count_stmt)

@davidjb99
Copy link

davidjb99 commented Nov 8, 2021

This worked perfectly for me after I upgraded from sa 1.3 to 1.4 and a count query went from 80ms to 800ms.

by switching form the built in .count() to the suggested first gist the query went back to 80ms. I believe the problem was .count() loading all columns into python which is not required and very slow for thousands or rows, using with_only_columns and removing the sub query took it back to 80ms. No idea what broke it in 1.4.

    count_q = q.statement.with_only_columns([func.count()]).order_by(None)
    count = q.session.execute(count_q).scalar()

I'm not sure why people are putting warnings on this gist or writing very long replies expecting help. If you are stuck ask on Stackoverflow!

My thanks to @hest

@kigawas
Copy link

kigawas commented Nov 9, 2021

This worked perfectly for me after I upgraded from sa 1.3 to 1.4 and a count query went from 80ms to 800ms.

by switching form the built in .count() to the suggested first gist the query went back to 80ms. I believe the problem was .count() loading all columns into python which is not required and very slow for thousands or rows, using with_only_columns and removing the sub query took it back to 80ms. No idea what broke it in 1.4.

    count_q = q.statement.with_only_columns([func.count()]).order_by(None)
    count = q.session.execute(count_q).scalar()

I'm not sure why people are putting warnings on this gist or writing very long replies expecting help. If you are stuck ask on Stackoverflow!

My thanks to @hest

This thread is generally a kind of misinformation. If you check the stackoverflow link above, these sqls are exactly the same:

-- on postgres
EXPLAIN ANALYZE SELECT COUNT(*) FROM some_big_table WHERE some_col = 'some_val'
EXPLAIN ANALYZE SELECT COUNT(*) FROM ( SELECT col1, col2, col3, col4 FROM some_big_table WHERE some_col = 'some_val' )

If you find your query executing slow, the first is to try removing order by.

@davidjb99
Copy link

If you find your query executing slow, the first is to try removing order by.

I tried this, and the methods outlined in the stackoverflow link, and it did not work. Please don't say I'm providing misinformation by outlining what worked for me it is rude.

@kigawas
Copy link

kigawas commented Dec 20, 2021

@davidjb99

Sorry for letting you misunderstand. I meant the gist (which was posted in 2015) is likely misinformation now, not your comment.

@kannasuresh99
Copy link

This is the method, I'm using

def get_count(self, model_fields, filter_clause):
        """ Note: filter_clause should not be 'None' or 'Null' for this method to work """
        query = self.session.query().with_entities(*model_fields)
        query = query.filter(filter_clause)
        count_query = query.statement \
                .with_only_columns([func.count()]) \
                .order_by(None)
        result = query.session.execute(count_query).scalar()
        return result

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment