Skip to content

Instantly share code, notes, and snippets.

@hest
Created February 4, 2014 06:08
Show Gist options
  • Save hest/8798884 to your computer and use it in GitHub Desktop.
Save hest/8798884 to your computer and use it in GitHub Desktop.
Fast SQLAlchemy counting (avoid query.count() subquery)
def get_count(q):
count_q = q.statement.with_only_columns([func.count()]).order_by(None)
count = q.session.execute(count_q).scalar()
return count
q = session.query(TestModel).filter(...).order_by(...)
# Slow: SELECT COUNT(*) FROM (SELECT ... FROM TestModel WHERE ...) ...
print q.count()
# Fast: SELECT COUNT(*) FROM TestModel WHERE ...
print get_count(q)
@wingyiu
Copy link

wingyiu commented Aug 10, 2015

nice jj

@mark-keaton
Copy link

This is amazing.

@lloydzhou
Copy link

👍

@panuhorsmalahti
Copy link

Thanks. I used it like this:

db.session.execute(
                    db.session
                        .query(TableName)
                        .filter_by(x_id=y.id)
                        .statement.with_only_columns([func.count()]).order_by(None)
                    ).scalar()

@hagai26
Copy link

hagai26 commented Sep 1, 2016

In case of 'joined' relations it can make bad queries. So I cancel all joined to be lazy:

count_q = q.options(lazyload('*')).statement.with_only_columns([func.count()]).order_by(None)
count = q.session.execute(count_q).scalar()

@nchudleigh
Copy link

nchudleigh commented Nov 18, 2016

@hagai26 I am running into the same bad queries with joins.

.options(lazyload('*')) doesn't change the result or the resulting sql

count = get_count(
        User
        .join(Partnership)
        .filter(
            Customer.partnership_id == Partnership.id,
            Partnership.user_id == user_id
        )
    )

which results with 1

If I use .count() it is returning 2 (the correct answer)

Any ideas?

@nchudleigh
Copy link

nchudleigh commented Nov 18, 2016

ended up using with_entities instead

def get_count(q):
    return q.with_entities(func.count()).scalar()

@hagai26
Copy link

hagai26 commented Dec 21, 2017

I added support for more stuff. Updated code:

from sqlalchemy import func, distinct
from sqlalchemy.orm import lazyload

def get_count(q):
	disable_group_by = False
	if len(q._entities) > 1:
		# currently support only one entity
		raise Exception('only one entity is supported for get_count, got: %s' % q)
	entity = q._entities[0]
	if hasattr(entity, 'column'):
		# _ColumnEntity has column attr - on case: query(Model.column)...
		col = entity.column
		if q._group_by and q._distinct:
			# which query can have both?
			raise NotImplementedError
		if q._group_by or q._distinct:
			col = distinct(col)
		if q._group_by:
			# need to disable group_by and enable distinct - we can do this because we have only 1 entity
			disable_group_by = True
		count_func = func.count(col)
	else:
		# _MapperEntity doesn't have column attr - on case: query(Model)...
		count_func = func.count()
	if q._group_by and not disable_group_by:
		count_func = count_func.over(None)
	count_q = q.options(lazyload('*')).statement.with_only_columns([count_func]).order_by(None)
	if disable_group_by:
		count_q = count_q.group_by(None)
	return q.session.execute(count_q).scalar()

@veuncent
Copy link

Very helpful, thanks!

@rdammkoehler
Copy link

rdammkoehler commented Feb 21, 2018

This is pretty cool, but I can't seem to get it working against a view. The underlying query I'm testing is the equivalent of select * from myview, no query params (yet) but the query it generates is select count(*) as count_1 which always returns 1. It seems the from clause just gets dropped. Do you have any thoughts about this?

UPDATE: So if there are no filters in the query the solution above results in select count(*) as count_1 and adding a filter like '1=1' changes the query to select count(*) as count_1 where 1=1, so there still isn't a table to count from.

SOLUTION: I added the following code to force the existance of a filter in the case where there are no filters;

    if hasattr(model_class, 'columns'):
        q = q.filter(model_class.columns.get('id') == model_class.columns.get('id'))
    else:
        q = q.filter(getattr(model_class, 'id') == getattr(model_class, 'id'))

