bugfix: mixed up global and local index in unfixing

This commit is contained in:
mzwiessele 2014-04-17 15:01:43 +01:00
parent 483cb7ddc0
commit 8abc45c4ca
5 changed files with 49 additions and 19 deletions

View file

@ -184,7 +184,7 @@ class ParameterIndexOperationsView(object):
def remove(self, prop, indices): def remove(self, prop, indices):
removed = self._param_index_ops.remove(prop, numpy.array(indices)+self._offset) removed = self._param_index_ops.remove(prop, numpy.array(indices)+self._offset)
if removed.size > 0: if removed.size > 0:
return removed - self._size + 1 return removed-self._offset
return removed return removed

View file

@ -312,7 +312,8 @@ class Indexable(object):
This does not need to account for shaped parameters, as it This does not need to account for shaped parameters, as it
basically just sums up the parameter sizes which come before param. basically just sums up the parameter sizes which come before param.
""" """
raise NotImplementedError, "shouldnt happen, offset required from non parameterization object?" return 0
#raise NotImplementedError, "shouldnt happen, offset required from non parameterization object?"
def _raveled_index_for(self, param): def _raveled_index_for(self, param):
""" """
@ -320,7 +321,8 @@ class Indexable(object):
that is an int array, containing the indexes for the flattened that is an int array, containing the indexes for the flattened
param inside this parameterized logic. param inside this parameterized logic.
""" """
raise NotImplementedError, "shouldnt happen, raveld index transformation required from non parameterization object?" return param._raveled_index()
#raise NotImplementedError, "shouldnt happen, raveld index transformation required from non parameterization object?"
class Constrainable(Nameable, Indexable, Observable): class Constrainable(Nameable, Indexable, Observable):
@ -368,10 +370,10 @@ class Constrainable(Nameable, Indexable, Observable):
if value is not None: if value is not None:
self[:] = value self[:] = value
reconstrained = self.unconstrain() reconstrained = self.unconstrain()
self._add_to_index_operations(self.constraints, reconstrained, __fixed__, warning) index = self._add_to_index_operations(self.constraints, reconstrained, __fixed__, warning)
rav_i = self._highest_parent_._raveled_index_for(self) self._highest_parent_._set_fixed(self, index)
self._highest_parent_._set_fixed(rav_i)
self.notify_observers(self, None if trigger_parent else -np.inf) self.notify_observers(self, None if trigger_parent else -np.inf)
return index
fix = constrain_fixed fix = constrain_fixed
def unconstrain_fixed(self): def unconstrain_fixed(self):
@ -379,7 +381,8 @@ class Constrainable(Nameable, Indexable, Observable):
This parameter will no longer be fixed. This parameter will no longer be fixed.
""" """
unconstrained = self.unconstrain(__fixed__) unconstrained = self.unconstrain(__fixed__)
self._highest_parent_._set_unfixed(unconstrained) self._highest_parent_._set_unfixed(self, unconstrained)
return unconstrained
unfix = unconstrain_fixed unfix = unconstrain_fixed
def _ensure_fixes(self): def _ensure_fixes(self):
@ -388,14 +391,16 @@ class Constrainable(Nameable, Indexable, Observable):
# Param: ones(self._realsize_ # Param: ones(self._realsize_
if not self._has_fixes(): self._fixes_ = np.ones(self.size, dtype=bool) if not self._has_fixes(): self._fixes_ = np.ones(self.size, dtype=bool)
def _set_fixed(self, index): def _set_fixed(self, param, index):
self._ensure_fixes() self._ensure_fixes()
self._fixes_[index] = FIXED offset = self._offset_for(param)
self._fixes_[index+offset] = FIXED
if np.all(self._fixes_): self._fixes_ = None # ==UNFIXED if np.all(self._fixes_): self._fixes_ = None # ==UNFIXED
def _set_unfixed(self, index): def _set_unfixed(self, param, index):
self._ensure_fixes() self._ensure_fixes()
self._fixes_[index] = UNFIXED offset = self._offset_for(param)
self._fixes_[index+offset] = UNFIXED
if np.all(self._fixes_): self._fixes_ = None # ==UNFIXED if np.all(self._fixes_): self._fixes_ = None # ==UNFIXED
def _connect_fixes(self): def _connect_fixes(self):
@ -469,8 +474,9 @@ class Constrainable(Nameable, Indexable, Observable):
""" """
self.param_array[...] = transform.initialize(self.param_array) self.param_array[...] = transform.initialize(self.param_array)
reconstrained = self.unconstrain() reconstrained = self.unconstrain()
self._add_to_index_operations(self.constraints, reconstrained, transform, warning) added = self._add_to_index_operations(self.constraints, reconstrained, transform, warning)
self.notify_observers(self, None if trigger_parent else -np.inf) self.notify_observers(self, None if trigger_parent else -np.inf)
return added
def unconstrain(self, *transforms): def unconstrain(self, *transforms):
""" """
@ -549,7 +555,9 @@ class Constrainable(Nameable, Indexable, Observable):
if warning and reconstrained.size > 0: if warning and reconstrained.size > 0:
# TODO: figure out which parameters have changed and only print those # TODO: figure out which parameters have changed and only print those
print "WARNING: reconstraining parameters {}".format(self.parameter_names() or self.name) print "WARNING: reconstraining parameters {}".format(self.parameter_names() or self.name)
which.add(what, self._raveled_index()) index = self._raveled_index()
which.add(what, index)
return index
def _remove_from_index_operations(self, which, transforms): def _remove_from_index_operations(self, which, transforms):
""" """
@ -561,9 +569,10 @@ class Constrainable(Nameable, Indexable, Observable):
removed = np.empty((0,), dtype=int) removed = np.empty((0,), dtype=int)
for t in transforms: for t in transforms:
unconstrained = which.remove(t, self._raveled_index()) unconstrained = which.remove(t, self._raveled_index())
print unconstrained
removed = np.union1d(removed, unconstrained) removed = np.union1d(removed, unconstrained)
if t is __fixed__: if t is __fixed__:
self._highest_parent_._set_unfixed(unconstrained) self._highest_parent_._set_unfixed(self, unconstrained)
return removed return removed

View file

@ -20,6 +20,6 @@ except ImportError:
if sympy_available: if sympy_available:
from _src.symbolic import Symbolic from _src.symbolic import Symbolic
from _src.heat_eqinit import Heat_eqinit #from _src.heat_eqinit import Heat_eqinit
from _src.ode1_eq_lfm import Ode1_eq_lfm #from _src.ode1_eq_lfm import Ode1_eq_lfm

View file

@ -24,12 +24,14 @@ class Test(unittest.TestCase):
self.assertDictEqual(self.param_index._properties, {}) self.assertDictEqual(self.param_index._properties, {})
def test_remove(self): def test_remove(self):
self.param_index.remove(three, np.r_[3:10]) removed = self.param_index.remove(three, np.r_[3:10])
self.assertListEqual(removed.tolist(), [4, 7])
self.assertListEqual(self.param_index[three].tolist(), [2]) self.assertListEqual(self.param_index[three].tolist(), [2])
self.param_index.remove(one, [1]) removed = self.param_index.remove(one, [1])
self.assertListEqual(removed.tolist(), [])
self.assertListEqual(self.param_index[one].tolist(), [3]) self.assertListEqual(self.param_index[one].tolist(), [3])
self.assertListEqual(self.param_index.remove('not in there', []).tolist(), []) self.assertListEqual(self.param_index.remove('not in there', []).tolist(), [])
self.param_index.remove(one, [9]) removed = self.param_index.remove(one, [9])
self.assertListEqual(self.param_index[one].tolist(), [3]) self.assertListEqual(self.param_index[one].tolist(), [3])
self.assertListEqual(self.param_index.remove('not in there', [2,3,4]).tolist(), []) self.assertListEqual(self.param_index.remove('not in there', [2,3,4]).tolist(), [])
@ -78,6 +80,13 @@ class Test(unittest.TestCase):
self.assertEqual(i, i2) self.assertEqual(i, i2)
self.assertTrue(np.all(v == v2)) self.assertTrue(np.all(v == v2))
def test_indexview_remove(self):
removed = self.view.remove(two, [3])
self.assertListEqual(removed.tolist(), [3])
removed = self.view.remove(three, np.r_[:5])
self.assertListEqual(removed.tolist(), [0, 2])
def test_misc(self): def test_misc(self):
for k,v in self.param_index.copy()._properties.iteritems(): for k,v in self.param_index.copy()._properties.iteritems():
self.assertListEqual(self.param_index[k].tolist(), v.tolist()) self.assertListEqual(self.param_index[k].tolist(), v.tolist())

View file

@ -153,6 +153,18 @@ class ParameterizedTest(unittest.TestCase):
self.testmodel.randomize() self.testmodel.randomize()
np.testing.assert_equal(variances, self.testmodel['.*var'].values()) np.testing.assert_equal(variances, self.testmodel['.*var'].values())
def test_fix_unfix(self):
fixed = self.testmodel.kern.lengthscale.fix()
self.assertListEqual(fixed.tolist(), [0])
unfixed = self.testmodel.kern.lengthscale.unfix()
self.testmodel.kern.lengthscale.constrain_positive()
self.assertListEqual(unfixed.tolist(), [0])
fixed = self.testmodel.kern.fix()
self.assertListEqual(fixed.tolist(), [0,1])
unfixed = self.testmodel.kern.unfix()
self.assertListEqual(unfixed.tolist(), [0,1])
def test_printing(self): def test_printing(self):
print self.test1 print self.test1
print self.param print self.param