From 2d9ecaa042a79589dcaad4155582995f093f0a96 Mon Sep 17 00:00:00 2001 From: Max Zwiessele Date: Wed, 8 May 2013 15:36:38 +0100 Subject: [PATCH] conjugate gradient optimizer without callback (no c.join) --- GPy/inference/conjugate_gradient_descent.py | 9 ++++---- GPy/testing/cgd_tests.py | 25 +++++++++++---------- 2 files changed, 18 insertions(+), 16 deletions(-) diff --git a/GPy/inference/conjugate_gradient_descent.py b/GPy/inference/conjugate_gradient_descent.py index 2fcd9ba0..0f6603e5 100644 --- a/GPy/inference/conjugate_gradient_descent.py +++ b/GPy/inference/conjugate_gradient_descent.py @@ -3,7 +3,8 @@ Created on 24 Apr 2013 @author: maxz ''' -from GPy.inference.gradient_descent_update_rules import FletcherReeves +from GPy.inference.gradient_descent_update_rules import FletcherReeves, \ + PolakRibiere from Queue import Empty from multiprocessing import Value from multiprocessing.queues import Queue @@ -173,7 +174,7 @@ class Async_Optimize(object): except Empty: pass - def opt_async(self, f, df, x0, callback, update_rule=FletcherReeves, + def opt_async(self, f, df, x0, callback, update_rule=PolakRibiere, messages=0, maxiter=5e3, max_f_eval=15e3, gtol=1e-6, report_every=10, *args, **kwargs): self.runsignal.set() @@ -199,12 +200,12 @@ class Async_Optimize(object): while self.runsignal.is_set(): try: p.join(1) - c.join(1) + if c: c.join(1) except KeyboardInterrupt: # print "^C" self.runsignal.clear() p.join() - c.join() + if c: c.join() if c and c.is_alive(): # self.runsignal.set() # while self.runsignal.is_set(): diff --git a/GPy/testing/cgd_tests.py b/GPy/testing/cgd_tests.py index 19c1c21b..d999c6fc 100644 --- a/GPy/testing/cgd_tests.py +++ b/GPy/testing/cgd_tests.py @@ -26,12 +26,12 @@ class Test(unittest.TestCase): restarts = 10 for _ in range(restarts): try: - x0 = numpy.random.randn(N) * 300 - res = opt.opt(f, df, x0, messages=0, - maxiter=1000, gtol=1e-10) - assert numpy.allclose(res[0], 0, atol=1e-3) + x0 = numpy.random.randn(N) * 10 + res = opt.opt(f, df, x0, messages=0, maxiter=1000, gtol=1e-15) + assert numpy.allclose(res[0], 0, atol=1e-5) break - except: + except AssertionError: + import ipdb;ipdb.set_trace() # RESTART pass else: @@ -47,9 +47,9 @@ class Test(unittest.TestCase): restarts = 10 for _ in range(restarts): try: - x0 = numpy.random.randn(N) * .5 + x0 = (numpy.random.randn(N) * .5) + numpy.ones(N) res = opt.opt(f, df, x0, messages=0, - maxiter=5e2, gtol=1e-2) + maxiter=1e3, gtol=1e-12) assert numpy.allclose(res[0], 1, atol=.1) break except: @@ -68,10 +68,10 @@ if __name__ == "__main__": N = 2 A = numpy.random.rand(N) * numpy.eye(N) b = numpy.random.rand(N) * 0 -# f = lambda x: numpy.dot(x.T.dot(A), x) - numpy.dot(x.T, b) -# df = lambda x: numpy.dot(A, x) - b - f = rosen - df = rosen_der + f = lambda x: numpy.dot(x.T.dot(A), x) - numpy.dot(x.T, b) + df = lambda x: numpy.dot(A, x) - b +# f = rosen +# df = rosen_der x0 = (numpy.random.randn(N) * .5) + numpy.ones(N) print x0 @@ -86,7 +86,8 @@ if __name__ == "__main__": ax = fig.add_subplot(111, projection='3d') interpolation = 40 - x, y = numpy.linspace(.5, 1.5, interpolation)[:, None], numpy.linspace(.5, 1.5, interpolation)[:, None] +# x, y = numpy.linspace(.5, 1.5, interpolation)[:, None], numpy.linspace(.5, 1.5, interpolation)[:, None] + x, y = numpy.linspace(-1, 1, interpolation)[:, None], numpy.linspace(-1, 1, interpolation)[:, None] X, Y = numpy.meshgrid(x, y) fXY = numpy.array([f(numpy.array([x, y])) for x, y in zip(X.flatten(), Y.flatten())]).reshape(interpolation, interpolation)