mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-15 06:52:39 +02:00
added conjugate gradient descent asunc
This commit is contained in:
parent
0332fa14f8
commit
2218eeece2
3 changed files with 358 additions and 0 deletions
259
GPy/inference/conjugate_gradient_descent.py
Normal file
259
GPy/inference/conjugate_gradient_descent.py
Normal file
|
|
@ -0,0 +1,259 @@
|
|||
'''
|
||||
Created on 24 Apr 2013
|
||||
|
||||
@author: maxz
|
||||
'''
|
||||
from multiprocessing.process import Process
|
||||
from GPy.inference.gradient_descent_update_rules import FletcherReeves
|
||||
import numpy
|
||||
from multiprocessing import Value
|
||||
from scipy.optimize.linesearch import line_search_wolfe1, line_search_wolfe2
|
||||
from multiprocessing.synchronize import Lock, Event
|
||||
from copy import deepcopy
|
||||
from multiprocessing.queues import Queue
|
||||
from Queue import Empty
|
||||
import sys
|
||||
|
||||
RUNNING = "running"
|
||||
CONVERGED = "converged"
|
||||
MAXITER = "maximum number of iterations reached"
|
||||
MAX_F_EVAL = "maximum number of function calls reached"
|
||||
LINE_SEARCH = "line search failed"
|
||||
KBINTERRUPT = "interrupted"
|
||||
|
||||
class _Async_Optimization(Process):
|
||||
def __init__(self, f, df, x0, update_rule, runsignal,
|
||||
report_every=10, messages=0, maxiter=5e3, max_f_eval=15e3,
|
||||
gtol=1e-6, outqueue=None, *args, **kw):
|
||||
"""
|
||||
Helper Process class for async optimization
|
||||
|
||||
f_call and df_call are Multiprocessing Values, for synchronized assignment
|
||||
"""
|
||||
self.f_call = Value('i', 0)
|
||||
self.df_call = Value('i', 0)
|
||||
self.f = self.f_wrapper(f, self.f_call)
|
||||
self.df = self.f_wrapper(df, self.df_call)
|
||||
self.x0 = x0
|
||||
self.update_rule = update_rule
|
||||
self.report_every = report_every
|
||||
self.messages = messages
|
||||
self.maxiter = maxiter
|
||||
self.max_f_eval = max_f_eval
|
||||
self.gtol = gtol
|
||||
self.runsignal = runsignal
|
||||
# self.parent = parent
|
||||
# self.result = None
|
||||
self.outq = outqueue
|
||||
super(_Async_Optimization, self).__init__(target=self.run,
|
||||
name="CG Optimization",
|
||||
*args, **kw)
|
||||
|
||||
# def __enter__(self):
|
||||
# return self
|
||||
#
|
||||
# def __exit__(self, type, value, traceback):
|
||||
# return isinstance(value, TypeError)
|
||||
|
||||
def f_wrapper(self, f, counter):
|
||||
def f_w(*a, **kw):
|
||||
counter.value += 1
|
||||
return f(*a, **kw)
|
||||
return f_w
|
||||
|
||||
def callback(self, *a):
|
||||
self.outq.put(a)
|
||||
# self.parent and self.parent.callback(*a, **kw)
|
||||
pass
|
||||
# print "callback done"
|
||||
|
||||
def run(self, *args, **kwargs):
|
||||
raise NotImplementedError("Overwrite this with optimization (for async use)")
|
||||
pass
|
||||
|
||||
class _CGDAsync(_Async_Optimization):
|
||||
|
||||
def reset(self, xi, *a, **kw):
|
||||
gi = -self.df(xi, *a, **kw)
|
||||
si = gi
|
||||
ur = self.update_rule(gi)
|
||||
return gi, ur, si
|
||||
|
||||
def run(self, *a, **kw):
|
||||
status = RUNNING
|
||||
|
||||
fi = self.f(self.x0)
|
||||
fi_old = fi + 5000
|
||||
|
||||
gi, ur, si = self.reset(self.x0, *a, **kw)
|
||||
xi = self.x0
|
||||
xi_old = numpy.nan
|
||||
it = 0
|
||||
|
||||
while it < self.maxiter:
|
||||
print self.runsignal.is_set()
|
||||
if not self.runsignal.is_set():
|
||||
break
|
||||
|
||||
if self.f_call.value > self.max_f_eval:
|
||||
status = MAX_F_EVAL
|
||||
|
||||
gi = -self.df(xi, *a, **kw)
|
||||
if numpy.dot(gi.T, gi) < self.gtol:
|
||||
status = CONVERGED
|
||||
break
|
||||
if numpy.isnan(numpy.dot(gi.T, gi)):
|
||||
if numpy.any(numpy.isnan(xi_old)):
|
||||
status = CONVERGED
|
||||
break
|
||||
self.reset(xi_old)
|
||||
|
||||
gammai = ur(gi)
|
||||
if gammai < 1e-6 or it % xi.shape[0] == 0:
|
||||
gi, ur, si = self.reset(xi, *a, **kw)
|
||||
si = gi + gammai * si
|
||||
alphai, _, _, fi2, fi_old2, gfi = line_search_wolfe1(self.f,
|
||||
self.df,
|
||||
xi,
|
||||
si, gi,
|
||||
fi, fi_old)
|
||||
if alphai is not None:
|
||||
fi, fi_old = fi2, fi_old2
|
||||
else:
|
||||
alphai, _, _, fi, fi_old, gfi = \
|
||||
line_search_wolfe2(self.f, self.df,
|
||||
xi, si, gi,
|
||||
fi, fi_old)
|
||||
if alphai is None:
|
||||
# This line search also failed to find a better solution.
|
||||
status = LINE_SEARCH
|
||||
break
|
||||
if gfi is not None:
|
||||
gi = gfi
|
||||
xi += numpy.dot(alphai, si)
|
||||
if self.messages:
|
||||
sys.stdout.write("\r")
|
||||
sys.stdout.flush()
|
||||
sys.stdout.write("iteration: {0:> 6g} f: {1:> 12F} g: {2:> 12F}".format(it, fi, gi))
|
||||
|
||||
if it % self.report_every == 0:
|
||||
self.callback(xi, fi, it, self.f_call.value, self.df_call.value, status)
|
||||
it += 1
|
||||
else:
|
||||
status = MAXITER
|
||||
# self.result = [xi, fi, it, self.f_call.value, self.df_call.value, status]
|
||||
self.callback(xi, fi, it, self.f_call.value, self.df_call.value, status)
|
||||
return
|
||||
|
||||
class Async_Optimize(object):
|
||||
callback = None
|
||||
SENTINEL = object()
|
||||
runsignal = Event()
|
||||
|
||||
def async_callback_collect(self, q):
|
||||
while self.runsignal.is_set():
|
||||
try:
|
||||
for ret in iter(lambda: q.get(timeout=1), self.SENTINEL):
|
||||
self.callback(*ret)
|
||||
except Empty:
|
||||
pass
|
||||
|
||||
def fmin_async(self, f, df, x0, callback, update_rule=FletcherReeves,
|
||||
messages=0, maxiter=5e3, max_f_eval=15e3, gtol=1e-6,
|
||||
report_every=10, *args, **kwargs):
|
||||
self.runsignal.set()
|
||||
outqueue = Queue()
|
||||
if callback:
|
||||
self.callback = callback
|
||||
collector = Process(target=self.async_callback_collect, args=(outqueue,))
|
||||
collector.start()
|
||||
p = _CGDAsync(f, df, x0, update_rule, self.runsignal,
|
||||
report_every=report_every, messages=messages, maxiter=maxiter,
|
||||
max_f_eval=max_f_eval, gtol=gtol, outqueue=outqueue, *args, **kwargs)
|
||||
p.start()
|
||||
return p
|
||||
|
||||
def fmin(self, f, df, x0, callback=None, update_rule=FletcherReeves,
|
||||
messages=0, maxiter=5e3, max_f_eval=15e3, gtol=1e-6,
|
||||
report_every=10, *args, **kwargs):
|
||||
p = self.fmin_async(f, df, x0, callback, update_rule, messages,
|
||||
maxiter, max_f_eval, gtol,
|
||||
report_every, *args, **kwargs)
|
||||
while self.runsignal.is_set():
|
||||
try:
|
||||
p.join(1)
|
||||
except KeyboardInterrupt:
|
||||
print "^C"
|
||||
self.runsignal.clear()
|
||||
p.join()
|
||||
|
||||
class CGD(Async_Optimize):
|
||||
'''
|
||||
Conjugate gradient descent algorithm to minimize
|
||||
function f with gradients df, starting at x0
|
||||
with update rule update_rule
|
||||
|
||||
if df returns tuple (grad, natgrad) it will optimize according
|
||||
to natural gradient rules
|
||||
'''
|
||||
name = "Conjugate Gradient Descent"
|
||||
|
||||
def fmin_async(self, *a, **kw):
|
||||
"""
|
||||
fmin_async(self, f, df, x0, callback, update_rule=FletcherReeves,
|
||||
messages=0, maxiter=5e3, max_f_eval=15e3, gtol=1e-6,
|
||||
report_every=10, *args, **kwargs)
|
||||
|
||||
callback gets called every `report_every` iterations
|
||||
|
||||
callback(xi, fi, iteration, function_calls, gradient_calls, status_message)
|
||||
|
||||
if df returns tuple (grad, natgrad) it will optimize according
|
||||
to natural gradient rules
|
||||
|
||||
f, and df will be called with
|
||||
|
||||
f(xi, *args, **kwargs)
|
||||
df(xi, *args, **kwargs)
|
||||
|
||||
**returns**
|
||||
-----------
|
||||
|
||||
Started `Process` object, optimizing asynchronously
|
||||
|
||||
**calls**
|
||||
---------
|
||||
|
||||
callback(x_opt, f_opt, iteration, function_calls, gradient_calls, status_message)
|
||||
|
||||
at end of optimization!
|
||||
"""
|
||||
return super(CGD, self).fmin_async(*a, **kw)
|
||||
|
||||
def fmin(self, *a, **kw):
|
||||
"""
|
||||
fmin(self, f, df, x0, callback=None, update_rule=FletcherReeves,
|
||||
messages=0, maxiter=5e3, max_f_eval=15e3, gtol=1e-6,
|
||||
report_every=10, *args, **kwargs)
|
||||
|
||||
Minimize f, calling callback every `report_every` iterations with following syntax:
|
||||
|
||||
callback(xi, fi, iteration, function_calls, gradient_calls, status_message)
|
||||
|
||||
if df returns tuple (grad, natgrad) it will optimize according
|
||||
to natural gradient rules
|
||||
|
||||
f, and df will be called with
|
||||
|
||||
f(xi, *args, **kwargs)
|
||||
df(xi, *args, **kwargs)
|
||||
|
||||
**returns**
|
||||
---------
|
||||
|
||||
x_opt, f_opt, iteration, function_calls, gradient_calls, status_message
|
||||
|
||||
at end of optimization
|
||||
"""
|
||||
return super(CGD, self).fmin(*a, **kw)
|
||||
|
||||
43
GPy/inference/gradient_descent_update_rules.py
Normal file
43
GPy/inference/gradient_descent_update_rules.py
Normal file
|
|
@ -0,0 +1,43 @@
|
|||
'''
|
||||
Created on 24 Apr 2013
|
||||
|
||||
@author: maxz
|
||||
'''
|
||||
import numpy
|
||||
|
||||
class GDUpdateRule():
|
||||
_gradnat = None
|
||||
_gradnatold = None
|
||||
def __init__(self, initgrad, initgradnat=None):
|
||||
self.grad = initgrad
|
||||
if initgradnat:
|
||||
self.gradnat = initgradnat
|
||||
else:
|
||||
self.gradnat = initgrad
|
||||
# self.grad, self.gradnat
|
||||
def _gamma(self):
|
||||
raise NotImplemented("""Implement gamma update rule here,
|
||||
you can use self.grad and self.gradold for parameters, as well as
|
||||
self.gradnat and self.gradnatold for natural gradients.""")
|
||||
def __call__(self, grad, gradnat=None, si=None, *args, **kw):
|
||||
"""
|
||||
Return gamma for given gradients and optional natural gradients
|
||||
"""
|
||||
if not gradnat:
|
||||
gradnat = grad
|
||||
self.gradold = self.grad
|
||||
self.gradnatold = self.gradnat
|
||||
self.grad = grad
|
||||
self.gradnat = gradnat
|
||||
self.si = si
|
||||
return self._gamma(*args, **kw)
|
||||
|
||||
class FletcherReeves(GDUpdateRule):
|
||||
'''
|
||||
Fletcher Reeves update rule for gamma
|
||||
'''
|
||||
def _gamma(self, *a, **kw):
|
||||
tmp = numpy.dot(self.grad.T, self.gradnat)
|
||||
if tmp:
|
||||
return tmp / numpy.dot(self.gradold.T, self.gradnatold)
|
||||
return tmp
|
||||
56
GPy/testing/cgd_tests.py
Normal file
56
GPy/testing/cgd_tests.py
Normal file
|
|
@ -0,0 +1,56 @@
|
|||
'''
|
||||
Created on 26 Apr 2013
|
||||
|
||||
@author: maxz
|
||||
'''
|
||||
import unittest
|
||||
import numpy
|
||||
from GPy.inference.conjugate_gradient_descent import CGD
|
||||
import pylab
|
||||
import time
|
||||
from scipy.optimize.optimize import rosen, rosen_der
|
||||
|
||||
|
||||
class Test(unittest.TestCase):
|
||||
|
||||
def testMinimizeSquare(self):
|
||||
f = lambda x: x ** 2 + 2 * x - 2
|
||||
|
||||
if __name__ == "__main__":
|
||||
# import sys;sys.argv = ['', 'Test.testMinimizeSquare']
|
||||
# unittest.main()
|
||||
N = 2
|
||||
A = numpy.random.rand(N) * numpy.eye(N)
|
||||
b = numpy.random.rand(N)
|
||||
# f = lambda x: numpy.dot(x.T.dot(A), x) + numpy.dot(x.T, b)
|
||||
# df = lambda x: numpy.dot(A, x) - b
|
||||
|
||||
f = rosen
|
||||
df = rosen_der
|
||||
x0 = numpy.random.randn(N) * .5
|
||||
|
||||
opt = CGD()
|
||||
|
||||
fig = pylab.figure("cgd optimize")
|
||||
if fig.axes:
|
||||
ax = fig.axes[0]
|
||||
ax.cla()
|
||||
else:
|
||||
ax = fig.add_subplot(111, projection='3d')
|
||||
|
||||
interpolation = 40
|
||||
x, y = numpy.linspace(-1, 1, interpolation)[:, None], numpy.linspace(-1, 1, interpolation)[:, None]
|
||||
X, Y = numpy.meshgrid(x, y)
|
||||
fXY = numpy.array([f(numpy.array([x, y])) for x, y in zip(X.flatten(), Y.flatten())]).reshape(interpolation, interpolation)
|
||||
|
||||
ax.plot_wireframe(X, Y, fXY)
|
||||
xopts = [x0.copy()]
|
||||
optplts, = ax.plot3D([x0[0]], [x0[1]], zs=f(x0), marker='o', color='r')
|
||||
|
||||
def callback(x, *a, **kw):
|
||||
xopts.append(x.copy())
|
||||
time.sleep(.3)
|
||||
optplts._verts3d = [numpy.array(xopts)[:, 0], numpy.array(xopts)[:, 1], [f(xs) for xs in xopts]]
|
||||
fig.canvas.draw()
|
||||
|
||||
res = opt.fmin(f, df, x0, callback, messages=True, report_every=1)
|
||||
Loading…
Add table
Add a link
Reference in a new issue