It's not a great solution, but it seems to be working. I'll let you know if I can refine this (which I'm sure I can)

@miron0xff
Copy link

Thank you!

@gdaolewe
Copy link

gdaolewe commented Oct 29, 2018

The solution in the gist was working, but I was getting warnings like SAWarning: Column 'id' on table being replaced by Column('id', Integer(), table=, primary_key=True, nullable=False), which has the same key. Consider use_labels for select() statements..

Using with_entities as suggested in @nchudleigh's comment got rid of those warnings.

@ewalid
Copy link

ewalid commented Mar 27, 2019

Not All Heroes wear capes

@ansarisufiyan777
Copy link

I was into this problem for hours, finally got the solution. Really appreciate the solution.

@jbolotin
Copy link

jbolotin commented Aug 6, 2019

If sharding is at play, make sure to pass in shard_id=YOUR_SHARD_ID as a param to execute as follows: q.session.execute(count_q, shard_id=YOUR_SHARD_ID).scalar()

@mononobi
Copy link

amazing work.
it's so odd that query.count() makes a subquery, what's the reason behind that?

@provinzio
Copy link

provinzio commented Feb 6, 2020

Nice function! FYI it does not consider the query limit.

So i extended the function for my usage.

count = q.session.execute(count_q).scalar()
return min(q._limit, count) if q._limit else count

@mononobi
Copy link

mononobi commented Feb 15, 2020

I've implemented this in a subclass of Session with count() overridden, so now I could use query.count() and it does not make a sub-query.
you should implement this method in a subclass of Session, I also sub-classed the Column class, but its not really needed if you don't want.
you must pass this sub-classed Session class to your session maker.

THIS IS INSIDE SUB-CLASSED SESSION:

def count(self, **options):
    """
    returns the count of rows the sql formed by this `Query` would return.
    this method is overridden to prevent inefficient count() of sqlalchemy `Query`
    which produces a subquery.

    this method generates a single sql query like below:
    select count(column, ...)
    from table
    where ...

    :keyword bool distinct: specifies that count should
                            be executed on distinct select.
                            defaults to False if not provided.

    :keyword bool fallback: specifies that count should
                            be executed using original sqlalchemy
                            function which produces a subquery.
                            defaults to False if not provided.

    :rtype: int
    """

    fallback = options.get('fallback', False)
    needs_fallback = False
    columns = []
    for single_column in self.selectable.columns:
        if not isinstance(single_column, CoreColumn):
            if fallback is False:
                raise UnsupportedQueryStyleError('Current query does not have pure columns '
                                                 'in its expression. if you need to apply a '
                                                 'keyword like "DISTINCT", you should apply '
                                                 'it by passing "distinct=True" keyword to '
                                                 'query method and do not apply it in query '
                                                 'structure itself. for example instead of '
                                                 'writing '
                                                 '"store.query(distinct(Entity.id)).count()" '
                                                 'you should write this in the following form '
                                                 '"store.query(Entity.id).count()" and then '
                                                 'pass "distinct=True" in options of query '
                                                 'method. if you want the sqlalchemy original '
                                                 'style of count() which produces a subquery, '
                                                 'it is also possible to fallback to that '
                                                 'default sqlalchemy count() but keep in '
                                                 'mind that, that method is not efficient. '
                                                 'you could pass "fallback=True" in options '
                                                 'to fallback to default mode if CoreSession '
                                                 'failed to execute count().')
            else:
                needs_fallback = True
                break

        fullname = single_column.get_table_fullname()
        if fullname not in (None, ''):
            columns.append(fullname)

    if needs_fallback is True:
        return super().count()

    func_count = func.count()
    if len(columns) > 0:
        distinct = options.get('distinct', False)
        column_clause = ', '.join(columns)
        if distinct is True:
            column_clause = 'distinct {clause}'.format(clause=column_clause)
        func_count = func.count(column_clause)

    statement = self.options(lazyload('*')).statement.with_only_columns(
        [func_count]).order_by(None)

    store = get_current_store()
    return store.execute(statement).scalar()

THIS IS INSIDE SUB-CLASSED COLUMN:

def get_table_fullname(self):
    """
    gets the current column's table fullname, if the
    column has no table, it returns None.

    :rtype: str
    """

    if self.table is None:
        return None

    for single_column in self.table.columns:
        if single_column.name == self.name:
            for base_column in single_column.base_columns:
                return '{table}.{column}'.format(table=base_column.table.fullname,
                                                 column=self._get_real_name())

    return None

def _get_real_name(self):
    """
    gets column's real name.

    :rtype: str
    """

    for column in self.base_columns:
        return column.name

Then you could use it as normal select:
count = session.query(SomeEntity.Field1, SomeEntity.Field2).count()
and this will produce:
select count(field1, field2) from some_table

the full project is also available on GitHub.
https://github.com/mononobi/pyrin

@transfluxus
Copy link

I am not an expert on this, but it doesnt give me the right result in a query which has a join, filter and contains_eager option

@Bryant-Yang
Copy link

Bryant-Yang commented Aug 17, 2020

@mononobi

I've implemented this in a subclass of Session with count() overridden, so now I could use query.count() and it does not make a sub-query ...

why not just override default query's count?

original 'count' in default query:

    def count(self):
        r"""Return a count of rows this the SQL formed by this :class:`Query`
        would return.

        This generates the SQL for this Query as follows::

            SELECT count(1) AS count_1 FROM (
                SELECT <rest of query follows...>
            ) AS anon_1

        The above SQL returns a single row, which is the aggregate value
        of the count function; the :meth:`_query.Query.count`
        method then returns
        that single integer value.

        .. warning::

            It is important to note that the value returned by
            count() is **not the same as the number of ORM objects that this
            Query would return from a method such as the .all() method**.
            The :class:`_query.Query` object,
            when asked to return full entities,
            will **deduplicate entries based on primary key**, meaning if the
            same primary key value would appear in the results more than once,
            only one object of that primary key would be present.  This does
            not apply to a query that is against individual columns.

            .. seealso::

                :ref:`faq_query_deduplicating`

                :ref:`orm_tutorial_query_returning`

        For fine grained control over specific columns to count, to skip the
        usage of a subquery or otherwise control of the FROM clause, or to use
        other aggregate functions, use :attr:`~sqlalchemy.sql.expression.func`
        expressions in conjunction with :meth:`~.Session.query`, i.e.::

            from sqlalchemy import func

            # count User records, without
            # using a subquery.
            session.query(func.count(User.id))

            # return count of user "id" grouped
            # by "name"
            session.query(func.count(User.id)).\
                    group_by(User.name)

            from sqlalchemy import distinct

            # count distinct "name" values
            session.query(func.count(distinct(User.name)))

        """
        col = sql.func.count(sql.literal_column("*"))
        return self.from_self(col).scalar()

overridden 'count' in custom query which is subclass of BaseQuery

    def count(self):
		disable_group_by = False
		if len(self._entities) > 1:
			# currently support only one entity
			raise Exception('only one entity is supported for get_count, got: %s' % q)
		entity = self._entities[0]
		if hasattr(entity, 'column'):
			# _ColumnEntity has column attr - on case: query(Model.column)...
			col = entity.column
			if self._group_by and self._distinct:
				# which query can have both?
				raise NotImplementedError
			if self._group_by or self._distinct:
				col = distinct(col)
			if self._group_by:
				# need to disable group_by and enable distinct - we can do this because we have only 1 entity
				disable_group_by = True
			count_func = func.count(col)
		else:
			# _MapperEntity doesn't have column attr - on case: query(Model)...
			count_func = func.count()
		if self._group_by and not disable_group_by:
			count_func = count_func.over(None)
		count_q = self.options(lazyload('*')).statement.with_only_columns([count_func]).order_by(None)
		if disable_group_by:
			count_q = count_q.group_by(None)
		return self.session.execute(count_q).scalar()

had to say, original count with self.from_self(col).scalar() is clear enough, and no troubles caused by group by, distinct, or may be even more advanced query, what ever.

@zt50tz
Copy link

zt50tz commented Nov 8, 2020

Hi all!

If you using ORM querys with model as first argument — the comments below are good to use.

But, in my project i use many querys like:
db.session.query(User.id, UserInfo.name).outerjoin(UserInfo, UserInfo.user_id == User.id)

And prev examples are not able to use.

So i write class that process the query and removes unusable joins for count query.
This class checks joins of query and where conditions then removes unnecessary joins.

Usage:

q = db.session.query(
    User.id, User.nname, UserInfo.theme_id
).outerjoin(
    UserInfo, UserInfo.user_id == User.id
).filter(
    User.id > 0,
    UserInfo.name.like('123')
)

# where=True — take a look on "where"; no drop joins if it exists in "filter" statements
# fields=False — not take a look on fields; drop joins if it exists in fields
count_q = DBRemoveJoin(q, fields=False, where=True, debug=False).process()
count_q = count_q.statement.with_only_columns([func.count('*')]).order_by(None)

print(count_q)

Output:

SELECT count(:count_2) AS count_1
FROM "User" LEFT OUTER JOIN "UserInfo" ON "UserInfo".user_id = "User".id
WHERE "User".id > :id_1 AND "UserInfo".name LIKE :name_1

If you comment UserInfo where statement:

q = db.session.query(
    User.id, User.nname, UserInfo.theme_id
).outerjoin(
    UserInfo, UserInfo.user_id == User.id
).filter(
    User.id > 0,
    # UserInfo.name.like('123')
)

Result will be:

SELECT count(:count_2) AS count_1
FROM "User"
WHERE "User".id > :id_1

Class code:


import collections
import six
from sqlalchemy.orm.util import _ORMJoin, outerjoin
from sqlalchemy import Table, Column

from collections import OrderedDict


def iterable(arg):
    return (
        isinstance(arg, collections.Iterable)
        and not isinstance(arg, six.string_types)
    )


class DBRemoveJoin:
    q = None

    fields = True
    where = True
    debug = False

    table_main = None
    table_need = []
    table_joins = {}

    column_items = []
    where_items = []

    def __init__(self, q, fields=True, where=True, debug=False):
        self.fields = fields
        self.where = where
        self.debug = debug
        self.q = q.filter()
        self.table_need = []
        self.column_items = []
        self.where_items = []
        self.table_joins = OrderedDict()

    @staticmethod
    def clause_columns_process(clause):
        ret = []
        column_objs = []
        if hasattr(clause, 'left'):
            column_objs.append(clause.left)
        if hasattr(clause, 'right'):
            column_objs.append(clause.right)
        if column_objs:
            for column_obj in column_objs:
                if isinstance(column_obj, Column):
                    where_item = {
                        'name': column_obj.name,
                        'column': column_obj,
                        'table_name': str(column_obj.table.name)
                    }
                    ret.append(where_item)
        else:
            for tbl in clause._from_objects:
                where_item = {
                    'name': None,
                    'column': None,
                    'table_name': str(tbl.name)
                }
                ret.append(where_item)
        return ret

    @staticmethod
    def clause_columns(clauses):
        ret = []
        if not iterable(clauses):
            clauses = [clauses]
        for clause in clauses:
            ret += DBRemoveJoin.clause_columns_process(clause)
        return ret
            
    def table_joins_process_join(self, table):
        if isinstance(table.left, Table):
            self.table_main = table.left
        else:
            self.table_joins_process_join(table.left)

        right_str = str(table.right)
        if not right_str:
            right_str = str(table.right.selectable.name)

        right_el = {
            'name': right_str,
            'table': table.right,
            'onclause': table.onclause.expression,
            'table_need': [],
            'need': False
        }

        clause_columns = self.clause_columns(table.onclause)
        clause_columns_skip_table_name = [right_el['name'], self.table_main]
        for clause_column in clause_columns:
            if clause_column['table_name'] not in clause_columns_skip_table_name:
                right_el['table_need'].append(clause_column['table_name'])

        self.table_joins[right_str] = right_el

    def table_joins_process(self):
        for table in self.q._from_obj:
            if isinstance(table, _ORMJoin):
                self.table_joins_process_join(table)
            else:
                pass

    def table_need_add(self, table_name):
        if table_name in self.table_need:
            return
        self.table_need.append(table_name)
        table_join = self.table_joins.get(table_name)
        if table_join:
            table_join['need'] = True
            for sub_table_name in table_join['table_need']:
                self.table_need_add(sub_table_name)

    def fields_process(self):
        if not self.fields:
            return
        for column in self.q.statement.columns:
            column_obj = list(column.base_columns)[0]
            column_el = {
                'name': column.name,
                'column': column_obj,
                'table_name': str(column_obj.table.name)
            }
            self.column_items.append(column_el)
            self.table_need_add(column_el['table_name'])

    def where_process_item(self, clause):
        if clause is None:
            return
        if iterable(clause):
            for el in clause:
                self.where_process_item(el)
            return
        items = self.clause_columns(clause)
        for item in items:
            self.table_need_add(item['table_name'])

    def where_process(self):
        if not self.where:
            return
        if not iterable(self.q.whereclause):
            clauses = self.q.whereclause
        else:
            clauses = self.q.whereclause.clauses
        self.where_process_item(clauses)

    def query_process(self):
        if not self.table_joins:
            return

        self.q._from_obj = (self.table_main, )
        for table in self.table_joins.values():
            if not table['need']:
                continue
            self.q._from_obj = (outerjoin(self.q._from_obj[0], table['table'], table['onclause']), )

    def process(self):
        self.table_joins_process()
        self.fields_process()
        self.where_process()
        self.query_process()
        return self.q

Keep in mind that is home project code.

Sorry for my english (=

@ivcuello
Copy link

ivcuello commented Jan 5, 2021

If anyone is having issues of the ORM "dropping" the FROM statement all together. In my case it was solved by referencing a column in the func.count() call. Instead of using a literal like func.count(1) just reference a column (func.count(Table.id)) & the ORM won't delete the from clause in the resulting query.

@kigawas
Copy link

kigawas commented Apr 9, 2021

Nice function! FYI it does not consider the query limit.

So i extended the function for my usage.

count = q.session.execute(count_q).scalar()
return min(q._limit, count) if q._limit else count

THIS METHOD DOES NOT APPLY TO QUERIES WITH LIMIT AND OFFSET

What if I want to know the total count? Normally in an API's pagination, (page, page_size, total_count) is needed.

Say

SELECT  count(*) AS total_count
FROM user
WHERE user.is_deleted = false
 LIMIT 20 OFFSET 5;

If you execute this in a DB, it'll return nothing, which is None in Python. In fact, the count_q cannot include limit and offset:

image

Although you can work around like this, it still isn't ideal due to overheads.

Solution

To avoid using limit and offset in the query.

@transfluxus
Copy link

Hi,
somebody an idea how that looks like in sqlalchemy 1.4?
I am getting:

sqlalchemy.exc.ArgumentError: Query has only expression-based entities - can't find property named "template". where template is something from my db I guess

@kigawas
Copy link

kigawas commented Oct 25, 2021

⚠️ YOU DO NOT NEED THIS ⚠️

https://dba.stackexchange.com/questions/168022/performance-of-count-in-subquery

⚠️ IF YOUR QUERY IS SLOW, YOU SHOULD REMOVE ORDER BY FIRST ⚠️

count_stmt = select(func.count()).select_from(statement.order_by(None).subquery())
await session.scalar(count_stmt)

@davidjb99
Copy link

davidjb99 commented Nov 8, 2021

This worked perfectly for me after I upgraded from sa 1.3 to 1.4 and a count query went from 80ms to 800ms.

by switching form the built in .count() to the suggested first gist the query went back to 80ms. I believe the problem was .count() loading all columns into python which is not required and very slow for thousands or rows, using with_only_columns and removing the sub query took it back to 80ms. No idea what broke it in 1.4.

    count_q = q.statement.with_only_columns([func.count()]).order_by(None)
    count = q.session.execute(count_q).scalar()

I'm not sure why people are putting warnings on this gist or writing very long replies expecting help. If you are stuck ask on Stackoverflow!

My thanks to @hest

@kigawas
Copy link

kigawas commented Nov 9, 2021

This worked perfectly for me after I upgraded from sa 1.3 to 1.4 and a count query went from 80ms to 800ms.

by switching form the built in .count() to the suggested first gist the query went back to 80ms. I believe the problem was .count() loading all columns into python which is not required and very slow for thousands or rows, using with_only_columns and removing the sub query took it back to 80ms. No idea what broke it in 1.4.

    count_q = q.statement.with_only_columns([func.count()]).order_by(None)
    count = q.session.execute(count_q).scalar()

I'm not sure why people are putting warnings on this gist or writing very long replies expecting help. If you are stuck ask on Stackoverflow!

My thanks to @hest

This thread is generally a kind of misinformation. If you check the stackoverflow link above, these sqls are exactly the same:

-- on postgres
EXPLAIN ANALYZE SELECT COUNT(*) FROM some_big_table WHERE some_col = 'some_val'
EXPLAIN ANALYZE SELECT COUNT(*) FROM ( SELECT col1, col2, col3, col4 FROM some_big_table WHERE some_col = 'some_val' )

If you find your query executing slow, the first is to try removing order by.

@davidjb99
Copy link

If you find your query executing slow, the first is to try removing order by.

I tried this, and the methods outlined in the stackoverflow link, and it did not work. Please don't say I'm providing misinformation by outlining what worked for me it is rude.

@kigawas
Copy link

kigawas commented Dec 20, 2021

@davidjb99

Sorry for letting you misunderstand. I meant the gist (which was posted in 2015) is likely misinformation now, not your comment.

@kannasuresh99
Copy link

This is the method, I'm using

def get_count(self, model_fields, filter_clause):
        """ Note: filter_clause should not be 'None' or 'Null' for this method to work """
        query = self.session.query().with_entities(*model_fields)
        query = query.filter(filter_clause)
        count_query = query.statement \
                .with_only_columns([func.count()]) \
                .order_by(None)
        result = query.session.execute(count_query).scalar()
        return result

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment