ruby-changes:53733
From: knu <ko1@a...>
Date: Sat, 24 Nov 2018 17:38:40 +0900 (JST)
Subject: [ruby-changes:53733] knu:r65949 (trunk): Implement Enumerator#+ and Enumerable#chain [Feature #15144]
knu 2018-11-24 17:38:35 +0900 (Sat, 24 Nov 2018) New Revision: 65949 https://svn.ruby-lang.org/cgi-bin/viewvc.cgi?view=revision&revision=65949 Log: Implement Enumerator#+ and Enumerable#chain [Feature #15144] They return an Enumerator::Chain object which is a subclass of Enumerator, which represents a chain of enumerables that works as a single enumerator. ```ruby e = (1..3).chain([4, 5]) e.to_a #=> [1, 2, 3, 4, 5] e = (1..3).each + [4, 5] e.to_a #=> [1, 2, 3, 4, 5] ``` Modified files: trunk/enumerator.c trunk/test/ruby/test_enumerator.rb Index: enumerator.c =================================================================== --- enumerator.c (revision 65948) +++ enumerator.c (revision 65949) @@ -12,6 +12,7 @@ https://github.com/ruby/ruby/blob/trunk/enumerator.c#L12 ************************************************/ +#include "ruby/ruby.h" #include "internal.h" #include "id.h" @@ -161,6 +162,13 @@ struct proc_entry { https://github.com/ruby/ruby/blob/trunk/enumerator.c#L162 static VALUE generator_allocate(VALUE klass); static VALUE generator_init(VALUE obj, VALUE proc); +static VALUE rb_cEnumChain; + +struct enum_chain { + VALUE enums; + long pos; +}; + static VALUE rb_cArithSeq; /* @@ -2412,6 +2420,300 @@ stop_result(VALUE self) https://github.com/ruby/ruby/blob/trunk/enumerator.c#L2420 } /* + * Document-class: Enumerator::Chain + * + * Enumerator::Chain is a subclass of Enumerator, which represents a + * chain of enumerables that works as a single enumerator. + * + * This type of objects can be created by Enumerable#chain and + * Enumerator#+. + */ + +static void +enum_chain_mark(void *p) +{ + struct enum_chain *ptr = p; + rb_gc_mark(ptr->enums); +} + +#define enum_chain_free RUBY_TYPED_DEFAULT_FREE + +static size_t +enum_chain_memsize(const void *p) +{ + return sizeof(struct enum_chain); +} + +static const rb_data_type_t enum_chain_data_type = { + "chain", + { + enum_chain_mark, + enum_chain_free, + enum_chain_memsize, + }, + 0, 0, RUBY_TYPED_FREE_IMMEDIATELY +}; + +static struct enum_chain * +enum_chain_ptr(VALUE obj) +{ + struct enum_chain *ptr; + + TypedData_Get_Struct(obj, struct enum_chain, &enum_chain_data_type, ptr); + if (!ptr || ptr->enums == Qundef) { + rb_raise(rb_eArgError, "uninitialized chain"); + } + return ptr; +} + +/* :nodoc: */ +static VALUE +enum_chain_allocate(VALUE klass) +{ + struct enum_chain *ptr; + VALUE obj; + + obj = TypedData_Make_Struct(klass, struct enum_chain, &enum_chain_data_type, ptr); + ptr->enums = Qundef; + ptr->pos = -1; + + return obj; +} + +/* + * call-seq: + * Enumerator::Chain.new(*enums) -> enum + * + * Generates a new enumerator object that iterates over the elements + * of given enumerable objects in sequence. + * + * e = Enumerator::Chain.new(1..3, [4, 5]) + * e.to_a #=> [1, 2, 3, 4, 5] + * e.size #=> 5 + */ +static VALUE +enum_chain_initialize(VALUE obj, VALUE enums) +{ + struct enum_chain *ptr; + + rb_check_frozen(obj); + TypedData_Get_Struct(obj, struct enum_chain, &enum_chain_data_type, ptr); + + if (!ptr) rb_raise(rb_eArgError, "unallocated chain"); + + ptr->enums = rb_obj_freeze(enums); + ptr->pos = -1; + + return obj; +} + +/* :nodoc: */ +static VALUE +enum_chain_init_copy(VALUE obj, VALUE orig) +{ + struct enum_chain *ptr0, *ptr1; + + if (!OBJ_INIT_COPY(obj, orig)) return obj; + ptr0 = enum_chain_ptr(orig); + + TypedData_Get_Struct(obj, struct enum_chain, &enum_chain_data_type, ptr1); + + if (!ptr1) rb_raise(rb_eArgError, "unallocated chain"); + + ptr1->enums = ptr0->enums; + ptr1->pos = ptr0->pos; + + return obj; +} + +static VALUE +enum_chain_total_size(VALUE enums) +{ + VALUE total = INT2FIX(0); + + RARRAY_PTR_USE(enums, ptr, { + long i; + + for (i = 0; i < RARRAY_LEN(enums); i++) { + VALUE size = enum_size(ptr[i]); + + if (NIL_P(size) || (RB_TYPE_P(size, T_FLOAT) && isinf(NUM2DBL(size)))) { + return size; + } + if (!RB_INTEGER_TYPE_P(size)) { + return Qnil; + } + + total = rb_funcall(total, '+', 1, size); + } + }); + + return total; +} + +/* + * call-seq: + * obj.size -> integer + * + * Returns the total size of the enumerator chain calculated by + * summing up the size of each enumerable in the chain. If any of the + * enumerables reports its size as nil or Float::INFINITY, that value + * is returned as the total size. + */ +static VALUE +enum_chain_size(VALUE obj) +{ + return enum_chain_total_size(enum_chain_ptr(obj)->enums); +} + +static VALUE +enum_chain_enum_size(VALUE obj, VALUE args, VALUE eobj) +{ + return enum_chain_size(obj); +} + +static VALUE +enum_chain_yield_block(VALUE arg, VALUE block, int argc, VALUE *argv) +{ + return rb_funcallv(block, rb_intern("call"), argc, argv); +} + +static VALUE +enum_chain_enum_no_size(VALUE obj, VALUE args, VALUE eobj) +{ + return Qnil; +} + +/* + * call-seq: + * obj.each(*args) { |...| ... } -> obj + * obj.each(*args) -> enumerator + * + * Iterates over the elements of the first enumerable by calling the + * "each" method on it with the given arguments, then proceeds to the + * following enumerables in sequence until all of the enumerables are + * exhausted. + * + * If no block is given, returns an enumerator. + */ +static VALUE +enum_chain_each(int argc, VALUE *argv, VALUE obj) +{ + VALUE enums, block; + struct enum_chain *objptr; + + RETURN_SIZED_ENUMERATOR(obj, argc, argv, argc > 0 ? enum_chain_enum_no_size : enum_chain_enum_size); + + objptr = enum_chain_ptr(obj); + enums = objptr->enums; + block = rb_block_proc(); + + RARRAY_PTR_USE(enums, ptr, { + long i; + + for (i = 0; i < RARRAY_LEN(enums); i++) { + objptr->pos = i; + rb_block_call(ptr[i], id_each, argc, argv, enum_chain_yield_block, block); + } + }); + + return obj; +} + +/* + * call-seq: + * obj.rewind -> obj + * + * Rewinds the enumerator chain by calling the "rewind" method on each + * enumerable in reverse order. Each call is performed only if the + * enumerable responds to the method. + */ +static VALUE +enum_chain_rewind(VALUE obj) +{ + struct enum_chain *objptr = enum_chain_ptr(obj); + VALUE enums = objptr->enums; + + RARRAY_PTR_USE(enums, ptr, { + long i; + + for (i = objptr->pos; 0 <= i && i < RARRAY_LEN(enums); objptr->pos = --i) { + rb_check_funcall(ptr[i], id_rewind, 0, 0); + } + }); + + return obj; +} + +static VALUE +inspect_enum_chain(VALUE obj, VALUE dummy, int recur) +{ + VALUE klass = rb_obj_class(obj); + struct enum_chain *ptr; + + TypedData_Get_Struct(obj, struct enum_chain, &enum_chain_data_type, ptr); + + if (!ptr || ptr->enums == Qundef) { + return rb_sprintf("#<%"PRIsVALUE": uninitialized>", rb_class_path(klass)); + } + + if (recur) { + return rb_sprintf("#<%"PRIsVALUE": ...>", rb_class_path(klass)); + } + + return rb_sprintf("#<%"PRIsVALUE": %+"PRIsVALUE">", rb_class_path(klass), ptr->enums); +} + +/* + * call-seq: + * obj.inspect -> string + * + * Returns a printable version of the enumerator chain. + */ +static VALUE +enum_chain_inspect(VALUE obj) +{ + return rb_exec_recursive(inspect_enum_chain, obj, 0); +} + +/* + * call-seq: + * e.chain(*enums) -> enumerator + * + * Returns an enumerator object generated from this enumerator and + * given enumerables. + * + * e = (1..3).chain([4, 5]) + * e.to_a #=> [1, 2, 3, 4, 5] + */ +static VALUE +enum_chain(int argc, VALUE *argv, VALUE obj) +{ + VALUE enums = rb_ary_new_from_values(1, &obj); + rb_ary_cat(enums, argv, argc); + + return enum_chain_initialize(enum_chain_allocate(rb_cEnumChain), enums); +} + +/* + * call-seq: + * e + enum -> enumerator + * + * Returns an enumerator object generated from this enumerator and a + * given enumerable. + * + * e = (1..3).each + [4, 5] + * e.to_a #=> [1, 2, 3, 4, 5] + */ +static VALUE +enumerator_plus(VALUE obj, VALUE eobj) +{ + VALUE enums = rb_ary_new_from_args(2, obj, eobj); + + return enum_chain_initialize(enum_chain_allocate(rb_cEnumChain), enums); +} + +/* * Document-class: Enumerator::ArithmeticSequence * * Enumerator::ArithmeticSequence is a subclass of Enumerator, @@ -2907,6 +3209,8 @@ InitVM_Enumerator(void) https://github.com/ruby/ruby/blob/trunk/enumerator.c#L3209 rb_define_method(rb_cEnumerator, "rewind", enumerator_rewind, 0); rb_define_method(rb_cEnumerator, "inspect", enumerator_inspect, 0); rb_define_method(rb_cEnumerator, "size", enumerator_size, 0); + rb_define_method(rb_cEnumerator, "+", enumerator_plus, 1); + rb_define_method(rb_mEnumerable, "chain", enum_chain, -1); /* Lazy */ rb_cLazy = rb_define_class_under(rb_cEnumerator, "Lazy", rb_cEnumerator); @@ -2960,6 +3264,16 @@ InitVM_Enumerator(void) https://github.com/ruby/ruby/blob/trunk/enumerator.c#L3264 rb_define_method(rb_cYielder, "yield", yielder_yield, -2); rb_define_method(rb_cYielder, "<<", yielder_yield_push, 1); + /* Chain */ + rb_cEnumChain = rb_define_class_under(rb_cEnumerator, "Chain", rb_cEnumerator); + rb_define_alloc_func(rb_cEnumChain, enum_chain_allocate); + rb_define_method(rb_cEnumChain, "initialize", enum_chain_initialize, -2); + rb_define_method(rb_cEnumChain, "initialize_copy", enum_chain_init_copy, 1); + rb_define_method(rb_cEnumChain, "each", enum_chain_each, -1); + rb_define_method(rb_cEnumChain, "size", enum_chain_size, 0); + rb_define_method(rb_cEnumChain, "rewind", enum_chain_rewind, 0); + rb_define_method(rb_cEnumChain, "inspect", enum_chain_inspect, 0); + /* ArithmeticSequence */ rb_cArithSeq = rb_define_class_under(rb_cEnumerator, "ArithmeticSequence", rb_cEnumerator); rb_undef_alloc_func(rb_cArithSeq); Index: test/ruby/test_enumerator.rb =================================================================== --- test/ruby/test_enumerator.rb (revision 65948) +++ test/ruby/test_enumerator.rb (revision 65949) @@ -670,5 +670,119 @@ class TestEnumerator < Test::Unit::TestC https://github.com/ruby/ruby/blob/trunk/test/ruby/test_enumerator.rb#L670 assert_equal([0, 1], u.force) assert_equal([0, 1], u.force) end -end + def test_enum_chain_and_plus + r = 1..5 + + e1 = r.chain() + assert_kind_of(Enumerator::Chain, e1) + assert_equal(5, e1.size) + ary = [] + e1.each { |x| ary << x } + assert_equal([1, 2, 3, 4, 5], ary) + + e2 = r.chain([6, 7, 8]) + assert_kind_of(Enumerator::Chain, e2) + assert_equal(8, e2.size) + ary = [] + e2.each { |x| ary << x } + assert_equal([1, 2, 3, 4, 5, 6, 7, 8], ary) + + e3 = r.chain([6, 7], 8.step) + assert_kind_of(Enumerator::Chain, e3) + assert_equal(Float::INFINITY, e3.size) + ary = [] + e3.take(10).each { |x| ary << x } + assert_equal([1, 2, 3, 4, 5, 6, 7, 8, 9, 10], ary) + + # `a + b + c` should not return `Enumerator::Chain.new(a, b, c)` + # because it is expected that `(a + b).each` be called. + e4 = e2.dup + class << e4 + attr_reader :each_is_called + def each + super + @each_is_called = true + end + end + e5 = e4 + 9.step + assert_kind_of(Enumerator::Chain, e5) + assert_equal(Float::INFINITY, e5.size) + ary = [] + e5.take(10).each { |x| ary << x } + assert_equal([1, 2, 3, 4, 5, 6, 7, 8, 9, 10], ary) + assert_equal(true, e4.each_is_called) + end + + def test_chained_enums + a = (1..5).each + + e0 = Enumerator::Chain.new() + assert_kind_of(Enumerator::Chain, e0) + assert_equal(0, e0.size) + ary = [] + e0.each { |x| ary << x } + assert_equal([], ary) + + e1 = Enumerator::Chain.new(a) + assert_kind_of(Enumerator::Chain, e1) + assert_equal(5, e1.size) + ary = [] + e1.each { |x| ary << x } + assert_equal([1, 2, 3, 4, 5], ary) + + e2 = Enumerator::Chain.new(a, [6, 7, 8]) + assert_kind_of(Enumerator::Chain, e2) + assert_equal(8, e2.size) + ary = [] + e2.each { |x| ary << x } + assert_equal([1, 2, 3, 4, 5, 6, 7, 8], ary) + + e3 = Enumerator::Chain.new(a, [6, 7], 8.step) + assert_kind_of(Enumerator::Chain, e3) + assert_equal(Float::INFINITY, e3.size) + ary = [] + e3.take(10).each { |x| ary << x } + assert_equal([1, 2, 3, 4, 5, 6, 7, 8, 9, 10], ary) + + e4 = Enumerator::Chain.new(a, Enumerator.new { |y| y << 6 << 7 << 8 }) + assert_kind_of(Enumerator::Chain, e4) + assert_equal(nil, e4.size) + ary = [] + e4.each { |x| ary << x } + assert_equal([1, 2, 3, 4, 5, 6, 7, 8], ary) + + e5 = Enumerator::Chain.new(e1, e2) + assert_kind_of(Enumerator::Chain, e5) + assert_equal(13, e5.size) + ary = [] + e5.each { |x| ary << x } + assert_equal([1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 6, 7, 8], ary) + + rewound = [] + e1.define_singleton_method(:rewind) { rewound << object_id } + e2.define_singleton_method(:rewind) { rewound << object_id } + e5.rewind + assert_equal(rewound, [e2.object_id, e1.object_id]) + + rewound = [] + a = [1] + e6 = Enumerator::Chain.new(a) + a.define_singleton_method(:rewind) { rewound << object_id } + e6.rewind + assert_equal(rewound, []) + + assert_equal( + '#<Enumerator::Chain: [' + + '#<Enumerator::Chain: [' + + '#<Enumerator: 1..5:each>' + + ']>, ' + + '#<Enumerator::Chain: [' + + '#<Enumerator: 1..5:each>, ' + + '[6, 7, 8]' + + ']>' + + ']>', + e5.inspect + ) + end +end -- ML: ruby-changes@q... Info: http://www.atdot.net/~ko1/quickml/