last opt updates and tests

This commit is contained in:
Max Zwiessele 2013-05-03 14:38:42 +01:00
parent 5321bfc8c9
commit f4b997beb8
2 changed files with 11 additions and 10 deletions

View file

@ -63,6 +63,7 @@ class _Async_Optimization(Thread):
return f_w
def callback(self, *a):
if self.outq is not None:
self.outq.put(a)
# self.parent and self.parent.callback(*a, **kw)
pass
@ -70,7 +71,7 @@ class _Async_Optimization(Thread):
def callback_return(self, *a):
self.callback(*a)
self.outq.put(self.SENTINEL)
self.callback(self.SENTINEL)
self.runsignal.clear()
def run(self, *args, **kwargs):
@ -170,16 +171,17 @@ class Async_Optimize(object):
messages=0, maxiter=5e3, max_f_eval=15e3, gtol=1e-6,
report_every=10, *args, **kwargs):
self.runsignal.set()
outqueue = Queue()
c = None
outqueue = None
if callback:
outqueue = Queue()
self.callback = callback
c = Thread(target=self.async_callback_collect, args=(outqueue,))
c.start()
p = _CGDAsync(f, df, x0, update_rule, self.runsignal, self.SENTINEL,
report_every=report_every, messages=messages, maxiter=maxiter,
max_f_eval=max_f_eval, gtol=gtol, outqueue=outqueue, *args, **kwargs)
p.run()
p.start()
return p, c
def opt(self, f, df, x0, callback=None, update_rule=FletcherReeves,

View file

@ -14,7 +14,7 @@ from scipy.optimize.optimize import rosen, rosen_der
class Test(unittest.TestCase):
def testMinimizeSquare(self):
N = 2
N = 100
A = numpy.random.rand(N) * numpy.eye(N)
b = numpy.random.rand(N) * 0
f = lambda x: numpy.dot(x.T.dot(A), x) - numpy.dot(x.T, b)
@ -25,7 +25,7 @@ class Test(unittest.TestCase):
restarts = 10
for _ in range(restarts):
try:
x0 = numpy.random.randn(N) * .5
x0 = numpy.random.randn(N) * 300
res = opt.opt(f, df, x0, messages=0,
maxiter=1000, gtol=1e-10)
assert numpy.allclose(res[0], 0, atol=1e-3)
@ -37,10 +37,9 @@ class Test(unittest.TestCase):
raise AssertionError("Test failed for {} restarts".format(restarts))
def testRosen(self):
N = 2
N = 20
f = rosen
df = rosen_der
x0 = numpy.random.randn(N) * .5
opt = CGD()
@ -49,8 +48,8 @@ class Test(unittest.TestCase):
try:
x0 = numpy.random.randn(N) * .5
res = opt.opt(f, df, x0, messages=0,
maxiter=1000, gtol=1e-2)
assert numpy.allclose(res[0], 1, atol=.01)
maxiter=5e2, gtol=1e-2)
assert numpy.allclose(res[0], 1, atol=.1)
break
except:
# RESTART