Skip to content

Instantly share code, notes, and snippets.

@dkubb
Last active August 5, 2016 18:47
Show Gist options
  • Save dkubb/874e9d1dc9631c8c16f9022ef0779e7e to your computer and use it in GitHub Desktop.
Save dkubb/874e9d1dc9631c8c16f9022ef0779e7e to your computer and use it in GitHub Desktop.
module EnumerableExtensions
# An exception raised when an invalid number of entries is returned
class InvalidCountError < StandardError
# Initialize an exception to report an invalid enumerable count
#
# @param expectation [String]
# @param entries [Array]
#
# @return [undefined]
#
# @api public
def initialize(expectation, entries)
super(
'Found %{count}, expected %{expectation}' % {
count: entries.count,
expectation: expectation
}
)
end
end # InvalidCountError
# Object to represent undefined arguments
Undefined = Object.new.freeze
# Return exactly one entry from the enumerable
#
# @param default [Object]
#
# @yield [count, entries]
#
# @yieldparam [Integer] count
# @yieldparam [Enumerable] entries
#
# @yieldreturn [Object]
# return the default from the block, if provided
#
# @return [Object]
# returned if exactly one entry is found
#
# @raise [InvalidCountError]
# raised if zero or more than one entry is found and there is no default
#
# @api public
def one(default = Undefined)
block = -> (*block_args) { [yield(*block_args)] } if block_given?
result = if block || default.equal?(Undefined)
exactly(1, default, &block)
else
exactly(1, [default])
end
result.fetch(0)
end
# Return one or more entries from the enumerable
#
# @return [Array]
# returned if one or more entries
#
# @raise [InvalidCountError]
# raised if zero entries are found
#
# @api public
def min_one
entries = to_a
fail InvalidCountError.new('one or more', entries) if entries.none?
entries
end
# Return zero or one entry from the enumerable
#
# @return [Object]
# returned if zero or one entry is found
#
# @raise [InvalidCountError]
# raised if more than one entry is ound
#
# @api public
def max_one
entries = take(2).to_a
fail InvalidCountError.new('zero or one', entries) if entries.many?
entries.first
end
# Return an exact number of entries from the enumerable
#
# @param count [Integer]
# @param default [Object]
#
# @yield [count, entries]
#
# @yieldparam [Integer] count
# @yieldparam [Enumerable] entries
#
# @yieldreturn [Object]
# return the default from the block, if provided
#
# @return [Enumerable]
# returned if exactly one entry is found
#
# @raise [InvalidCountError]
# raised if an invalid number of entries is found
#
# @api public
def exactly(count, default = Undefined, &block)
assert_default_or_block(default, &block)
entries = take(count.succ).to_a
return entries if entries.count.equal?(count)
return default unless default.equal?(Undefined)
block ||= -> (*args) { fail(InvalidCountError.new(*args)) }
block.call(count, self)
end
private
# Assert that a block and default argument cannot be provided together
#
# @param default [Object]
#
# @raise [ArgumentError]
# raised if a block and default value are provided
#
# @api private
def assert_default_or_block(default)
return unless block_given? && !default.equal?(Undefined)
fail ArgumentError, 'Must pass in a block or a default argument, not both'
end
end # EnumerableExtensions
ActiveRecord::Base.extend(EnumerableExtensions)
Array.module_eval { include EnumerableExtensions }
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment