mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-08 19:42:39 +02:00
[parameterized] init greatly improved
This commit is contained in:
parent
8a8818aa91
commit
89f7f2dc73
3 changed files with 26 additions and 16 deletions
|
|
@ -10,9 +10,15 @@ from parameter_core import HierarchyError, Parameterizable, adjust_name_for_prin
|
|||
|
||||
class ParametersChangedMeta(type):
|
||||
def __call__(self, *args, **kw):
|
||||
instance = super(ParametersChangedMeta, self).__call__(*args, **kw)
|
||||
instance.parameters_changed()
|
||||
return instance
|
||||
self._in_init_ = True
|
||||
#import ipdb;ipdb.set_trace()
|
||||
self = super(ParametersChangedMeta, self).__call__(*args, **kw)
|
||||
self._in_init_ = False
|
||||
self._highest_parent_._connect_parameters()
|
||||
self._highest_parent_._notify_parent_change()
|
||||
self._highest_parent_._connect_fixes()
|
||||
self.parameters_changed()
|
||||
return self
|
||||
|
||||
class Parameterized(Parameterizable):
|
||||
"""
|
||||
|
|
@ -64,14 +70,12 @@ class Parameterized(Parameterizable):
|
|||
#===========================================================================
|
||||
def __init__(self, name=None, parameters=[], *a, **kw):
|
||||
super(Parameterized, self).__init__(name=name, *a, **kw)
|
||||
self._in_init_ = True
|
||||
self.size = sum(p.size for p in self.parameters)
|
||||
self.add_observer(self, self._parameters_changed_notification, -100)
|
||||
if not self._has_fixes():
|
||||
self._fixes_ = None
|
||||
self._param_slices_ = []
|
||||
self._connect_parameters()
|
||||
del self._in_init_
|
||||
#self._connect_parameters()
|
||||
self.add_parameters(*parameters)
|
||||
|
||||
def build_pydot(self, G=None):
|
||||
|
|
@ -146,15 +150,16 @@ class Parameterized(Parameterizable):
|
|||
parent.size += param.size
|
||||
parent = parent._parent_
|
||||
|
||||
self._connect_parameters()
|
||||
self._notify_parent_change()
|
||||
if not self._in_init_:
|
||||
self._connect_parameters()
|
||||
self._notify_parent_change()
|
||||
|
||||
self._highest_parent_._connect_parameters(ignore_added_names=_ignore_added_names)
|
||||
self._highest_parent_._notify_parent_change()
|
||||
self._highest_parent_._connect_fixes()
|
||||
self._highest_parent_._connect_parameters(ignore_added_names=_ignore_added_names)
|
||||
self._highest_parent_._notify_parent_change()
|
||||
self._highest_parent_._connect_fixes()
|
||||
|
||||
else:
|
||||
raise HierarchyError, """Parameter exists already and no copy made"""
|
||||
raise HierarchyError, """Parameter exists already, try making a copy"""
|
||||
|
||||
|
||||
def add_parameters(self, *parameters):
|
||||
|
|
@ -209,19 +214,21 @@ class Parameterized(Parameterizable):
|
|||
old_size = 0
|
||||
self._param_slices_ = []
|
||||
for i, p in enumerate(self.parameters):
|
||||
if not p.param_array.flags['C_CONTIGUOUS']:
|
||||
raise ValueError, "This should not happen! Please write an email to the developers with the code, which reproduces this error. All parameter arrays must be C_CONTIGUOUS"
|
||||
|
||||
p._parent_ = self
|
||||
p._parent_index_ = i
|
||||
|
||||
pslice = slice(old_size, old_size + p.size)
|
||||
|
||||
# first connect all children
|
||||
p._propagate_param_grad(self.param_array[pslice], self.gradient_full[pslice])
|
||||
|
||||
# then connect children to self
|
||||
self.param_array[pslice] = p.param_array.flat # , requirements=['C', 'W']).ravel(order='C')
|
||||
self.gradient_full[pslice] = p.gradient_full.flat # , requirements=['C', 'W']).ravel(order='C')
|
||||
|
||||
if not p.param_array.flags['C_CONTIGUOUS']:
|
||||
raise ValueError, "This should not happen! Please write an email to the developers with the code, which reproduces this error. All parameter arrays must be C_CONTIGUOUS"
|
||||
|
||||
p.param_array.data = self.param_array[pslice].data
|
||||
p.gradient_full.data = self.gradient_full[pslice].data
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue