[copy] is now fully functional, only hierarchy observers will be copied and pickled

This commit is contained in:
mzwiessele 2014-05-23 11:11:21 +01:00
parent 96e1e13f7e
commit 3525e45b2f
9 changed files with 89 additions and 59 deletions

View file

@ -276,5 +276,9 @@ class GP(Model):
TODO: valid args
"""
self.inference_method.on_optimization_start()
super(GP, self).optimize(optimizer, start, **kwargs)
self.inference_method.on_optimization_end()
try:
super(GP, self).optimize(optimizer, start, **kwargs)
except KeyboardInterrupt:
print "KeyboardInterrupt caught, calling on_optimization_end() to round things up"
self.inference_method.on_optimization_end()
raise

View file

@ -77,8 +77,18 @@ class ObserverList(object):
self._poc.insert(ins, (priority, weakref.ref(observer), callble))
def __str__(self):
from . import ObsAr, Param
from parameter_core import Parameterizable
ret = []
curr_p = None
def frmt(o):
if isinstance(o, ObsAr):
return 'ObsArr <{}>'.format(hex(id(o)))
elif isinstance(o, (Param,Parameterizable)):
return '{}'.format(o.hierarchy_name())
else:
return repr(o)
for p, o, c in self:
curr = ''
if curr_p != p:
@ -87,8 +97,9 @@ class ObserverList(object):
else: curr_pre = " "*len(pre)
curr_p = p
curr += curr_pre
ret.append(curr + ", ".join(map(repr, [o,c])))
return '\n'.join(ret)
ret.append(curr + ", ".join([frmt(o), str(c)]))
return '\n'.join(ret)
def flush(self):
"""

View file

@ -30,9 +30,15 @@ class ObsAr(np.ndarray, Pickleable, Observable):
def __array_wrap__(self, out_arr, context=None):
return out_arr.view(np.ndarray)
def _setup_observers(self):
# do not setup anything, as observable arrays do not have default observers
pass
def copy(self):
from lists_and_dicts import ObserverList
memo = {}
memo[id(self)] = self
memo[id(self.observers)] = ObserverList()
return self.__deepcopy__(memo)
def __deepcopy__(self, memo):

View file

@ -173,36 +173,6 @@ class Param(Parameterizable, ObsAr):
def _ensure_fixes(self):
if not self._has_fixes(): self._fixes_ = numpy.ones(self._realsize_, dtype=bool)
#===========================================================================
# parameterizable
#===========================================================================
def traverse(self, visit, *args, **kwargs):
"""
Traverse the hierarchy performing visit(self, *args, **kwargs) at every node passed by.
See "visitor pattern" in literature. This is implemented in pre-order fashion.
This will function will just call visit on self, as Param are leaf nodes.
"""
self.__visited = True
visit(self, *args, **kwargs)
self.__visited = False
def traverse_parents(self, visit, *args, **kwargs):
"""
Traverse the hierarchy upwards, visiting all parents and their children, except self.
See "visitor pattern" in literature. This is implemented in pre-order fashion.
Example:
parents = []
self.traverse_parents(parents.append)
print parents
"""
if self.has_parent():
self.__visited = True
self._parent_._traverse_parents(visit, *args, **kwargs)
self.__visited = False
#===========================================================================
# Convenience
#===========================================================================
@ -217,13 +187,24 @@ class Param(Parameterizable, ObsAr):
#===========================================================================
# Pickling and copying
#===========================================================================
def copy(self):
return Parameterizable.copy(self, which=self)
def __deepcopy__(self, memo):
s = self.__new__(self.__class__, name=self.name, input_array=self.view(numpy.ndarray).copy())
memo[id(self)] = s
memo[id(self)] = s
import copy
Pickleable.__setstate__(s, copy.deepcopy(self.__getstate__(), memo))
return s
def _setup_observers(self):
"""
Setup the default observers
1: pass through to parent, if present
"""
if self.has_parent():
self.add_observer(self._parent_, self._parent_._pass_through_notify_observers, -np.inf)
#===========================================================================
# Printing -> done
#===========================================================================

View file

@ -92,7 +92,7 @@ class Observable(object):
for poc in self.observers:
_, obs, clble = poc
if callble is not None:
if (obs == observer) and (callble == clble):
if (obs is observer) and (callble == clble):
to_remove.append(poc)
else:
if obs is observer:
@ -199,28 +199,32 @@ class Pickleable(object):
#===========================================================================
# copy and pickling
#===========================================================================
def copy(self):
def copy(self, memo=None, which=None):
"""
Returns a (deep) copy of the current parameter handle.
All connections to parents of the copy will be cut.
:param dict memo: memo for deepcopy
:param Parameterized which: parameterized object which started the copy process [default: self]
"""
#raise NotImplementedError, "Copy is not yet implemented, TODO: Observable hierarchy"
if memo is None:
memo = {}
import copy
memo = {}
# the next part makes sure that we do not include parents in any form:
parents = []
self.traverse_parents(parents.append) # collect parents
if which is None:
which = self
which.traverse_parents(parents.append) # collect parents
for p in parents:
memo[id(p)] = None # set all parents to be None, so they will not be copied
memo[id(self.gradient)] = None # reset the gradient
memo[id(self.param_array)] = None # and param_array
memo[id(self.optimizer_array)] = None # and param_array
memo[id(self._fixes_)] = None # fixes have to be reset, as this is now highest parent
c = copy.deepcopy(self, memo) # and start the copy
c._parent_index_ = None
c._trigger_params_changed()
return c
if not memo.has_key(id(p)):memo[id(p)] = None # set all parents to be None, so they will not be copied
if not memo.has_key(id(self.gradient)):memo[id(self.gradient)] = None # reset the gradient
if not memo.has_key(id(self._fixes_)):memo[id(self._fixes_)] = None # fixes have to be reset, as this is now highest parent
copy = copy.deepcopy(self, memo) # and start the copy
copy._parent_index_ = None
copy._trigger_params_changed()
return copy
def __deepcopy__(self, memo):
s = self.__new__(self.__class__) # fresh instance
@ -234,6 +238,7 @@ class Pickleable(object):
'_gradient_array_', # as well as gradients
'_optimizer_copy_',
'logger',
'observers',
'_fixes_', # and fixes
'_Cacher_wrap__cachers', # never pickle cachers
]
@ -245,6 +250,9 @@ class Pickleable(object):
def __setstate__(self, state):
self.__dict__.update(state)
from lists_and_dicts import ObserverList
self.observers = ObserverList()
self._setup_observers()
self._optimizer_copy_transformed = False
@ -982,7 +990,16 @@ class Parameterizable(OptimizationHandlable):
self.parameters_changed()
def _pass_through_notify_observers(self, me, which=None):
self.notify_observers(which=which)
def _setup_observers(self):
"""
Setup the default observers
1: parameters_changed_notify
2: pass through to parent, if present
"""
self.add_observer(self, self._parameters_changed_notification, -100)
if self.has_parent():
self.add_observer(self._parent_, self._parent_._pass_through_notify_observers, -np.inf)
#===========================================================================
# From being parentable, we have to define the parent_change notification
#===========================================================================

View file

@ -292,12 +292,16 @@ class Parameterized(Parameterizable):
except Exception as e:
print "WARNING: caught exception {!s}, trying to continue".format(e)
def copy(self):
c = super(Parameterized, self).copy()
c._connect_parameters()
c._connect_fixes()
c._notify_parent_change()
return c
def copy(self, memo=None):
if memo is None:
memo = {}
memo[id(self.optimizer_array)] = None # and param_array
memo[id(self.param_array)] = None # and param_array
copy = super(Parameterized, self).copy(memo)
copy._connect_parameters()
copy._connect_fixes()
copy._notify_parent_change()
return copy
#===========================================================================
# Printing: