[param] hierarchy traversal easier now

This commit is contained in:
mzwiessele 2014-05-14 08:53:56 +01:00
parent cff37293d9
commit 8d6eed6010
4 changed files with 57 additions and 33 deletions

View file

@ -88,19 +88,17 @@ class ObservablesList(object):
def __iter__(self):
self.flush()
for p, o, c in self._poc:
if o() is not None:
yield p, o(), c
yield p, o(), c
def __len__(self):
self.flush()
return self._poc.__len__()
def __deepcopy__(self, memo):
self.flush()
s = ObservablesList()
for p,o,c in self._poc:
for p,o,c in self:
import copy
s.add(p, copy.deepcopy(o(), memo), copy.deepcopy(c, memo))
s.add(p, copy.deepcopy(o, memo), copy.deepcopy(c, memo))
s.flush()
return s

View file

@ -169,8 +169,29 @@ class Param(OptimizationHandlable, ObsAr):
# parameterizable
#===========================================================================
def traverse(self, visit, *args, **kwargs):
visit(self, *args, **kwargs)
"""
Traverse the hierarchy performing visit(self, *args, **kwargs) at every node passed by.
See "visitor pattern" in literature. This is implemented in pre-order fashion.
This will function will just call visit on self, as Param are leaf nodes.
"""
visit(self, *args, **kwargs)
def traverse_parents(self, visit, *args, **kwargs):
"""
Traverse the hierarchy upwards, visiting all parents and their children, except self.
See "visitor pattern" in literature. This is implemented in pre-order fashion.
Example:
parents = []
self.traverse_parents(parents.append)
print parents
"""
if self.has_parent():
self.__visited = True
self._parent_._traverse_parents(visit, *args, **kwargs)
self.__visited = False
#===========================================================================
# Convenience

View file

@ -176,24 +176,23 @@ class Pickleable(object):
#raise NotImplementedError, "Copy is not yet implemented, TODO: Observable hierarchy"
import copy
memo = {}
# the next part makes sure that we do not include parents in any form:
parents = []
self.traverse_parents(parents.append)
# remove self, which is the first arguments
parents = [p for p in parents if p is not self]
self.traverse_parents(parents.append) # collect parents
for p in parents:
memo[id(p)] = None
memo[id(self.gradient)] = None
memo[id(self.param_array)] = None
memo[id(self._fixes_)] = None
c = copy.deepcopy(self, memo)
memo[id(p)] = None # set all parents to be None, so they will not be copied
memo[id(self.gradient)] = None # reset the gradient
memo[id(self.param_array)] = None # and param_array
memo[id(self._fixes_)] = None # fixes have to be reset, as this is now highest parent
c = copy.deepcopy(self, memo) # and start the copy
c._parent_index_ = None
return c
def __deepcopy__(self, memo):
s = self.__new__(self.__class__)
memo[id(self)] = s
s = self.__new__(self.__class__) # fresh instance
memo[id(self)] = s # be sure to break all cycles --> self is already done
import copy
s.__dict__.update(copy.deepcopy(self.__dict__, memo))
s.__dict__.update(copy.deepcopy(self.__dict__, memo)) # standard copy
return s
def __getstate__(self):
@ -580,12 +579,6 @@ class OptimizationHandlable(Constrainable):
def __init__(self, name, default_constraint=None, *a, **kw):
super(OptimizationHandlable, self).__init__(name, default_constraint=default_constraint, *a, **kw)
def transform(self):
[np.put(self.param_array, ind, c.finv(self.param_array.flat[ind])) for c, ind in self.constraints.iteritems() if c != __fixed__]
def untransform(self):
[np.put(self.param_array, ind, c.f(self.param_array.flat[ind])) for c, ind in self.constraints.iteritems() if c != __fixed__]
def _get_params_transformed(self):
# transformed parameters (apply transformation rules)
p = self.param_array.copy()
@ -606,7 +599,8 @@ class OptimizationHandlable(Constrainable):
self.param_array.flat[fixes] = p
elif self._has_fixes(): self.param_array.flat[self._fixes_] = p
else: self.param_array.flat = p
self.untransform()
[np.put(self.param_array, ind, c.f(self.param_array.flat[ind]))
for c, ind in self.constraints.iteritems() if c != __fixed__]
self._trigger_params_changed()
def _trigger_params_changed(self, trigger_parent=True):
@ -726,7 +720,9 @@ class Parameterizable(OptimizationHandlable):
def traverse(self, visit, *args, **kwargs):
"""
Traverse the hierarchy performing visit(self, *args, **kwargs) at every node passed by.
Traverse the hierarchy performing visit(self, *args, **kwargs)
at every node passed by downwards. This function includes self!
See "visitor pattern" in literature. This is implemented in pre-order fashion.
Example:
@ -745,7 +741,7 @@ class Parameterizable(OptimizationHandlable):
def traverse_parents(self, visit, *args, **kwargs):
"""
Traverse the hierarchy upwards, visiting all parents and their children.
Traverse the hierarchy upwards, visiting all parents and their children except self.
See "visitor pattern" in literature. This is implemented in pre-order fashion.
Example:
@ -754,13 +750,20 @@ class Parameterizable(OptimizationHandlable):
self.traverse_parents(parents.append)
print parents
"""
if not self.__visited:
visit(self, *args, **kwargs)
if self.has_parent():
self.__visited = True
self._parent_._traverse_parents(visit, *args, **kwargs)
self.__visited = False
def _traverse_parents(self, visit, *args, **kwargs):
if not self.__visited:
self.__visited = True
visit(self, *args, **kwargs)
if self.has_parent():
self._parent_.traverse_parents(visit, *args, **kwargs)
self._parent_._traverse_parents(visit, *args, **kwargs)
self._parent_.traverse(visit, *args, **kwargs)
self.__visited = False
#=========================================================================
# Gradient handling
#=========================================================================
@ -827,11 +830,10 @@ class Parameterizable(OptimizationHandlable):
# raise HierarchyError, "parameter {} already in another model ({}), create new object (or copy) for adding".format(param._short(), param._highest_parent_._short())
elif param not in self._parameters_:
if param.has_parent():
parent = param._parent_
while parent is not None:
def visit(parent, self):
if parent is self:
raise HierarchyError, "You cannot add a parameter twice into the hierarchy"
parent = parent._parent_
param.traverse_parents(visit, self)
param._parent_.remove_parameter(param)
# make sure the size is set
if index is None:
@ -875,7 +877,7 @@ class Parameterizable(OptimizationHandlable):
:param param: param object to remove from being a parameter of this parameterized object.
"""
if not param in self._parameters_:
raise RuntimeError, "Parameter {} does not belong to this object, remove parameters directly from their respective parents".format(param._short())
raise RuntimeError, "Parameter {} does not belong to this object {}, remove parameters directly from their respective parents".format(param._short(), self.name)
start = sum([p.size for p in self._parameters_[:param._parent_index_]])
self._remove_parameter_name(param)