From 5b8b3b2256c3fc1dd3404f8e2aace92b5525ba6c Mon Sep 17 00:00:00 2001 From: Max Zwiessele Date: Mon, 12 May 2014 11:42:53 +0100 Subject: [PATCH] [copy] handled hierarchy error for copying --- GPy/core/parameterization/lists_and_dicts.py | 21 ++++--- GPy/core/parameterization/param.py | 7 +++ GPy/core/parameterization/parameter_core.py | 64 +++++++++++++++----- 3 files changed, 68 insertions(+), 24 deletions(-) diff --git a/GPy/core/parameterization/lists_and_dicts.py b/GPy/core/parameterization/lists_and_dicts.py index 604d0a01..13547c94 100644 --- a/GPy/core/parameterization/lists_and_dicts.py +++ b/GPy/core/parameterization/lists_and_dicts.py @@ -59,13 +59,14 @@ class ObservablesList(object): return self._poc.__repr__() def add(self, priority, observable, callble): - ins = 0 - for pr, _, _ in self: - if priority > pr: - break - ins += 1 - self._poc.insert(ins, (priority, weakref.ref(observable), callble)) - + if observable is not None: + ins = 0 + for pr, _, _ in self: + if priority > pr: + break + ins += 1 + self._poc.insert(ins, (priority, weakref.ref(observable), callble)) + def __str__(self): ret = [] curr_p = None @@ -96,8 +97,10 @@ class ObservablesList(object): def __deepcopy__(self, memo): self.flush() s = ObservablesList() - import copy - s._poc = copy.deepcopy(self._poc, memo) + for p,o,c in self._poc: + import copy + s.add(p, copy.deepcopy(o(), memo), copy.deepcopy(c, memo)) + s.flush() return s def __getstate__(self): diff --git a/GPy/core/parameterization/param.py b/GPy/core/parameterization/param.py index 7055838a..1c67b9d9 100644 --- a/GPy/core/parameterization/param.py +++ b/GPy/core/parameterization/param.py @@ -156,6 +156,13 @@ class Param(OptimizationHandlable, ObsAr): def _ensure_fixes(self): if not self._has_fixes(): self._fixes_ = numpy.ones(self._realsize_, dtype=bool) + #=========================================================================== + # parameterizable + #=========================================================================== + def traverse(self, visit, *args, **kwargs): + visit(self, *args, **kwargs) + + #=========================================================================== # Convenience #=========================================================================== diff --git a/GPy/core/parameterization/parameter_core.py b/GPy/core/parameterization/parameter_core.py index 68140763..93924678 100644 --- a/GPy/core/parameterization/parameter_core.py +++ b/GPy/core/parameterization/parameter_core.py @@ -17,7 +17,7 @@ from transformations import Logexp, NegativeLogexp, Logistic, __fixed__, FIXED, import numpy as np import re -__updated__ = '2014-04-16' +__updated__ = '2014-05-12' class HierarchyError(Exception): """ @@ -124,7 +124,7 @@ class Parentable(object): """ Disconnect this object from its parent """ - raise NotImplementedError, "Abstaract superclass" + raise NotImplementedError, "Abstract superclass" @property def _highest_parent_(self): @@ -162,14 +162,13 @@ class Pickleable(object): :param protocol: pickling protocol to use, python-pickle for details. """ import cPickle as pickle - import pickle #TODO: cPickle if isinstance(f, str): with open(f, 'w') as f: pickle.dump(self, f, protocol) else: pickle.dump(self, f, protocol) - #=========================================================================== + #=========================================================================== # copy and pickling #=========================================================================== def copy(self): @@ -177,7 +176,12 @@ class Pickleable(object): #raise NotImplementedError, "Copy is not yet implemented, TODO: Observable hierarchy" import copy memo = {} - memo[id(self._parent_)] = None + parents = [] + self.traverse_parents(parents.append) + # remove self, which is the first arguments + parents = [p for p in parents if p is not self] + for p in parents: + memo[id(p)] = None memo[id(self.gradient)] = None memo[id(self.param_array)] = None memo[id(self._fixes_)] = None @@ -202,9 +206,6 @@ class Pickleable(object): dc = dict() for k,v in self.__dict__.iteritems(): if k not in ignore_list: - #if hasattr(v, "__getstate__"): - #dc[k] = v.__getstate__() - #else: dc[k] = v return dc @@ -212,12 +213,6 @@ class Pickleable(object): self.__dict__.update(state) return self - #def __getstate__(self, memo): - # raise NotImplementedError, "get state must be implemented to be able to pickle objects" - - #def __setstate__(self, memo): - # raise NotImplementedError, "set state must be implemented to be able to pickle objects" - class Gradcheckable(Pickleable, Parentable): """ Adds the functionality for an object to be gradcheckable. @@ -644,6 +639,7 @@ class OptimizationHandlable(Constrainable): else: names = [adjust(x.name) for x in self._parameters_] if add_self: names = map(lambda x: adjust(self.name) + "." + x, names) return names + def _get_param_names(self): n = np.array([p.hierarchy_name() + '[' + str(i) + ']' for p in self.flattened_parameters for i in p._indices()]) return n @@ -710,12 +706,14 @@ class Parameterizable(OptimizationHandlable): super(Parameterizable, self).__init__(*args, **kwargs) from GPy.core.parameterization.lists_and_dicts import ArrayList self._parameters_ = ArrayList() + self._param_array_ = None self.size = 0 self._added_names_ = set() + self.__visited = False # for traversing in reverse order we need to know if we were here already @property def param_array(self): - if not hasattr(self, '_param_array_'): + if self._param_array_ is None: self._param_array_ = np.empty(self.size, dtype=np.float64) return self._param_array_ @@ -723,6 +721,42 @@ class Parameterizable(OptimizationHandlable): def param_array(self, arr): self._param_array_ = arr + def traverse(self, visit, *args, **kwargs): + """ + Traverse the hierarchy performing visit(self, *args, **kwargs) at every node passed by. + See "visitor pattern" in literature. This is implemented in pre-order fashion. + + Example: + Collect all children: + + children = [] + self.traverse(children.append) + print children + """ + if not self.__visited: + visit(self, *args, **kwargs) + self.__visited = True + for c in self._parameters_: + c.traverse(visit, *args, **kwargs) + + def traverse_parents(self, visit, *args, **kwargs): + """ + Traverse the hierarchy upwards, visiting all parents and their children. + See "visitor pattern" in literature. This is implemented in pre-order fashion. + + Example: + + parents = [] + self.traverse_parents(parents.append) + print parents + """ + if not self.__visited: + visit(self, *args, **kwargs) + self.__visited = True + if self.has_parent(): + self._parent_.traverse_parents(visit, *args, **kwargs) + self._parent_.traverse(visit, *args, **kwargs) + self.__visited = False #========================================================================= # Gradient handling #=========================================================================