Skip to content

Instantly share code, notes, and snippets.

@chrislawlor
Last active April 16, 2020 13:41
Show Gist options
  • Save chrislawlor/0612e8cdaa564e7ed25fd8fc7382199d to your computer and use it in GitHub Desktop.
Save chrislawlor/0612e8cdaa564e7ed25fd8fc7382199d to your computer and use it in GitHub Desktop.
Python iterable processing with delayed execution. Inspired by LINQ
Copyright 2020 Christopher Lawlor
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
from functools import partial
from typing import Any, Callable, Generator, Hashable, Iterable, List, TypeVar
T = TypeVar("T")
TExpr = Callable[[T], Any]
class _Element:
"""
Simple class used as a target for query expressions. It is a placeholder
for any given item in the query source collection.
"""
def __eq__(self, other):
return lambda x: x == other
def __ne__(self, other):
return lambda x: x != other
def __gt__(self, other):
return lambda x: x > other
def __ge__(self, other):
return lambda x: x >= other
def __lt__(self, other):
return lambda x: x < other
def __le__(self, other):
return lambda x: x <= other
def __call__(self, *args, **kwargs):
raise ValueError("Must be used in an expression")
def __repr__(self):
return "<class 'linq.Q'>"
# This is the public object that client code will use. It would have been preferable
# to implement our dunder method overrides on Q as class or static method, but
# that doesn't seem to work. Since there's no reason for client code to really ever
# care about instantiating a Q, as it holds no state, a singleton is a reasonable
# compromise
_Element.__name__ = "Element"
Element = _Element()
class Linq:
def __init__(self, collection: Iterable[T]):
self._collection = collection
# seed our operations stack with an identity operator
# This makes us iterable immediately, and also
# makes __iter__ a little simpler than it might be otherwise
def identity(stream):
yield from stream
self._operations: List[TExpr] = [identity]
# Chainable methods
def distinct(self, hash_=False) -> "Linq":
"""
Remove duplicates.
This requires keeping a set of every seen element in memory.
Since we only care about membership, and never need to retrieve
the values from the set of seen elements, ``distinct`` supports
an ``hash_`` option. If set, we only store the hash of each element
in the set of seen elements. For cases where the hash is smaller than
the element, this can significantly reduce memory requirements.
"""
def _distinct(collection):
seen = set()
for item in collection:
# use "if not" to leverage short-circuiting and only
# call hash() when we need to
comparator = item if not hash_ else hash(item)
if comparator not in seen:
yield item
seen.add(comparator)
self._operations.append(_distinct)
return self
def where(self, filter_: TExpr) -> "Linq":
"""
Allow elements meeting some condition.
"""
def where_(collection):
yield from filter(filter_, collection)
self._operations.append(where_)
return self
def apply(self, selector: TExpr) -> "Linq":
"""
Apply some transorm on received elements.
"""
def select_(collection):
return (selector(i) for i in collection)
self._operations.append(select_)
return self
# Terminating methods
def any(self, filter_: TExpr) -> bool:
return any(filter(filter_, self._collection))
def count(self) -> int:
"""
Get the final count of the query's results.
This consumes the query.
"""
return len(self)
# Operators
@staticmethod
def eq(value) -> Callable[[T], bool]:
return lambda x: x == value
@staticmethod
def ne(value) -> Callable[[T], bool]:
return lambda x: x != value
@staticmethod
def gt(value) -> Callable[[T], bool]:
return lambda x: x > value
@staticmethod
def gte(value) -> Callable[[T], bool]:
return lambda x: x >= value
@staticmethod
def lt(value) -> Callable[[T], bool]:
return lambda x: x < value
@staticmethod
def lte(value) -> Callable[[T], bool]:
return lambda x: x <= value
# Internals
def __len__(self):
# Gets the query result count without storing
# results in memory
count = 0
for _ in self:
count += 1
return count
def __iter__(self):
# converts our list of operations from this:
# [f0(iter), f1(iter), f2(iter), ... fn(iter)]
# to this:
# fn(f2(f1(f0(iter))))
# We always have at least our identity operator in the
# stack, so will never get an IndexError here
p = partial(self._operations[0], self._collection)
# func-ception
for operation in self._operations[1:]:
p = partial(operation, p())
yield from p()
import pytest
from linq import Linq, Element
def test_any():
assert Linq(["foo"]).any(lambda s: s == "foo") is True
def test_eq():
assert Linq(["foo"]).any(Linq.eq("foo")) is True
def test_ne():
assert Linq(["foo"]).any(Linq.ne("foo")) is False
def test_gt():
l = Linq([5, 10]).where(Linq.gt(5))
assert list(l) == [10]
def test_gte():
l = Linq([5, 6, 10]).where(Linq.gte(6))
assert list(l) == [6, 10]
def test_lt():
l = Linq([5, 10]).where(Linq.lt(10))
assert list(l) == [5]
def test_lte():
l = Linq([5, 9, 10]).where(Linq.lte(9))
assert list(l) == [5, 9]
def test_count():
assert Linq(range(10)).where(Linq.gt(5)).count() == 4 # [6, 7, 8, 9]
def test_always_iterable():
l = Linq(["foo"])
result = list(l)
assert result == ["foo"]
def test_distinct():
l = Linq(["foo", "foo"]).distinct()
result = list(l)
assert result == ["foo"]
def test_distinct_with_hash():
l = Linq(["foo", "foo"]).distinct(hash_=True)
assert list(l) == ["foo"]
def test_chain():
l = Linq(["foo", "foo", "bar"]).distinct().any(lambda n: n == "bar")
assert l is True
def test_where():
l = Linq(["alice", "bob"]).where(lambda n: n == "alice")
result = list(l)
assert result == ["alice"]
def test_apply():
l = Linq(["alice", "bob"]).apply(lambda x: len(x)).where(Linq.gt(3))
assert list(l) == [5]
def test_Q_eq():
assert (Element == "foo")("foo") is True
assert (Element == "foo")("bar") is False
def test_Q_ne():
assert (Element != "foo")("foo") is False
assert (Element != "foo")("bar") is True
def test_Q_gt():
assert (Element > 5)(6) is True
assert (Element > 5)(5) is False
assert (Element > 5)(4) is False
def test_Q_gte():
assert (Element >= 5)(5) is True
assert (Element >= 5)(6) is True
assert (Element >= 5)(4) is False
def test_any_with_Q():
assert Linq(["foo", "bar"]).any(Element == "foo") is True
def test_value_error_with_naked_Element():
with pytest.raises(ValueError):
Linq(["a"]).any(Element)
if __name__ == "__main__":
pytest.main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment