mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-13 14:03:20 +02:00
[param] hierarchy traversal easier now
This commit is contained in:
parent
cff37293d9
commit
8d6eed6010
4 changed files with 57 additions and 33 deletions
|
|
@ -88,19 +88,17 @@ class ObservablesList(object):
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
self.flush()
|
self.flush()
|
||||||
for p, o, c in self._poc:
|
for p, o, c in self._poc:
|
||||||
if o() is not None:
|
yield p, o(), c
|
||||||
yield p, o(), c
|
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
self.flush()
|
self.flush()
|
||||||
return self._poc.__len__()
|
return self._poc.__len__()
|
||||||
|
|
||||||
def __deepcopy__(self, memo):
|
def __deepcopy__(self, memo):
|
||||||
self.flush()
|
|
||||||
s = ObservablesList()
|
s = ObservablesList()
|
||||||
for p,o,c in self._poc:
|
for p,o,c in self:
|
||||||
import copy
|
import copy
|
||||||
s.add(p, copy.deepcopy(o(), memo), copy.deepcopy(c, memo))
|
s.add(p, copy.deepcopy(o, memo), copy.deepcopy(c, memo))
|
||||||
s.flush()
|
s.flush()
|
||||||
return s
|
return s
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -169,8 +169,29 @@ class Param(OptimizationHandlable, ObsAr):
|
||||||
# parameterizable
|
# parameterizable
|
||||||
#===========================================================================
|
#===========================================================================
|
||||||
def traverse(self, visit, *args, **kwargs):
|
def traverse(self, visit, *args, **kwargs):
|
||||||
visit(self, *args, **kwargs)
|
"""
|
||||||
|
Traverse the hierarchy performing visit(self, *args, **kwargs) at every node passed by.
|
||||||
|
See "visitor pattern" in literature. This is implemented in pre-order fashion.
|
||||||
|
|
||||||
|
This will function will just call visit on self, as Param are leaf nodes.
|
||||||
|
"""
|
||||||
|
visit(self, *args, **kwargs)
|
||||||
|
|
||||||
|
def traverse_parents(self, visit, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
Traverse the hierarchy upwards, visiting all parents and their children, except self.
|
||||||
|
See "visitor pattern" in literature. This is implemented in pre-order fashion.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
parents = []
|
||||||
|
self.traverse_parents(parents.append)
|
||||||
|
print parents
|
||||||
|
"""
|
||||||
|
if self.has_parent():
|
||||||
|
self.__visited = True
|
||||||
|
self._parent_._traverse_parents(visit, *args, **kwargs)
|
||||||
|
self.__visited = False
|
||||||
|
|
||||||
#===========================================================================
|
#===========================================================================
|
||||||
# Convenience
|
# Convenience
|
||||||
|
|
|
||||||
|
|
@ -176,24 +176,23 @@ class Pickleable(object):
|
||||||
#raise NotImplementedError, "Copy is not yet implemented, TODO: Observable hierarchy"
|
#raise NotImplementedError, "Copy is not yet implemented, TODO: Observable hierarchy"
|
||||||
import copy
|
import copy
|
||||||
memo = {}
|
memo = {}
|
||||||
|
# the next part makes sure that we do not include parents in any form:
|
||||||
parents = []
|
parents = []
|
||||||
self.traverse_parents(parents.append)
|
self.traverse_parents(parents.append) # collect parents
|
||||||
# remove self, which is the first arguments
|
|
||||||
parents = [p for p in parents if p is not self]
|
|
||||||
for p in parents:
|
for p in parents:
|
||||||
memo[id(p)] = None
|
memo[id(p)] = None # set all parents to be None, so they will not be copied
|
||||||
memo[id(self.gradient)] = None
|
memo[id(self.gradient)] = None # reset the gradient
|
||||||
memo[id(self.param_array)] = None
|
memo[id(self.param_array)] = None # and param_array
|
||||||
memo[id(self._fixes_)] = None
|
memo[id(self._fixes_)] = None # fixes have to be reset, as this is now highest parent
|
||||||
c = copy.deepcopy(self, memo)
|
c = copy.deepcopy(self, memo) # and start the copy
|
||||||
c._parent_index_ = None
|
c._parent_index_ = None
|
||||||
return c
|
return c
|
||||||
|
|
||||||
def __deepcopy__(self, memo):
|
def __deepcopy__(self, memo):
|
||||||
s = self.__new__(self.__class__)
|
s = self.__new__(self.__class__) # fresh instance
|
||||||
memo[id(self)] = s
|
memo[id(self)] = s # be sure to break all cycles --> self is already done
|
||||||
import copy
|
import copy
|
||||||
s.__dict__.update(copy.deepcopy(self.__dict__, memo))
|
s.__dict__.update(copy.deepcopy(self.__dict__, memo)) # standard copy
|
||||||
return s
|
return s
|
||||||
|
|
||||||
def __getstate__(self):
|
def __getstate__(self):
|
||||||
|
|
@ -580,12 +579,6 @@ class OptimizationHandlable(Constrainable):
|
||||||
def __init__(self, name, default_constraint=None, *a, **kw):
|
def __init__(self, name, default_constraint=None, *a, **kw):
|
||||||
super(OptimizationHandlable, self).__init__(name, default_constraint=default_constraint, *a, **kw)
|
super(OptimizationHandlable, self).__init__(name, default_constraint=default_constraint, *a, **kw)
|
||||||
|
|
||||||
def transform(self):
|
|
||||||
[np.put(self.param_array, ind, c.finv(self.param_array.flat[ind])) for c, ind in self.constraints.iteritems() if c != __fixed__]
|
|
||||||
|
|
||||||
def untransform(self):
|
|
||||||
[np.put(self.param_array, ind, c.f(self.param_array.flat[ind])) for c, ind in self.constraints.iteritems() if c != __fixed__]
|
|
||||||
|
|
||||||
def _get_params_transformed(self):
|
def _get_params_transformed(self):
|
||||||
# transformed parameters (apply transformation rules)
|
# transformed parameters (apply transformation rules)
|
||||||
p = self.param_array.copy()
|
p = self.param_array.copy()
|
||||||
|
|
@ -606,7 +599,8 @@ class OptimizationHandlable(Constrainable):
|
||||||
self.param_array.flat[fixes] = p
|
self.param_array.flat[fixes] = p
|
||||||
elif self._has_fixes(): self.param_array.flat[self._fixes_] = p
|
elif self._has_fixes(): self.param_array.flat[self._fixes_] = p
|
||||||
else: self.param_array.flat = p
|
else: self.param_array.flat = p
|
||||||
self.untransform()
|
[np.put(self.param_array, ind, c.f(self.param_array.flat[ind]))
|
||||||
|
for c, ind in self.constraints.iteritems() if c != __fixed__]
|
||||||
self._trigger_params_changed()
|
self._trigger_params_changed()
|
||||||
|
|
||||||
def _trigger_params_changed(self, trigger_parent=True):
|
def _trigger_params_changed(self, trigger_parent=True):
|
||||||
|
|
@ -726,7 +720,9 @@ class Parameterizable(OptimizationHandlable):
|
||||||
|
|
||||||
def traverse(self, visit, *args, **kwargs):
|
def traverse(self, visit, *args, **kwargs):
|
||||||
"""
|
"""
|
||||||
Traverse the hierarchy performing visit(self, *args, **kwargs) at every node passed by.
|
Traverse the hierarchy performing visit(self, *args, **kwargs)
|
||||||
|
at every node passed by downwards. This function includes self!
|
||||||
|
|
||||||
See "visitor pattern" in literature. This is implemented in pre-order fashion.
|
See "visitor pattern" in literature. This is implemented in pre-order fashion.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
|
|
@ -745,7 +741,7 @@ class Parameterizable(OptimizationHandlable):
|
||||||
|
|
||||||
def traverse_parents(self, visit, *args, **kwargs):
|
def traverse_parents(self, visit, *args, **kwargs):
|
||||||
"""
|
"""
|
||||||
Traverse the hierarchy upwards, visiting all parents and their children.
|
Traverse the hierarchy upwards, visiting all parents and their children except self.
|
||||||
See "visitor pattern" in literature. This is implemented in pre-order fashion.
|
See "visitor pattern" in literature. This is implemented in pre-order fashion.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
|
|
@ -754,13 +750,20 @@ class Parameterizable(OptimizationHandlable):
|
||||||
self.traverse_parents(parents.append)
|
self.traverse_parents(parents.append)
|
||||||
print parents
|
print parents
|
||||||
"""
|
"""
|
||||||
if not self.__visited:
|
if self.has_parent():
|
||||||
visit(self, *args, **kwargs)
|
|
||||||
self.__visited = True
|
self.__visited = True
|
||||||
|
self._parent_._traverse_parents(visit, *args, **kwargs)
|
||||||
|
self.__visited = False
|
||||||
|
|
||||||
|
def _traverse_parents(self, visit, *args, **kwargs):
|
||||||
|
if not self.__visited:
|
||||||
|
self.__visited = True
|
||||||
|
visit(self, *args, **kwargs)
|
||||||
if self.has_parent():
|
if self.has_parent():
|
||||||
self._parent_.traverse_parents(visit, *args, **kwargs)
|
self._parent_._traverse_parents(visit, *args, **kwargs)
|
||||||
self._parent_.traverse(visit, *args, **kwargs)
|
self._parent_.traverse(visit, *args, **kwargs)
|
||||||
self.__visited = False
|
self.__visited = False
|
||||||
|
|
||||||
#=========================================================================
|
#=========================================================================
|
||||||
# Gradient handling
|
# Gradient handling
|
||||||
#=========================================================================
|
#=========================================================================
|
||||||
|
|
@ -827,11 +830,10 @@ class Parameterizable(OptimizationHandlable):
|
||||||
# raise HierarchyError, "parameter {} already in another model ({}), create new object (or copy) for adding".format(param._short(), param._highest_parent_._short())
|
# raise HierarchyError, "parameter {} already in another model ({}), create new object (or copy) for adding".format(param._short(), param._highest_parent_._short())
|
||||||
elif param not in self._parameters_:
|
elif param not in self._parameters_:
|
||||||
if param.has_parent():
|
if param.has_parent():
|
||||||
parent = param._parent_
|
def visit(parent, self):
|
||||||
while parent is not None:
|
|
||||||
if parent is self:
|
if parent is self:
|
||||||
raise HierarchyError, "You cannot add a parameter twice into the hierarchy"
|
raise HierarchyError, "You cannot add a parameter twice into the hierarchy"
|
||||||
parent = parent._parent_
|
param.traverse_parents(visit, self)
|
||||||
param._parent_.remove_parameter(param)
|
param._parent_.remove_parameter(param)
|
||||||
# make sure the size is set
|
# make sure the size is set
|
||||||
if index is None:
|
if index is None:
|
||||||
|
|
@ -875,7 +877,7 @@ class Parameterizable(OptimizationHandlable):
|
||||||
:param param: param object to remove from being a parameter of this parameterized object.
|
:param param: param object to remove from being a parameter of this parameterized object.
|
||||||
"""
|
"""
|
||||||
if not param in self._parameters_:
|
if not param in self._parameters_:
|
||||||
raise RuntimeError, "Parameter {} does not belong to this object, remove parameters directly from their respective parents".format(param._short())
|
raise RuntimeError, "Parameter {} does not belong to this object {}, remove parameters directly from their respective parents".format(param._short(), self.name)
|
||||||
|
|
||||||
start = sum([p.size for p in self._parameters_[:param._parent_index_]])
|
start = sum([p.size for p in self._parameters_[:param._parent_index_]])
|
||||||
self._remove_parameter_name(param)
|
self._remove_parameter_name(param)
|
||||||
|
|
|
||||||
|
|
@ -132,6 +132,9 @@ class Test(ListDictTestCase):
|
||||||
self.assertIsNot(par.full_gradient, pcopy.full_gradient)
|
self.assertIsNot(par.full_gradient, pcopy.full_gradient)
|
||||||
self.assertTrue(pcopy.checkgrad())
|
self.assertTrue(pcopy.checkgrad())
|
||||||
self.assert_(np.any(pcopy.gradient!=0.0))
|
self.assert_(np.any(pcopy.gradient!=0.0))
|
||||||
|
pcopy.optimize('bfgs')
|
||||||
|
par.optimize('bfgs')
|
||||||
|
np.testing.assert_allclose(pcopy.param_array, par.param_array, atol=.001)
|
||||||
with tempfile.TemporaryFile('w+b') as f:
|
with tempfile.TemporaryFile('w+b') as f:
|
||||||
par.pickle(f)
|
par.pickle(f)
|
||||||
f.seek(0)
|
f.seek(0)
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue