last opt updates and tests

This commit is contained in:
Max Zwiessele 2013-05-03 14:38:42 +01:00
parent 5321bfc8c9
commit f4b997beb8
2 changed files with 11 additions and 10 deletions

View file

@ -63,14 +63,15 @@ class _Async_Optimization(Thread):
return f_w return f_w
def callback(self, *a): def callback(self, *a):
self.outq.put(a) if self.outq is not None:
self.outq.put(a)
# self.parent and self.parent.callback(*a, **kw) # self.parent and self.parent.callback(*a, **kw)
pass pass
# print "callback done" # print "callback done"
def callback_return(self, *a): def callback_return(self, *a):
self.callback(*a) self.callback(*a)
self.outq.put(self.SENTINEL) self.callback(self.SENTINEL)
self.runsignal.clear() self.runsignal.clear()
def run(self, *args, **kwargs): def run(self, *args, **kwargs):
@ -170,16 +171,17 @@ class Async_Optimize(object):
messages=0, maxiter=5e3, max_f_eval=15e3, gtol=1e-6, messages=0, maxiter=5e3, max_f_eval=15e3, gtol=1e-6,
report_every=10, *args, **kwargs): report_every=10, *args, **kwargs):
self.runsignal.set() self.runsignal.set()
outqueue = Queue()
c = None c = None
outqueue = None
if callback: if callback:
outqueue = Queue()
self.callback = callback self.callback = callback
c = Thread(target=self.async_callback_collect, args=(outqueue,)) c = Thread(target=self.async_callback_collect, args=(outqueue,))
c.start() c.start()
p = _CGDAsync(f, df, x0, update_rule, self.runsignal, self.SENTINEL, p = _CGDAsync(f, df, x0, update_rule, self.runsignal, self.SENTINEL,
report_every=report_every, messages=messages, maxiter=maxiter, report_every=report_every, messages=messages, maxiter=maxiter,
max_f_eval=max_f_eval, gtol=gtol, outqueue=outqueue, *args, **kwargs) max_f_eval=max_f_eval, gtol=gtol, outqueue=outqueue, *args, **kwargs)
p.run() p.start()
return p, c return p, c
def opt(self, f, df, x0, callback=None, update_rule=FletcherReeves, def opt(self, f, df, x0, callback=None, update_rule=FletcherReeves,

View file

@ -14,7 +14,7 @@ from scipy.optimize.optimize import rosen, rosen_der
class Test(unittest.TestCase): class Test(unittest.TestCase):
def testMinimizeSquare(self): def testMinimizeSquare(self):
N = 2 N = 100
A = numpy.random.rand(N) * numpy.eye(N) A = numpy.random.rand(N) * numpy.eye(N)
b = numpy.random.rand(N) * 0 b = numpy.random.rand(N) * 0
f = lambda x: numpy.dot(x.T.dot(A), x) - numpy.dot(x.T, b) f = lambda x: numpy.dot(x.T.dot(A), x) - numpy.dot(x.T, b)
@ -25,7 +25,7 @@ class Test(unittest.TestCase):
restarts = 10 restarts = 10
for _ in range(restarts): for _ in range(restarts):
try: try:
x0 = numpy.random.randn(N) * .5 x0 = numpy.random.randn(N) * 300
res = opt.opt(f, df, x0, messages=0, res = opt.opt(f, df, x0, messages=0,
maxiter=1000, gtol=1e-10) maxiter=1000, gtol=1e-10)
assert numpy.allclose(res[0], 0, atol=1e-3) assert numpy.allclose(res[0], 0, atol=1e-3)
@ -37,10 +37,9 @@ class Test(unittest.TestCase):
raise AssertionError("Test failed for {} restarts".format(restarts)) raise AssertionError("Test failed for {} restarts".format(restarts))
def testRosen(self): def testRosen(self):
N = 2 N = 20
f = rosen f = rosen
df = rosen_der df = rosen_der
x0 = numpy.random.randn(N) * .5
opt = CGD() opt = CGD()
@ -49,8 +48,8 @@ class Test(unittest.TestCase):
try: try:
x0 = numpy.random.randn(N) * .5 x0 = numpy.random.randn(N) * .5
res = opt.opt(f, df, x0, messages=0, res = opt.opt(f, df, x0, messages=0,
maxiter=1000, gtol=1e-2) maxiter=5e2, gtol=1e-2)
assert numpy.allclose(res[0], 1, atol=.01) assert numpy.allclose(res[0], 1, atol=.1)
break break
except: except:
# RESTART # RESTART