[copy] handled hierarchy error for copying

This commit is contained in:
Max Zwiessele 2014-05-12 11:42:53 +01:00
parent a163bf985e
commit 5b8b3b2256
3 changed files with 68 additions and 24 deletions

View file

@ -59,13 +59,14 @@ class ObservablesList(object):
return self._poc.__repr__() return self._poc.__repr__()
def add(self, priority, observable, callble): def add(self, priority, observable, callble):
ins = 0 if observable is not None:
for pr, _, _ in self: ins = 0
if priority > pr: for pr, _, _ in self:
break if priority > pr:
ins += 1 break
self._poc.insert(ins, (priority, weakref.ref(observable), callble)) ins += 1
self._poc.insert(ins, (priority, weakref.ref(observable), callble))
def __str__(self): def __str__(self):
ret = [] ret = []
curr_p = None curr_p = None
@ -96,8 +97,10 @@ class ObservablesList(object):
def __deepcopy__(self, memo): def __deepcopy__(self, memo):
self.flush() self.flush()
s = ObservablesList() s = ObservablesList()
import copy for p,o,c in self._poc:
s._poc = copy.deepcopy(self._poc, memo) import copy
s.add(p, copy.deepcopy(o(), memo), copy.deepcopy(c, memo))
s.flush()
return s return s
def __getstate__(self): def __getstate__(self):

View file

@ -156,6 +156,13 @@ class Param(OptimizationHandlable, ObsAr):
def _ensure_fixes(self): def _ensure_fixes(self):
if not self._has_fixes(): self._fixes_ = numpy.ones(self._realsize_, dtype=bool) if not self._has_fixes(): self._fixes_ = numpy.ones(self._realsize_, dtype=bool)
#===========================================================================
# parameterizable
#===========================================================================
def traverse(self, visit, *args, **kwargs):
visit(self, *args, **kwargs)
#=========================================================================== #===========================================================================
# Convenience # Convenience
#=========================================================================== #===========================================================================

View file

@ -17,7 +17,7 @@ from transformations import Logexp, NegativeLogexp, Logistic, __fixed__, FIXED,
import numpy as np import numpy as np
import re import re
__updated__ = '2014-04-16' __updated__ = '2014-05-12'
class HierarchyError(Exception): class HierarchyError(Exception):
""" """
@ -124,7 +124,7 @@ class Parentable(object):
""" """
Disconnect this object from its parent Disconnect this object from its parent
""" """
raise NotImplementedError, "Abstaract superclass" raise NotImplementedError, "Abstract superclass"
@property @property
def _highest_parent_(self): def _highest_parent_(self):
@ -162,14 +162,13 @@ class Pickleable(object):
:param protocol: pickling protocol to use, python-pickle for details. :param protocol: pickling protocol to use, python-pickle for details.
""" """
import cPickle as pickle import cPickle as pickle
import pickle #TODO: cPickle
if isinstance(f, str): if isinstance(f, str):
with open(f, 'w') as f: with open(f, 'w') as f:
pickle.dump(self, f, protocol) pickle.dump(self, f, protocol)
else: else:
pickle.dump(self, f, protocol) pickle.dump(self, f, protocol)
#=========================================================================== #===========================================================================
# copy and pickling # copy and pickling
#=========================================================================== #===========================================================================
def copy(self): def copy(self):
@ -177,7 +176,12 @@ class Pickleable(object):
#raise NotImplementedError, "Copy is not yet implemented, TODO: Observable hierarchy" #raise NotImplementedError, "Copy is not yet implemented, TODO: Observable hierarchy"
import copy import copy
memo = {} memo = {}
memo[id(self._parent_)] = None parents = []
self.traverse_parents(parents.append)
# remove self, which is the first arguments
parents = [p for p in parents if p is not self]
for p in parents:
memo[id(p)] = None
memo[id(self.gradient)] = None memo[id(self.gradient)] = None
memo[id(self.param_array)] = None memo[id(self.param_array)] = None
memo[id(self._fixes_)] = None memo[id(self._fixes_)] = None
@ -202,9 +206,6 @@ class Pickleable(object):
dc = dict() dc = dict()
for k,v in self.__dict__.iteritems(): for k,v in self.__dict__.iteritems():
if k not in ignore_list: if k not in ignore_list:
#if hasattr(v, "__getstate__"):
#dc[k] = v.__getstate__()
#else:
dc[k] = v dc[k] = v
return dc return dc
@ -212,12 +213,6 @@ class Pickleable(object):
self.__dict__.update(state) self.__dict__.update(state)
return self return self
#def __getstate__(self, memo):
# raise NotImplementedError, "get state must be implemented to be able to pickle objects"
#def __setstate__(self, memo):
# raise NotImplementedError, "set state must be implemented to be able to pickle objects"
class Gradcheckable(Pickleable, Parentable): class Gradcheckable(Pickleable, Parentable):
""" """
Adds the functionality for an object to be gradcheckable. Adds the functionality for an object to be gradcheckable.
@ -644,6 +639,7 @@ class OptimizationHandlable(Constrainable):
else: names = [adjust(x.name) for x in self._parameters_] else: names = [adjust(x.name) for x in self._parameters_]
if add_self: names = map(lambda x: adjust(self.name) + "." + x, names) if add_self: names = map(lambda x: adjust(self.name) + "." + x, names)
return names return names
def _get_param_names(self): def _get_param_names(self):
n = np.array([p.hierarchy_name() + '[' + str(i) + ']' for p in self.flattened_parameters for i in p._indices()]) n = np.array([p.hierarchy_name() + '[' + str(i) + ']' for p in self.flattened_parameters for i in p._indices()])
return n return n
@ -710,12 +706,14 @@ class Parameterizable(OptimizationHandlable):
super(Parameterizable, self).__init__(*args, **kwargs) super(Parameterizable, self).__init__(*args, **kwargs)
from GPy.core.parameterization.lists_and_dicts import ArrayList from GPy.core.parameterization.lists_and_dicts import ArrayList
self._parameters_ = ArrayList() self._parameters_ = ArrayList()
self._param_array_ = None
self.size = 0 self.size = 0
self._added_names_ = set() self._added_names_ = set()
self.__visited = False # for traversing in reverse order we need to know if we were here already
@property @property
def param_array(self): def param_array(self):
if not hasattr(self, '_param_array_'): if self._param_array_ is None:
self._param_array_ = np.empty(self.size, dtype=np.float64) self._param_array_ = np.empty(self.size, dtype=np.float64)
return self._param_array_ return self._param_array_
@ -723,6 +721,42 @@ class Parameterizable(OptimizationHandlable):
def param_array(self, arr): def param_array(self, arr):
self._param_array_ = arr self._param_array_ = arr
def traverse(self, visit, *args, **kwargs):
"""
Traverse the hierarchy performing visit(self, *args, **kwargs) at every node passed by.
See "visitor pattern" in literature. This is implemented in pre-order fashion.
Example:
Collect all children:
children = []
self.traverse(children.append)
print children
"""
if not self.__visited:
visit(self, *args, **kwargs)
self.__visited = True
for c in self._parameters_:
c.traverse(visit, *args, **kwargs)
def traverse_parents(self, visit, *args, **kwargs):
"""
Traverse the hierarchy upwards, visiting all parents and their children.
See "visitor pattern" in literature. This is implemented in pre-order fashion.
Example:
parents = []
self.traverse_parents(parents.append)
print parents
"""
if not self.__visited:
visit(self, *args, **kwargs)
self.__visited = True
if self.has_parent():
self._parent_.traverse_parents(visit, *args, **kwargs)
self._parent_.traverse(visit, *args, **kwargs)
self.__visited = False
#========================================================================= #=========================================================================
# Gradient handling # Gradient handling
#========================================================================= #=========================================================================