mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-06-20 15:38:08 +02:00
LinearCF Psi Stat not working yet, strange bug in psi computations
This commit is contained in:
parent
c502b66ea3
commit
42474f0044
8 changed files with 353 additions and 244 deletions
|
|
@ -4,14 +4,14 @@ Created on 24 Apr 2013
|
|||
@author: maxz
|
||||
'''
|
||||
from GPy.inference.gradient_descent_update_rules import FletcherReeves
|
||||
import numpy
|
||||
from multiprocessing import Value
|
||||
from scipy.optimize.linesearch import line_search_wolfe1, line_search_wolfe2
|
||||
from multiprocessing.synchronize import Event
|
||||
from multiprocessing.queues import Queue
|
||||
from Queue import Empty
|
||||
import sys
|
||||
from multiprocessing import Value
|
||||
from multiprocessing.queues import Queue
|
||||
from multiprocessing.synchronize import Event
|
||||
from scipy.optimize.linesearch import line_search_wolfe1, line_search_wolfe2
|
||||
from threading import Thread
|
||||
import numpy
|
||||
import sys
|
||||
|
||||
RUNNING = "running"
|
||||
CONVERGED = "converged"
|
||||
|
|
@ -20,10 +20,9 @@ MAX_F_EVAL = "maximum number of function calls reached"
|
|||
LINE_SEARCH = "line search failed"
|
||||
KBINTERRUPT = "interrupted"
|
||||
|
||||
SENTINEL = None
|
||||
|
||||
class _Async_Optimization(Thread):
|
||||
def __init__(self, f, df, x0, update_rule, runsignal,
|
||||
|
||||
def __init__(self, f, df, x0, update_rule, runsignal, SENTINEL,
|
||||
report_every=10, messages=0, maxiter=5e3, max_f_eval=15e3,
|
||||
gtol=1e-6, outqueue=None, *args, **kw):
|
||||
"""
|
||||
|
|
@ -42,6 +41,7 @@ class _Async_Optimization(Thread):
|
|||
self.maxiter = maxiter
|
||||
self.max_f_eval = max_f_eval
|
||||
self.gtol = gtol
|
||||
self.SENTINEL = SENTINEL
|
||||
self.runsignal = runsignal
|
||||
# self.parent = parent
|
||||
# self.result = None
|
||||
|
|
@ -70,7 +70,7 @@ class _Async_Optimization(Thread):
|
|||
|
||||
def callback_return(self, *a):
|
||||
self.callback(*a)
|
||||
self.outq.put(SENTINEL)
|
||||
self.outq.put(self.SENTINEL)
|
||||
self.runsignal.clear()
|
||||
|
||||
def run(self, *args, **kwargs):
|
||||
|
|
@ -136,7 +136,7 @@ class _CGDAsync(_Async_Optimization):
|
|||
if gfi is not None:
|
||||
gi = gfi
|
||||
|
||||
if fi_old > fi:
|
||||
if numpy.isnan(fi) or fi_old < fi:
|
||||
gi, ur, si = self.reset(xi, *a, **kw)
|
||||
else:
|
||||
xi += numpy.dot(alphai, si)
|
||||
|
|
@ -145,22 +145,23 @@ class _CGDAsync(_Async_Optimization):
|
|||
sys.stdout.flush()
|
||||
sys.stdout.write("iteration: {0:> 6g} f:{1:> 12e} |g|:{2:> 12e}".format(it, fi, numpy.dot(gi.T, gi)))
|
||||
|
||||
if it % self.report_every == 0:
|
||||
self.callback(xi, fi, it, self.f_call.value, self.df_call.value, status)
|
||||
if it % self.report_every == 0:
|
||||
self.callback(xi, fi, gi, it, self.f_call.value, self.df_call.value, status)
|
||||
it += 1
|
||||
else:
|
||||
status = MAXITER
|
||||
# self.result = [xi, fi, it, self.f_call.value, self.df_call.value, status]
|
||||
self.callback_return(xi, fi, it, self.f_call.value, self.df_call.value, status)
|
||||
self.callback_return(xi, fi, gi, it, self.f_call.value, self.df_call.value, status)
|
||||
self.result = [xi, fi, gi, it, self.f_call.value, self.df_call.value, status]
|
||||
|
||||
class Async_Optimize(object):
|
||||
callback = lambda *x: None
|
||||
runsignal = Event()
|
||||
SENTINEL = "SENTINEL"
|
||||
|
||||
def async_callback_collect(self, q):
|
||||
while self.runsignal.is_set():
|
||||
try:
|
||||
for ret in iter(lambda: q.get(timeout=1), SENTINEL):
|
||||
for ret in iter(lambda: q.get(timeout=1), self.SENTINEL):
|
||||
self.callback(*ret)
|
||||
except Empty:
|
||||
pass
|
||||
|
|
@ -169,12 +170,12 @@ class Async_Optimize(object):
|
|||
messages=0, maxiter=5e3, max_f_eval=15e3, gtol=1e-6,
|
||||
report_every=10, *args, **kwargs):
|
||||
self.runsignal.set()
|
||||
outqueue = Queue(5)
|
||||
outqueue = Queue()
|
||||
if callback:
|
||||
self.callback = callback
|
||||
c = Thread(target=self.async_callback_collect, args=(outqueue,))
|
||||
c.start()
|
||||
p = _CGDAsync(f, df, x0, update_rule, self.runsignal,
|
||||
p = _CGDAsync(f, df, x0, update_rule, self.runsignal, self.SENTINEL,
|
||||
report_every=report_every, messages=messages, maxiter=maxiter,
|
||||
max_f_eval=max_f_eval, gtol=gtol, outqueue=outqueue, *args, **kwargs)
|
||||
p.run()
|
||||
|
|
@ -189,12 +190,14 @@ class Async_Optimize(object):
|
|||
while self.runsignal.is_set():
|
||||
try:
|
||||
p.join(1)
|
||||
c.join(1)
|
||||
# c.join(1)
|
||||
except KeyboardInterrupt:
|
||||
# print "^C"
|
||||
self.runsignal.clear()
|
||||
p.join()
|
||||
c.join()
|
||||
if c.is_alive():
|
||||
print "WARNING: callback still running, optimisation done!"
|
||||
return p.result
|
||||
|
||||
class CGD(Async_Optimize):
|
||||
'''
|
||||
|
|
@ -215,7 +218,7 @@ class CGD(Async_Optimize):
|
|||
|
||||
callback gets called every `report_every` iterations
|
||||
|
||||
callback(xi, fi, iteration, function_calls, gradient_calls, status_message)
|
||||
callback(xi, fi, gi, iteration, function_calls, gradient_calls, status_message)
|
||||
|
||||
if df returns tuple (grad, natgrad) it will optimize according
|
||||
to natural gradient rules
|
||||
|
|
@ -233,7 +236,7 @@ class CGD(Async_Optimize):
|
|||
**calls**
|
||||
---------
|
||||
|
||||
callback(x_opt, f_opt, iteration, function_calls, gradient_calls, status_message)
|
||||
callback(x_opt, f_opt, g_opt, iteration, function_calls, gradient_calls, status_message)
|
||||
|
||||
at end of optimization!
|
||||
"""
|
||||
|
|
@ -247,7 +250,7 @@ class CGD(Async_Optimize):
|
|||
|
||||
Minimize f, calling callback every `report_every` iterations with following syntax:
|
||||
|
||||
callback(xi, fi, iteration, function_calls, gradient_calls, status_message)
|
||||
callback(xi, fi, gi, iteration, function_calls, gradient_calls, status_message)
|
||||
|
||||
if df returns tuple (grad, natgrad) it will optimize according
|
||||
to natural gradient rules
|
||||
|
|
@ -260,7 +263,7 @@ class CGD(Async_Optimize):
|
|||
**returns**
|
||||
---------
|
||||
|
||||
x_opt, f_opt, iteration, function_calls, gradient_calls, status_message
|
||||
x_opt, f_opt, g_opt, iteration, function_calls, gradient_calls, status_message
|
||||
|
||||
at end of optimization
|
||||
"""
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue