[parameterized] handle updates inside init

This commit is contained in:
Max Zwiessele 2014-11-03 14:19:38 +00:00
parent d0a5420f2f
commit 6242a75f6a
3 changed files with 23 additions and 18 deletions

View file

@ -26,39 +26,39 @@ class ParameterIndexOperations(object):
This object wraps a dictionary, whos keys are _operations_ that we'd like This object wraps a dictionary, whos keys are _operations_ that we'd like
to apply to a parameter array, and whose values are np integer arrays which to apply to a parameter array, and whose values are np integer arrays which
index the parameter array appropriately. index the parameter array appropriately.
A model instance will contain one instance of this class for each thing A model instance will contain one instance of this class for each thing
that needs indexing (i.e. constraints, ties and priors). Parameters within that needs indexing (i.e. constraints, ties and priors). Parameters within
the model constain instances of the ParameterIndexOperationsView class, the model constain instances of the ParameterIndexOperationsView class,
which can map from a 'local' index (starting 0) to this global index. which can map from a 'local' index (starting 0) to this global index.
Here's an illustration: Here's an illustration:
#======================================================================= #=======================================================================
model : 0 1 2 3 4 5 6 7 8 9 model : 0 1 2 3 4 5 6 7 8 9
key1: 4 5 key1: 4 5
key2: 7 8 key2: 7 8
param1: 0 1 2 3 4 5 param1: 0 1 2 3 4 5
key1: 2 3 key1: 2 3
key2: 5 key2: 5
param2: 0 1 2 3 4 param2: 0 1 2 3 4
key1: 0 key1: 0
key2: 2 3 key2: 2 3
#======================================================================= #=======================================================================
The views of this global index have a subset of the keys in this global The views of this global index have a subset of the keys in this global
(model) index. (model) index.
Adding a new key (e.g. a constraint) to a view will cause the view to pass Adding a new key (e.g. a constraint) to a view will cause the view to pass
the new key to the global index, along with the local index and an offset. the new key to the global index, along with the local index and an offset.
This global index then stores the key and the appropriate global index This global index then stores the key and the appropriate global index
(which can be seen by the view). (which can be seen by the view).
See also: See also:
ParameterIndexOperationsView ParameterIndexOperationsView
""" """
_offset = 0 _offset = 0
def __init__(self, constraints=None): def __init__(self, constraints=None):
@ -221,8 +221,6 @@ class ParameterIndexOperationsView(object):
def shift_left(self, start, size): def shift_left(self, start, size):
self._param_index_ops.shift_left(start+self._offset, size) self._param_index_ops.shift_left(start+self._offset, size)
self._offset -= size
self._size -= size
def clear(self): def clear(self):
for i, ind in self.items(): for i, ind in self.items():

View file

@ -18,7 +18,7 @@ import numpy as np
import re import re
import logging import logging
__updated__ = '2014-10-28' __updated__ = '2014-11-03'
class HierarchyError(Exception): class HierarchyError(Exception):
""" """
@ -924,7 +924,7 @@ class Parameterizable(OptimizationHandlable):
!WARNING!: setting the parameter array MUST always be done in memory: !WARNING!: setting the parameter array MUST always be done in memory:
m.param_array[:] = m_copy.param_array m.param_array[:] = m_copy.param_array
""" """
if self.__dict__.get('_param_array_', None) is None: if (self.__dict__.get('_param_array_', None) is None) or (self._param_array_.size != self.size):
self._param_array_ = np.empty(self.size, dtype=np.float64) self._param_array_ = np.empty(self.size, dtype=np.float64)
return self._param_array_ return self._param_array_
@ -1002,7 +1002,7 @@ class Parameterizable(OptimizationHandlable):
#========================================================================= #=========================================================================
@property @property
def gradient(self): def gradient(self):
if self.__dict__.get('_gradient_array_', None) is None: if (self.__dict__.get('_gradient_array_', None) is None) or self._gradient_array_.size != self.size:
self._gradient_array_ = np.empty(self.size, dtype=np.float64) self._gradient_array_ = np.empty(self.size, dtype=np.float64)
return self._gradient_array_ return self._gradient_array_

View file

@ -9,6 +9,7 @@ from param import ParamConcatenation
from parameter_core import HierarchyError, Parameterizable, adjust_name_for_printing from parameter_core import HierarchyError, Parameterizable, adjust_name_for_printing
import logging import logging
from GPy.core.parameterization.index_operations import ParameterIndexOperationsView
logger = logging.getLogger("parameters changed meta") logger = logging.getLogger("parameters changed meta")
class ParametersChangedMeta(type): class ParametersChangedMeta(type):
@ -20,7 +21,7 @@ class ParametersChangedMeta(type):
self._in_init_ = False self._in_init_ = False
logger.debug("connecting parameters") logger.debug("connecting parameters")
self._highest_parent_._connect_parameters() self._highest_parent_._connect_parameters()
self._highest_parent_._notify_parent_change() #self._highest_parent_._notify_parent_change()
self._highest_parent_._connect_fixes() self._highest_parent_._connect_fixes()
logger.debug("calling parameters changed") logger.debug("calling parameters changed")
self.parameters_changed() self.parameters_changed()
@ -140,6 +141,8 @@ class Parameterized(Parameterizable):
self.priors.shift_right(start, param.size) self.priors.shift_right(start, param.size)
self.constraints.update(param.constraints, self.size) self.constraints.update(param.constraints, self.size)
self.priors.update(param.priors, self.size) self.priors.update(param.priors, self.size)
param._parent_ = self
param._parent_index_ = len(self.parameters)
self.parameters.append(param) self.parameters.append(param)
else: else:
start = sum(p.size for p in self.parameters[:index]) start = sum(p.size for p in self.parameters[:index])
@ -147,19 +150,23 @@ class Parameterized(Parameterizable):
self.priors.shift_right(start, param.size) self.priors.shift_right(start, param.size)
self.constraints.update(param.constraints, start) self.constraints.update(param.constraints, start)
self.priors.update(param.priors, start) self.priors.update(param.priors, start)
param._parent_ = self
param._parent_index_ = index if index>=0 else len(self.parameters[:index])
for p in self.parameters[index:]:
p._parent_index_ += 1
self.parameters.insert(index, param) self.parameters.insert(index, param)
self._notify_parent_change()
param.add_observer(self, self._pass_through_notify_observers, -np.inf) param.add_observer(self, self._pass_through_notify_observers, -np.inf)
parent = self parent = self
while parent is not None: while parent is not None:
parent.size += param.size parent.size += param.size
parent = parent._parent_ parent = parent._parent_
self._notify_parent_change()
if not self._in_init_: if not self._in_init_:
self._connect_parameters() #self._connect_parameters()
self._notify_parent_change() #self._notify_parent_change()
self._highest_parent_._connect_parameters(ignore_added_names=_ignore_added_names) self._highest_parent_._connect_parameters(ignore_added_names=_ignore_added_names)
self._highest_parent_._notify_parent_change() self._highest_parent_._notify_parent_change()