bugfix: mixed up global and local index in unfixing

This commit is contained in:
mzwiessele 2014-04-17 15:01:43 +01:00
parent 483cb7ddc0
commit 8abc45c4ca
5 changed files with 49 additions and 19 deletions

View file

@ -184,7 +184,7 @@ class ParameterIndexOperationsView(object):
def remove(self, prop, indices):
removed = self._param_index_ops.remove(prop, numpy.array(indices)+self._offset)
if removed.size > 0:
return removed - self._size + 1
return removed-self._offset
return removed

View file

@ -312,7 +312,8 @@ class Indexable(object):
This does not need to account for shaped parameters, as it
basically just sums up the parameter sizes which come before param.
"""
raise NotImplementedError, "shouldnt happen, offset required from non parameterization object?"
return 0
#raise NotImplementedError, "shouldnt happen, offset required from non parameterization object?"
def _raveled_index_for(self, param):
"""
@ -320,7 +321,8 @@ class Indexable(object):
that is an int array, containing the indexes for the flattened
param inside this parameterized logic.
"""
raise NotImplementedError, "shouldnt happen, raveld index transformation required from non parameterization object?"
return param._raveled_index()
#raise NotImplementedError, "shouldnt happen, raveld index transformation required from non parameterization object?"
class Constrainable(Nameable, Indexable, Observable):
@ -368,10 +370,10 @@ class Constrainable(Nameable, Indexable, Observable):
if value is not None:
self[:] = value
reconstrained = self.unconstrain()
self._add_to_index_operations(self.constraints, reconstrained, __fixed__, warning)
rav_i = self._highest_parent_._raveled_index_for(self)
self._highest_parent_._set_fixed(rav_i)
index = self._add_to_index_operations(self.constraints, reconstrained, __fixed__, warning)
self._highest_parent_._set_fixed(self, index)
self.notify_observers(self, None if trigger_parent else -np.inf)
return index
fix = constrain_fixed
def unconstrain_fixed(self):
@ -379,7 +381,8 @@ class Constrainable(Nameable, Indexable, Observable):
This parameter will no longer be fixed.
"""
unconstrained = self.unconstrain(__fixed__)
self._highest_parent_._set_unfixed(unconstrained)
self._highest_parent_._set_unfixed(self, unconstrained)
return unconstrained
unfix = unconstrain_fixed
def _ensure_fixes(self):
@ -388,14 +391,16 @@ class Constrainable(Nameable, Indexable, Observable):
# Param: ones(self._realsize_
if not self._has_fixes(): self._fixes_ = np.ones(self.size, dtype=bool)
def _set_fixed(self, index):
def _set_fixed(self, param, index):
self._ensure_fixes()
self._fixes_[index] = FIXED
offset = self._offset_for(param)
self._fixes_[index+offset] = FIXED
if np.all(self._fixes_): self._fixes_ = None # ==UNFIXED
def _set_unfixed(self, index):
def _set_unfixed(self, param, index):
self._ensure_fixes()
self._fixes_[index] = UNFIXED
offset = self._offset_for(param)
self._fixes_[index+offset] = UNFIXED
if np.all(self._fixes_): self._fixes_ = None # ==UNFIXED
def _connect_fixes(self):
@ -469,8 +474,9 @@ class Constrainable(Nameable, Indexable, Observable):
"""
self.param_array[...] = transform.initialize(self.param_array)
reconstrained = self.unconstrain()
self._add_to_index_operations(self.constraints, reconstrained, transform, warning)
added = self._add_to_index_operations(self.constraints, reconstrained, transform, warning)
self.notify_observers(self, None if trigger_parent else -np.inf)
return added
def unconstrain(self, *transforms):
"""
@ -549,7 +555,9 @@ class Constrainable(Nameable, Indexable, Observable):
if warning and reconstrained.size > 0:
# TODO: figure out which parameters have changed and only print those
print "WARNING: reconstraining parameters {}".format(self.parameter_names() or self.name)
which.add(what, self._raveled_index())
index = self._raveled_index()
which.add(what, index)
return index
def _remove_from_index_operations(self, which, transforms):
"""
@ -561,9 +569,10 @@ class Constrainable(Nameable, Indexable, Observable):
removed = np.empty((0,), dtype=int)
for t in transforms:
unconstrained = which.remove(t, self._raveled_index())
print unconstrained
removed = np.union1d(removed, unconstrained)
if t is __fixed__:
self._highest_parent_._set_unfixed(unconstrained)
self._highest_parent_._set_unfixed(self, unconstrained)
return removed

View file

@ -20,6 +20,6 @@ except ImportError:
if sympy_available:
from _src.symbolic import Symbolic
from _src.heat_eqinit import Heat_eqinit
from _src.ode1_eq_lfm import Ode1_eq_lfm
#from _src.heat_eqinit import Heat_eqinit
#from _src.ode1_eq_lfm import Ode1_eq_lfm

View file

@ -24,12 +24,14 @@ class Test(unittest.TestCase):
self.assertDictEqual(self.param_index._properties, {})
def test_remove(self):
self.param_index.remove(three, np.r_[3:10])
removed = self.param_index.remove(three, np.r_[3:10])
self.assertListEqual(removed.tolist(), [4, 7])
self.assertListEqual(self.param_index[three].tolist(), [2])
self.param_index.remove(one, [1])
removed = self.param_index.remove(one, [1])
self.assertListEqual(removed.tolist(), [])
self.assertListEqual(self.param_index[one].tolist(), [3])
self.assertListEqual(self.param_index.remove('not in there', []).tolist(), [])
self.param_index.remove(one, [9])
removed = self.param_index.remove(one, [9])
self.assertListEqual(self.param_index[one].tolist(), [3])
self.assertListEqual(self.param_index.remove('not in there', [2,3,4]).tolist(), [])
@ -78,6 +80,13 @@ class Test(unittest.TestCase):
self.assertEqual(i, i2)
self.assertTrue(np.all(v == v2))
def test_indexview_remove(self):
removed = self.view.remove(two, [3])
self.assertListEqual(removed.tolist(), [3])
removed = self.view.remove(three, np.r_[:5])
self.assertListEqual(removed.tolist(), [0, 2])
def test_misc(self):
for k,v in self.param_index.copy()._properties.iteritems():
self.assertListEqual(self.param_index[k].tolist(), v.tolist())

View file

@ -153,6 +153,18 @@ class ParameterizedTest(unittest.TestCase):
self.testmodel.randomize()
np.testing.assert_equal(variances, self.testmodel['.*var'].values())
def test_fix_unfix(self):
fixed = self.testmodel.kern.lengthscale.fix()
self.assertListEqual(fixed.tolist(), [0])
unfixed = self.testmodel.kern.lengthscale.unfix()
self.testmodel.kern.lengthscale.constrain_positive()
self.assertListEqual(unfixed.tolist(), [0])
fixed = self.testmodel.kern.fix()
self.assertListEqual(fixed.tolist(), [0,1])
unfixed = self.testmodel.kern.unfix()
self.assertListEqual(unfixed.tolist(), [0,1])
def test_printing(self):
print self.test1
print self.param