Skip to content

Instantly share code, notes, and snippets.

@camertron
Created July 21, 2020 17:09
Show Gist options
  • Star 3 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save camertron/36f553f0d98de86b34617acc1099f0cd to your computer and use it in GitHub Desktop.
Save camertron/36f553f0d98de86b34617acc1099f0cd to your computer and use it in GitHub Desktop.
Python implementation of Ruby's enumerable module
import unittest
def enum(obj):
if type(obj) == list:
return EnumList(obj)
elif type(obj) == dict:
return EnumDict(obj)
class Enumerable(object):
def each(self, cb):
for elem in self.generate():
cb(elem)
def map(self, cb):
return [cb(elem) for elem in self.generate()]
def filter_map(self, cb):
results = []
for elem in self.generate():
result = cb(elem)
if result:
results.append(result)
return results
def inject(self, init_val, cb):
memo = init_val
for elem in self.generate():
memo = cb(memo, elem)
return memo
def partition(self, cb):
first = []
second = []
for elem in self.generate():
if cb(elem):
first.append(elem)
else:
second.append(elem)
return [first, second]
def all(self, cb):
for elem in self.generate():
if not cb(elem):
return False
return True
def any(self, cb):
for elem in self.generate():
if cb(elem):
return True
return False
def cnt(self, cb):
counter = 0
for elem in self.generate():
if cb(elem):
counter += 1
return counter
def each_cons(self, n, cb):
current = []
gen = self.generate()
for i in range(0, n):
try:
current.append(next(gen))
except StopIteration:
break
if len(current) > 0:
cb(current)
while True:
try:
current = [*current[1:], next(gen)]
cb(current)
except StopIteration:
break
def each_slice(self, n, cb):
current = []
gen = self.generate()
for i in range(0, n):
try:
current.append(next(gen))
except StopIteration:
break
while True:
if len(current) == n:
cb(current)
current = []
else:
try:
current.append(next(gen))
except StopIteration:
break
if 0 < len(current) < n:
cb(current)
def find(self, cb):
for elem in self.generate():
if cb(elem):
return elem
def find_index(self, cb):
for (i, elem) in enumerate(self.generate()):
if cb(elem):
return i
def select(self, cb):
return [elem for elem in self.generate() if cb(elem)]
def reject(self, cb):
return [elem for elem in self.generate() if not cb(elem)]
class EnumList(list, Enumerable):
def generate(self):
for elem in self:
yield elem
class EnumDict(dict, Enumerable):
def generate(self):
for k in self:
yield (k, self[k])
class ListTest(unittest.TestCase):
def test_map(self):
result = enum([1, 2]).map(lambda elem: elem * 2)
self.assertEqual(result, [2, 4])
def test_filter_map(self):
result = enum([1, 2, 3, 4]).filter_map(lambda elem: elem * 2 if elem % 2 == 0 else None)
self.assertEqual(result, [4, 8])
def test_inject(self):
result = enum([1, 2, 3]).inject(0, lambda memo, elem: memo + elem)
self.assertEqual(result, 6)
def test_partition(self):
first, second = enum([1, 2, 3, 4]).partition(lambda elem: elem % 2 == 0)
self.assertEqual(first, [2, 4])
self.assertEqual(second, [1, 3])
def test_all(self):
self.assertTrue(enum([1, 2, 3]).all(lambda elem: elem > 0))
self.assertFalse(enum([-1, 2, 3]).all(lambda elem: elem > 0))
def test_any(self):
self.assertTrue(enum([-1, 2, -3]).any(lambda elem: elem > 0))
self.assertFalse(enum([-1, -2, -3]).any(lambda elem: elem > 0))
def test_count(self):
count = enum([1, 2, 3, 4, 5]).cnt(lambda elem: elem % 2 == 1)
self.assertEqual(count, 3)
def test_each_cons2(self):
results = []
enum([1, 2, 3, 4]).each_cons(2, lambda elems: results.append(elems))
self.assertEqual(results, [[1, 2], [2, 3], [3, 4]])
def test_each_cons3(self):
results = []
enum([1, 2, 3, 4]).each_cons(3, lambda elems: results.append(elems))
self.assertEqual(results, [[1, 2, 3], [2, 3, 4]])
def test_each_cons_empty(self):
results = []
enum([]).each_cons(3, lambda elems: results.append(elems))
self.assertEqual(results, [])
def test_each_cons_not_enough(self):
results = []
enum([1, 2]).each_cons(3, lambda elems: results.append(elems))
self.assertEqual(results, [[1, 2]])
def test_each_slice(self):
results = []
enum([1, 2, 3, 4]).each_slice(2, lambda elems: results.append(elems))
self.assertEqual(results, [[1, 2], [3, 4]])
def test_each_slice_empty(self):
results = []
enum([]).each_slice(2, lambda elems: results.append(elems))
self.assertEqual(results, [])
def test_each_slice_not_enough(self):
results = []
enum([1, 2, 3]).each_slice(2, lambda elems: results.append(elems))
self.assertEqual(results, [[1, 2], [3]])
def test_find(self):
result = enum([1, 2, 3, 4]).find(lambda elem: elem % 2 == 0)
self.assertEqual(result, 2)
def test_find_no_match(self):
result = enum([1, 2, 3, 4]).find(lambda elem: elem == 5)
self.assertEqual(result, None)
def test_find_index(self):
result = enum([1, 2, 3, 4]).find_index(lambda elem: elem % 3 == 0)
self.assertEqual(result, 2)
def test_select(self):
result = enum([1, 2, 3, 4]).select(lambda elem: elem % 2 == 0)
self.assertEqual(result, [2, 4])
def test_reject(self):
result = enum([1, 2, 3, 4]).reject(lambda elem: elem % 2 == 0)
self.assertEqual(result, [1, 3])
if __name__ == '__main__':
unittest.main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment