mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-08 19:42:39 +02:00
Fixed parameterized oddity where it was updating all constrained parameters as soon as any were constrained rather than after all are constrained@
@
This commit is contained in:
parent
1c9151a7d0
commit
b12fb6a2a8
3 changed files with 87 additions and 69 deletions
|
|
@ -12,32 +12,32 @@ __index_name__ = "Index"
|
|||
__tie_name__ = "Tied to"
|
||||
__precision__ = numpy.get_printoptions()['precision'] # numpy printing precision used, sublassing numpy ndarray after all
|
||||
__print_threshold__ = 5
|
||||
######
|
||||
######
|
||||
|
||||
class Float(numpy.float64, Constrainable):
|
||||
def __init__(self, f, base):
|
||||
super(Float,self).__init__(f)
|
||||
self._base = base
|
||||
|
||||
|
||||
|
||||
|
||||
class Param(ObservableArray, Constrainable):
|
||||
"""
|
||||
Parameter object for GPy models.
|
||||
|
||||
:param name: name of the parameter to be printed
|
||||
:param input_array: array which this parameter handles
|
||||
|
||||
|
||||
You can add/remove constraints by calling constrain on the parameter itself, e.g:
|
||||
|
||||
|
||||
- self[:,1].constrain_positive()
|
||||
- self[0].tie_to(other)
|
||||
- self.untie()
|
||||
- self[:3,:].unconstrain()
|
||||
- self[1].fix()
|
||||
|
||||
|
||||
Fixing parameters will fix them to the value they are right now. If you change
|
||||
the fixed value, it will be fixed to the new value!
|
||||
|
||||
|
||||
See :py:class:`GPy.core.parameterized.Parameterized` for more details on constraining etc.
|
||||
|
||||
This ndarray can be stored in lists and checked if it is in.
|
||||
|
|
@ -46,7 +46,7 @@ class Param(ObservableArray, Constrainable):
|
|||
>>> x = np.random.normal(size=(10,3))
|
||||
>>> x in [[1], x, [3]]
|
||||
True
|
||||
|
||||
|
||||
WARNING: This overrides the functionality of x==y!!!
|
||||
Use numpy.equal(x,y) for element-wise equality testing.
|
||||
"""
|
||||
|
|
@ -68,7 +68,7 @@ class Param(ObservableArray, Constrainable):
|
|||
|
||||
def __init__(self, name, input_array):
|
||||
super(Param, self).__init__(name=name)
|
||||
|
||||
|
||||
def __array_finalize__(self, obj):
|
||||
# see InfoArray.__array_finalize__ for comments
|
||||
if obj is None: return
|
||||
|
|
@ -86,7 +86,7 @@ class Param(ObservableArray, Constrainable):
|
|||
self._original_ = getattr(obj, '_original_', None)
|
||||
self._name = getattr(obj, 'name', None)
|
||||
self.gradient = getattr(obj, 'gradient', None)
|
||||
|
||||
|
||||
def __array_wrap__(self, out_arr, context=None):
|
||||
return out_arr.view(numpy.ndarray)
|
||||
#===========================================================================
|
||||
|
|
@ -94,7 +94,7 @@ class Param(ObservableArray, Constrainable):
|
|||
#===========================================================================
|
||||
def __reduce_ex__(self):
|
||||
func, args, state = super(Param, self).__reduce__()
|
||||
return func, args, (state,
|
||||
return func, args, (state,
|
||||
(self.name,
|
||||
self._direct_parent_,
|
||||
self._parent_index_,
|
||||
|
|
@ -132,13 +132,13 @@ class Param(ObservableArray, Constrainable):
|
|||
self.flat = param
|
||||
self._notify_tied_parameters()
|
||||
self._notify_observers()
|
||||
|
||||
|
||||
def _get_params(self):
|
||||
return self.flat
|
||||
# @property
|
||||
# def name(self):
|
||||
# """
|
||||
# Name of this parameter.
|
||||
# Name of this parameter.
|
||||
# This can be a callable without parameters. The callable will be called
|
||||
# every time the name property is accessed.
|
||||
# """
|
||||
|
|
@ -163,7 +163,7 @@ class Param(ObservableArray, Constrainable):
|
|||
def constrain_fixed(self, warning=True):
|
||||
"""
|
||||
Constrain this paramter to be fixed to the current value it carries.
|
||||
|
||||
|
||||
:param warning: print a warning for overwriting constraints.
|
||||
"""
|
||||
self._highest_parent_._fix(self,warning)
|
||||
|
|
@ -179,18 +179,18 @@ class Param(ObservableArray, Constrainable):
|
|||
#===========================================================================
|
||||
def tie_to(self, param):
|
||||
"""
|
||||
:param param: the parameter object to tie this parameter to.
|
||||
:param param: the parameter object to tie this parameter to.
|
||||
Can be ParamConcatenation (retrieved by regexp search)
|
||||
|
||||
|
||||
Tie this parameter to the given parameter.
|
||||
Broadcasting is not allowed, but you can tie a whole dimension to
|
||||
one parameter: self[:,0].tie_to(other), where other is a one-value
|
||||
parameter.
|
||||
|
||||
|
||||
Note: For now only one parameter can have ties, so all of a parameter
|
||||
will be removed, when re-tieing!
|
||||
"""
|
||||
#Note: this method will tie to the parameter which is the last in
|
||||
#Note: this method will tie to the parameter which is the last in
|
||||
# the chain of ties. Thus, if you tie to a tied parameter,
|
||||
# this tie will be created to the parameter the param is tied
|
||||
# to.
|
||||
|
|
@ -200,12 +200,12 @@ class Param(ObservableArray, Constrainable):
|
|||
if param.size != 1:
|
||||
raise NotImplementedError, "Broadcast tying is not implemented yet"
|
||||
try:
|
||||
if self._original_:
|
||||
if self._original_:
|
||||
self[:] = param
|
||||
else: # this happens when indexing created a copy of the array
|
||||
self._direct_parent_._get_original(self)[self._current_slice_] = param
|
||||
except ValueError:
|
||||
raise ValueError("Trying to tie {} with shape {} to {} with shape {}".format(self.name, self.shape, param.name, param.shape))
|
||||
raise ValueError("Trying to tie {} with shape {} to {} with shape {}".format(self.name, self.shape, param.name, param.shape))
|
||||
if param is self:
|
||||
raise RuntimeError, 'Cyclic tieing is not allowed'
|
||||
# if len(param._tied_to_) > 0:
|
||||
|
|
@ -293,7 +293,7 @@ class Param(ObservableArray, Constrainable):
|
|||
def unset_prior(self, *priors):
|
||||
"""
|
||||
:param priors: priors to remove from this parameter
|
||||
|
||||
|
||||
Remove all priors from this parameter
|
||||
"""
|
||||
self._highest_parent_._remove_prior(self, *priors)
|
||||
|
|
@ -324,7 +324,7 @@ class Param(ObservableArray, Constrainable):
|
|||
if numpy.all(si == Ellipsis):
|
||||
continue
|
||||
if isinstance(si, slice):
|
||||
a = si.indices(self._realshape_[i])[0]
|
||||
a = si.indices(self._realshape_[i])[0]
|
||||
elif isinstance(si, (list,numpy.ndarray,tuple)):
|
||||
a = si[0]
|
||||
else: a = si
|
||||
|
|
@ -424,7 +424,7 @@ class Param(ObservableArray, Constrainable):
|
|||
slice_index = self._current_slice_
|
||||
if isinstance(slice_index, (tuple, list)):
|
||||
clean_curr_slice = [s for s in slice_index if numpy.any(s != Ellipsis)]
|
||||
if (all(isinstance(n, (numpy.ndarray, list, tuple)) for n in clean_curr_slice)
|
||||
if (all(isinstance(n, (numpy.ndarray, list, tuple)) for n in clean_curr_slice)
|
||||
and len(set(map(len,clean_curr_slice))) <= 1):
|
||||
return numpy.fromiter(itertools.izip(*clean_curr_slice),
|
||||
dtype=[('',int)]*self._realndim_,count=len(clean_curr_slice[0])).view((int, self._realndim_))
|
||||
|
|
@ -443,7 +443,7 @@ class Param(ObservableArray, Constrainable):
|
|||
if self._realsize_ < 2:
|
||||
return name
|
||||
ind = self._indices()
|
||||
if ind.size > 4: indstr = ','.join(map(str,ind[:2])) + "..." + ','.join(map(str,ind[-2:]))
|
||||
if ind.size > 4: indstr = ','.join(map(str,ind[:2])) + "..." + ','.join(map(str,ind[-2:]))
|
||||
else: indstr = ','.join(map(str,ind))
|
||||
return name+'['+indstr+']'
|
||||
def __str__(self, constr_matrix=None, indices=None, ties=None, lc=None, lx=None, li=None, lt=None):
|
||||
|
|
@ -477,7 +477,7 @@ class ParamConcatenation(object):
|
|||
for p in params:
|
||||
for p in p.flattened_parameters:
|
||||
if p not in self.params:
|
||||
self.params.append(p)
|
||||
self.params.append(p)
|
||||
self._param_sizes = [p.size for p in self.params]
|
||||
startstops = numpy.cumsum([0] + self._param_sizes)
|
||||
self._param_slices_ = [slice(start, stop) for start,stop in zip(startstops, startstops[1:])]
|
||||
|
|
@ -485,15 +485,15 @@ class ParamConcatenation(object):
|
|||
# Get/set items, enable broadcasting
|
||||
#===========================================================================
|
||||
def __getitem__(self, s):
|
||||
ind = numpy.zeros(sum(self._param_sizes), dtype=bool); ind[s] = True;
|
||||
ind = numpy.zeros(sum(self._param_sizes), dtype=bool); ind[s] = True;
|
||||
params = [p._get_params()[ind[ps]] for p,ps in zip(self.params, self._param_slices_) if numpy.any(p._get_params()[ind[ps]])]
|
||||
if len(params)==1: return params[0]
|
||||
return ParamConcatenation(params)
|
||||
def __setitem__(self, s, val, update=True):
|
||||
ind = numpy.zeros(sum(self._param_sizes), dtype=bool); ind[s] = True;
|
||||
ind = numpy.zeros(sum(self._param_sizes), dtype=bool); ind[s] = True;
|
||||
vals = self._vals(); vals[s] = val; del val
|
||||
[numpy.place(p, ind[ps], vals[ps]) and p._notify_tied_parameters()
|
||||
for p, ps in zip(self.params, self._param_slices_)]
|
||||
[numpy.place(p, ind[ps], vals[ps]) and p._notify_tied_parameters()
|
||||
for p, ps in zip(self.params, self._param_slices_)]
|
||||
if update:
|
||||
self.params[0]._highest_parent_.parameters_changed()
|
||||
def _vals(self):
|
||||
|
|
@ -501,38 +501,55 @@ class ParamConcatenation(object):
|
|||
#===========================================================================
|
||||
# parameter operations:
|
||||
#===========================================================================
|
||||
def update_all_params(self):
|
||||
self.params[0]._highest_parent_.parameters_changed()
|
||||
|
||||
def constrain(self, constraint, warning=True):
|
||||
[param.constrain(constraint) for param in self.params]
|
||||
[param.constrain(constraint, update=False) for param in self.params]
|
||||
self.update_all_params()
|
||||
constrain.__doc__ = Param.constrain.__doc__
|
||||
|
||||
def constrain_positive(self, warning=True):
|
||||
[param.constrain_positive(warning) for param in self.params]
|
||||
[param.constrain_positive(warning, update=False) for param in self.params]
|
||||
self.update_all_params()
|
||||
constrain_positive.__doc__ = Param.constrain_positive.__doc__
|
||||
|
||||
def constrain_fixed(self, warning=True):
|
||||
[param.constrain_fixed(warning) for param in self.params]
|
||||
constrain_fixed.__doc__ = Param.constrain_fixed.__doc__
|
||||
fix = constrain_fixed
|
||||
|
||||
def constrain_negative(self, warning=True):
|
||||
[param.constrain_negative(warning) for param in self.params]
|
||||
[param.constrain_negative(warning, update=False) for param in self.params]
|
||||
self.update_all_params()
|
||||
constrain_negative.__doc__ = Param.constrain_negative.__doc__
|
||||
|
||||
def constrain_bounded(self, lower, upper, warning=True):
|
||||
[param.constrain_bounded(lower, upper, warning) for param in self.params]
|
||||
[param.constrain_bounded(lower, upper, warning, update=False) for param in self.params]
|
||||
self.update_all_params()
|
||||
constrain_bounded.__doc__ = Param.constrain_bounded.__doc__
|
||||
|
||||
def unconstrain(self, *constraints):
|
||||
[param.unconstrain(*constraints) for param in self.params]
|
||||
unconstrain.__doc__ = Param.unconstrain.__doc__
|
||||
|
||||
def unconstrain_negative(self):
|
||||
[param.unconstrain_negative() for param in self.params]
|
||||
unconstrain_negative.__doc__ = Param.unconstrain_negative.__doc__
|
||||
|
||||
def unconstrain_positive(self):
|
||||
[param.unconstrain_positive() for param in self.params]
|
||||
unconstrain_positive.__doc__ = Param.unconstrain_positive.__doc__
|
||||
|
||||
def unconstrain_fixed(self):
|
||||
[param.unconstrain_fixed() for param in self.params]
|
||||
unconstrain_fixed.__doc__ = Param.unconstrain_fixed.__doc__
|
||||
unfix = unconstrain_fixed
|
||||
|
||||
def unconstrain_bounded(self, lower, upper):
|
||||
[param.unconstrain_bounded(lower, upper) for param in self.params]
|
||||
unconstrain_bounded.__doc__ = Param.unconstrain_bounded.__doc__
|
||||
|
||||
def untie(self, *ties):
|
||||
[param.untie(*ties) for param in self.params]
|
||||
__lt__ = lambda self, val: self._vals()<val
|
||||
|
|
@ -557,9 +574,9 @@ class ParamConcatenation(object):
|
|||
return "\n{}\n".format(" -"+"- | -".join(['-'*l for l in [li,lx,lc,lt]])).join(strings)
|
||||
def __repr__(self):
|
||||
return "\n".join(map(repr,self.params))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
|
||||
|
||||
from GPy.core.parameterized import Parameterized
|
||||
from GPy.core.parameter import Param
|
||||
|
|
@ -570,16 +587,16 @@ if __name__ == '__main__':
|
|||
p = Param("q_mean", X)
|
||||
p1 = Param("q_variance", numpy.random.rand(*p.shape))
|
||||
p2 = Param("Y", numpy.random.randn(p.shape[0],1))
|
||||
|
||||
|
||||
p3 = Param("variance", numpy.random.rand())
|
||||
p4 = Param("lengthscale", numpy.random.rand(2))
|
||||
|
||||
|
||||
m = Parameterized()
|
||||
rbf = Parameterized(name='rbf')
|
||||
|
||||
|
||||
rbf.add_parameter(p3,p4)
|
||||
m.add_parameter(p,p1,rbf)
|
||||
|
||||
|
||||
print "setting params"
|
||||
#print m.q_v[3:5,[1,4,5]]
|
||||
print "constraining variance"
|
||||
|
|
|
|||
|
|
@ -20,7 +20,7 @@ class Observable(object):
|
|||
def _notify_observers(self):
|
||||
[callble(self) for callble in self._observers_.itervalues()]
|
||||
|
||||
|
||||
|
||||
class Pickleable(object):
|
||||
def _getstate(self):
|
||||
"""
|
||||
|
|
@ -36,9 +36,9 @@ class Pickleable(object):
|
|||
Set the state (memento pattern) of this class to the given state.
|
||||
Usually this is just the counterpart to _getstate, such that
|
||||
an object is a copy of another when calling
|
||||
|
||||
|
||||
copy = <classname>.__new__(*args,**kw)._setstate(<to_be_copied>._getstate())
|
||||
|
||||
|
||||
See python doc "pickling" (`__getstate__` and `__setstate__`) for details.
|
||||
"""
|
||||
raise NotImplementedError, "To be able to use pickling you need to implement this method"
|
||||
|
|
@ -49,14 +49,14 @@ class Pickleable(object):
|
|||
|
||||
class Parentable(object):
|
||||
def __init__(self, direct_parent=None, highest_parent=None, parent_index=None):
|
||||
super(Parentable,self).__init__()
|
||||
super(Parentable,self).__init__()
|
||||
self._direct_parent_ = direct_parent
|
||||
self._parent_index_ = parent_index
|
||||
self._highest_parent_ = highest_parent
|
||||
|
||||
|
||||
def has_parent(self):
|
||||
return self._direct_parent_ is not None and self._highest_parent_ is not None
|
||||
|
||||
|
||||
class Nameable(Parentable):
|
||||
_name = None
|
||||
def __init__(self, name, direct_parent=None, highest_parent=None, parent_index=None):
|
||||
|
|
@ -69,10 +69,10 @@ class Nameable(Parentable):
|
|||
@name.setter
|
||||
def name(self, name):
|
||||
from_name = self.name
|
||||
self._name = name
|
||||
self._name = name
|
||||
if self.has_parent():
|
||||
self._direct_parent_._name_changed(self, from_name)
|
||||
|
||||
|
||||
class Constrainable(Nameable):
|
||||
def __init__(self, name):
|
||||
super(Constrainable,self).__init__(name)
|
||||
|
|
@ -84,7 +84,7 @@ class Constrainable(Nameable):
|
|||
:param transform: the :py:class:`GPy.core.transformations.Transformation`
|
||||
to constrain the this parameter to.
|
||||
:param warning: print a warning if re-constraining parameters.
|
||||
|
||||
|
||||
Constrain the parameter to the given
|
||||
:py:class:`GPy.core.transformations.Transformation`.
|
||||
"""
|
||||
|
|
@ -97,37 +97,37 @@ class Constrainable(Nameable):
|
|||
self._add_constrain(p, transform, warning)
|
||||
if update:
|
||||
self.parameters_changed()
|
||||
|
||||
def constrain_positive(self, warning=True):
|
||||
|
||||
def constrain_positive(self, warning=True, update=True):
|
||||
"""
|
||||
:param warning: print a warning if re-constraining parameters.
|
||||
|
||||
|
||||
Constrain this parameter to the default positive constraint.
|
||||
"""
|
||||
self.constrain(Logexp(), warning)
|
||||
self.constrain(Logexp(), warning=warning, update=update)
|
||||
|
||||
def constrain_negative(self, warning=True):
|
||||
def constrain_negative(self, warning=True, update=True):
|
||||
"""
|
||||
:param warning: print a warning if re-constraining parameters.
|
||||
|
||||
|
||||
Constrain this parameter to the default negative constraint.
|
||||
"""
|
||||
self.constrain(NegativeLogexp(), warning)
|
||||
self.constrain(NegativeLogexp(), warning=warning, update=update)
|
||||
|
||||
def constrain_bounded(self, lower, upper, warning=True):
|
||||
def constrain_bounded(self, lower, upper, warning=True, update=True):
|
||||
"""
|
||||
:param lower, upper: the limits to bound this parameter to
|
||||
:param warning: print a warning if re-constraining parameters.
|
||||
|
||||
|
||||
Constrain this parameter to lie within the given range.
|
||||
"""
|
||||
self.constrain(Logistic(lower, upper), warning)
|
||||
self.constrain(Logistic(lower, upper), warning=warning, update=update)
|
||||
|
||||
def unconstrain(self, *transforms):
|
||||
"""
|
||||
:param transforms: The transformations to unconstrain from.
|
||||
|
||||
remove all :py:class:`GPy.core.transformations.Transformation`
|
||||
|
||||
remove all :py:class:`GPy.core.transformations.Transformation`
|
||||
transformats of this parameter object.
|
||||
"""
|
||||
if self.has_parent():
|
||||
|
|
@ -138,20 +138,20 @@ class Constrainable(Nameable):
|
|||
|
||||
def unconstrain_positive(self):
|
||||
"""
|
||||
Remove positive constraint of this parameter.
|
||||
Remove positive constraint of this parameter.
|
||||
"""
|
||||
self.unconstrain(Logexp())
|
||||
|
||||
def unconstrain_negative(self):
|
||||
"""
|
||||
Remove negative constraint of this parameter.
|
||||
Remove negative constraint of this parameter.
|
||||
"""
|
||||
self.unconstrain(NegativeLogexp())
|
||||
|
||||
def unconstrain_bounded(self, lower, upper):
|
||||
"""
|
||||
:param lower, upper: the limits to unbound this parameter from
|
||||
|
||||
|
||||
Remove (lower, upper) bounded constrain from this parameter/
|
||||
"""
|
||||
self.unconstrain(Logistic(lower, upper))
|
||||
|
|
|
|||
|
|
@ -233,7 +233,7 @@ class Parameterized(Constrainable, Pickleable, Observable):
|
|||
elif not (pname in not_unique):
|
||||
self.__dict__[pname] = p
|
||||
self._added_names_.add(pname)
|
||||
|
||||
|
||||
def _connect_highest_parent(self, highest_parent):
|
||||
self._highest_parent_ = highest_parent
|
||||
if not hasattr(self, "_parameters_") or len(self._parameters_) < 1:
|
||||
|
|
@ -242,7 +242,7 @@ class Parameterized(Constrainable, Pickleable, Observable):
|
|||
for p in self._parameters_:
|
||||
p._highest_parent_ = highest_parent
|
||||
p._connect_highest_parent(highest_parent)
|
||||
|
||||
|
||||
#===========================================================================
|
||||
# Pickling operations
|
||||
#===========================================================================
|
||||
|
|
@ -413,10 +413,10 @@ class Parameterized(Constrainable, Pickleable, Observable):
|
|||
#===========================================================================
|
||||
# Fixing parameters:
|
||||
#===========================================================================
|
||||
def _fix(self, param, warning=True):
|
||||
def _fix(self, param, warning=True, update=True):
|
||||
f = self._add_constrain(param, __fixed__, warning)
|
||||
self._set_fixed(f)
|
||||
def _unfix(self, param):
|
||||
def _unfix(self, param, update=True):
|
||||
if self._has_fixes():
|
||||
f = self._remove_constrain(param, __fixed__)
|
||||
self._set_unfixed(f)
|
||||
|
|
@ -438,7 +438,8 @@ class Parameterized(Constrainable, Pickleable, Observable):
|
|||
# if advanced indexing is activated it happens that the array is a copy
|
||||
# you can retrieve the original param through this method, by passing
|
||||
# the copy here
|
||||
return self._parameters_[param._parent_index_]
|
||||
#return self._parameters_[param._parent_index_]
|
||||
return param._direct_parent_._parameters_[param._parent_index_]
|
||||
def hirarchy_name(self):
|
||||
if self.has_parent():
|
||||
return self._direct_parent_.hirarchy_name() + adjust_name_for_printing(self.name) + "."
|
||||
|
|
@ -452,7 +453,7 @@ class Parameterized(Constrainable, Pickleable, Observable):
|
|||
# if removing constraints before adding new is not wanted, just delete the above line!
|
||||
self.constraints.add(transform, rav_i)
|
||||
param = self._get_original(param)
|
||||
param._set_params(transform.initialize(param._get_params()))
|
||||
param._set_params(transform.initialize(param._get_params()), update=False)
|
||||
if warning and any(reconstrained):
|
||||
# if you want to print the whole params object, which was reconstrained use:
|
||||
# m = str(param[self._backtranslate_index(param, reconstrained)])
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue