[parameterized] init greatly improved

This commit is contained in:
mzwiessele 2014-07-02 08:50:59 -07:00
parent 8a8818aa91
commit 89f7f2dc73
3 changed files with 26 additions and 16 deletions

View file

@ -853,6 +853,9 @@ class OptimizationHandlable(Indexable):
pi._propagate_param_grad(parray[pislice], garray[pislice])
pi_old_size += pi.size
def _connect_parameters(self):
pass
class Parameterizable(OptimizationHandlable):
"""
A parameterisable class.

View file

@ -10,9 +10,15 @@ from parameter_core import HierarchyError, Parameterizable, adjust_name_for_prin
class ParametersChangedMeta(type):
def __call__(self, *args, **kw):
instance = super(ParametersChangedMeta, self).__call__(*args, **kw)
instance.parameters_changed()
return instance
self._in_init_ = True
#import ipdb;ipdb.set_trace()
self = super(ParametersChangedMeta, self).__call__(*args, **kw)
self._in_init_ = False
self._highest_parent_._connect_parameters()
self._highest_parent_._notify_parent_change()
self._highest_parent_._connect_fixes()
self.parameters_changed()
return self
class Parameterized(Parameterizable):
"""
@ -64,14 +70,12 @@ class Parameterized(Parameterizable):
#===========================================================================
def __init__(self, name=None, parameters=[], *a, **kw):
super(Parameterized, self).__init__(name=name, *a, **kw)
self._in_init_ = True
self.size = sum(p.size for p in self.parameters)
self.add_observer(self, self._parameters_changed_notification, -100)
if not self._has_fixes():
self._fixes_ = None
self._param_slices_ = []
self._connect_parameters()
del self._in_init_
#self._connect_parameters()
self.add_parameters(*parameters)
def build_pydot(self, G=None):
@ -146,15 +150,16 @@ class Parameterized(Parameterizable):
parent.size += param.size
parent = parent._parent_
self._connect_parameters()
self._notify_parent_change()
if not self._in_init_:
self._connect_parameters()
self._notify_parent_change()
self._highest_parent_._connect_parameters(ignore_added_names=_ignore_added_names)
self._highest_parent_._notify_parent_change()
self._highest_parent_._connect_fixes()
self._highest_parent_._connect_parameters(ignore_added_names=_ignore_added_names)
self._highest_parent_._notify_parent_change()
self._highest_parent_._connect_fixes()
else:
raise HierarchyError, """Parameter exists already and no copy made"""
raise HierarchyError, """Parameter exists already, try making a copy"""
def add_parameters(self, *parameters):
@ -209,19 +214,21 @@ class Parameterized(Parameterizable):
old_size = 0
self._param_slices_ = []
for i, p in enumerate(self.parameters):
if not p.param_array.flags['C_CONTIGUOUS']:
raise ValueError, "This should not happen! Please write an email to the developers with the code, which reproduces this error. All parameter arrays must be C_CONTIGUOUS"
p._parent_ = self
p._parent_index_ = i
pslice = slice(old_size, old_size + p.size)
# first connect all children
p._propagate_param_grad(self.param_array[pslice], self.gradient_full[pslice])
# then connect children to self
self.param_array[pslice] = p.param_array.flat # , requirements=['C', 'W']).ravel(order='C')
self.gradient_full[pslice] = p.gradient_full.flat # , requirements=['C', 'W']).ravel(order='C')
if not p.param_array.flags['C_CONTIGUOUS']:
raise ValueError, "This should not happen! Please write an email to the developers with the code, which reproduces this error. All parameter arrays must be C_CONTIGUOUS"
p.param_array.data = self.param_array[pslice].data
p.gradient_full.data = self.gradient_full[pslice].data

View file

@ -223,7 +223,7 @@ class VarDTCMissingData(LatentFunctionInference):
inan = self._inan
has_none = True
if has_none:
print "caching missing data slices, this can take several minutes..."
print "caching missing data slices, this can take several minutes depending on the number of unique dimensions of the data..."
from ...util.subarray_and_sorting import common_subarrays
self._subarray_indices = []
csa = common_subarrays(inan, 1)