mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-12 13:32:39 +02:00
new termination rule for scg
This commit is contained in:
parent
0e8bc5662a
commit
803a1d99ed
6 changed files with 84 additions and 50 deletions
|
|
@ -9,6 +9,7 @@ from GPy.inference.conjugate_gradient_descent import CGD, RUNNING
|
|||
import pylab
|
||||
import time
|
||||
from scipy.optimize.optimize import rosen, rosen_der
|
||||
from GPy.inference.gradient_descent_update_rules import PolakRibiere
|
||||
|
||||
|
||||
class Test(unittest.TestCase):
|
||||
|
|
@ -71,10 +72,12 @@ if __name__ == "__main__":
|
|||
# df = lambda x: numpy.dot(A, x) - b
|
||||
f = rosen
|
||||
df = rosen_der
|
||||
x0 = numpy.random.randn(N) * .5
|
||||
x0 = (numpy.random.randn(N) * .5) + numpy.ones(N)
|
||||
print x0
|
||||
|
||||
opt = CGD()
|
||||
|
||||
pylab.ion()
|
||||
fig = pylab.figure("cgd optimize")
|
||||
if fig.axes:
|
||||
ax = fig.axes[0]
|
||||
|
|
@ -83,13 +86,13 @@ if __name__ == "__main__":
|
|||
ax = fig.add_subplot(111, projection='3d')
|
||||
|
||||
interpolation = 40
|
||||
x, y = numpy.linspace(-1, 1, interpolation)[:, None], numpy.linspace(-1, 1, interpolation)[:, None]
|
||||
x, y = numpy.linspace(.5, 1.5, interpolation)[:, None], numpy.linspace(.5, 1.5, interpolation)[:, None]
|
||||
X, Y = numpy.meshgrid(x, y)
|
||||
fXY = numpy.array([f(numpy.array([x, y])) for x, y in zip(X.flatten(), Y.flatten())]).reshape(interpolation, interpolation)
|
||||
|
||||
ax.plot_wireframe(X, Y, fXY)
|
||||
xopts = [x0.copy()]
|
||||
optplts, = ax.plot3D([x0[0]], [x0[1]], zs=f(x0), marker='o', color='r')
|
||||
optplts, = ax.plot3D([x0[0]], [x0[1]], zs=f(x0), marker='', color='r')
|
||||
|
||||
raw_input("enter to start optimize")
|
||||
res = [0]
|
||||
|
|
@ -102,11 +105,7 @@ if __name__ == "__main__":
|
|||
if r[-1] != RUNNING:
|
||||
res[0] = r
|
||||
|
||||
p, c = opt.opt_async(f, df, x0.copy(), callback, messages=True, maxiter=1000,
|
||||
report_every=20, gtol=1e-12)
|
||||
|
||||
|
||||
pylab.ion()
|
||||
pylab.show()
|
||||
res[0] = opt.opt(f, df, x0.copy(), callback, messages=True, maxiter=1000,
|
||||
report_every=7, gtol=1e-12, update_rule=PolakRibiere)
|
||||
|
||||
pass
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue