[前][次][番号順一覧][スレッド一覧]

ruby-changes:53733

From: knu <ko1@a...>
Date: Sat, 24 Nov 2018 17:38:40 +0900 (JST)
Subject: [ruby-changes:53733] knu:r65949 (trunk): Implement Enumerator#+ and Enumerable#chain [Feature #15144]

knu	2018-11-24 17:38:35 +0900 (Sat, 24 Nov 2018)

  New Revision: 65949

  https://svn.ruby-lang.org/cgi-bin/viewvc.cgi?view=revision&revision=65949

  Log:
    Implement Enumerator#+ and Enumerable#chain [Feature #15144]
    
    They return an Enumerator::Chain object which is a subclass of
    Enumerator, which represents a chain of enumerables that works as a
    single enumerator.
    
    ```ruby
    e = (1..3).chain([4, 5])
    e.to_a #=> [1, 2, 3, 4, 5]
    
    e = (1..3).each + [4, 5]
    e.to_a #=> [1, 2, 3, 4, 5]
    ```

  Modified files:
    trunk/enumerator.c
    trunk/test/ruby/test_enumerator.rb
Index: enumerator.c
===================================================================
--- enumerator.c	(revision 65948)
+++ enumerator.c	(revision 65949)
@@ -12,6 +12,7 @@ https://github.com/ruby/ruby/blob/trunk/enumerator.c#L12
 
 ************************************************/
 
+#include "ruby/ruby.h"
 #include "internal.h"
 #include "id.h"
 
@@ -161,6 +162,13 @@ struct proc_entry { https://github.com/ruby/ruby/blob/trunk/enumerator.c#L162
 static VALUE generator_allocate(VALUE klass);
 static VALUE generator_init(VALUE obj, VALUE proc);
 
+static VALUE rb_cEnumChain;
+
+struct enum_chain {
+    VALUE enums;
+    long pos;
+};
+
 static VALUE rb_cArithSeq;
 
 /*
@@ -2412,6 +2420,300 @@ stop_result(VALUE self) https://github.com/ruby/ruby/blob/trunk/enumerator.c#L2420
 }
 
 /*
+ * Document-class: Enumerator::Chain
+ *
+ * Enumerator::Chain is a subclass of Enumerator, which represents a
+ * chain of enumerables that works as a single enumerator.
+ *
+ * This type of objects can be created by Enumerable#chain and
+ * Enumerator#+.
+ */
+
+static void
+enum_chain_mark(void *p)
+{
+    struct enum_chain *ptr = p;
+    rb_gc_mark(ptr->enums);
+}
+
+#define enum_chain_free RUBY_TYPED_DEFAULT_FREE
+
+static size_t
+enum_chain_memsize(const void *p)
+{
+    return sizeof(struct enum_chain);
+}
+
+static const rb_data_type_t enum_chain_data_type = {
+    "chain",
+    {
+        enum_chain_mark,
+        enum_chain_free,
+        enum_chain_memsize,
+    },
+    0, 0, RUBY_TYPED_FREE_IMMEDIATELY
+};
+
+static struct enum_chain *
+enum_chain_ptr(VALUE obj)
+{
+    struct enum_chain *ptr;
+
+    TypedData_Get_Struct(obj, struct enum_chain, &enum_chain_data_type, ptr);
+    if (!ptr || ptr->enums == Qundef) {
+        rb_raise(rb_eArgError, "uninitialized chain");
+    }
+    return ptr;
+}
+
+/* :nodoc: */
+static VALUE
+enum_chain_allocate(VALUE klass)
+{
+    struct enum_chain *ptr;
+    VALUE obj;
+
+    obj = TypedData_Make_Struct(klass, struct enum_chain, &enum_chain_data_type, ptr);
+    ptr->enums = Qundef;
+    ptr->pos = -1;
+
+    return obj;
+}
+
+/*
+ * call-seq:
+ *   Enumerator::Chain.new(*enums) -> enum
+ *
+ * Generates a new enumerator object that iterates over the elements
+ * of given enumerable objects in sequence.
+ *
+ *   e = Enumerator::Chain.new(1..3, [4, 5])
+ *   e.to_a #=> [1, 2, 3, 4, 5]
+ *   e.size #=> 5
+ */
+static VALUE
+enum_chain_initialize(VALUE obj, VALUE enums)
+{
+    struct enum_chain *ptr;
+
+    rb_check_frozen(obj);
+    TypedData_Get_Struct(obj, struct enum_chain, &enum_chain_data_type, ptr);
+
+    if (!ptr) rb_raise(rb_eArgError, "unallocated chain");
+
+    ptr->enums = rb_obj_freeze(enums);
+    ptr->pos = -1;
+
+    return obj;
+}
+
+/* :nodoc: */
+static VALUE
+enum_chain_init_copy(VALUE obj, VALUE orig)
+{
+    struct enum_chain *ptr0, *ptr1;
+
+    if (!OBJ_INIT_COPY(obj, orig)) return obj;
+    ptr0 = enum_chain_ptr(orig);
+
+    TypedData_Get_Struct(obj, struct enum_chain, &enum_chain_data_type, ptr1);
+
+    if (!ptr1) rb_raise(rb_eArgError, "unallocated chain");
+
+    ptr1->enums = ptr0->enums;
+    ptr1->pos = ptr0->pos;
+
+    return obj;
+}
+
+static VALUE
+enum_chain_total_size(VALUE enums)
+{
+    VALUE total = INT2FIX(0);
+
+    RARRAY_PTR_USE(enums, ptr, {
+        long i;
+
+        for (i = 0; i < RARRAY_LEN(enums); i++) {
+            VALUE size = enum_size(ptr[i]);
+
+            if (NIL_P(size) || (RB_TYPE_P(size, T_FLOAT) && isinf(NUM2DBL(size)))) {
+                return size;
+            }
+            if (!RB_INTEGER_TYPE_P(size)) {
+                return Qnil;
+            }
+
+            total = rb_funcall(total, '+', 1, size);
+        }
+    });
+
+    return total;
+}
+
+/*
+ * call-seq:
+ *   obj.size -> integer
+ *
+ * Returns the total size of the enumerator chain calculated by
+ * summing up the size of each enumerable in the chain.  If any of the
+ * enumerables reports its size as nil or Float::INFINITY, that value
+ * is returned as the total size.
+ */
+static VALUE
+enum_chain_size(VALUE obj)
+{
+    return enum_chain_total_size(enum_chain_ptr(obj)->enums);
+}
+
+static VALUE
+enum_chain_enum_size(VALUE obj, VALUE args, VALUE eobj)
+{
+    return enum_chain_size(obj);
+}
+
+static VALUE
+enum_chain_yield_block(VALUE arg, VALUE block, int argc, VALUE *argv)
+{
+    return rb_funcallv(block, rb_intern("call"), argc, argv);
+}
+
+static VALUE
+enum_chain_enum_no_size(VALUE obj, VALUE args, VALUE eobj)
+{
+    return Qnil;
+}
+
+/*
+ * call-seq:
+ *   obj.each(*args) { |...| ... } -> obj
+ *   obj.each(*args) -> enumerator
+ *
+ * Iterates over the elements of the first enumerable by calling the
+ * "each" method on it with the given arguments, then proceeds to the
+ * following enumerables in sequence until all of the enumerables are
+ * exhausted.
+ *
+ * If no block is given, returns an enumerator.
+ */
+static VALUE
+enum_chain_each(int argc, VALUE *argv, VALUE obj)
+{
+    VALUE enums, block;
+    struct enum_chain *objptr;
+
+    RETURN_SIZED_ENUMERATOR(obj, argc, argv, argc > 0 ? enum_chain_enum_no_size : enum_chain_enum_size);
+
+    objptr = enum_chain_ptr(obj);
+    enums = objptr->enums;
+    block = rb_block_proc();
+
+    RARRAY_PTR_USE(enums, ptr, {
+        long i;
+
+        for (i = 0; i < RARRAY_LEN(enums); i++) {
+            objptr->pos = i;
+            rb_block_call(ptr[i], id_each, argc, argv, enum_chain_yield_block, block);
+        }
+    });
+
+    return obj;
+}
+
+/*
+ * call-seq:
+ *   obj.rewind -> obj
+ *
+ * Rewinds the enumerator chain by calling the "rewind" method on each
+ * enumerable in reverse order.  Each call is performed only if the
+ * enumerable responds to the method.
+ */
+static VALUE
+enum_chain_rewind(VALUE obj)
+{
+    struct enum_chain *objptr = enum_chain_ptr(obj);
+    VALUE enums = objptr->enums;
+
+    RARRAY_PTR_USE(enums, ptr, {
+        long i;
+
+        for (i = objptr->pos; 0 <= i && i < RARRAY_LEN(enums); objptr->pos = --i) {
+            rb_check_funcall(ptr[i], id_rewind, 0, 0);
+        }
+    });
+
+    return obj;
+}
+
+static VALUE
+inspect_enum_chain(VALUE obj, VALUE dummy, int recur)
+{
+    VALUE klass = rb_obj_class(obj);
+    struct enum_chain *ptr;
+
+    TypedData_Get_Struct(obj, struct enum_chain, &enum_chain_data_type, ptr);
+
+    if (!ptr || ptr->enums == Qundef) {
+        return rb_sprintf("#<%"PRIsVALUE": uninitialized>", rb_class_path(klass));
+    }
+
+    if (recur) {
+        return rb_sprintf("#<%"PRIsVALUE": ...>", rb_class_path(klass));
+    }
+
+    return rb_sprintf("#<%"PRIsVALUE": %+"PRIsVALUE">", rb_class_path(klass), ptr->enums);
+}
+
+/*
+ * call-seq:
+ *   obj.inspect -> string
+ *
+ * Returns a printable version of the enumerator chain.
+ */
+static VALUE
+enum_chain_inspect(VALUE obj)
+{
+    return rb_exec_recursive(inspect_enum_chain, obj, 0);
+}
+
+/*
+ * call-seq:
+ *   e.chain(*enums) -> enumerator
+ *
+ * Returns an enumerator object generated from this enumerator and
+ * given enumerables.
+ *
+ *   e = (1..3).chain([4, 5])
+ *   e.to_a #=> [1, 2, 3, 4, 5]
+ */
+static VALUE
+enum_chain(int argc, VALUE *argv, VALUE obj)
+{
+    VALUE enums = rb_ary_new_from_values(1, &obj);
+    rb_ary_cat(enums, argv, argc);
+
+    return enum_chain_initialize(enum_chain_allocate(rb_cEnumChain), enums);
+}
+
+/*
+ * call-seq:
+ *   e + enum -> enumerator
+ *
+ * Returns an enumerator object generated from this enumerator and a
+ * given enumerable.
+ *
+ *   e = (1..3).each + [4, 5]
+ *   e.to_a #=> [1, 2, 3, 4, 5]
+ */
+static VALUE
+enumerator_plus(VALUE obj, VALUE eobj)
+{
+    VALUE enums = rb_ary_new_from_args(2, obj, eobj);
+
+    return enum_chain_initialize(enum_chain_allocate(rb_cEnumChain), enums);
+}
+
+/*
  * Document-class: Enumerator::ArithmeticSequence
  *
  * Enumerator::ArithmeticSequence is a subclass of Enumerator,
@@ -2907,6 +3209,8 @@ InitVM_Enumerator(void) https://github.com/ruby/ruby/blob/trunk/enumerator.c#L3209
     rb_define_method(rb_cEnumerator, "rewind", enumerator_rewind, 0);
     rb_define_method(rb_cEnumerator, "inspect", enumerator_inspect, 0);
     rb_define_method(rb_cEnumerator, "size", enumerator_size, 0);
+    rb_define_method(rb_cEnumerator, "+", enumerator_plus, 1);
+    rb_define_method(rb_mEnumerable, "chain", enum_chain, -1);
 
     /* Lazy */
     rb_cLazy = rb_define_class_under(rb_cEnumerator, "Lazy", rb_cEnumerator);
@@ -2960,6 +3264,16 @@ InitVM_Enumerator(void) https://github.com/ruby/ruby/blob/trunk/enumerator.c#L3264
     rb_define_method(rb_cYielder, "yield", yielder_yield, -2);
     rb_define_method(rb_cYielder, "<<", yielder_yield_push, 1);
 
+    /* Chain */
+    rb_cEnumChain = rb_define_class_under(rb_cEnumerator, "Chain", rb_cEnumerator);
+    rb_define_alloc_func(rb_cEnumChain, enum_chain_allocate);
+    rb_define_method(rb_cEnumChain, "initialize", enum_chain_initialize, -2);
+    rb_define_method(rb_cEnumChain, "initialize_copy", enum_chain_init_copy, 1);
+    rb_define_method(rb_cEnumChain, "each", enum_chain_each, -1);
+    rb_define_method(rb_cEnumChain, "size", enum_chain_size, 0);
+    rb_define_method(rb_cEnumChain, "rewind", enum_chain_rewind, 0);
+    rb_define_method(rb_cEnumChain, "inspect", enum_chain_inspect, 0);
+
     /* ArithmeticSequence */
     rb_cArithSeq = rb_define_class_under(rb_cEnumerator, "ArithmeticSequence", rb_cEnumerator);
     rb_undef_alloc_func(rb_cArithSeq);
Index: test/ruby/test_enumerator.rb
===================================================================
--- test/ruby/test_enumerator.rb	(revision 65948)
+++ test/ruby/test_enumerator.rb	(revision 65949)
@@ -670,5 +670,119 @@ class TestEnumerator < Test::Unit::TestC https://github.com/ruby/ruby/blob/trunk/test/ruby/test_enumerator.rb#L670
     assert_equal([0, 1], u.force)
     assert_equal([0, 1], u.force)
   end
-end
 
+  def test_enum_chain_and_plus
+    r = 1..5
+
+    e1 = r.chain()
+    assert_kind_of(Enumerator::Chain, e1)
+    assert_equal(5, e1.size)
+    ary = []
+    e1.each { |x| ary << x }
+    assert_equal([1, 2, 3, 4, 5], ary)
+
+    e2 = r.chain([6, 7, 8])
+    assert_kind_of(Enumerator::Chain, e2)
+    assert_equal(8, e2.size)
+    ary = []
+    e2.each { |x| ary << x }
+    assert_equal([1, 2, 3, 4, 5, 6, 7, 8], ary)
+
+    e3 = r.chain([6, 7], 8.step)
+    assert_kind_of(Enumerator::Chain, e3)
+    assert_equal(Float::INFINITY, e3.size)
+    ary = []
+    e3.take(10).each { |x| ary << x }
+    assert_equal([1, 2, 3, 4, 5, 6, 7, 8, 9, 10], ary)
+
+    # `a + b + c` should not return `Enumerator::Chain.new(a, b, c)`
+    # because it is expected that `(a + b).each` be called.
+    e4 = e2.dup
+    class << e4
+      attr_reader :each_is_called
+      def each
+        super
+        @each_is_called = true
+      end
+    end
+    e5 = e4 + 9.step
+    assert_kind_of(Enumerator::Chain, e5)
+    assert_equal(Float::INFINITY, e5.size)
+    ary = []
+    e5.take(10).each { |x| ary << x }
+    assert_equal([1, 2, 3, 4, 5, 6, 7, 8, 9, 10], ary)
+    assert_equal(true, e4.each_is_called)
+  end
+
+  def test_chained_enums
+    a = (1..5).each
+
+    e0 = Enumerator::Chain.new()
+    assert_kind_of(Enumerator::Chain, e0)
+    assert_equal(0, e0.size)
+    ary = []
+    e0.each { |x| ary << x }
+    assert_equal([], ary)
+
+    e1 = Enumerator::Chain.new(a)
+    assert_kind_of(Enumerator::Chain, e1)
+    assert_equal(5, e1.size)
+    ary = []
+    e1.each { |x| ary << x }
+    assert_equal([1, 2, 3, 4, 5], ary)
+
+    e2 = Enumerator::Chain.new(a, [6, 7, 8])
+    assert_kind_of(Enumerator::Chain, e2)
+    assert_equal(8, e2.size)
+    ary = []
+    e2.each { |x| ary << x }
+    assert_equal([1, 2, 3, 4, 5, 6, 7, 8], ary)
+
+    e3 = Enumerator::Chain.new(a, [6, 7], 8.step)
+    assert_kind_of(Enumerator::Chain, e3)
+    assert_equal(Float::INFINITY, e3.size)
+    ary = []
+    e3.take(10).each { |x| ary << x }
+    assert_equal([1, 2, 3, 4, 5, 6, 7, 8, 9, 10], ary)
+
+    e4 = Enumerator::Chain.new(a, Enumerator.new { |y| y << 6 << 7 << 8 })
+    assert_kind_of(Enumerator::Chain, e4)
+    assert_equal(nil, e4.size)
+    ary = []
+    e4.each { |x| ary << x }
+    assert_equal([1, 2, 3, 4, 5, 6, 7, 8], ary)
+
+    e5 = Enumerator::Chain.new(e1, e2)
+    assert_kind_of(Enumerator::Chain, e5)
+    assert_equal(13, e5.size)
+    ary = []
+    e5.each { |x| ary << x }
+    assert_equal([1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 6, 7, 8], ary)
+
+    rewound = []
+    e1.define_singleton_method(:rewind) { rewound << object_id }
+    e2.define_singleton_method(:rewind) { rewound << object_id }
+    e5.rewind
+    assert_equal(rewound, [e2.object_id, e1.object_id])
+
+    rewound = []
+    a = [1]
+    e6 = Enumerator::Chain.new(a)
+    a.define_singleton_method(:rewind) { rewound << object_id }
+    e6.rewind
+    assert_equal(rewound, [])
+
+    assert_equal(
+      '#<Enumerator::Chain: [' +
+        '#<Enumerator::Chain: [' +
+          '#<Enumerator: 1..5:each>' +
+        ']>, ' +
+        '#<Enumerator::Chain: [' +
+          '#<Enumerator: 1..5:each>, ' +
+          '[6, 7, 8]' +
+        ']>' +
+      ']>',
+      e5.inspect
+    )
+  end
+end

--
ML: ruby-changes@q...
Info: http://www.atdot.net/~ko1/quickml/

[前][次][番号順一覧][スレッド一覧]