mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-09 03:52:39 +02:00
Fixed parameterized oddity where it was updating all constrained parameters as soon as any were constrained rather than after all are constrained@
@
This commit is contained in:
parent
1c9151a7d0
commit
b12fb6a2a8
3 changed files with 87 additions and 69 deletions
|
|
@ -12,32 +12,32 @@ __index_name__ = "Index"
|
||||||
__tie_name__ = "Tied to"
|
__tie_name__ = "Tied to"
|
||||||
__precision__ = numpy.get_printoptions()['precision'] # numpy printing precision used, sublassing numpy ndarray after all
|
__precision__ = numpy.get_printoptions()['precision'] # numpy printing precision used, sublassing numpy ndarray after all
|
||||||
__print_threshold__ = 5
|
__print_threshold__ = 5
|
||||||
######
|
######
|
||||||
|
|
||||||
class Float(numpy.float64, Constrainable):
|
class Float(numpy.float64, Constrainable):
|
||||||
def __init__(self, f, base):
|
def __init__(self, f, base):
|
||||||
super(Float,self).__init__(f)
|
super(Float,self).__init__(f)
|
||||||
self._base = base
|
self._base = base
|
||||||
|
|
||||||
|
|
||||||
class Param(ObservableArray, Constrainable):
|
class Param(ObservableArray, Constrainable):
|
||||||
"""
|
"""
|
||||||
Parameter object for GPy models.
|
Parameter object for GPy models.
|
||||||
|
|
||||||
:param name: name of the parameter to be printed
|
:param name: name of the parameter to be printed
|
||||||
:param input_array: array which this parameter handles
|
:param input_array: array which this parameter handles
|
||||||
|
|
||||||
You can add/remove constraints by calling constrain on the parameter itself, e.g:
|
You can add/remove constraints by calling constrain on the parameter itself, e.g:
|
||||||
|
|
||||||
- self[:,1].constrain_positive()
|
- self[:,1].constrain_positive()
|
||||||
- self[0].tie_to(other)
|
- self[0].tie_to(other)
|
||||||
- self.untie()
|
- self.untie()
|
||||||
- self[:3,:].unconstrain()
|
- self[:3,:].unconstrain()
|
||||||
- self[1].fix()
|
- self[1].fix()
|
||||||
|
|
||||||
Fixing parameters will fix them to the value they are right now. If you change
|
Fixing parameters will fix them to the value they are right now. If you change
|
||||||
the fixed value, it will be fixed to the new value!
|
the fixed value, it will be fixed to the new value!
|
||||||
|
|
||||||
See :py:class:`GPy.core.parameterized.Parameterized` for more details on constraining etc.
|
See :py:class:`GPy.core.parameterized.Parameterized` for more details on constraining etc.
|
||||||
|
|
||||||
This ndarray can be stored in lists and checked if it is in.
|
This ndarray can be stored in lists and checked if it is in.
|
||||||
|
|
@ -46,7 +46,7 @@ class Param(ObservableArray, Constrainable):
|
||||||
>>> x = np.random.normal(size=(10,3))
|
>>> x = np.random.normal(size=(10,3))
|
||||||
>>> x in [[1], x, [3]]
|
>>> x in [[1], x, [3]]
|
||||||
True
|
True
|
||||||
|
|
||||||
WARNING: This overrides the functionality of x==y!!!
|
WARNING: This overrides the functionality of x==y!!!
|
||||||
Use numpy.equal(x,y) for element-wise equality testing.
|
Use numpy.equal(x,y) for element-wise equality testing.
|
||||||
"""
|
"""
|
||||||
|
|
@ -68,7 +68,7 @@ class Param(ObservableArray, Constrainable):
|
||||||
|
|
||||||
def __init__(self, name, input_array):
|
def __init__(self, name, input_array):
|
||||||
super(Param, self).__init__(name=name)
|
super(Param, self).__init__(name=name)
|
||||||
|
|
||||||
def __array_finalize__(self, obj):
|
def __array_finalize__(self, obj):
|
||||||
# see InfoArray.__array_finalize__ for comments
|
# see InfoArray.__array_finalize__ for comments
|
||||||
if obj is None: return
|
if obj is None: return
|
||||||
|
|
@ -86,7 +86,7 @@ class Param(ObservableArray, Constrainable):
|
||||||
self._original_ = getattr(obj, '_original_', None)
|
self._original_ = getattr(obj, '_original_', None)
|
||||||
self._name = getattr(obj, 'name', None)
|
self._name = getattr(obj, 'name', None)
|
||||||
self.gradient = getattr(obj, 'gradient', None)
|
self.gradient = getattr(obj, 'gradient', None)
|
||||||
|
|
||||||
def __array_wrap__(self, out_arr, context=None):
|
def __array_wrap__(self, out_arr, context=None):
|
||||||
return out_arr.view(numpy.ndarray)
|
return out_arr.view(numpy.ndarray)
|
||||||
#===========================================================================
|
#===========================================================================
|
||||||
|
|
@ -94,7 +94,7 @@ class Param(ObservableArray, Constrainable):
|
||||||
#===========================================================================
|
#===========================================================================
|
||||||
def __reduce_ex__(self):
|
def __reduce_ex__(self):
|
||||||
func, args, state = super(Param, self).__reduce__()
|
func, args, state = super(Param, self).__reduce__()
|
||||||
return func, args, (state,
|
return func, args, (state,
|
||||||
(self.name,
|
(self.name,
|
||||||
self._direct_parent_,
|
self._direct_parent_,
|
||||||
self._parent_index_,
|
self._parent_index_,
|
||||||
|
|
@ -132,13 +132,13 @@ class Param(ObservableArray, Constrainable):
|
||||||
self.flat = param
|
self.flat = param
|
||||||
self._notify_tied_parameters()
|
self._notify_tied_parameters()
|
||||||
self._notify_observers()
|
self._notify_observers()
|
||||||
|
|
||||||
def _get_params(self):
|
def _get_params(self):
|
||||||
return self.flat
|
return self.flat
|
||||||
# @property
|
# @property
|
||||||
# def name(self):
|
# def name(self):
|
||||||
# """
|
# """
|
||||||
# Name of this parameter.
|
# Name of this parameter.
|
||||||
# This can be a callable without parameters. The callable will be called
|
# This can be a callable without parameters. The callable will be called
|
||||||
# every time the name property is accessed.
|
# every time the name property is accessed.
|
||||||
# """
|
# """
|
||||||
|
|
@ -163,7 +163,7 @@ class Param(ObservableArray, Constrainable):
|
||||||
def constrain_fixed(self, warning=True):
|
def constrain_fixed(self, warning=True):
|
||||||
"""
|
"""
|
||||||
Constrain this paramter to be fixed to the current value it carries.
|
Constrain this paramter to be fixed to the current value it carries.
|
||||||
|
|
||||||
:param warning: print a warning for overwriting constraints.
|
:param warning: print a warning for overwriting constraints.
|
||||||
"""
|
"""
|
||||||
self._highest_parent_._fix(self,warning)
|
self._highest_parent_._fix(self,warning)
|
||||||
|
|
@ -179,18 +179,18 @@ class Param(ObservableArray, Constrainable):
|
||||||
#===========================================================================
|
#===========================================================================
|
||||||
def tie_to(self, param):
|
def tie_to(self, param):
|
||||||
"""
|
"""
|
||||||
:param param: the parameter object to tie this parameter to.
|
:param param: the parameter object to tie this parameter to.
|
||||||
Can be ParamConcatenation (retrieved by regexp search)
|
Can be ParamConcatenation (retrieved by regexp search)
|
||||||
|
|
||||||
Tie this parameter to the given parameter.
|
Tie this parameter to the given parameter.
|
||||||
Broadcasting is not allowed, but you can tie a whole dimension to
|
Broadcasting is not allowed, but you can tie a whole dimension to
|
||||||
one parameter: self[:,0].tie_to(other), where other is a one-value
|
one parameter: self[:,0].tie_to(other), where other is a one-value
|
||||||
parameter.
|
parameter.
|
||||||
|
|
||||||
Note: For now only one parameter can have ties, so all of a parameter
|
Note: For now only one parameter can have ties, so all of a parameter
|
||||||
will be removed, when re-tieing!
|
will be removed, when re-tieing!
|
||||||
"""
|
"""
|
||||||
#Note: this method will tie to the parameter which is the last in
|
#Note: this method will tie to the parameter which is the last in
|
||||||
# the chain of ties. Thus, if you tie to a tied parameter,
|
# the chain of ties. Thus, if you tie to a tied parameter,
|
||||||
# this tie will be created to the parameter the param is tied
|
# this tie will be created to the parameter the param is tied
|
||||||
# to.
|
# to.
|
||||||
|
|
@ -200,12 +200,12 @@ class Param(ObservableArray, Constrainable):
|
||||||
if param.size != 1:
|
if param.size != 1:
|
||||||
raise NotImplementedError, "Broadcast tying is not implemented yet"
|
raise NotImplementedError, "Broadcast tying is not implemented yet"
|
||||||
try:
|
try:
|
||||||
if self._original_:
|
if self._original_:
|
||||||
self[:] = param
|
self[:] = param
|
||||||
else: # this happens when indexing created a copy of the array
|
else: # this happens when indexing created a copy of the array
|
||||||
self._direct_parent_._get_original(self)[self._current_slice_] = param
|
self._direct_parent_._get_original(self)[self._current_slice_] = param
|
||||||
except ValueError:
|
except ValueError:
|
||||||
raise ValueError("Trying to tie {} with shape {} to {} with shape {}".format(self.name, self.shape, param.name, param.shape))
|
raise ValueError("Trying to tie {} with shape {} to {} with shape {}".format(self.name, self.shape, param.name, param.shape))
|
||||||
if param is self:
|
if param is self:
|
||||||
raise RuntimeError, 'Cyclic tieing is not allowed'
|
raise RuntimeError, 'Cyclic tieing is not allowed'
|
||||||
# if len(param._tied_to_) > 0:
|
# if len(param._tied_to_) > 0:
|
||||||
|
|
@ -293,7 +293,7 @@ class Param(ObservableArray, Constrainable):
|
||||||
def unset_prior(self, *priors):
|
def unset_prior(self, *priors):
|
||||||
"""
|
"""
|
||||||
:param priors: priors to remove from this parameter
|
:param priors: priors to remove from this parameter
|
||||||
|
|
||||||
Remove all priors from this parameter
|
Remove all priors from this parameter
|
||||||
"""
|
"""
|
||||||
self._highest_parent_._remove_prior(self, *priors)
|
self._highest_parent_._remove_prior(self, *priors)
|
||||||
|
|
@ -324,7 +324,7 @@ class Param(ObservableArray, Constrainable):
|
||||||
if numpy.all(si == Ellipsis):
|
if numpy.all(si == Ellipsis):
|
||||||
continue
|
continue
|
||||||
if isinstance(si, slice):
|
if isinstance(si, slice):
|
||||||
a = si.indices(self._realshape_[i])[0]
|
a = si.indices(self._realshape_[i])[0]
|
||||||
elif isinstance(si, (list,numpy.ndarray,tuple)):
|
elif isinstance(si, (list,numpy.ndarray,tuple)):
|
||||||
a = si[0]
|
a = si[0]
|
||||||
else: a = si
|
else: a = si
|
||||||
|
|
@ -424,7 +424,7 @@ class Param(ObservableArray, Constrainable):
|
||||||
slice_index = self._current_slice_
|
slice_index = self._current_slice_
|
||||||
if isinstance(slice_index, (tuple, list)):
|
if isinstance(slice_index, (tuple, list)):
|
||||||
clean_curr_slice = [s for s in slice_index if numpy.any(s != Ellipsis)]
|
clean_curr_slice = [s for s in slice_index if numpy.any(s != Ellipsis)]
|
||||||
if (all(isinstance(n, (numpy.ndarray, list, tuple)) for n in clean_curr_slice)
|
if (all(isinstance(n, (numpy.ndarray, list, tuple)) for n in clean_curr_slice)
|
||||||
and len(set(map(len,clean_curr_slice))) <= 1):
|
and len(set(map(len,clean_curr_slice))) <= 1):
|
||||||
return numpy.fromiter(itertools.izip(*clean_curr_slice),
|
return numpy.fromiter(itertools.izip(*clean_curr_slice),
|
||||||
dtype=[('',int)]*self._realndim_,count=len(clean_curr_slice[0])).view((int, self._realndim_))
|
dtype=[('',int)]*self._realndim_,count=len(clean_curr_slice[0])).view((int, self._realndim_))
|
||||||
|
|
@ -443,7 +443,7 @@ class Param(ObservableArray, Constrainable):
|
||||||
if self._realsize_ < 2:
|
if self._realsize_ < 2:
|
||||||
return name
|
return name
|
||||||
ind = self._indices()
|
ind = self._indices()
|
||||||
if ind.size > 4: indstr = ','.join(map(str,ind[:2])) + "..." + ','.join(map(str,ind[-2:]))
|
if ind.size > 4: indstr = ','.join(map(str,ind[:2])) + "..." + ','.join(map(str,ind[-2:]))
|
||||||
else: indstr = ','.join(map(str,ind))
|
else: indstr = ','.join(map(str,ind))
|
||||||
return name+'['+indstr+']'
|
return name+'['+indstr+']'
|
||||||
def __str__(self, constr_matrix=None, indices=None, ties=None, lc=None, lx=None, li=None, lt=None):
|
def __str__(self, constr_matrix=None, indices=None, ties=None, lc=None, lx=None, li=None, lt=None):
|
||||||
|
|
@ -477,7 +477,7 @@ class ParamConcatenation(object):
|
||||||
for p in params:
|
for p in params:
|
||||||
for p in p.flattened_parameters:
|
for p in p.flattened_parameters:
|
||||||
if p not in self.params:
|
if p not in self.params:
|
||||||
self.params.append(p)
|
self.params.append(p)
|
||||||
self._param_sizes = [p.size for p in self.params]
|
self._param_sizes = [p.size for p in self.params]
|
||||||
startstops = numpy.cumsum([0] + self._param_sizes)
|
startstops = numpy.cumsum([0] + self._param_sizes)
|
||||||
self._param_slices_ = [slice(start, stop) for start,stop in zip(startstops, startstops[1:])]
|
self._param_slices_ = [slice(start, stop) for start,stop in zip(startstops, startstops[1:])]
|
||||||
|
|
@ -485,15 +485,15 @@ class ParamConcatenation(object):
|
||||||
# Get/set items, enable broadcasting
|
# Get/set items, enable broadcasting
|
||||||
#===========================================================================
|
#===========================================================================
|
||||||
def __getitem__(self, s):
|
def __getitem__(self, s):
|
||||||
ind = numpy.zeros(sum(self._param_sizes), dtype=bool); ind[s] = True;
|
ind = numpy.zeros(sum(self._param_sizes), dtype=bool); ind[s] = True;
|
||||||
params = [p._get_params()[ind[ps]] for p,ps in zip(self.params, self._param_slices_) if numpy.any(p._get_params()[ind[ps]])]
|
params = [p._get_params()[ind[ps]] for p,ps in zip(self.params, self._param_slices_) if numpy.any(p._get_params()[ind[ps]])]
|
||||||
if len(params)==1: return params[0]
|
if len(params)==1: return params[0]
|
||||||
return ParamConcatenation(params)
|
return ParamConcatenation(params)
|
||||||
def __setitem__(self, s, val, update=True):
|
def __setitem__(self, s, val, update=True):
|
||||||
ind = numpy.zeros(sum(self._param_sizes), dtype=bool); ind[s] = True;
|
ind = numpy.zeros(sum(self._param_sizes), dtype=bool); ind[s] = True;
|
||||||
vals = self._vals(); vals[s] = val; del val
|
vals = self._vals(); vals[s] = val; del val
|
||||||
[numpy.place(p, ind[ps], vals[ps]) and p._notify_tied_parameters()
|
[numpy.place(p, ind[ps], vals[ps]) and p._notify_tied_parameters()
|
||||||
for p, ps in zip(self.params, self._param_slices_)]
|
for p, ps in zip(self.params, self._param_slices_)]
|
||||||
if update:
|
if update:
|
||||||
self.params[0]._highest_parent_.parameters_changed()
|
self.params[0]._highest_parent_.parameters_changed()
|
||||||
def _vals(self):
|
def _vals(self):
|
||||||
|
|
@ -501,38 +501,55 @@ class ParamConcatenation(object):
|
||||||
#===========================================================================
|
#===========================================================================
|
||||||
# parameter operations:
|
# parameter operations:
|
||||||
#===========================================================================
|
#===========================================================================
|
||||||
|
def update_all_params(self):
|
||||||
|
self.params[0]._highest_parent_.parameters_changed()
|
||||||
|
|
||||||
def constrain(self, constraint, warning=True):
|
def constrain(self, constraint, warning=True):
|
||||||
[param.constrain(constraint) for param in self.params]
|
[param.constrain(constraint, update=False) for param in self.params]
|
||||||
|
self.update_all_params()
|
||||||
constrain.__doc__ = Param.constrain.__doc__
|
constrain.__doc__ = Param.constrain.__doc__
|
||||||
|
|
||||||
def constrain_positive(self, warning=True):
|
def constrain_positive(self, warning=True):
|
||||||
[param.constrain_positive(warning) for param in self.params]
|
[param.constrain_positive(warning, update=False) for param in self.params]
|
||||||
|
self.update_all_params()
|
||||||
constrain_positive.__doc__ = Param.constrain_positive.__doc__
|
constrain_positive.__doc__ = Param.constrain_positive.__doc__
|
||||||
|
|
||||||
def constrain_fixed(self, warning=True):
|
def constrain_fixed(self, warning=True):
|
||||||
[param.constrain_fixed(warning) for param in self.params]
|
[param.constrain_fixed(warning) for param in self.params]
|
||||||
constrain_fixed.__doc__ = Param.constrain_fixed.__doc__
|
constrain_fixed.__doc__ = Param.constrain_fixed.__doc__
|
||||||
fix = constrain_fixed
|
fix = constrain_fixed
|
||||||
|
|
||||||
def constrain_negative(self, warning=True):
|
def constrain_negative(self, warning=True):
|
||||||
[param.constrain_negative(warning) for param in self.params]
|
[param.constrain_negative(warning, update=False) for param in self.params]
|
||||||
|
self.update_all_params()
|
||||||
constrain_negative.__doc__ = Param.constrain_negative.__doc__
|
constrain_negative.__doc__ = Param.constrain_negative.__doc__
|
||||||
|
|
||||||
def constrain_bounded(self, lower, upper, warning=True):
|
def constrain_bounded(self, lower, upper, warning=True):
|
||||||
[param.constrain_bounded(lower, upper, warning) for param in self.params]
|
[param.constrain_bounded(lower, upper, warning, update=False) for param in self.params]
|
||||||
|
self.update_all_params()
|
||||||
constrain_bounded.__doc__ = Param.constrain_bounded.__doc__
|
constrain_bounded.__doc__ = Param.constrain_bounded.__doc__
|
||||||
|
|
||||||
def unconstrain(self, *constraints):
|
def unconstrain(self, *constraints):
|
||||||
[param.unconstrain(*constraints) for param in self.params]
|
[param.unconstrain(*constraints) for param in self.params]
|
||||||
unconstrain.__doc__ = Param.unconstrain.__doc__
|
unconstrain.__doc__ = Param.unconstrain.__doc__
|
||||||
|
|
||||||
def unconstrain_negative(self):
|
def unconstrain_negative(self):
|
||||||
[param.unconstrain_negative() for param in self.params]
|
[param.unconstrain_negative() for param in self.params]
|
||||||
unconstrain_negative.__doc__ = Param.unconstrain_negative.__doc__
|
unconstrain_negative.__doc__ = Param.unconstrain_negative.__doc__
|
||||||
|
|
||||||
def unconstrain_positive(self):
|
def unconstrain_positive(self):
|
||||||
[param.unconstrain_positive() for param in self.params]
|
[param.unconstrain_positive() for param in self.params]
|
||||||
unconstrain_positive.__doc__ = Param.unconstrain_positive.__doc__
|
unconstrain_positive.__doc__ = Param.unconstrain_positive.__doc__
|
||||||
|
|
||||||
def unconstrain_fixed(self):
|
def unconstrain_fixed(self):
|
||||||
[param.unconstrain_fixed() for param in self.params]
|
[param.unconstrain_fixed() for param in self.params]
|
||||||
unconstrain_fixed.__doc__ = Param.unconstrain_fixed.__doc__
|
unconstrain_fixed.__doc__ = Param.unconstrain_fixed.__doc__
|
||||||
unfix = unconstrain_fixed
|
unfix = unconstrain_fixed
|
||||||
|
|
||||||
def unconstrain_bounded(self, lower, upper):
|
def unconstrain_bounded(self, lower, upper):
|
||||||
[param.unconstrain_bounded(lower, upper) for param in self.params]
|
[param.unconstrain_bounded(lower, upper) for param in self.params]
|
||||||
unconstrain_bounded.__doc__ = Param.unconstrain_bounded.__doc__
|
unconstrain_bounded.__doc__ = Param.unconstrain_bounded.__doc__
|
||||||
|
|
||||||
def untie(self, *ties):
|
def untie(self, *ties):
|
||||||
[param.untie(*ties) for param in self.params]
|
[param.untie(*ties) for param in self.params]
|
||||||
__lt__ = lambda self, val: self._vals()<val
|
__lt__ = lambda self, val: self._vals()<val
|
||||||
|
|
@ -557,9 +574,9 @@ class ParamConcatenation(object):
|
||||||
return "\n{}\n".format(" -"+"- | -".join(['-'*l for l in [li,lx,lc,lt]])).join(strings)
|
return "\n{}\n".format(" -"+"- | -".join(['-'*l for l in [li,lx,lc,lt]])).join(strings)
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return "\n".join(map(repr,self.params))
|
return "\n".join(map(repr,self.params))
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
||||||
|
|
||||||
from GPy.core.parameterized import Parameterized
|
from GPy.core.parameterized import Parameterized
|
||||||
from GPy.core.parameter import Param
|
from GPy.core.parameter import Param
|
||||||
|
|
@ -570,16 +587,16 @@ if __name__ == '__main__':
|
||||||
p = Param("q_mean", X)
|
p = Param("q_mean", X)
|
||||||
p1 = Param("q_variance", numpy.random.rand(*p.shape))
|
p1 = Param("q_variance", numpy.random.rand(*p.shape))
|
||||||
p2 = Param("Y", numpy.random.randn(p.shape[0],1))
|
p2 = Param("Y", numpy.random.randn(p.shape[0],1))
|
||||||
|
|
||||||
p3 = Param("variance", numpy.random.rand())
|
p3 = Param("variance", numpy.random.rand())
|
||||||
p4 = Param("lengthscale", numpy.random.rand(2))
|
p4 = Param("lengthscale", numpy.random.rand(2))
|
||||||
|
|
||||||
m = Parameterized()
|
m = Parameterized()
|
||||||
rbf = Parameterized(name='rbf')
|
rbf = Parameterized(name='rbf')
|
||||||
|
|
||||||
rbf.add_parameter(p3,p4)
|
rbf.add_parameter(p3,p4)
|
||||||
m.add_parameter(p,p1,rbf)
|
m.add_parameter(p,p1,rbf)
|
||||||
|
|
||||||
print "setting params"
|
print "setting params"
|
||||||
#print m.q_v[3:5,[1,4,5]]
|
#print m.q_v[3:5,[1,4,5]]
|
||||||
print "constraining variance"
|
print "constraining variance"
|
||||||
|
|
|
||||||
|
|
@ -20,7 +20,7 @@ class Observable(object):
|
||||||
def _notify_observers(self):
|
def _notify_observers(self):
|
||||||
[callble(self) for callble in self._observers_.itervalues()]
|
[callble(self) for callble in self._observers_.itervalues()]
|
||||||
|
|
||||||
|
|
||||||
class Pickleable(object):
|
class Pickleable(object):
|
||||||
def _getstate(self):
|
def _getstate(self):
|
||||||
"""
|
"""
|
||||||
|
|
@ -36,9 +36,9 @@ class Pickleable(object):
|
||||||
Set the state (memento pattern) of this class to the given state.
|
Set the state (memento pattern) of this class to the given state.
|
||||||
Usually this is just the counterpart to _getstate, such that
|
Usually this is just the counterpart to _getstate, such that
|
||||||
an object is a copy of another when calling
|
an object is a copy of another when calling
|
||||||
|
|
||||||
copy = <classname>.__new__(*args,**kw)._setstate(<to_be_copied>._getstate())
|
copy = <classname>.__new__(*args,**kw)._setstate(<to_be_copied>._getstate())
|
||||||
|
|
||||||
See python doc "pickling" (`__getstate__` and `__setstate__`) for details.
|
See python doc "pickling" (`__getstate__` and `__setstate__`) for details.
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError, "To be able to use pickling you need to implement this method"
|
raise NotImplementedError, "To be able to use pickling you need to implement this method"
|
||||||
|
|
@ -49,14 +49,14 @@ class Pickleable(object):
|
||||||
|
|
||||||
class Parentable(object):
|
class Parentable(object):
|
||||||
def __init__(self, direct_parent=None, highest_parent=None, parent_index=None):
|
def __init__(self, direct_parent=None, highest_parent=None, parent_index=None):
|
||||||
super(Parentable,self).__init__()
|
super(Parentable,self).__init__()
|
||||||
self._direct_parent_ = direct_parent
|
self._direct_parent_ = direct_parent
|
||||||
self._parent_index_ = parent_index
|
self._parent_index_ = parent_index
|
||||||
self._highest_parent_ = highest_parent
|
self._highest_parent_ = highest_parent
|
||||||
|
|
||||||
def has_parent(self):
|
def has_parent(self):
|
||||||
return self._direct_parent_ is not None and self._highest_parent_ is not None
|
return self._direct_parent_ is not None and self._highest_parent_ is not None
|
||||||
|
|
||||||
class Nameable(Parentable):
|
class Nameable(Parentable):
|
||||||
_name = None
|
_name = None
|
||||||
def __init__(self, name, direct_parent=None, highest_parent=None, parent_index=None):
|
def __init__(self, name, direct_parent=None, highest_parent=None, parent_index=None):
|
||||||
|
|
@ -69,10 +69,10 @@ class Nameable(Parentable):
|
||||||
@name.setter
|
@name.setter
|
||||||
def name(self, name):
|
def name(self, name):
|
||||||
from_name = self.name
|
from_name = self.name
|
||||||
self._name = name
|
self._name = name
|
||||||
if self.has_parent():
|
if self.has_parent():
|
||||||
self._direct_parent_._name_changed(self, from_name)
|
self._direct_parent_._name_changed(self, from_name)
|
||||||
|
|
||||||
class Constrainable(Nameable):
|
class Constrainable(Nameable):
|
||||||
def __init__(self, name):
|
def __init__(self, name):
|
||||||
super(Constrainable,self).__init__(name)
|
super(Constrainable,self).__init__(name)
|
||||||
|
|
@ -84,7 +84,7 @@ class Constrainable(Nameable):
|
||||||
:param transform: the :py:class:`GPy.core.transformations.Transformation`
|
:param transform: the :py:class:`GPy.core.transformations.Transformation`
|
||||||
to constrain the this parameter to.
|
to constrain the this parameter to.
|
||||||
:param warning: print a warning if re-constraining parameters.
|
:param warning: print a warning if re-constraining parameters.
|
||||||
|
|
||||||
Constrain the parameter to the given
|
Constrain the parameter to the given
|
||||||
:py:class:`GPy.core.transformations.Transformation`.
|
:py:class:`GPy.core.transformations.Transformation`.
|
||||||
"""
|
"""
|
||||||
|
|
@ -97,37 +97,37 @@ class Constrainable(Nameable):
|
||||||
self._add_constrain(p, transform, warning)
|
self._add_constrain(p, transform, warning)
|
||||||
if update:
|
if update:
|
||||||
self.parameters_changed()
|
self.parameters_changed()
|
||||||
|
|
||||||
def constrain_positive(self, warning=True):
|
def constrain_positive(self, warning=True, update=True):
|
||||||
"""
|
"""
|
||||||
:param warning: print a warning if re-constraining parameters.
|
:param warning: print a warning if re-constraining parameters.
|
||||||
|
|
||||||
Constrain this parameter to the default positive constraint.
|
Constrain this parameter to the default positive constraint.
|
||||||
"""
|
"""
|
||||||
self.constrain(Logexp(), warning)
|
self.constrain(Logexp(), warning=warning, update=update)
|
||||||
|
|
||||||
def constrain_negative(self, warning=True):
|
def constrain_negative(self, warning=True, update=True):
|
||||||
"""
|
"""
|
||||||
:param warning: print a warning if re-constraining parameters.
|
:param warning: print a warning if re-constraining parameters.
|
||||||
|
|
||||||
Constrain this parameter to the default negative constraint.
|
Constrain this parameter to the default negative constraint.
|
||||||
"""
|
"""
|
||||||
self.constrain(NegativeLogexp(), warning)
|
self.constrain(NegativeLogexp(), warning=warning, update=update)
|
||||||
|
|
||||||
def constrain_bounded(self, lower, upper, warning=True):
|
def constrain_bounded(self, lower, upper, warning=True, update=True):
|
||||||
"""
|
"""
|
||||||
:param lower, upper: the limits to bound this parameter to
|
:param lower, upper: the limits to bound this parameter to
|
||||||
:param warning: print a warning if re-constraining parameters.
|
:param warning: print a warning if re-constraining parameters.
|
||||||
|
|
||||||
Constrain this parameter to lie within the given range.
|
Constrain this parameter to lie within the given range.
|
||||||
"""
|
"""
|
||||||
self.constrain(Logistic(lower, upper), warning)
|
self.constrain(Logistic(lower, upper), warning=warning, update=update)
|
||||||
|
|
||||||
def unconstrain(self, *transforms):
|
def unconstrain(self, *transforms):
|
||||||
"""
|
"""
|
||||||
:param transforms: The transformations to unconstrain from.
|
:param transforms: The transformations to unconstrain from.
|
||||||
|
|
||||||
remove all :py:class:`GPy.core.transformations.Transformation`
|
remove all :py:class:`GPy.core.transformations.Transformation`
|
||||||
transformats of this parameter object.
|
transformats of this parameter object.
|
||||||
"""
|
"""
|
||||||
if self.has_parent():
|
if self.has_parent():
|
||||||
|
|
@ -138,20 +138,20 @@ class Constrainable(Nameable):
|
||||||
|
|
||||||
def unconstrain_positive(self):
|
def unconstrain_positive(self):
|
||||||
"""
|
"""
|
||||||
Remove positive constraint of this parameter.
|
Remove positive constraint of this parameter.
|
||||||
"""
|
"""
|
||||||
self.unconstrain(Logexp())
|
self.unconstrain(Logexp())
|
||||||
|
|
||||||
def unconstrain_negative(self):
|
def unconstrain_negative(self):
|
||||||
"""
|
"""
|
||||||
Remove negative constraint of this parameter.
|
Remove negative constraint of this parameter.
|
||||||
"""
|
"""
|
||||||
self.unconstrain(NegativeLogexp())
|
self.unconstrain(NegativeLogexp())
|
||||||
|
|
||||||
def unconstrain_bounded(self, lower, upper):
|
def unconstrain_bounded(self, lower, upper):
|
||||||
"""
|
"""
|
||||||
:param lower, upper: the limits to unbound this parameter from
|
:param lower, upper: the limits to unbound this parameter from
|
||||||
|
|
||||||
Remove (lower, upper) bounded constrain from this parameter/
|
Remove (lower, upper) bounded constrain from this parameter/
|
||||||
"""
|
"""
|
||||||
self.unconstrain(Logistic(lower, upper))
|
self.unconstrain(Logistic(lower, upper))
|
||||||
|
|
|
||||||
|
|
@ -233,7 +233,7 @@ class Parameterized(Constrainable, Pickleable, Observable):
|
||||||
elif not (pname in not_unique):
|
elif not (pname in not_unique):
|
||||||
self.__dict__[pname] = p
|
self.__dict__[pname] = p
|
||||||
self._added_names_.add(pname)
|
self._added_names_.add(pname)
|
||||||
|
|
||||||
def _connect_highest_parent(self, highest_parent):
|
def _connect_highest_parent(self, highest_parent):
|
||||||
self._highest_parent_ = highest_parent
|
self._highest_parent_ = highest_parent
|
||||||
if not hasattr(self, "_parameters_") or len(self._parameters_) < 1:
|
if not hasattr(self, "_parameters_") or len(self._parameters_) < 1:
|
||||||
|
|
@ -242,7 +242,7 @@ class Parameterized(Constrainable, Pickleable, Observable):
|
||||||
for p in self._parameters_:
|
for p in self._parameters_:
|
||||||
p._highest_parent_ = highest_parent
|
p._highest_parent_ = highest_parent
|
||||||
p._connect_highest_parent(highest_parent)
|
p._connect_highest_parent(highest_parent)
|
||||||
|
|
||||||
#===========================================================================
|
#===========================================================================
|
||||||
# Pickling operations
|
# Pickling operations
|
||||||
#===========================================================================
|
#===========================================================================
|
||||||
|
|
@ -413,10 +413,10 @@ class Parameterized(Constrainable, Pickleable, Observable):
|
||||||
#===========================================================================
|
#===========================================================================
|
||||||
# Fixing parameters:
|
# Fixing parameters:
|
||||||
#===========================================================================
|
#===========================================================================
|
||||||
def _fix(self, param, warning=True):
|
def _fix(self, param, warning=True, update=True):
|
||||||
f = self._add_constrain(param, __fixed__, warning)
|
f = self._add_constrain(param, __fixed__, warning)
|
||||||
self._set_fixed(f)
|
self._set_fixed(f)
|
||||||
def _unfix(self, param):
|
def _unfix(self, param, update=True):
|
||||||
if self._has_fixes():
|
if self._has_fixes():
|
||||||
f = self._remove_constrain(param, __fixed__)
|
f = self._remove_constrain(param, __fixed__)
|
||||||
self._set_unfixed(f)
|
self._set_unfixed(f)
|
||||||
|
|
@ -438,7 +438,8 @@ class Parameterized(Constrainable, Pickleable, Observable):
|
||||||
# if advanced indexing is activated it happens that the array is a copy
|
# if advanced indexing is activated it happens that the array is a copy
|
||||||
# you can retrieve the original param through this method, by passing
|
# you can retrieve the original param through this method, by passing
|
||||||
# the copy here
|
# the copy here
|
||||||
return self._parameters_[param._parent_index_]
|
#return self._parameters_[param._parent_index_]
|
||||||
|
return param._direct_parent_._parameters_[param._parent_index_]
|
||||||
def hirarchy_name(self):
|
def hirarchy_name(self):
|
||||||
if self.has_parent():
|
if self.has_parent():
|
||||||
return self._direct_parent_.hirarchy_name() + adjust_name_for_printing(self.name) + "."
|
return self._direct_parent_.hirarchy_name() + adjust_name_for_printing(self.name) + "."
|
||||||
|
|
@ -452,7 +453,7 @@ class Parameterized(Constrainable, Pickleable, Observable):
|
||||||
# if removing constraints before adding new is not wanted, just delete the above line!
|
# if removing constraints before adding new is not wanted, just delete the above line!
|
||||||
self.constraints.add(transform, rav_i)
|
self.constraints.add(transform, rav_i)
|
||||||
param = self._get_original(param)
|
param = self._get_original(param)
|
||||||
param._set_params(transform.initialize(param._get_params()))
|
param._set_params(transform.initialize(param._get_params()), update=False)
|
||||||
if warning and any(reconstrained):
|
if warning and any(reconstrained):
|
||||||
# if you want to print the whole params object, which was reconstrained use:
|
# if you want to print the whole params object, which was reconstrained use:
|
||||||
# m = str(param[self._backtranslate_index(param, reconstrained)])
|
# m = str(param[self._backtranslate_index(param, reconstrained)])
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue