mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-02 08:12:39 +02:00
conjugate gradient optimizer without callback (no c.join)
This commit is contained in:
parent
bc0bd59874
commit
2d9ecaa042
2 changed files with 18 additions and 16 deletions
|
|
@ -3,7 +3,8 @@ Created on 24 Apr 2013
|
||||||
|
|
||||||
@author: maxz
|
@author: maxz
|
||||||
'''
|
'''
|
||||||
from GPy.inference.gradient_descent_update_rules import FletcherReeves
|
from GPy.inference.gradient_descent_update_rules import FletcherReeves, \
|
||||||
|
PolakRibiere
|
||||||
from Queue import Empty
|
from Queue import Empty
|
||||||
from multiprocessing import Value
|
from multiprocessing import Value
|
||||||
from multiprocessing.queues import Queue
|
from multiprocessing.queues import Queue
|
||||||
|
|
@ -173,7 +174,7 @@ class Async_Optimize(object):
|
||||||
except Empty:
|
except Empty:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def opt_async(self, f, df, x0, callback, update_rule=FletcherReeves,
|
def opt_async(self, f, df, x0, callback, update_rule=PolakRibiere,
|
||||||
messages=0, maxiter=5e3, max_f_eval=15e3, gtol=1e-6,
|
messages=0, maxiter=5e3, max_f_eval=15e3, gtol=1e-6,
|
||||||
report_every=10, *args, **kwargs):
|
report_every=10, *args, **kwargs):
|
||||||
self.runsignal.set()
|
self.runsignal.set()
|
||||||
|
|
@ -199,12 +200,12 @@ class Async_Optimize(object):
|
||||||
while self.runsignal.is_set():
|
while self.runsignal.is_set():
|
||||||
try:
|
try:
|
||||||
p.join(1)
|
p.join(1)
|
||||||
c.join(1)
|
if c: c.join(1)
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
# print "^C"
|
# print "^C"
|
||||||
self.runsignal.clear()
|
self.runsignal.clear()
|
||||||
p.join()
|
p.join()
|
||||||
c.join()
|
if c: c.join()
|
||||||
if c and c.is_alive():
|
if c and c.is_alive():
|
||||||
# self.runsignal.set()
|
# self.runsignal.set()
|
||||||
# while self.runsignal.is_set():
|
# while self.runsignal.is_set():
|
||||||
|
|
|
||||||
|
|
@ -26,12 +26,12 @@ class Test(unittest.TestCase):
|
||||||
restarts = 10
|
restarts = 10
|
||||||
for _ in range(restarts):
|
for _ in range(restarts):
|
||||||
try:
|
try:
|
||||||
x0 = numpy.random.randn(N) * 300
|
x0 = numpy.random.randn(N) * 10
|
||||||
res = opt.opt(f, df, x0, messages=0,
|
res = opt.opt(f, df, x0, messages=0, maxiter=1000, gtol=1e-15)
|
||||||
maxiter=1000, gtol=1e-10)
|
assert numpy.allclose(res[0], 0, atol=1e-5)
|
||||||
assert numpy.allclose(res[0], 0, atol=1e-3)
|
|
||||||
break
|
break
|
||||||
except:
|
except AssertionError:
|
||||||
|
import ipdb;ipdb.set_trace()
|
||||||
# RESTART
|
# RESTART
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
|
|
@ -47,9 +47,9 @@ class Test(unittest.TestCase):
|
||||||
restarts = 10
|
restarts = 10
|
||||||
for _ in range(restarts):
|
for _ in range(restarts):
|
||||||
try:
|
try:
|
||||||
x0 = numpy.random.randn(N) * .5
|
x0 = (numpy.random.randn(N) * .5) + numpy.ones(N)
|
||||||
res = opt.opt(f, df, x0, messages=0,
|
res = opt.opt(f, df, x0, messages=0,
|
||||||
maxiter=5e2, gtol=1e-2)
|
maxiter=1e3, gtol=1e-12)
|
||||||
assert numpy.allclose(res[0], 1, atol=.1)
|
assert numpy.allclose(res[0], 1, atol=.1)
|
||||||
break
|
break
|
||||||
except:
|
except:
|
||||||
|
|
@ -68,10 +68,10 @@ if __name__ == "__main__":
|
||||||
N = 2
|
N = 2
|
||||||
A = numpy.random.rand(N) * numpy.eye(N)
|
A = numpy.random.rand(N) * numpy.eye(N)
|
||||||
b = numpy.random.rand(N) * 0
|
b = numpy.random.rand(N) * 0
|
||||||
# f = lambda x: numpy.dot(x.T.dot(A), x) - numpy.dot(x.T, b)
|
f = lambda x: numpy.dot(x.T.dot(A), x) - numpy.dot(x.T, b)
|
||||||
# df = lambda x: numpy.dot(A, x) - b
|
df = lambda x: numpy.dot(A, x) - b
|
||||||
f = rosen
|
# f = rosen
|
||||||
df = rosen_der
|
# df = rosen_der
|
||||||
x0 = (numpy.random.randn(N) * .5) + numpy.ones(N)
|
x0 = (numpy.random.randn(N) * .5) + numpy.ones(N)
|
||||||
print x0
|
print x0
|
||||||
|
|
||||||
|
|
@ -86,7 +86,8 @@ if __name__ == "__main__":
|
||||||
ax = fig.add_subplot(111, projection='3d')
|
ax = fig.add_subplot(111, projection='3d')
|
||||||
|
|
||||||
interpolation = 40
|
interpolation = 40
|
||||||
x, y = numpy.linspace(.5, 1.5, interpolation)[:, None], numpy.linspace(.5, 1.5, interpolation)[:, None]
|
# x, y = numpy.linspace(.5, 1.5, interpolation)[:, None], numpy.linspace(.5, 1.5, interpolation)[:, None]
|
||||||
|
x, y = numpy.linspace(-1, 1, interpolation)[:, None], numpy.linspace(-1, 1, interpolation)[:, None]
|
||||||
X, Y = numpy.meshgrid(x, y)
|
X, Y = numpy.meshgrid(x, y)
|
||||||
fXY = numpy.array([f(numpy.array([x, y])) for x, y in zip(X.flatten(), Y.flatten())]).reshape(interpolation, interpolation)
|
fXY = numpy.array([f(numpy.array([x, y])) for x, y in zip(X.flatten(), Y.flatten())]).reshape(interpolation, interpolation)
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue