[param] hierarchy traversal easier now

This commit is contained in:
mzwiessele 2014-05-14 08:53:56 +01:00
parent cff37293d9
commit 8d6eed6010
4 changed files with 57 additions and 33 deletions

View file

@ -88,19 +88,17 @@ class ObservablesList(object):
def __iter__(self): def __iter__(self):
self.flush() self.flush()
for p, o, c in self._poc: for p, o, c in self._poc:
if o() is not None: yield p, o(), c
yield p, o(), c
def __len__(self): def __len__(self):
self.flush() self.flush()
return self._poc.__len__() return self._poc.__len__()
def __deepcopy__(self, memo): def __deepcopy__(self, memo):
self.flush()
s = ObservablesList() s = ObservablesList()
for p,o,c in self._poc: for p,o,c in self:
import copy import copy
s.add(p, copy.deepcopy(o(), memo), copy.deepcopy(c, memo)) s.add(p, copy.deepcopy(o, memo), copy.deepcopy(c, memo))
s.flush() s.flush()
return s return s

View file

@ -169,8 +169,29 @@ class Param(OptimizationHandlable, ObsAr):
# parameterizable # parameterizable
#=========================================================================== #===========================================================================
def traverse(self, visit, *args, **kwargs): def traverse(self, visit, *args, **kwargs):
"""
Traverse the hierarchy performing visit(self, *args, **kwargs) at every node passed by.
See "visitor pattern" in literature. This is implemented in pre-order fashion.
This will function will just call visit on self, as Param are leaf nodes.
"""
visit(self, *args, **kwargs) visit(self, *args, **kwargs)
def traverse_parents(self, visit, *args, **kwargs):
"""
Traverse the hierarchy upwards, visiting all parents and their children, except self.
See "visitor pattern" in literature. This is implemented in pre-order fashion.
Example:
parents = []
self.traverse_parents(parents.append)
print parents
"""
if self.has_parent():
self.__visited = True
self._parent_._traverse_parents(visit, *args, **kwargs)
self.__visited = False
#=========================================================================== #===========================================================================
# Convenience # Convenience

View file

@ -176,24 +176,23 @@ class Pickleable(object):
#raise NotImplementedError, "Copy is not yet implemented, TODO: Observable hierarchy" #raise NotImplementedError, "Copy is not yet implemented, TODO: Observable hierarchy"
import copy import copy
memo = {} memo = {}
# the next part makes sure that we do not include parents in any form:
parents = [] parents = []
self.traverse_parents(parents.append) self.traverse_parents(parents.append) # collect parents
# remove self, which is the first arguments
parents = [p for p in parents if p is not self]
for p in parents: for p in parents:
memo[id(p)] = None memo[id(p)] = None # set all parents to be None, so they will not be copied
memo[id(self.gradient)] = None memo[id(self.gradient)] = None # reset the gradient
memo[id(self.param_array)] = None memo[id(self.param_array)] = None # and param_array
memo[id(self._fixes_)] = None memo[id(self._fixes_)] = None # fixes have to be reset, as this is now highest parent
c = copy.deepcopy(self, memo) c = copy.deepcopy(self, memo) # and start the copy
c._parent_index_ = None c._parent_index_ = None
return c return c
def __deepcopy__(self, memo): def __deepcopy__(self, memo):
s = self.__new__(self.__class__) s = self.__new__(self.__class__) # fresh instance
memo[id(self)] = s memo[id(self)] = s # be sure to break all cycles --> self is already done
import copy import copy
s.__dict__.update(copy.deepcopy(self.__dict__, memo)) s.__dict__.update(copy.deepcopy(self.__dict__, memo)) # standard copy
return s return s
def __getstate__(self): def __getstate__(self):
@ -580,12 +579,6 @@ class OptimizationHandlable(Constrainable):
def __init__(self, name, default_constraint=None, *a, **kw): def __init__(self, name, default_constraint=None, *a, **kw):
super(OptimizationHandlable, self).__init__(name, default_constraint=default_constraint, *a, **kw) super(OptimizationHandlable, self).__init__(name, default_constraint=default_constraint, *a, **kw)
def transform(self):
[np.put(self.param_array, ind, c.finv(self.param_array.flat[ind])) for c, ind in self.constraints.iteritems() if c != __fixed__]
def untransform(self):
[np.put(self.param_array, ind, c.f(self.param_array.flat[ind])) for c, ind in self.constraints.iteritems() if c != __fixed__]
def _get_params_transformed(self): def _get_params_transformed(self):
# transformed parameters (apply transformation rules) # transformed parameters (apply transformation rules)
p = self.param_array.copy() p = self.param_array.copy()
@ -606,7 +599,8 @@ class OptimizationHandlable(Constrainable):
self.param_array.flat[fixes] = p self.param_array.flat[fixes] = p
elif self._has_fixes(): self.param_array.flat[self._fixes_] = p elif self._has_fixes(): self.param_array.flat[self._fixes_] = p
else: self.param_array.flat = p else: self.param_array.flat = p
self.untransform() [np.put(self.param_array, ind, c.f(self.param_array.flat[ind]))
for c, ind in self.constraints.iteritems() if c != __fixed__]
self._trigger_params_changed() self._trigger_params_changed()
def _trigger_params_changed(self, trigger_parent=True): def _trigger_params_changed(self, trigger_parent=True):
@ -726,7 +720,9 @@ class Parameterizable(OptimizationHandlable):
def traverse(self, visit, *args, **kwargs): def traverse(self, visit, *args, **kwargs):
""" """
Traverse the hierarchy performing visit(self, *args, **kwargs) at every node passed by. Traverse the hierarchy performing visit(self, *args, **kwargs)
at every node passed by downwards. This function includes self!
See "visitor pattern" in literature. This is implemented in pre-order fashion. See "visitor pattern" in literature. This is implemented in pre-order fashion.
Example: Example:
@ -745,7 +741,7 @@ class Parameterizable(OptimizationHandlable):
def traverse_parents(self, visit, *args, **kwargs): def traverse_parents(self, visit, *args, **kwargs):
""" """
Traverse the hierarchy upwards, visiting all parents and their children. Traverse the hierarchy upwards, visiting all parents and their children except self.
See "visitor pattern" in literature. This is implemented in pre-order fashion. See "visitor pattern" in literature. This is implemented in pre-order fashion.
Example: Example:
@ -754,13 +750,20 @@ class Parameterizable(OptimizationHandlable):
self.traverse_parents(parents.append) self.traverse_parents(parents.append)
print parents print parents
""" """
if not self.__visited: if self.has_parent():
visit(self, *args, **kwargs)
self.__visited = True self.__visited = True
self._parent_._traverse_parents(visit, *args, **kwargs)
self.__visited = False
def _traverse_parents(self, visit, *args, **kwargs):
if not self.__visited:
self.__visited = True
visit(self, *args, **kwargs)
if self.has_parent(): if self.has_parent():
self._parent_.traverse_parents(visit, *args, **kwargs) self._parent_._traverse_parents(visit, *args, **kwargs)
self._parent_.traverse(visit, *args, **kwargs) self._parent_.traverse(visit, *args, **kwargs)
self.__visited = False self.__visited = False
#========================================================================= #=========================================================================
# Gradient handling # Gradient handling
#========================================================================= #=========================================================================
@ -827,11 +830,10 @@ class Parameterizable(OptimizationHandlable):
# raise HierarchyError, "parameter {} already in another model ({}), create new object (or copy) for adding".format(param._short(), param._highest_parent_._short()) # raise HierarchyError, "parameter {} already in another model ({}), create new object (or copy) for adding".format(param._short(), param._highest_parent_._short())
elif param not in self._parameters_: elif param not in self._parameters_:
if param.has_parent(): if param.has_parent():
parent = param._parent_ def visit(parent, self):
while parent is not None:
if parent is self: if parent is self:
raise HierarchyError, "You cannot add a parameter twice into the hierarchy" raise HierarchyError, "You cannot add a parameter twice into the hierarchy"
parent = parent._parent_ param.traverse_parents(visit, self)
param._parent_.remove_parameter(param) param._parent_.remove_parameter(param)
# make sure the size is set # make sure the size is set
if index is None: if index is None:
@ -875,7 +877,7 @@ class Parameterizable(OptimizationHandlable):
:param param: param object to remove from being a parameter of this parameterized object. :param param: param object to remove from being a parameter of this parameterized object.
""" """
if not param in self._parameters_: if not param in self._parameters_:
raise RuntimeError, "Parameter {} does not belong to this object, remove parameters directly from their respective parents".format(param._short()) raise RuntimeError, "Parameter {} does not belong to this object {}, remove parameters directly from their respective parents".format(param._short(), self.name)
start = sum([p.size for p in self._parameters_[:param._parent_index_]]) start = sum([p.size for p in self._parameters_[:param._parent_index_]])
self._remove_parameter_name(param) self._remove_parameter_name(param)

View file

@ -132,6 +132,9 @@ class Test(ListDictTestCase):
self.assertIsNot(par.full_gradient, pcopy.full_gradient) self.assertIsNot(par.full_gradient, pcopy.full_gradient)
self.assertTrue(pcopy.checkgrad()) self.assertTrue(pcopy.checkgrad())
self.assert_(np.any(pcopy.gradient!=0.0)) self.assert_(np.any(pcopy.gradient!=0.0))
pcopy.optimize('bfgs')
par.optimize('bfgs')
np.testing.assert_allclose(pcopy.param_array, par.param_array, atol=.001)
with tempfile.TemporaryFile('w+b') as f: with tempfile.TemporaryFile('w+b') as f:
par.pickle(f) par.pickle(f)
f.seek(0) f.seek(0)