new termination rule for scg

This commit is contained in:
Max Zwiessele 2013-05-08 09:44:09 +01:00
parent 0e8bc5662a
commit 803a1d99ed
6 changed files with 84 additions and 50 deletions

View file

@ -9,6 +9,7 @@ from GPy.inference.conjugate_gradient_descent import CGD, RUNNING
import pylab
import time
from scipy.optimize.optimize import rosen, rosen_der
from GPy.inference.gradient_descent_update_rules import PolakRibiere
class Test(unittest.TestCase):
@ -71,10 +72,12 @@ if __name__ == "__main__":
# df = lambda x: numpy.dot(A, x) - b
f = rosen
df = rosen_der
x0 = numpy.random.randn(N) * .5
x0 = (numpy.random.randn(N) * .5) + numpy.ones(N)
print x0
opt = CGD()
pylab.ion()
fig = pylab.figure("cgd optimize")
if fig.axes:
ax = fig.axes[0]
@ -83,13 +86,13 @@ if __name__ == "__main__":
ax = fig.add_subplot(111, projection='3d')
interpolation = 40
x, y = numpy.linspace(-1, 1, interpolation)[:, None], numpy.linspace(-1, 1, interpolation)[:, None]
x, y = numpy.linspace(.5, 1.5, interpolation)[:, None], numpy.linspace(.5, 1.5, interpolation)[:, None]
X, Y = numpy.meshgrid(x, y)
fXY = numpy.array([f(numpy.array([x, y])) for x, y in zip(X.flatten(), Y.flatten())]).reshape(interpolation, interpolation)
ax.plot_wireframe(X, Y, fXY)
xopts = [x0.copy()]
optplts, = ax.plot3D([x0[0]], [x0[1]], zs=f(x0), marker='o', color='r')
optplts, = ax.plot3D([x0[0]], [x0[1]], zs=f(x0), marker='', color='r')
raw_input("enter to start optimize")
res = [0]
@ -102,11 +105,7 @@ if __name__ == "__main__":
if r[-1] != RUNNING:
res[0] = r
p, c = opt.opt_async(f, df, x0.copy(), callback, messages=True, maxiter=1000,
report_every=20, gtol=1e-12)
pylab.ion()
pylab.show()
res[0] = opt.opt(f, df, x0.copy(), callback, messages=True, maxiter=1000,
report_every=7, gtol=1e-12, update_rule=PolakRibiere)
pass