conjugate gradient optimizer without callback (no c.join)

This commit is contained in:
Max Zwiessele 2013-05-08 15:36:38 +01:00
parent bc0bd59874
commit 2d9ecaa042
2 changed files with 18 additions and 16 deletions

View file

@ -3,7 +3,8 @@ Created on 24 Apr 2013
@author: maxz @author: maxz
''' '''
from GPy.inference.gradient_descent_update_rules import FletcherReeves from GPy.inference.gradient_descent_update_rules import FletcherReeves, \
PolakRibiere
from Queue import Empty from Queue import Empty
from multiprocessing import Value from multiprocessing import Value
from multiprocessing.queues import Queue from multiprocessing.queues import Queue
@ -173,7 +174,7 @@ class Async_Optimize(object):
except Empty: except Empty:
pass pass
def opt_async(self, f, df, x0, callback, update_rule=FletcherReeves, def opt_async(self, f, df, x0, callback, update_rule=PolakRibiere,
messages=0, maxiter=5e3, max_f_eval=15e3, gtol=1e-6, messages=0, maxiter=5e3, max_f_eval=15e3, gtol=1e-6,
report_every=10, *args, **kwargs): report_every=10, *args, **kwargs):
self.runsignal.set() self.runsignal.set()
@ -199,12 +200,12 @@ class Async_Optimize(object):
while self.runsignal.is_set(): while self.runsignal.is_set():
try: try:
p.join(1) p.join(1)
c.join(1) if c: c.join(1)
except KeyboardInterrupt: except KeyboardInterrupt:
# print "^C" # print "^C"
self.runsignal.clear() self.runsignal.clear()
p.join() p.join()
c.join() if c: c.join()
if c and c.is_alive(): if c and c.is_alive():
# self.runsignal.set() # self.runsignal.set()
# while self.runsignal.is_set(): # while self.runsignal.is_set():

View file

@ -26,12 +26,12 @@ class Test(unittest.TestCase):
restarts = 10 restarts = 10
for _ in range(restarts): for _ in range(restarts):
try: try:
x0 = numpy.random.randn(N) * 300 x0 = numpy.random.randn(N) * 10
res = opt.opt(f, df, x0, messages=0, res = opt.opt(f, df, x0, messages=0, maxiter=1000, gtol=1e-15)
maxiter=1000, gtol=1e-10) assert numpy.allclose(res[0], 0, atol=1e-5)
assert numpy.allclose(res[0], 0, atol=1e-3)
break break
except: except AssertionError:
import ipdb;ipdb.set_trace()
# RESTART # RESTART
pass pass
else: else:
@ -47,9 +47,9 @@ class Test(unittest.TestCase):
restarts = 10 restarts = 10
for _ in range(restarts): for _ in range(restarts):
try: try:
x0 = numpy.random.randn(N) * .5 x0 = (numpy.random.randn(N) * .5) + numpy.ones(N)
res = opt.opt(f, df, x0, messages=0, res = opt.opt(f, df, x0, messages=0,
maxiter=5e2, gtol=1e-2) maxiter=1e3, gtol=1e-12)
assert numpy.allclose(res[0], 1, atol=.1) assert numpy.allclose(res[0], 1, atol=.1)
break break
except: except:
@ -68,10 +68,10 @@ if __name__ == "__main__":
N = 2 N = 2
A = numpy.random.rand(N) * numpy.eye(N) A = numpy.random.rand(N) * numpy.eye(N)
b = numpy.random.rand(N) * 0 b = numpy.random.rand(N) * 0
# f = lambda x: numpy.dot(x.T.dot(A), x) - numpy.dot(x.T, b) f = lambda x: numpy.dot(x.T.dot(A), x) - numpy.dot(x.T, b)
# df = lambda x: numpy.dot(A, x) - b df = lambda x: numpy.dot(A, x) - b
f = rosen # f = rosen
df = rosen_der # df = rosen_der
x0 = (numpy.random.randn(N) * .5) + numpy.ones(N) x0 = (numpy.random.randn(N) * .5) + numpy.ones(N)
print x0 print x0
@ -86,7 +86,8 @@ if __name__ == "__main__":
ax = fig.add_subplot(111, projection='3d') ax = fig.add_subplot(111, projection='3d')
interpolation = 40 interpolation = 40
x, y = numpy.linspace(.5, 1.5, interpolation)[:, None], numpy.linspace(.5, 1.5, interpolation)[:, None] # x, y = numpy.linspace(.5, 1.5, interpolation)[:, None], numpy.linspace(.5, 1.5, interpolation)[:, None]
x, y = numpy.linspace(-1, 1, interpolation)[:, None], numpy.linspace(-1, 1, interpolation)[:, None]
X, Y = numpy.meshgrid(x, y) X, Y = numpy.meshgrid(x, y)
fXY = numpy.array([f(numpy.array([x, y])) for x, y in zip(X.flatten(), Y.flatten())]).reshape(interpolation, interpolation) fXY = numpy.array([f(numpy.array([x, y])) for x, y in zip(X.flatten(), Y.flatten())]).reshape(interpolation, interpolation)