[parameterized] init greatly improved

This commit is contained in:
mzwiessele 2014-07-02 08:50:59 -07:00
parent 8a8818aa91
commit 89f7f2dc73
3 changed files with 26 additions and 16 deletions

View file

@ -853,6 +853,9 @@ class OptimizationHandlable(Indexable):
pi._propagate_param_grad(parray[pislice], garray[pislice]) pi._propagate_param_grad(parray[pislice], garray[pislice])
pi_old_size += pi.size pi_old_size += pi.size
def _connect_parameters(self):
pass
class Parameterizable(OptimizationHandlable): class Parameterizable(OptimizationHandlable):
""" """
A parameterisable class. A parameterisable class.

View file

@ -10,9 +10,15 @@ from parameter_core import HierarchyError, Parameterizable, adjust_name_for_prin
class ParametersChangedMeta(type): class ParametersChangedMeta(type):
def __call__(self, *args, **kw): def __call__(self, *args, **kw):
instance = super(ParametersChangedMeta, self).__call__(*args, **kw) self._in_init_ = True
instance.parameters_changed() #import ipdb;ipdb.set_trace()
return instance self = super(ParametersChangedMeta, self).__call__(*args, **kw)
self._in_init_ = False
self._highest_parent_._connect_parameters()
self._highest_parent_._notify_parent_change()
self._highest_parent_._connect_fixes()
self.parameters_changed()
return self
class Parameterized(Parameterizable): class Parameterized(Parameterizable):
""" """
@ -64,14 +70,12 @@ class Parameterized(Parameterizable):
#=========================================================================== #===========================================================================
def __init__(self, name=None, parameters=[], *a, **kw): def __init__(self, name=None, parameters=[], *a, **kw):
super(Parameterized, self).__init__(name=name, *a, **kw) super(Parameterized, self).__init__(name=name, *a, **kw)
self._in_init_ = True
self.size = sum(p.size for p in self.parameters) self.size = sum(p.size for p in self.parameters)
self.add_observer(self, self._parameters_changed_notification, -100) self.add_observer(self, self._parameters_changed_notification, -100)
if not self._has_fixes(): if not self._has_fixes():
self._fixes_ = None self._fixes_ = None
self._param_slices_ = [] self._param_slices_ = []
self._connect_parameters() #self._connect_parameters()
del self._in_init_
self.add_parameters(*parameters) self.add_parameters(*parameters)
def build_pydot(self, G=None): def build_pydot(self, G=None):
@ -146,15 +150,16 @@ class Parameterized(Parameterizable):
parent.size += param.size parent.size += param.size
parent = parent._parent_ parent = parent._parent_
self._connect_parameters() if not self._in_init_:
self._notify_parent_change() self._connect_parameters()
self._notify_parent_change()
self._highest_parent_._connect_parameters(ignore_added_names=_ignore_added_names) self._highest_parent_._connect_parameters(ignore_added_names=_ignore_added_names)
self._highest_parent_._notify_parent_change() self._highest_parent_._notify_parent_change()
self._highest_parent_._connect_fixes() self._highest_parent_._connect_fixes()
else: else:
raise HierarchyError, """Parameter exists already and no copy made""" raise HierarchyError, """Parameter exists already, try making a copy"""
def add_parameters(self, *parameters): def add_parameters(self, *parameters):
@ -209,19 +214,21 @@ class Parameterized(Parameterizable):
old_size = 0 old_size = 0
self._param_slices_ = [] self._param_slices_ = []
for i, p in enumerate(self.parameters): for i, p in enumerate(self.parameters):
if not p.param_array.flags['C_CONTIGUOUS']:
raise ValueError, "This should not happen! Please write an email to the developers with the code, which reproduces this error. All parameter arrays must be C_CONTIGUOUS"
p._parent_ = self p._parent_ = self
p._parent_index_ = i p._parent_index_ = i
pslice = slice(old_size, old_size + p.size) pslice = slice(old_size, old_size + p.size)
# first connect all children # first connect all children
p._propagate_param_grad(self.param_array[pslice], self.gradient_full[pslice]) p._propagate_param_grad(self.param_array[pslice], self.gradient_full[pslice])
# then connect children to self # then connect children to self
self.param_array[pslice] = p.param_array.flat # , requirements=['C', 'W']).ravel(order='C') self.param_array[pslice] = p.param_array.flat # , requirements=['C', 'W']).ravel(order='C')
self.gradient_full[pslice] = p.gradient_full.flat # , requirements=['C', 'W']).ravel(order='C') self.gradient_full[pslice] = p.gradient_full.flat # , requirements=['C', 'W']).ravel(order='C')
if not p.param_array.flags['C_CONTIGUOUS']:
raise ValueError, "This should not happen! Please write an email to the developers with the code, which reproduces this error. All parameter arrays must be C_CONTIGUOUS"
p.param_array.data = self.param_array[pslice].data p.param_array.data = self.param_array[pslice].data
p.gradient_full.data = self.gradient_full[pslice].data p.gradient_full.data = self.gradient_full[pslice].data

View file

@ -223,7 +223,7 @@ class VarDTCMissingData(LatentFunctionInference):
inan = self._inan inan = self._inan
has_none = True has_none = True
if has_none: if has_none:
print "caching missing data slices, this can take several minutes..." print "caching missing data slices, this can take several minutes depending on the number of unique dimensions of the data..."
from ...util.subarray_and_sorting import common_subarrays from ...util.subarray_and_sorting import common_subarrays
self._subarray_indices = [] self._subarray_indices = []
csa = common_subarrays(inan, 1) csa = common_subarrays(inan, 1)