[parameterized] handle updates inside init

This commit is contained in:
Max Zwiessele 2014-11-03 14:19:38 +00:00
parent d0a5420f2f
commit 6242a75f6a
3 changed files with 23 additions and 18 deletions

View file

@ -221,8 +221,6 @@ class ParameterIndexOperationsView(object):
def shift_left(self, start, size): def shift_left(self, start, size):
self._param_index_ops.shift_left(start+self._offset, size) self._param_index_ops.shift_left(start+self._offset, size)
self._offset -= size
self._size -= size
def clear(self): def clear(self):
for i, ind in self.items(): for i, ind in self.items():

View file

@ -18,7 +18,7 @@ import numpy as np
import re import re
import logging import logging
__updated__ = '2014-10-28' __updated__ = '2014-11-03'
class HierarchyError(Exception): class HierarchyError(Exception):
""" """
@ -924,7 +924,7 @@ class Parameterizable(OptimizationHandlable):
!WARNING!: setting the parameter array MUST always be done in memory: !WARNING!: setting the parameter array MUST always be done in memory:
m.param_array[:] = m_copy.param_array m.param_array[:] = m_copy.param_array
""" """
if self.__dict__.get('_param_array_', None) is None: if (self.__dict__.get('_param_array_', None) is None) or (self._param_array_.size != self.size):
self._param_array_ = np.empty(self.size, dtype=np.float64) self._param_array_ = np.empty(self.size, dtype=np.float64)
return self._param_array_ return self._param_array_
@ -1002,7 +1002,7 @@ class Parameterizable(OptimizationHandlable):
#========================================================================= #=========================================================================
@property @property
def gradient(self): def gradient(self):
if self.__dict__.get('_gradient_array_', None) is None: if (self.__dict__.get('_gradient_array_', None) is None) or self._gradient_array_.size != self.size:
self._gradient_array_ = np.empty(self.size, dtype=np.float64) self._gradient_array_ = np.empty(self.size, dtype=np.float64)
return self._gradient_array_ return self._gradient_array_

View file

@ -9,6 +9,7 @@ from param import ParamConcatenation
from parameter_core import HierarchyError, Parameterizable, adjust_name_for_printing from parameter_core import HierarchyError, Parameterizable, adjust_name_for_printing
import logging import logging
from GPy.core.parameterization.index_operations import ParameterIndexOperationsView
logger = logging.getLogger("parameters changed meta") logger = logging.getLogger("parameters changed meta")
class ParametersChangedMeta(type): class ParametersChangedMeta(type):
@ -20,7 +21,7 @@ class ParametersChangedMeta(type):
self._in_init_ = False self._in_init_ = False
logger.debug("connecting parameters") logger.debug("connecting parameters")
self._highest_parent_._connect_parameters() self._highest_parent_._connect_parameters()
self._highest_parent_._notify_parent_change() #self._highest_parent_._notify_parent_change()
self._highest_parent_._connect_fixes() self._highest_parent_._connect_fixes()
logger.debug("calling parameters changed") logger.debug("calling parameters changed")
self.parameters_changed() self.parameters_changed()
@ -140,6 +141,8 @@ class Parameterized(Parameterizable):
self.priors.shift_right(start, param.size) self.priors.shift_right(start, param.size)
self.constraints.update(param.constraints, self.size) self.constraints.update(param.constraints, self.size)
self.priors.update(param.priors, self.size) self.priors.update(param.priors, self.size)
param._parent_ = self
param._parent_index_ = len(self.parameters)
self.parameters.append(param) self.parameters.append(param)
else: else:
start = sum(p.size for p in self.parameters[:index]) start = sum(p.size for p in self.parameters[:index])
@ -147,19 +150,23 @@ class Parameterized(Parameterizable):
self.priors.shift_right(start, param.size) self.priors.shift_right(start, param.size)
self.constraints.update(param.constraints, start) self.constraints.update(param.constraints, start)
self.priors.update(param.priors, start) self.priors.update(param.priors, start)
param._parent_ = self
param._parent_index_ = index if index>=0 else len(self.parameters[:index])
for p in self.parameters[index:]:
p._parent_index_ += 1
self.parameters.insert(index, param) self.parameters.insert(index, param)
self._notify_parent_change()
param.add_observer(self, self._pass_through_notify_observers, -np.inf) param.add_observer(self, self._pass_through_notify_observers, -np.inf)
parent = self parent = self
while parent is not None: while parent is not None:
parent.size += param.size parent.size += param.size
parent = parent._parent_ parent = parent._parent_
self._notify_parent_change()
if not self._in_init_: if not self._in_init_:
self._connect_parameters() #self._connect_parameters()
self._notify_parent_change() #self._notify_parent_change()
self._highest_parent_._connect_parameters(ignore_added_names=_ignore_added_names) self._highest_parent_._connect_parameters(ignore_added_names=_ignore_added_names)
self._highest_parent_._notify_parent_change() self._highest_parent_._notify_parent_change()