mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-06-11 15:15:15 +02:00
parameters have a update hirarchy, in which updates to parameters are hirarchically pursuit
This commit is contained in:
parent
7f05dd4a43
commit
d12f055115
2 changed files with 595 additions and 564 deletions
|
|
@ -6,6 +6,7 @@ Created on 4 Sep 2013
|
|||
import itertools
|
||||
import numpy
|
||||
from transformations import Logexp, NegativeLogexp, Logistic
|
||||
from GPy.util.misc import fast_array_equal
|
||||
|
||||
###### printing
|
||||
__constraints_name__ = "Constraint"
|
||||
|
|
@ -44,11 +45,13 @@ class Param(numpy.ndarray):
|
|||
obj._name_ = name
|
||||
obj._parent_ = None
|
||||
obj._parent_index_ = None
|
||||
obj._updates_parent_ = None
|
||||
obj._gradient_ = gradient
|
||||
obj._current_slice_ = (slice(obj.shape[0]),)
|
||||
obj._realshape_ = obj.shape
|
||||
obj._realsize_ = obj.size
|
||||
obj._realndim_ = obj.ndim
|
||||
obj._updated_ = False
|
||||
from index_operations import ParamDict
|
||||
obj._tied_to_me_ = ParamDict(set)
|
||||
obj._tied_to_ = []
|
||||
|
|
@ -67,7 +70,10 @@ class Param(numpy.ndarray):
|
|||
self._realshape_ = getattr(obj, '_realshape_', None)
|
||||
self._realsize_ = getattr(obj, '_realsize_', None)
|
||||
self._realndim_ = getattr(obj, '_realndim_', None)
|
||||
self._updated_ = getattr(obj, '_updated_', None)
|
||||
self._original_ = getattr(obj, '_original_', None)
|
||||
self._updates_parent_ = getattr(obj, '_updates_parent_', None)
|
||||
|
||||
def __array_wrap__(self, out_arr, context=None):
|
||||
return out_arr.view(numpy.ndarray)
|
||||
#===========================================================================
|
||||
|
|
@ -79,22 +85,30 @@ class Param(numpy.ndarray):
|
|||
(self._name_,
|
||||
self._parent_,
|
||||
self._parent_index_,
|
||||
self._updates_parent_,
|
||||
self._gradient_,
|
||||
self._current_slice_,
|
||||
self._realshape_,
|
||||
self._realsize_,
|
||||
self._realndim_,
|
||||
)
|
||||
self._tied_to_me_,
|
||||
self._tied_to_,
|
||||
self._updated_,
|
||||
)
|
||||
)
|
||||
def __setstate__(self, state):
|
||||
super(Param, self).__setstate__(state[0])
|
||||
state = list(state[1])
|
||||
self._updated_ = state.pop()
|
||||
self._tied_to_ = state.pop()
|
||||
self._tied_to_me_ = state.pop()
|
||||
self._realndim_ = state.pop()
|
||||
self._realsize_ = state.pop()
|
||||
self._realshape_ = state.pop()
|
||||
self._current_slice_ = state.pop()
|
||||
self._parent_index_ = state.pop()
|
||||
self._gradient_ = state.pop()
|
||||
self._updates_parent_ = state.pop()
|
||||
self._parent_index_ = state.pop()
|
||||
self._parent_ = state.pop()
|
||||
self._name_ = state.pop()
|
||||
#===========================================================================
|
||||
|
|
@ -145,7 +159,7 @@ class Param(numpy.ndarray):
|
|||
#===========================================================================
|
||||
# Constrain operations -> done
|
||||
#===========================================================================
|
||||
def constrain(self, transform, warning=True):
|
||||
def constrain(self, transform, warning=True, update=True):
|
||||
"""
|
||||
:param transform: the :py:class:`GPy.core.transformations.Transformation`
|
||||
to constrain the this parameter to.
|
||||
|
|
@ -155,10 +169,12 @@ class Param(numpy.ndarray):
|
|||
:py:class:`GPy.core.transformations.Transformation`.
|
||||
"""
|
||||
if self._original_: # this happens when indexing created a copy of the array
|
||||
self.__setitem__(slice(None), transform.initialize(self))
|
||||
self.__setitem__(slice(None), transform.initialize(self), update=False)
|
||||
else:
|
||||
self._parent_._get_original(self)[self._current_slice_] = transform.initialize(self)
|
||||
self._parent_._get_original(self).__setitem__(self._current_slice_, transform.initialize(self), update=False)
|
||||
self._parent_._add_constrain(self, transform, warning)
|
||||
if update:
|
||||
self._updates_parent_.parameters_changed()
|
||||
def constrain_positive(self, warning=True):
|
||||
"""
|
||||
:param warning: print a warning if re-constraining parameters.
|
||||
|
|
@ -236,11 +252,8 @@ class Param(numpy.ndarray):
|
|||
"""
|
||||
remove tie of this parameter to ties it was tied to.
|
||||
"""
|
||||
[t._remove_tie_listener(self) for t in self._tied_to_]
|
||||
def set_index(tied_to,untie_from):
|
||||
tied_to._current_slice_ = numpy.array(set(tied_to._raveled_index()) & set(untie_from._raveled_index()))
|
||||
return tied_to
|
||||
self._tied_to_ = [tied_to for tied_to in self._tied_to_ for untie_from in ties if tied_to._parent_index_ == untie_from._parent_index_ and set_index(tied_to)._current_slice_.size > 0]
|
||||
[self._parent_._get_original(t)._remove_tie_listener(self) for t in self._tied_to_]
|
||||
self._tied_to_ = [tied_to for tied_to in self._tied_to_ for t in tied_to._tied_to_me_ if self._parent_index_==self._parent_._get_original(t)._parent_index_]
|
||||
self._parent_._set_unfixed(self)
|
||||
# self._parent_._remove_tie(self, *params)
|
||||
def _fire_changed(self):
|
||||
|
|
@ -250,17 +263,25 @@ class Param(numpy.ndarray):
|
|||
self._tied_to_me_[tied_to_me] |= set(self._raveled_index())
|
||||
def _remove_tie_listener(self, to_remove):
|
||||
for t in self._tied_to_me_.keys():
|
||||
if t._parent_index_ == self._parent_index_:
|
||||
self._tied_to_me_[t] &= set(t._raveled_index())
|
||||
if len(self._tied_to_me_[t]) == 0:
|
||||
if t._parent_index_ == to_remove._parent_index_:
|
||||
new_index = list(set(t._raveled_index()) - set(to_remove._raveled_index()))
|
||||
if new_index:
|
||||
tmp = self._parent_._get_original(t)[numpy.unravel_index(new_index,t._realshape_)]
|
||||
self._tied_to_me_[tmp] = self._tied_to_me_[t]
|
||||
del self._tied_to_me_[t]
|
||||
if len(self._tied_to_me_[tmp]) == 0:
|
||||
del self._tied_to_me_[tmp]
|
||||
else:
|
||||
del self._tied_to_me_[t]
|
||||
def _on_change(self, val, ind):
|
||||
if not numpy.all(self==val[ind]):
|
||||
if not self._updated_: #not fast_array_equal(self, val[ind]):
|
||||
self._updated_ = True
|
||||
if self._original_:
|
||||
self[:] = val[ind]
|
||||
self.__setitem__(slice(None), val[ind], update=False)
|
||||
else: # this happens when indexing created a copy of the array
|
||||
self._parent_._get_original(self)[self._current_slice_] = val[ind]
|
||||
self._parent_._get_original(self).__setitem__(self._current_slice_, val[ind], update=False)
|
||||
self._fire_changed()
|
||||
self._updated_ = False
|
||||
#===========================================================================
|
||||
# Prior Operations
|
||||
#===========================================================================
|
||||
|
|
@ -294,10 +315,13 @@ class Param(numpy.ndarray):
|
|||
return new_arr
|
||||
def __getslice__(self, start, stop):
|
||||
return self.__getitem__(slice(start, stop))
|
||||
def __setitem__(self, *args, **kwargs):
|
||||
numpy.ndarray.__setitem__(self, *args, **kwargs)
|
||||
def __setslice__(self, start, stop, val):
|
||||
return self.__setitem__(slice(start, stop), val)
|
||||
def __setitem__(self, s, val, update=True):
|
||||
numpy.ndarray.__setitem__(self, s, val)
|
||||
self._fire_changed()
|
||||
self._parent_.parameters_changed()
|
||||
if update:
|
||||
self._updates_parent_.parameters_changed()
|
||||
#===========================================================================
|
||||
# Index Operations:
|
||||
#===========================================================================
|
||||
|
|
@ -322,7 +346,7 @@ class Param(numpy.ndarray):
|
|||
extended_realshape = numpy.cumprod((1,) + self._realshape_[:0:-1])[::-1]
|
||||
ind = self._indices(slice_index)
|
||||
if ind.ndim < 2: ind=ind[:,None]
|
||||
return numpy.apply_along_axis(lambda x: numpy.sum(extended_realshape*x), 1, ind)
|
||||
return numpy.asarray(numpy.apply_along_axis(lambda x: numpy.sum(extended_realshape*x), 1, ind), dtype=int)
|
||||
def _expand_index(self, slice_index=None):
|
||||
# this calculates the full indexing arrays from the slicing objects given by get_item for _real..._ attributes
|
||||
# it basically translates slices to their respective index arrays and turns negative indices around
|
||||
|
|
@ -428,19 +452,22 @@ class ParamConcatenation(object):
|
|||
self.params = params
|
||||
self._param_sizes = [p.size for p in self.params]
|
||||
startstops = numpy.cumsum([0] + self._param_sizes)
|
||||
self._param_slices = [slice(start, stop) for start,stop in zip(startstops, startstops[1:])]
|
||||
self._param_slices_ = [slice(start, stop) for start,stop in zip(startstops, startstops[1:])]
|
||||
#===========================================================================
|
||||
# Get/set items, enable broadcasting
|
||||
#===========================================================================
|
||||
def __getitem__(self, s):
|
||||
ind = numpy.zeros(sum(self._param_sizes), dtype=bool); ind[s] = True;
|
||||
params = [p.flatten()[ind[ps]] for p,ps in zip(self.params, self._param_slices) if numpy.any(p.flat[ind[ps]])]
|
||||
params = [p.flatten()[ind[ps]] for p,ps in zip(self.params, self._param_slices_) if numpy.any(p.flat[ind[ps]])]
|
||||
if len(params)==1: return params[0]
|
||||
return ParamConcatenation(params)
|
||||
def __setitem__(self, s, val):
|
||||
def __setitem__(self, s, val, update=True):
|
||||
ind = numpy.zeros(sum(self._param_sizes), dtype=bool); ind[s] = True;
|
||||
vals = self._vals(); vals[s] = val; del val
|
||||
[numpy.place(p, ind[ps], vals[ps]) for p, ps in zip(self.params, self._param_slices)]
|
||||
[numpy.place(p, ind[ps], vals[ps]) and p._fire_changed()
|
||||
for p, ps in zip(self.params, self._param_slices_)]
|
||||
if update:
|
||||
self.params[0]._updates_parent_.parameters_changed()
|
||||
def _vals(self):
|
||||
return numpy.hstack([p._get_params() for p in self.params])
|
||||
#===========================================================================
|
||||
|
|
@ -502,7 +529,7 @@ class ParamConcatenation(object):
|
|||
if __name__ == '__main__':
|
||||
from GPy.core.parameterized import Parameterized
|
||||
#X = numpy.random.randn(2,3,1,5,2,4,3)
|
||||
X = numpy.random.randn(100,20)
|
||||
X = numpy.random.randn(1000,20)
|
||||
print "random done"
|
||||
p = Param("q_mean", X, None)
|
||||
p1 = Param("q_variance", numpy.random.rand(*p.shape), None)
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
Loading…
Add table
Add a link
Reference in a new issue