Skip to content

Instantly share code, notes, and snippets.

@0x0dea
Created June 17, 2015 12:16
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save 0x0dea/fe313cd4b0518cd190db to your computer and use it in GitHub Desktop.
Save 0x0dea/fe313cd4b0518cd190db to your computer and use it in GitHub Desktop.
Add case equality checks to Enumerable#any?/all?/none?/one?.
diff --git a/enum.c b/enum.c
index 4b1e119..c717540 100644
--- a/enum.c
+++ b/enum.c
@@ -1043,7 +1043,7 @@ enum_sort_by(VALUE obj)
return ary;
}
-#define ENUMFUNC(name) rb_block_given_p() ? name##_iter_i : name##_i
+#define ENUMFUNC(name, argc) argc ? name##_eqq : rb_block_given_p() ? name##_iter_i : name##_i
#define DEFINE_ENUMFUNCS(name) \
static VALUE enum_##name##_func(VALUE result, struct MEMO *memo); \
@@ -1061,6 +1061,12 @@ name##_iter_i(RB_BLOCK_CALL_FUNC_ARGLIST(i, memo)) \
} \
\
static VALUE \
+name##_eqq(RB_BLOCK_CALL_FUNC_ARGLIST(i, memo)) \
+{ \
+ return enum_##name##_func(rb_funcall(MEMO_CAST(memo)->v2, id_eqq, 1, i), MEMO_CAST(memo)); \
+} \
+\
+static VALUE \
enum_##name##_func(VALUE result, struct MEMO *memo)
DEFINE_ENUMFUNCS(all)
@@ -1090,10 +1096,13 @@ DEFINE_ENUMFUNCS(all)
*/
static VALUE
-enum_all(VALUE obj)
+enum_all(int argc, VALUE *argv, VALUE obj)
{
- struct MEMO *memo = MEMO_NEW(Qtrue, 0, 0);
- rb_block_call(obj, id_each, 0, 0, ENUMFUNC(all), (VALUE)memo);
+ struct MEMO *memo = MEMO_NEW(Qtrue, *argv, 0);
+
+ rb_check_arity(argc, 0, 1);
+ rb_block_call(obj, id_each, 0, 0, ENUMFUNC(all, argc), (VALUE)memo);
+
return memo->v1;
}
@@ -1124,10 +1133,13 @@ DEFINE_ENUMFUNCS(any)
*/
static VALUE
-enum_any(VALUE obj)
+enum_any(int argc, VALUE *argv, VALUE obj)
{
- struct MEMO *memo = MEMO_NEW(Qfalse, 0, 0);
- rb_block_call(obj, id_each, 0, 0, ENUMFUNC(any), (VALUE)memo);
+ struct MEMO *memo = MEMO_NEW(Qfalse, *argv, 0);
+
+ rb_check_arity(argc, 0, 1);
+ rb_block_call(obj, id_each, 0, 0, ENUMFUNC(any, argc), (VALUE)memo);
+
return memo->v1;
}
@@ -1367,12 +1379,14 @@ nmin_run(VALUE obj, VALUE num, int by, int rev)
*
*/
static VALUE
-enum_one(VALUE obj)
+enum_one(int argc, VALUE *argv, VALUE obj)
{
- struct MEMO *memo = MEMO_NEW(Qundef, 0, 0);
+ struct MEMO *memo = MEMO_NEW(Qundef, *argv, 0);
VALUE result;
- rb_block_call(obj, id_each, 0, 0, ENUMFUNC(one), (VALUE)memo);
+ rb_check_arity(argc, 0, 1);
+ rb_block_call(obj, id_each, 0, 0, ENUMFUNC(one, argc), (VALUE)memo);
+
result = memo->v1;
if (result == Qundef) return Qfalse;
return result;
@@ -1403,10 +1417,13 @@ DEFINE_ENUMFUNCS(none)
* [nil, false].none? #=> true
*/
static VALUE
-enum_none(VALUE obj)
+enum_none(int argc, VALUE *argv, VALUE obj)
{
- struct MEMO *memo = MEMO_NEW(Qtrue, 0, 0);
- rb_block_call(obj, id_each, 0, 0, ENUMFUNC(none), (VALUE)memo);
+ struct MEMO *memo = MEMO_NEW(Qtrue, *argv, 0);
+
+ rb_check_arity(argc, 0, 1);
+ rb_block_call(obj, id_each, 0, 0, ENUMFUNC(none, argc), (VALUE)memo);
+
return memo->v1;
}
@@ -3501,10 +3518,10 @@ Init_Enumerable(void)
rb_define_method(rb_mEnumerable, "partition", enum_partition, 0);
rb_define_method(rb_mEnumerable, "group_by", enum_group_by, 0);
rb_define_method(rb_mEnumerable, "first", enum_first, -1);
- rb_define_method(rb_mEnumerable, "all?", enum_all, 0);
- rb_define_method(rb_mEnumerable, "any?", enum_any, 0);
- rb_define_method(rb_mEnumerable, "one?", enum_one, 0);
- rb_define_method(rb_mEnumerable, "none?", enum_none, 0);
+ rb_define_method(rb_mEnumerable, "all?", enum_all, -1);
+ rb_define_method(rb_mEnumerable, "any?", enum_any, -1);
+ rb_define_method(rb_mEnumerable, "one?", enum_one, -1);
+ rb_define_method(rb_mEnumerable, "none?", enum_none, -1);
rb_define_method(rb_mEnumerable, "min", enum_min, -1);
rb_define_method(rb_mEnumerable, "max", enum_max, -1);
rb_define_method(rb_mEnumerable, "minmax", enum_minmax, 0);
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment