mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-09 20:12:38 +02:00
async optimize working
This commit is contained in:
parent
96a97ce790
commit
f3f6226287
4 changed files with 145 additions and 131 deletions
|
|
@ -173,7 +173,7 @@ def bgplvm_simulation_matlab_compare():
|
||||||
from GPy.models import mrd
|
from GPy.models import mrd
|
||||||
from GPy import kern
|
from GPy import kern
|
||||||
reload(mrd); reload(kern)
|
reload(mrd); reload(kern)
|
||||||
k = kern.rbf(Q, ARD=True) + kern.bias(Q, np.exp(-2)) + kern.white(Q, np.exp(-2))
|
k = kern.linear(Q, ARD=True) + kern.bias(Q, np.exp(-2)) + kern.white(Q, np.exp(-2))
|
||||||
m = Bayesian_GPLVM(Y, Q, init="PCA", M=M, kernel=k,
|
m = Bayesian_GPLVM(Y, Q, init="PCA", M=M, kernel=k,
|
||||||
# X=mu,
|
# X=mu,
|
||||||
# X_variance=S,
|
# X_variance=S,
|
||||||
|
|
|
||||||
|
|
@ -3,16 +3,15 @@ Created on 24 Apr 2013
|
||||||
|
|
||||||
@author: maxz
|
@author: maxz
|
||||||
'''
|
'''
|
||||||
from multiprocessing.process import Process
|
|
||||||
from GPy.inference.gradient_descent_update_rules import FletcherReeves
|
from GPy.inference.gradient_descent_update_rules import FletcherReeves
|
||||||
import numpy
|
import numpy
|
||||||
from multiprocessing import Value
|
from multiprocessing import Value
|
||||||
from scipy.optimize.linesearch import line_search_wolfe1, line_search_wolfe2
|
from scipy.optimize.linesearch import line_search_wolfe1, line_search_wolfe2
|
||||||
from multiprocessing.synchronize import Lock, Event
|
from multiprocessing.synchronize import Event
|
||||||
from copy import deepcopy
|
|
||||||
from multiprocessing.queues import Queue
|
from multiprocessing.queues import Queue
|
||||||
from Queue import Empty
|
from Queue import Empty
|
||||||
import sys
|
import sys
|
||||||
|
from threading import Thread
|
||||||
|
|
||||||
RUNNING = "running"
|
RUNNING = "running"
|
||||||
CONVERGED = "converged"
|
CONVERGED = "converged"
|
||||||
|
|
@ -21,7 +20,9 @@ MAX_F_EVAL = "maximum number of function calls reached"
|
||||||
LINE_SEARCH = "line search failed"
|
LINE_SEARCH = "line search failed"
|
||||||
KBINTERRUPT = "interrupted"
|
KBINTERRUPT = "interrupted"
|
||||||
|
|
||||||
class _Async_Optimization(Process):
|
SENTINEL = None
|
||||||
|
|
||||||
|
class _Async_Optimization(Thread):
|
||||||
def __init__(self, f, df, x0, update_rule, runsignal,
|
def __init__(self, f, df, x0, update_rule, runsignal,
|
||||||
report_every=10, messages=0, maxiter=5e3, max_f_eval=15e3,
|
report_every=10, messages=0, maxiter=5e3, max_f_eval=15e3,
|
||||||
gtol=1e-6, outqueue=None, *args, **kw):
|
gtol=1e-6, outqueue=None, *args, **kw):
|
||||||
|
|
@ -67,6 +68,11 @@ class _Async_Optimization(Process):
|
||||||
pass
|
pass
|
||||||
# print "callback done"
|
# print "callback done"
|
||||||
|
|
||||||
|
def callback_return(self, *a):
|
||||||
|
self.callback(*a)
|
||||||
|
self.outq.put(SENTINEL)
|
||||||
|
self.runsignal.clear()
|
||||||
|
|
||||||
def run(self, *args, **kwargs):
|
def run(self, *args, **kwargs):
|
||||||
raise NotImplementedError("Overwrite this with optimization (for async use)")
|
raise NotImplementedError("Overwrite this with optimization (for async use)")
|
||||||
pass
|
pass
|
||||||
|
|
@ -91,7 +97,6 @@ class _CGDAsync(_Async_Optimization):
|
||||||
it = 0
|
it = 0
|
||||||
|
|
||||||
while it < self.maxiter:
|
while it < self.maxiter:
|
||||||
print self.runsignal.is_set()
|
|
||||||
if not self.runsignal.is_set():
|
if not self.runsignal.is_set():
|
||||||
break
|
break
|
||||||
|
|
||||||
|
|
@ -117,7 +122,7 @@ class _CGDAsync(_Async_Optimization):
|
||||||
xi,
|
xi,
|
||||||
si, gi,
|
si, gi,
|
||||||
fi, fi_old)
|
fi, fi_old)
|
||||||
if alphai is not None:
|
if alphai is not None and fi2 < fi:
|
||||||
fi, fi_old = fi2, fi_old2
|
fi, fi_old = fi2, fi_old2
|
||||||
else:
|
else:
|
||||||
alphai, _, _, fi, fi_old, gfi = \
|
alphai, _, _, fi, fi_old, gfi = \
|
||||||
|
|
@ -130,11 +135,15 @@ class _CGDAsync(_Async_Optimization):
|
||||||
break
|
break
|
||||||
if gfi is not None:
|
if gfi is not None:
|
||||||
gi = gfi
|
gi = gfi
|
||||||
|
|
||||||
|
if fi_old > fi:
|
||||||
|
gi, ur, si = self.reset(xi, *a, **kw)
|
||||||
|
else:
|
||||||
xi += numpy.dot(alphai, si)
|
xi += numpy.dot(alphai, si)
|
||||||
if self.messages:
|
if self.messages:
|
||||||
sys.stdout.write("\r")
|
sys.stdout.write("\r")
|
||||||
sys.stdout.flush()
|
sys.stdout.flush()
|
||||||
sys.stdout.write("iteration: {0:> 6g} f: {1:> 12F} g: {2:> 12F}".format(it, fi, gi))
|
sys.stdout.write("iteration: {0:> 6g} f:{1:> 12e} |g|:{2:> 12e}".format(it, fi, numpy.dot(gi.T, gi)))
|
||||||
|
|
||||||
if it % self.report_every == 0:
|
if it % self.report_every == 0:
|
||||||
self.callback(xi, fi, it, self.f_call.value, self.df_call.value, status)
|
self.callback(xi, fi, it, self.f_call.value, self.df_call.value, status)
|
||||||
|
|
@ -142,18 +151,16 @@ class _CGDAsync(_Async_Optimization):
|
||||||
else:
|
else:
|
||||||
status = MAXITER
|
status = MAXITER
|
||||||
# self.result = [xi, fi, it, self.f_call.value, self.df_call.value, status]
|
# self.result = [xi, fi, it, self.f_call.value, self.df_call.value, status]
|
||||||
self.callback(xi, fi, it, self.f_call.value, self.df_call.value, status)
|
self.callback_return(xi, fi, it, self.f_call.value, self.df_call.value, status)
|
||||||
return
|
|
||||||
|
|
||||||
class Async_Optimize(object):
|
class Async_Optimize(object):
|
||||||
callback = None
|
callback = lambda *x: None
|
||||||
SENTINEL = object()
|
|
||||||
runsignal = Event()
|
runsignal = Event()
|
||||||
|
|
||||||
def async_callback_collect(self, q):
|
def async_callback_collect(self, q):
|
||||||
while self.runsignal.is_set():
|
while self.runsignal.is_set():
|
||||||
try:
|
try:
|
||||||
for ret in iter(lambda: q.get(timeout=1), self.SENTINEL):
|
for ret in iter(lambda: q.get(timeout=1), SENTINEL):
|
||||||
self.callback(*ret)
|
self.callback(*ret)
|
||||||
except Empty:
|
except Empty:
|
||||||
pass
|
pass
|
||||||
|
|
@ -162,30 +169,32 @@ class Async_Optimize(object):
|
||||||
messages=0, maxiter=5e3, max_f_eval=15e3, gtol=1e-6,
|
messages=0, maxiter=5e3, max_f_eval=15e3, gtol=1e-6,
|
||||||
report_every=10, *args, **kwargs):
|
report_every=10, *args, **kwargs):
|
||||||
self.runsignal.set()
|
self.runsignal.set()
|
||||||
outqueue = Queue()
|
outqueue = Queue(5)
|
||||||
if callback:
|
if callback:
|
||||||
self.callback = callback
|
self.callback = callback
|
||||||
collector = Process(target=self.async_callback_collect, args=(outqueue,))
|
c = Thread(target=self.async_callback_collect, args=(outqueue,))
|
||||||
collector.start()
|
c.start()
|
||||||
p = _CGDAsync(f, df, x0, update_rule, self.runsignal,
|
p = _CGDAsync(f, df, x0, update_rule, self.runsignal,
|
||||||
report_every=report_every, messages=messages, maxiter=maxiter,
|
report_every=report_every, messages=messages, maxiter=maxiter,
|
||||||
max_f_eval=max_f_eval, gtol=gtol, outqueue=outqueue, *args, **kwargs)
|
max_f_eval=max_f_eval, gtol=gtol, outqueue=outqueue, *args, **kwargs)
|
||||||
p.start()
|
p.run()
|
||||||
return p
|
return p, c
|
||||||
|
|
||||||
def fmin(self, f, df, x0, callback=None, update_rule=FletcherReeves,
|
def fmin(self, f, df, x0, callback=None, update_rule=FletcherReeves,
|
||||||
messages=0, maxiter=5e3, max_f_eval=15e3, gtol=1e-6,
|
messages=0, maxiter=5e3, max_f_eval=15e3, gtol=1e-6,
|
||||||
report_every=10, *args, **kwargs):
|
report_every=10, *args, **kwargs):
|
||||||
p = self.fmin_async(f, df, x0, callback, update_rule, messages,
|
p, c = self.fmin_async(f, df, x0, callback, update_rule, messages,
|
||||||
maxiter, max_f_eval, gtol,
|
maxiter, max_f_eval, gtol,
|
||||||
report_every, *args, **kwargs)
|
report_every, *args, **kwargs)
|
||||||
while self.runsignal.is_set():
|
while self.runsignal.is_set():
|
||||||
try:
|
try:
|
||||||
p.join(1)
|
p.join(1)
|
||||||
|
c.join(1)
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
print "^C"
|
# print "^C"
|
||||||
self.runsignal.clear()
|
self.runsignal.clear()
|
||||||
p.join()
|
p.join()
|
||||||
|
c.join()
|
||||||
|
|
||||||
class CGD(Async_Optimize):
|
class CGD(Async_Optimize):
|
||||||
'''
|
'''
|
||||||
|
|
|
||||||
|
|
@ -103,10 +103,10 @@ class sparse_GP(GP):
|
||||||
self.C = linalg.lapack.flapack.dtrtrs(self.Lm, np.asfortranarray(tmp.T), lower=1, trans=1)[0]
|
self.C = linalg.lapack.flapack.dtrtrs(self.Lm, np.asfortranarray(tmp.T), lower=1, trans=1)[0]
|
||||||
|
|
||||||
# self.Cpsi1V = np.dot(self.C,self.psi1V)
|
# self.Cpsi1V = np.dot(self.C,self.psi1V)
|
||||||
#back substutue C into psi1V
|
# back substitute C into psi1V
|
||||||
tmp,info1 = linalg.lapack.flapack.dtrtrs(self.Lm,np.asfortranarray(self.psi1V),lower=1,trans=0)
|
tmp, _ = linalg.lapack.flapack.dtrtrs(self.Lm, np.asfortranarray(self.psi1V), lower=1, trans=0)
|
||||||
tmp,info2 = linalg.lapack.flapack.dpotrs(self.LB,tmp,lower=1)
|
tmp, _ = linalg.lapack.flapack.dpotrs(self.LB, tmp, lower=1)
|
||||||
self.Cpsi1V,info3 = linalg.lapack.flapack.dtrtrs(self.Lm,tmp,lower=1,trans=1)
|
self.Cpsi1V, _ = linalg.lapack.flapack.dtrtrs(self.Lm, tmp, lower=1, trans=1)
|
||||||
|
|
||||||
self.Cpsi1VVpsi1 = np.dot(self.Cpsi1V, self.psi1V.T) # TODO: stabilize?
|
self.Cpsi1VVpsi1 = np.dot(self.Cpsi1V, self.psi1V.T) # TODO: stabilize?
|
||||||
self.E = tdot(self.Cpsi1V / sf)
|
self.E = tdot(self.Cpsi1V / sf)
|
||||||
|
|
|
||||||
|
|
@ -47,10 +47,15 @@ if __name__ == "__main__":
|
||||||
xopts = [x0.copy()]
|
xopts = [x0.copy()]
|
||||||
optplts, = ax.plot3D([x0[0]], [x0[1]], zs=f(x0), marker='o', color='r')
|
optplts, = ax.plot3D([x0[0]], [x0[1]], zs=f(x0), marker='o', color='r')
|
||||||
|
|
||||||
|
raw_input("enter to start optimize")
|
||||||
|
|
||||||
def callback(x, *a, **kw):
|
def callback(x, *a, **kw):
|
||||||
xopts.append(x.copy())
|
xopts.append(x.copy())
|
||||||
time.sleep(.3)
|
# time.sleep(.3)
|
||||||
optplts._verts3d = [numpy.array(xopts)[:, 0], numpy.array(xopts)[:, 1], [f(xs) for xs in xopts]]
|
optplts._verts3d = [numpy.array(xopts)[:, 0], numpy.array(xopts)[:, 1], [f(xs) for xs in xopts]]
|
||||||
fig.canvas.draw()
|
fig.canvas.draw()
|
||||||
|
|
||||||
res = opt.fmin(f, df, x0, callback, messages=True, report_every=1)
|
res = opt.fmin(f, df, x0, callback, messages=True, maxiter=1000, report_every=1)
|
||||||
|
|
||||||
|
pylab.ion()
|
||||||
|
pylab.show()
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue