[parameterized] init greatly improved

This commit is contained in:
mzwiessele 2014-07-02 08:50:59 -07:00
parent 8a8818aa91
commit 89f7f2dc73
3 changed files with 26 additions and 16 deletions

View file

@ -10,9 +10,15 @@ from parameter_core import HierarchyError, Parameterizable, adjust_name_for_prin
class ParametersChangedMeta(type):
def __call__(self, *args, **kw):
instance = super(ParametersChangedMeta, self).__call__(*args, **kw)
instance.parameters_changed()
return instance
self._in_init_ = True
#import ipdb;ipdb.set_trace()
self = super(ParametersChangedMeta, self).__call__(*args, **kw)
self._in_init_ = False
self._highest_parent_._connect_parameters()
self._highest_parent_._notify_parent_change()
self._highest_parent_._connect_fixes()
self.parameters_changed()
return self
class Parameterized(Parameterizable):
"""
@ -64,14 +70,12 @@ class Parameterized(Parameterizable):
#===========================================================================
def __init__(self, name=None, parameters=[], *a, **kw):
super(Parameterized, self).__init__(name=name, *a, **kw)
self._in_init_ = True
self.size = sum(p.size for p in self.parameters)
self.add_observer(self, self._parameters_changed_notification, -100)
if not self._has_fixes():
self._fixes_ = None
self._param_slices_ = []
self._connect_parameters()
del self._in_init_
#self._connect_parameters()
self.add_parameters(*parameters)
def build_pydot(self, G=None):
@ -146,15 +150,16 @@ class Parameterized(Parameterizable):
parent.size += param.size
parent = parent._parent_
self._connect_parameters()
self._notify_parent_change()
if not self._in_init_:
self._connect_parameters()
self._notify_parent_change()
self._highest_parent_._connect_parameters(ignore_added_names=_ignore_added_names)
self._highest_parent_._notify_parent_change()
self._highest_parent_._connect_fixes()
self._highest_parent_._connect_parameters(ignore_added_names=_ignore_added_names)
self._highest_parent_._notify_parent_change()
self._highest_parent_._connect_fixes()
else:
raise HierarchyError, """Parameter exists already and no copy made"""
raise HierarchyError, """Parameter exists already, try making a copy"""
def add_parameters(self, *parameters):
@ -209,19 +214,21 @@ class Parameterized(Parameterizable):
old_size = 0
self._param_slices_ = []
for i, p in enumerate(self.parameters):
if not p.param_array.flags['C_CONTIGUOUS']:
raise ValueError, "This should not happen! Please write an email to the developers with the code, which reproduces this error. All parameter arrays must be C_CONTIGUOUS"
p._parent_ = self
p._parent_index_ = i
pslice = slice(old_size, old_size + p.size)
# first connect all children
p._propagate_param_grad(self.param_array[pslice], self.gradient_full[pslice])
# then connect children to self
self.param_array[pslice] = p.param_array.flat # , requirements=['C', 'W']).ravel(order='C')
self.gradient_full[pslice] = p.gradient_full.flat # , requirements=['C', 'W']).ravel(order='C')
if not p.param_array.flags['C_CONTIGUOUS']:
raise ValueError, "This should not happen! Please write an email to the developers with the code, which reproduces this error. All parameter arrays must be C_CONTIGUOUS"
p.param_array.data = self.param_array[pslice].data
p.gradient_full.data = self.gradient_full[pslice].data