Dont call parameters_changed ever yourself anymore and parameters are now inplace once in memory

This commit is contained in:
Max Zwiessele 2014-03-04 17:32:46 +00:00
parent 56d749ded8
commit 0df263956f
21 changed files with 601 additions and 284 deletions

View file

@ -11,6 +11,12 @@ from parameter_core import Pickleable, Parameterizable, adjust_name_for_printing
from transformations import __fixed__
from lists_and_dicts import ArrayList
class ParametersChangedMeta(type):
def __call__(self, *args, **kw):
instance = super(ParametersChangedMeta, self).__call__(*args, **kw)
instance.parameters_changed()
return instance
class Parameterized(Parameterizable, Pickleable, Gradcheckable):
"""
Parameterized class
@ -53,6 +59,12 @@ class Parameterized(Parameterizable, Pickleable, Gradcheckable):
If you want to operate on all parameters use m[''] to wildcard select all paramters
and concatenate them. Printing m[''] will result in printing of all parameters in detail.
"""
#===========================================================================
# Metaclass for parameters changed after init.
# This makes sure, that parameters changed will always be called after __init__
# **Never** call parameters_changed() yourself
__metaclass__ = ParametersChangedMeta
#===========================================================================
def __init__(self, name=None, *a, **kw):
super(Parameterized, self).__init__(name=name, parent=None, parent_index=None, *a, **kw)
self._in_init_ = True
@ -88,39 +100,7 @@ class Parameterized(Parameterizable, Pickleable, Gradcheckable):
return G
return node
#===========================================================================
# Pickling operations
#===========================================================================
def pickle(self, f, protocol=-1):
"""
:param f: either filename or open file object to write to.
if it is an open buffer, you have to make sure to close
it properly.
:param protocol: pickling protocol to use, python-pickle for details.
"""
if isinstance(f, str):
with open(f, 'w') as f:
cPickle.dump(self, f, protocol)
else:
cPickle.dump(self, f, protocol)
def copy(self):
c = super(Parameterized, self).copy()
c.add_observer(c, c._parameters_changed_notification, -100)
return c
def __getstate__(self):
if self._has_get_set_state():
return self._getstate()
return self.__dict__
def __setstate__(self, state):
if self._has_get_set_state():
self._setstate(state) # set state
# self._set_params(self._get_params()) # restore all values
return
self.__dict__ = state
def _has_get_set_state(self):
return '_getstate' in vars(self.__class__) and '_setstate' in vars(self.__class__)
def _getstate(self):
"""
Get the current state of the class,
@ -149,25 +129,33 @@ class Parameterized(Parameterizable, Pickleable, Gradcheckable):
self._connect_parameters()
self.parameters_changed()
#===========================================================================
# Override copy to handle programmatically added observers
#===========================================================================
def copy(self):
c = super(Pickleable, self).copy()
c.add_observer(c, c._parameters_changed_notification, -100)
return c
#===========================================================================
# Gradient control
#===========================================================================
def _transform_gradients(self, g):
if self.has_parent():
return g
x = self._get_params()
[numpy.put(g, i, g[i] * c.gradfactor(x[i])) for c, i in self.constraints.iteritems() if c != __fixed__]
for p in self.flattened_parameters:
for t, i in p._tied_to_me_.iteritems():
g[self._offset_for(p) + numpy.array(list(i))] += g[self._raveled_index_for(t)]
[numpy.put(g, i, g[i] * c.gradfactor(self._param_array_[i])) for c, i in self.constraints.iteritems() if c != __fixed__]
if self._has_fixes(): return g[self._fixes_]
return g
#===========================================================================
# Indexable
#===========================================================================
def _offset_for(self, param):
# get the offset in the parameterized index array for param
if param.has_parent():
if param._direct_parent_._get_original(param) in self._parameters_:
return self._param_slices_[param._direct_parent_._get_original(param)._parent_index_].start
return self._offset_for(param._direct_parent_) + param._direct_parent_._offset_for(param)
if param._parent_._get_original(param) in self._parameters_:
return self._param_slices_[param._parent_._get_original(param)._parent_index_].start
return self._offset_for(param._parent_) + param._parent_._offset_for(param)
return 0
def _raveled_index_for(self, param):
@ -229,8 +217,8 @@ class Parameterized(Parameterizable, Pickleable, Gradcheckable):
return ParamConcatenation(paramlist)
def __setitem__(self, name, value, paramlist=None):
if isinstance(name, slice):
self[''][name] = value
if isinstance(name, (slice, tuple, np.ndarray)):
self._param_array_[name] = value
else:
try: param = self.__getitem__(name, paramlist)
except AttributeError as a: raise a