[前][次][番号順一覧][スレッド一覧]

ruby-changes:63953

From: Marc-Andre <ko1@a...>
Date: Sat, 5 Dec 2020 14:57:21 +0900 (JST)
Subject: [ruby-changes:63953] a83a51932d (master): [ruby/matrix] Optimize **

https://git.ruby-lang.org/ruby.git/commit/?id=a83a51932d

From a83a51932dbc31b549e11b9da8967f2f52a8b07c Mon Sep 17 00:00:00 2001
From: Marc-Andre Lafortune <github@m...>
Date: Fri, 4 Dec 2020 01:57:40 -0500
Subject: [ruby/matrix] Optimize **

Avoiding recursive call would imply iterating bits starting from
most significant, which is not easy to do efficiently.
Any saving would be dwarfed by the multiplications anyways.
[Feature #15233]

diff --git a/lib/matrix.rb b/lib/matrix.rb
index 336a928..c6193eb 100644
--- a/lib/matrix.rb
+++ b/lib/matrix.rb
@@ -1233,26 +1233,49 @@ class Matrix https://github.com/ruby/ruby/blob/trunk/lib/matrix.rb#L1233
   #   #  => 67 96
   #   #     48 99
   #
-  def **(other)
-    case other
+  def **(exp)
+    case exp
     when Integer
-      x = self
-      if other <= 0
-        x = self.inverse
-        return self.class.identity(self.column_count) if other == 0
-        other = -other
-      end
-      z = nil
-      loop do
-        z = z ? z * x : x if other[0] == 1
-        return z if (other >>= 1).zero?
-        x *= x
+      case
+      when exp == 0
+        _make_sure_it_is_invertible = inverse
+        self.class.identity(column_count)
+      when exp < 0
+        inverse.power_int(-exp)
+      else
+        power_int(exp)
       end
     when Numeric
       v, d, v_inv = eigensystem
-      v * self.class.diagonal(*d.each(:diagonal).map{|e| e ** other}) * v_inv
+      v * self.class.diagonal(*d.each(:diagonal).map{|e| e ** exp}) * v_inv
+    else
+      raise ErrOperationNotDefined, ["**", self.class, exp.class]
+    end
+  end
+
+  protected def power_int(exp)
+    # assumes `exp` is an Integer > 0
+    #
+    # Previous algorithm:
+    #   build M**2, M**4 = (M**2)**2, M**8, ... and multiplying those you need
+    #   e.g. M**0b1011 = M**11 = M * M**2 * M**8
+    #                              ^  ^
+    #   (highlighted the 2 out of 5 multiplications involving `M * x`)
+    #
+    # Current algorithm has same number of multiplications but with lower exponents:
+    #    M**11 = M * (M * M**4)**2
+    #              ^    ^  ^
+    #   (highlighted the 3 out of 5 multiplications involving `M * x`)
+    #
+    # This should be faster for all (non nil-potent) matrices.
+    case
+    when exp == 1
+      self
+    when exp.odd?
+      self * power_int(exp - 1)
     else
-      raise ErrOperationNotDefined, ["**", self.class, other.class]
+      sqrt = power_int(exp / 2)
+      sqrt * sqrt
     end
   end
 
diff --git a/test/matrix/test_matrix.rb b/test/matrix/test_matrix.rb
index b134bfb..8125fb2 100644
--- a/test/matrix/test_matrix.rb
+++ b/test/matrix/test_matrix.rb
@@ -448,6 +448,12 @@ class TestMatrix < Test::Unit::TestCase https://github.com/ruby/ruby/blob/trunk/test/matrix/test_matrix.rb#L448
     assert_equal(Matrix[[67,96],[48,99]], Matrix[[7,6],[3,9]] ** 2)
     assert_equal(Matrix.I(5), Matrix.I(5) ** -1)
     assert_raise(Matrix::ErrOperationNotDefined) { Matrix.I(5) ** Object.new }
+
+    m = Matrix[[0,2],[1,0]]
+    exp = 0b11101000
+    assert_equal(Matrix.scalar(2, 1 << (exp/2)), m ** exp)
+    exp = 0b11101001
+    assert_equal(Matrix[[0, 2 << (exp/2)], [1 << (exp/2), 0]], m ** exp)
   end
 
   def test_det
-- 
cgit v0.10.2


--
ML: ruby-changes@q...
Info: http://www.atdot.net/~ko1/quickml/

[前][次][番号順一覧][スレッド一覧]