mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-09 20:12:38 +02:00
[parameterized] init greatly improved
This commit is contained in:
parent
8a8818aa91
commit
89f7f2dc73
3 changed files with 26 additions and 16 deletions
|
|
@ -853,6 +853,9 @@ class OptimizationHandlable(Indexable):
|
||||||
pi._propagate_param_grad(parray[pislice], garray[pislice])
|
pi._propagate_param_grad(parray[pislice], garray[pislice])
|
||||||
pi_old_size += pi.size
|
pi_old_size += pi.size
|
||||||
|
|
||||||
|
def _connect_parameters(self):
|
||||||
|
pass
|
||||||
|
|
||||||
class Parameterizable(OptimizationHandlable):
|
class Parameterizable(OptimizationHandlable):
|
||||||
"""
|
"""
|
||||||
A parameterisable class.
|
A parameterisable class.
|
||||||
|
|
|
||||||
|
|
@ -10,9 +10,15 @@ from parameter_core import HierarchyError, Parameterizable, adjust_name_for_prin
|
||||||
|
|
||||||
class ParametersChangedMeta(type):
|
class ParametersChangedMeta(type):
|
||||||
def __call__(self, *args, **kw):
|
def __call__(self, *args, **kw):
|
||||||
instance = super(ParametersChangedMeta, self).__call__(*args, **kw)
|
self._in_init_ = True
|
||||||
instance.parameters_changed()
|
#import ipdb;ipdb.set_trace()
|
||||||
return instance
|
self = super(ParametersChangedMeta, self).__call__(*args, **kw)
|
||||||
|
self._in_init_ = False
|
||||||
|
self._highest_parent_._connect_parameters()
|
||||||
|
self._highest_parent_._notify_parent_change()
|
||||||
|
self._highest_parent_._connect_fixes()
|
||||||
|
self.parameters_changed()
|
||||||
|
return self
|
||||||
|
|
||||||
class Parameterized(Parameterizable):
|
class Parameterized(Parameterizable):
|
||||||
"""
|
"""
|
||||||
|
|
@ -64,14 +70,12 @@ class Parameterized(Parameterizable):
|
||||||
#===========================================================================
|
#===========================================================================
|
||||||
def __init__(self, name=None, parameters=[], *a, **kw):
|
def __init__(self, name=None, parameters=[], *a, **kw):
|
||||||
super(Parameterized, self).__init__(name=name, *a, **kw)
|
super(Parameterized, self).__init__(name=name, *a, **kw)
|
||||||
self._in_init_ = True
|
|
||||||
self.size = sum(p.size for p in self.parameters)
|
self.size = sum(p.size for p in self.parameters)
|
||||||
self.add_observer(self, self._parameters_changed_notification, -100)
|
self.add_observer(self, self._parameters_changed_notification, -100)
|
||||||
if not self._has_fixes():
|
if not self._has_fixes():
|
||||||
self._fixes_ = None
|
self._fixes_ = None
|
||||||
self._param_slices_ = []
|
self._param_slices_ = []
|
||||||
self._connect_parameters()
|
#self._connect_parameters()
|
||||||
del self._in_init_
|
|
||||||
self.add_parameters(*parameters)
|
self.add_parameters(*parameters)
|
||||||
|
|
||||||
def build_pydot(self, G=None):
|
def build_pydot(self, G=None):
|
||||||
|
|
@ -146,15 +150,16 @@ class Parameterized(Parameterizable):
|
||||||
parent.size += param.size
|
parent.size += param.size
|
||||||
parent = parent._parent_
|
parent = parent._parent_
|
||||||
|
|
||||||
self._connect_parameters()
|
if not self._in_init_:
|
||||||
self._notify_parent_change()
|
self._connect_parameters()
|
||||||
|
self._notify_parent_change()
|
||||||
|
|
||||||
self._highest_parent_._connect_parameters(ignore_added_names=_ignore_added_names)
|
self._highest_parent_._connect_parameters(ignore_added_names=_ignore_added_names)
|
||||||
self._highest_parent_._notify_parent_change()
|
self._highest_parent_._notify_parent_change()
|
||||||
self._highest_parent_._connect_fixes()
|
self._highest_parent_._connect_fixes()
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise HierarchyError, """Parameter exists already and no copy made"""
|
raise HierarchyError, """Parameter exists already, try making a copy"""
|
||||||
|
|
||||||
|
|
||||||
def add_parameters(self, *parameters):
|
def add_parameters(self, *parameters):
|
||||||
|
|
@ -209,19 +214,21 @@ class Parameterized(Parameterizable):
|
||||||
old_size = 0
|
old_size = 0
|
||||||
self._param_slices_ = []
|
self._param_slices_ = []
|
||||||
for i, p in enumerate(self.parameters):
|
for i, p in enumerate(self.parameters):
|
||||||
|
if not p.param_array.flags['C_CONTIGUOUS']:
|
||||||
|
raise ValueError, "This should not happen! Please write an email to the developers with the code, which reproduces this error. All parameter arrays must be C_CONTIGUOUS"
|
||||||
|
|
||||||
p._parent_ = self
|
p._parent_ = self
|
||||||
p._parent_index_ = i
|
p._parent_index_ = i
|
||||||
|
|
||||||
pslice = slice(old_size, old_size + p.size)
|
pslice = slice(old_size, old_size + p.size)
|
||||||
|
|
||||||
# first connect all children
|
# first connect all children
|
||||||
p._propagate_param_grad(self.param_array[pslice], self.gradient_full[pslice])
|
p._propagate_param_grad(self.param_array[pslice], self.gradient_full[pslice])
|
||||||
|
|
||||||
# then connect children to self
|
# then connect children to self
|
||||||
self.param_array[pslice] = p.param_array.flat # , requirements=['C', 'W']).ravel(order='C')
|
self.param_array[pslice] = p.param_array.flat # , requirements=['C', 'W']).ravel(order='C')
|
||||||
self.gradient_full[pslice] = p.gradient_full.flat # , requirements=['C', 'W']).ravel(order='C')
|
self.gradient_full[pslice] = p.gradient_full.flat # , requirements=['C', 'W']).ravel(order='C')
|
||||||
|
|
||||||
if not p.param_array.flags['C_CONTIGUOUS']:
|
|
||||||
raise ValueError, "This should not happen! Please write an email to the developers with the code, which reproduces this error. All parameter arrays must be C_CONTIGUOUS"
|
|
||||||
|
|
||||||
p.param_array.data = self.param_array[pslice].data
|
p.param_array.data = self.param_array[pslice].data
|
||||||
p.gradient_full.data = self.gradient_full[pslice].data
|
p.gradient_full.data = self.gradient_full[pslice].data
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -223,7 +223,7 @@ class VarDTCMissingData(LatentFunctionInference):
|
||||||
inan = self._inan
|
inan = self._inan
|
||||||
has_none = True
|
has_none = True
|
||||||
if has_none:
|
if has_none:
|
||||||
print "caching missing data slices, this can take several minutes..."
|
print "caching missing data slices, this can take several minutes depending on the number of unique dimensions of the data..."
|
||||||
from ...util.subarray_and_sorting import common_subarrays
|
from ...util.subarray_and_sorting import common_subarrays
|
||||||
self._subarray_indices = []
|
self._subarray_indices = []
|
||||||
csa = common_subarrays(inan, 1)
|
csa = common_subarrays(inan, 1)
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue