[前][次][番号順一覧][スレッド一覧]

ruby-changes:17108

From: nobu <ko1@a...>
Date: Fri, 27 Aug 2010 07:57:46 +0900 (JST)
Subject: [ruby-changes:17108] Ruby:r29108 (trunk): * array.c (rb_ary_shuffle_bang): bail out from modification during

nobu	2010-08-27 07:57:39 +0900 (Fri, 27 Aug 2010)

  New Revision: 29108

  http://svn.ruby-lang.org/cgi-bin/viewvc.cgi?view=rev&revision=29108

  Log:
    * array.c (rb_ary_shuffle_bang): bail out from modification during
      shuffle.
    
    * array.c (rb_ary_sample): ditto.

  Modified files:
    trunk/ChangeLog
    trunk/array.c
    trunk/test/ruby/test_array.rb

Index: array.c
===================================================================
--- array.c	(revision 29107)
+++ array.c	(revision 29108)
@@ -20,6 +20,8 @@
 #endif
 #include <assert.h>
 
+#define numberof(array) (int)(sizeof(array) / sizeof((array)[0]))
+
 VALUE rb_cArray;
 
 static ID id_cmp;
@@ -3748,8 +3750,8 @@
 static VALUE
 rb_ary_shuffle_bang(int argc, VALUE *argv, VALUE ary)
 {
-    VALUE *ptr, opts, randgen = rb_cRandom;
-    long i = RARRAY_LEN(ary);
+    VALUE *ptr, opts, *snap_ptr, randgen = rb_cRandom;
+    long i, snap_len;
 
     if (OPTHASH_GIVEN_P(opts)) {
 	randgen = rb_hash_lookup2(opts, sym_random, randgen);
@@ -3758,10 +3760,17 @@
 	rb_raise(rb_eArgError, "wrong number of arguments (%d for 0)", argc);
     }
     rb_ary_modify(ary);
+    i = RARRAY_LEN(ary);
     ptr = RARRAY_PTR(ary);
+    snap_len = i;
+    snap_ptr = ptr;
     while (i) {
 	long j = RAND_UPTO(i);
-	VALUE tmp = ptr[--i];
+	VALUE tmp;
+	if (snap_len != RARRAY_LEN(ary) || snap_ptr != RARRAY_PTR(ary)) {
+	    rb_raise(rb_eRuntimeError, "modified during shuffle");
+	}
+	tmp = ptr[--i];
 	ptr[i] = ptr[j];
 	ptr[j] = tmp;
     }
@@ -3814,37 +3823,54 @@
 rb_ary_sample(int argc, VALUE *argv, VALUE ary)
 {
     VALUE nv, result, *ptr;
-    VALUE opts, randgen = rb_cRandom;
+    VALUE opts, snap, randgen = rb_cRandom;
     long n, len, i, j, k, idx[10];
+    double rnds[numberof(idx)];
 
-    len = RARRAY_LEN(ary);
     if (OPTHASH_GIVEN_P(opts)) {
 	randgen = rb_hash_lookup2(opts, sym_random, randgen);
     }
+    ptr = RARRAY_PTR(ary);
+    len = RARRAY_LEN(ary);
     if (argc == 0) {
 	if (len == 0) return Qnil;
-	i = len == 1 ? 0 : RAND_UPTO(len);
+	if (len == 1) {
+	    i = 0;
+	}
+	else {
+	    double x = rb_random_real(randgen);
+	    if ((len = RARRAY_LEN(ary)) == 0) return Qnil;
+	    i = (long)(x * len);
+	}
 	return RARRAY_PTR(ary)[i];
     }
     rb_scan_args(argc, argv, "1", &nv);
     n = NUM2LONG(nv);
     if (n < 0) rb_raise(rb_eArgError, "negative sample number");
+    if (n > len) n = len;
+    if (n <= numberof(idx)) {
+	for (i = 0; i < n; ++i) {
+	    rnds[i] = rb_random_real(randgen);
+	}
+    }
+    len = RARRAY_LEN(ary);
     ptr = RARRAY_PTR(ary);
-    len = RARRAY_LEN(ary);
     if (n > len) n = len;
     switch (n) {
-      case 0: return rb_ary_new2(0);
+      case 0:
+	return rb_ary_new2(0);
       case 1:
-	return rb_ary_new4(1, &ptr[RAND_UPTO(len)]);
+	i = (long)(rnds[0] * len);
+	return rb_ary_new4(1, &ptr[i]);
       case 2:
-	i = RAND_UPTO(len);
-	j = RAND_UPTO(len-1);
+	i = (long)(rnds[0] * len);
+	j = (long)(rnds[1] * (len-1));
 	if (j >= i) j++;
 	return rb_ary_new3(2, ptr[i], ptr[j]);
       case 3:
-	i = RAND_UPTO(len);
-	j = RAND_UPTO(len-1);
-	k = RAND_UPTO(len-2);
+	i = (long)(rnds[0] * len);
+	j = (long)(rnds[1] * (len-1));
+	k = (long)(rnds[2] * (len-2));
 	{
 	    long l = j, g = i;
 	    if (j >= i) l = i, g = ++j;
@@ -3852,12 +3878,12 @@
 	}
 	return rb_ary_new3(3, ptr[i], ptr[j], ptr[k]);
     }
-    if ((size_t)n < sizeof(idx)/sizeof(idx[0])) {
+    if (n <= numberof(idx)) {
 	VALUE *ptr_result;
-	long sorted[sizeof(idx)/sizeof(idx[0])];
-	sorted[0] = idx[0] = RAND_UPTO(len);
+	long sorted[numberof(idx)];
+	sorted[0] = idx[0] = (long)(rnds[0] * len);
 	for (i=1; i<n; i++) {
-	    k = RAND_UPTO(--len);
+	    k = (long)(rnds[i] * --len);
 	    for (j = 0; j < i; ++j) {
 		if (k < sorted[j]) break;
 		++k;
@@ -3874,6 +3900,7 @@
     else {
 	VALUE *ptr_result;
 	result = rb_ary_new4(len, ptr);
+	RBASIC(result)->klass = 0;
 	ptr_result = RARRAY_PTR(result);
 	RB_GC_GUARD(ary);
 	for (i=0; i<n; i++) {
@@ -3882,6 +3909,7 @@
 	    ptr_result[j] = ptr_result[i];
 	    ptr_result[i] = nv;
 	}
+	RBASIC(result)->klass = rb_cArray;
     }
     ARY_SET_LEN(result, n);
 
Index: ChangeLog
===================================================================
--- ChangeLog	(revision 29107)
+++ ChangeLog	(revision 29108)
@@ -1,3 +1,10 @@
+Fri Aug 27 07:57:34 2010  Nobuyoshi Nakada  <nobu@r...>
+
+	* array.c (rb_ary_shuffle_bang): bail out from modification during
+	  shuffle.
+
+	* array.c (rb_ary_sample): ditto.
+
 Fri Aug 27 05:11:51 2010  Tanaka Akira  <akr@f...>
 
 	* ext/pathname/pathname.c (path_sysopen): Pathname#sysopen translated
Index: test/ruby/test_array.rb
===================================================================
--- test/ruby/test_array.rb	(revision 29107)
+++ test/ruby/test_array.rb	(revision 29108)
@@ -1901,7 +1901,6 @@
   end
 
   def test_shuffle_random
-    cc = nil
     gen = proc do
       10000000
     end
@@ -1911,6 +1910,16 @@
     assert_raise(RangeError) {
       [*0..2].shuffle(random: gen)
     }
+
+    ary = (0...10000).to_a
+    gen = proc do
+      ary.replace([])
+      0.5
+    end
+    class << gen
+      alias rand call
+    end
+    assert_raise(RuntimeError) {ary.shuffle!(random: gen)}
   end
 
   def test_sample
@@ -1951,6 +1960,51 @@
     end
   end
 
+  def test_sample_random
+    ary = (0...10000).to_a
+    assert_raise(ArgumentError) {ary.sample(1, 2, random: nil)}
+    gen0 = proc do
+      0.5
+    end
+    class << gen0
+      alias rand call
+    end
+    gen1 = proc do
+      ary.replace([])
+      0.5
+    end
+    class << gen1
+      alias rand call
+    end
+    assert_equal(5000, ary.sample(random: gen0))
+    assert_nil(ary.sample(random: gen1))
+    assert_equal([], ary)
+    ary = (0...10000).to_a
+    assert_equal([5000], ary.sample(1, random: gen0))
+    assert_equal([], ary.sample(1, random: gen1))
+    assert_equal([], ary)
+    ary = (0...10000).to_a
+    assert_equal([5000, 4999], ary.sample(2, random: gen0))
+    assert_equal([], ary.sample(2, random: gen1))
+    assert_equal([], ary)
+    ary = (0...10000).to_a
+    assert_equal([5000, 4999, 5001], ary.sample(3, random: gen0))
+    assert_equal([], ary.sample(3, random: gen1))
+    assert_equal([], ary)
+    ary = (0...10000).to_a
+    assert_equal([5000, 4999, 5001, 4998], ary.sample(4, random: gen0))
+    assert_equal([], ary.sample(4, random: gen1))
+    assert_equal([], ary)
+    ary = (0...10000).to_a
+    assert_equal([5000, 4999, 5001, 4998, 5002, 4997, 5003, 4996, 5004, 4995], ary.sample(10, random: gen0))
+    assert_equal([], ary.sample(10, random: gen1))
+    assert_equal([], ary)
+    ary = (0...10000).to_a
+    assert_equal([5000, 0, 5001, 2, 5002, 4, 5003, 6, 5004, 8, 5005], ary.sample(11, random: gen0))
+    ary.sample(11, random: gen1) # implementation detail, may change in the future
+    assert_equal([], ary)
+  end
+
   def test_cycle
     a = []
     [0, 1, 2].cycle do |i|

--
ML: ruby-changes@q...
Info: http://www.atdot.net/~ko1/quickml/

[前][次][番号順一覧][スレッド一覧]