Fixed parameterized oddity where it was updating all constrained parameters as soon as any were constrained rather than after all are constrained@

@
This commit is contained in:
Alan Saul 2014-02-06 16:22:08 +00:00
parent 1c9151a7d0
commit b12fb6a2a8
3 changed files with 87 additions and 69 deletions

View file

@ -12,32 +12,32 @@ __index_name__ = "Index"
__tie_name__ = "Tied to" __tie_name__ = "Tied to"
__precision__ = numpy.get_printoptions()['precision'] # numpy printing precision used, sublassing numpy ndarray after all __precision__ = numpy.get_printoptions()['precision'] # numpy printing precision used, sublassing numpy ndarray after all
__print_threshold__ = 5 __print_threshold__ = 5
###### ######
class Float(numpy.float64, Constrainable): class Float(numpy.float64, Constrainable):
def __init__(self, f, base): def __init__(self, f, base):
super(Float,self).__init__(f) super(Float,self).__init__(f)
self._base = base self._base = base
class Param(ObservableArray, Constrainable): class Param(ObservableArray, Constrainable):
""" """
Parameter object for GPy models. Parameter object for GPy models.
:param name: name of the parameter to be printed :param name: name of the parameter to be printed
:param input_array: array which this parameter handles :param input_array: array which this parameter handles
You can add/remove constraints by calling constrain on the parameter itself, e.g: You can add/remove constraints by calling constrain on the parameter itself, e.g:
- self[:,1].constrain_positive() - self[:,1].constrain_positive()
- self[0].tie_to(other) - self[0].tie_to(other)
- self.untie() - self.untie()
- self[:3,:].unconstrain() - self[:3,:].unconstrain()
- self[1].fix() - self[1].fix()
Fixing parameters will fix them to the value they are right now. If you change Fixing parameters will fix them to the value they are right now. If you change
the fixed value, it will be fixed to the new value! the fixed value, it will be fixed to the new value!
See :py:class:`GPy.core.parameterized.Parameterized` for more details on constraining etc. See :py:class:`GPy.core.parameterized.Parameterized` for more details on constraining etc.
This ndarray can be stored in lists and checked if it is in. This ndarray can be stored in lists and checked if it is in.
@ -46,7 +46,7 @@ class Param(ObservableArray, Constrainable):
>>> x = np.random.normal(size=(10,3)) >>> x = np.random.normal(size=(10,3))
>>> x in [[1], x, [3]] >>> x in [[1], x, [3]]
True True
WARNING: This overrides the functionality of x==y!!! WARNING: This overrides the functionality of x==y!!!
Use numpy.equal(x,y) for element-wise equality testing. Use numpy.equal(x,y) for element-wise equality testing.
""" """
@ -68,7 +68,7 @@ class Param(ObservableArray, Constrainable):
def __init__(self, name, input_array): def __init__(self, name, input_array):
super(Param, self).__init__(name=name) super(Param, self).__init__(name=name)
def __array_finalize__(self, obj): def __array_finalize__(self, obj):
# see InfoArray.__array_finalize__ for comments # see InfoArray.__array_finalize__ for comments
if obj is None: return if obj is None: return
@ -86,7 +86,7 @@ class Param(ObservableArray, Constrainable):
self._original_ = getattr(obj, '_original_', None) self._original_ = getattr(obj, '_original_', None)
self._name = getattr(obj, 'name', None) self._name = getattr(obj, 'name', None)
self.gradient = getattr(obj, 'gradient', None) self.gradient = getattr(obj, 'gradient', None)
def __array_wrap__(self, out_arr, context=None): def __array_wrap__(self, out_arr, context=None):
return out_arr.view(numpy.ndarray) return out_arr.view(numpy.ndarray)
#=========================================================================== #===========================================================================
@ -94,7 +94,7 @@ class Param(ObservableArray, Constrainable):
#=========================================================================== #===========================================================================
def __reduce_ex__(self): def __reduce_ex__(self):
func, args, state = super(Param, self).__reduce__() func, args, state = super(Param, self).__reduce__()
return func, args, (state, return func, args, (state,
(self.name, (self.name,
self._direct_parent_, self._direct_parent_,
self._parent_index_, self._parent_index_,
@ -132,13 +132,13 @@ class Param(ObservableArray, Constrainable):
self.flat = param self.flat = param
self._notify_tied_parameters() self._notify_tied_parameters()
self._notify_observers() self._notify_observers()
def _get_params(self): def _get_params(self):
return self.flat return self.flat
# @property # @property
# def name(self): # def name(self):
# """ # """
# Name of this parameter. # Name of this parameter.
# This can be a callable without parameters. The callable will be called # This can be a callable without parameters. The callable will be called
# every time the name property is accessed. # every time the name property is accessed.
# """ # """
@ -163,7 +163,7 @@ class Param(ObservableArray, Constrainable):
def constrain_fixed(self, warning=True): def constrain_fixed(self, warning=True):
""" """
Constrain this paramter to be fixed to the current value it carries. Constrain this paramter to be fixed to the current value it carries.
:param warning: print a warning for overwriting constraints. :param warning: print a warning for overwriting constraints.
""" """
self._highest_parent_._fix(self,warning) self._highest_parent_._fix(self,warning)
@ -179,18 +179,18 @@ class Param(ObservableArray, Constrainable):
#=========================================================================== #===========================================================================
def tie_to(self, param): def tie_to(self, param):
""" """
:param param: the parameter object to tie this parameter to. :param param: the parameter object to tie this parameter to.
Can be ParamConcatenation (retrieved by regexp search) Can be ParamConcatenation (retrieved by regexp search)
Tie this parameter to the given parameter. Tie this parameter to the given parameter.
Broadcasting is not allowed, but you can tie a whole dimension to Broadcasting is not allowed, but you can tie a whole dimension to
one parameter: self[:,0].tie_to(other), where other is a one-value one parameter: self[:,0].tie_to(other), where other is a one-value
parameter. parameter.
Note: For now only one parameter can have ties, so all of a parameter Note: For now only one parameter can have ties, so all of a parameter
will be removed, when re-tieing! will be removed, when re-tieing!
""" """
#Note: this method will tie to the parameter which is the last in #Note: this method will tie to the parameter which is the last in
# the chain of ties. Thus, if you tie to a tied parameter, # the chain of ties. Thus, if you tie to a tied parameter,
# this tie will be created to the parameter the param is tied # this tie will be created to the parameter the param is tied
# to. # to.
@ -200,12 +200,12 @@ class Param(ObservableArray, Constrainable):
if param.size != 1: if param.size != 1:
raise NotImplementedError, "Broadcast tying is not implemented yet" raise NotImplementedError, "Broadcast tying is not implemented yet"
try: try:
if self._original_: if self._original_:
self[:] = param self[:] = param
else: # this happens when indexing created a copy of the array else: # this happens when indexing created a copy of the array
self._direct_parent_._get_original(self)[self._current_slice_] = param self._direct_parent_._get_original(self)[self._current_slice_] = param
except ValueError: except ValueError:
raise ValueError("Trying to tie {} with shape {} to {} with shape {}".format(self.name, self.shape, param.name, param.shape)) raise ValueError("Trying to tie {} with shape {} to {} with shape {}".format(self.name, self.shape, param.name, param.shape))
if param is self: if param is self:
raise RuntimeError, 'Cyclic tieing is not allowed' raise RuntimeError, 'Cyclic tieing is not allowed'
# if len(param._tied_to_) > 0: # if len(param._tied_to_) > 0:
@ -293,7 +293,7 @@ class Param(ObservableArray, Constrainable):
def unset_prior(self, *priors): def unset_prior(self, *priors):
""" """
:param priors: priors to remove from this parameter :param priors: priors to remove from this parameter
Remove all priors from this parameter Remove all priors from this parameter
""" """
self._highest_parent_._remove_prior(self, *priors) self._highest_parent_._remove_prior(self, *priors)
@ -324,7 +324,7 @@ class Param(ObservableArray, Constrainable):
if numpy.all(si == Ellipsis): if numpy.all(si == Ellipsis):
continue continue
if isinstance(si, slice): if isinstance(si, slice):
a = si.indices(self._realshape_[i])[0] a = si.indices(self._realshape_[i])[0]
elif isinstance(si, (list,numpy.ndarray,tuple)): elif isinstance(si, (list,numpy.ndarray,tuple)):
a = si[0] a = si[0]
else: a = si else: a = si
@ -424,7 +424,7 @@ class Param(ObservableArray, Constrainable):
slice_index = self._current_slice_ slice_index = self._current_slice_
if isinstance(slice_index, (tuple, list)): if isinstance(slice_index, (tuple, list)):
clean_curr_slice = [s for s in slice_index if numpy.any(s != Ellipsis)] clean_curr_slice = [s for s in slice_index if numpy.any(s != Ellipsis)]
if (all(isinstance(n, (numpy.ndarray, list, tuple)) for n in clean_curr_slice) if (all(isinstance(n, (numpy.ndarray, list, tuple)) for n in clean_curr_slice)
and len(set(map(len,clean_curr_slice))) <= 1): and len(set(map(len,clean_curr_slice))) <= 1):
return numpy.fromiter(itertools.izip(*clean_curr_slice), return numpy.fromiter(itertools.izip(*clean_curr_slice),
dtype=[('',int)]*self._realndim_,count=len(clean_curr_slice[0])).view((int, self._realndim_)) dtype=[('',int)]*self._realndim_,count=len(clean_curr_slice[0])).view((int, self._realndim_))
@ -443,7 +443,7 @@ class Param(ObservableArray, Constrainable):
if self._realsize_ < 2: if self._realsize_ < 2:
return name return name
ind = self._indices() ind = self._indices()
if ind.size > 4: indstr = ','.join(map(str,ind[:2])) + "..." + ','.join(map(str,ind[-2:])) if ind.size > 4: indstr = ','.join(map(str,ind[:2])) + "..." + ','.join(map(str,ind[-2:]))
else: indstr = ','.join(map(str,ind)) else: indstr = ','.join(map(str,ind))
return name+'['+indstr+']' return name+'['+indstr+']'
def __str__(self, constr_matrix=None, indices=None, ties=None, lc=None, lx=None, li=None, lt=None): def __str__(self, constr_matrix=None, indices=None, ties=None, lc=None, lx=None, li=None, lt=None):
@ -477,7 +477,7 @@ class ParamConcatenation(object):
for p in params: for p in params:
for p in p.flattened_parameters: for p in p.flattened_parameters:
if p not in self.params: if p not in self.params:
self.params.append(p) self.params.append(p)
self._param_sizes = [p.size for p in self.params] self._param_sizes = [p.size for p in self.params]
startstops = numpy.cumsum([0] + self._param_sizes) startstops = numpy.cumsum([0] + self._param_sizes)
self._param_slices_ = [slice(start, stop) for start,stop in zip(startstops, startstops[1:])] self._param_slices_ = [slice(start, stop) for start,stop in zip(startstops, startstops[1:])]
@ -485,15 +485,15 @@ class ParamConcatenation(object):
# Get/set items, enable broadcasting # Get/set items, enable broadcasting
#=========================================================================== #===========================================================================
def __getitem__(self, s): def __getitem__(self, s):
ind = numpy.zeros(sum(self._param_sizes), dtype=bool); ind[s] = True; ind = numpy.zeros(sum(self._param_sizes), dtype=bool); ind[s] = True;
params = [p._get_params()[ind[ps]] for p,ps in zip(self.params, self._param_slices_) if numpy.any(p._get_params()[ind[ps]])] params = [p._get_params()[ind[ps]] for p,ps in zip(self.params, self._param_slices_) if numpy.any(p._get_params()[ind[ps]])]
if len(params)==1: return params[0] if len(params)==1: return params[0]
return ParamConcatenation(params) return ParamConcatenation(params)
def __setitem__(self, s, val, update=True): def __setitem__(self, s, val, update=True):
ind = numpy.zeros(sum(self._param_sizes), dtype=bool); ind[s] = True; ind = numpy.zeros(sum(self._param_sizes), dtype=bool); ind[s] = True;
vals = self._vals(); vals[s] = val; del val vals = self._vals(); vals[s] = val; del val
[numpy.place(p, ind[ps], vals[ps]) and p._notify_tied_parameters() [numpy.place(p, ind[ps], vals[ps]) and p._notify_tied_parameters()
for p, ps in zip(self.params, self._param_slices_)] for p, ps in zip(self.params, self._param_slices_)]
if update: if update:
self.params[0]._highest_parent_.parameters_changed() self.params[0]._highest_parent_.parameters_changed()
def _vals(self): def _vals(self):
@ -501,38 +501,55 @@ class ParamConcatenation(object):
#=========================================================================== #===========================================================================
# parameter operations: # parameter operations:
#=========================================================================== #===========================================================================
def update_all_params(self):
self.params[0]._highest_parent_.parameters_changed()
def constrain(self, constraint, warning=True): def constrain(self, constraint, warning=True):
[param.constrain(constraint) for param in self.params] [param.constrain(constraint, update=False) for param in self.params]
self.update_all_params()
constrain.__doc__ = Param.constrain.__doc__ constrain.__doc__ = Param.constrain.__doc__
def constrain_positive(self, warning=True): def constrain_positive(self, warning=True):
[param.constrain_positive(warning) for param in self.params] [param.constrain_positive(warning, update=False) for param in self.params]
self.update_all_params()
constrain_positive.__doc__ = Param.constrain_positive.__doc__ constrain_positive.__doc__ = Param.constrain_positive.__doc__
def constrain_fixed(self, warning=True): def constrain_fixed(self, warning=True):
[param.constrain_fixed(warning) for param in self.params] [param.constrain_fixed(warning) for param in self.params]
constrain_fixed.__doc__ = Param.constrain_fixed.__doc__ constrain_fixed.__doc__ = Param.constrain_fixed.__doc__
fix = constrain_fixed fix = constrain_fixed
def constrain_negative(self, warning=True): def constrain_negative(self, warning=True):
[param.constrain_negative(warning) for param in self.params] [param.constrain_negative(warning, update=False) for param in self.params]
self.update_all_params()
constrain_negative.__doc__ = Param.constrain_negative.__doc__ constrain_negative.__doc__ = Param.constrain_negative.__doc__
def constrain_bounded(self, lower, upper, warning=True): def constrain_bounded(self, lower, upper, warning=True):
[param.constrain_bounded(lower, upper, warning) for param in self.params] [param.constrain_bounded(lower, upper, warning, update=False) for param in self.params]
self.update_all_params()
constrain_bounded.__doc__ = Param.constrain_bounded.__doc__ constrain_bounded.__doc__ = Param.constrain_bounded.__doc__
def unconstrain(self, *constraints): def unconstrain(self, *constraints):
[param.unconstrain(*constraints) for param in self.params] [param.unconstrain(*constraints) for param in self.params]
unconstrain.__doc__ = Param.unconstrain.__doc__ unconstrain.__doc__ = Param.unconstrain.__doc__
def unconstrain_negative(self): def unconstrain_negative(self):
[param.unconstrain_negative() for param in self.params] [param.unconstrain_negative() for param in self.params]
unconstrain_negative.__doc__ = Param.unconstrain_negative.__doc__ unconstrain_negative.__doc__ = Param.unconstrain_negative.__doc__
def unconstrain_positive(self): def unconstrain_positive(self):
[param.unconstrain_positive() for param in self.params] [param.unconstrain_positive() for param in self.params]
unconstrain_positive.__doc__ = Param.unconstrain_positive.__doc__ unconstrain_positive.__doc__ = Param.unconstrain_positive.__doc__
def unconstrain_fixed(self): def unconstrain_fixed(self):
[param.unconstrain_fixed() for param in self.params] [param.unconstrain_fixed() for param in self.params]
unconstrain_fixed.__doc__ = Param.unconstrain_fixed.__doc__ unconstrain_fixed.__doc__ = Param.unconstrain_fixed.__doc__
unfix = unconstrain_fixed unfix = unconstrain_fixed
def unconstrain_bounded(self, lower, upper): def unconstrain_bounded(self, lower, upper):
[param.unconstrain_bounded(lower, upper) for param in self.params] [param.unconstrain_bounded(lower, upper) for param in self.params]
unconstrain_bounded.__doc__ = Param.unconstrain_bounded.__doc__ unconstrain_bounded.__doc__ = Param.unconstrain_bounded.__doc__
def untie(self, *ties): def untie(self, *ties):
[param.untie(*ties) for param in self.params] [param.untie(*ties) for param in self.params]
__lt__ = lambda self, val: self._vals()<val __lt__ = lambda self, val: self._vals()<val
@ -557,9 +574,9 @@ class ParamConcatenation(object):
return "\n{}\n".format(" -"+"- | -".join(['-'*l for l in [li,lx,lc,lt]])).join(strings) return "\n{}\n".format(" -"+"- | -".join(['-'*l for l in [li,lx,lc,lt]])).join(strings)
def __repr__(self): def __repr__(self):
return "\n".join(map(repr,self.params)) return "\n".join(map(repr,self.params))
if __name__ == '__main__': if __name__ == '__main__':
from GPy.core.parameterized import Parameterized from GPy.core.parameterized import Parameterized
from GPy.core.parameter import Param from GPy.core.parameter import Param
@ -570,16 +587,16 @@ if __name__ == '__main__':
p = Param("q_mean", X) p = Param("q_mean", X)
p1 = Param("q_variance", numpy.random.rand(*p.shape)) p1 = Param("q_variance", numpy.random.rand(*p.shape))
p2 = Param("Y", numpy.random.randn(p.shape[0],1)) p2 = Param("Y", numpy.random.randn(p.shape[0],1))
p3 = Param("variance", numpy.random.rand()) p3 = Param("variance", numpy.random.rand())
p4 = Param("lengthscale", numpy.random.rand(2)) p4 = Param("lengthscale", numpy.random.rand(2))
m = Parameterized() m = Parameterized()
rbf = Parameterized(name='rbf') rbf = Parameterized(name='rbf')
rbf.add_parameter(p3,p4) rbf.add_parameter(p3,p4)
m.add_parameter(p,p1,rbf) m.add_parameter(p,p1,rbf)
print "setting params" print "setting params"
#print m.q_v[3:5,[1,4,5]] #print m.q_v[3:5,[1,4,5]]
print "constraining variance" print "constraining variance"

View file

@ -20,7 +20,7 @@ class Observable(object):
def _notify_observers(self): def _notify_observers(self):
[callble(self) for callble in self._observers_.itervalues()] [callble(self) for callble in self._observers_.itervalues()]
class Pickleable(object): class Pickleable(object):
def _getstate(self): def _getstate(self):
""" """
@ -36,9 +36,9 @@ class Pickleable(object):
Set the state (memento pattern) of this class to the given state. Set the state (memento pattern) of this class to the given state.
Usually this is just the counterpart to _getstate, such that Usually this is just the counterpart to _getstate, such that
an object is a copy of another when calling an object is a copy of another when calling
copy = <classname>.__new__(*args,**kw)._setstate(<to_be_copied>._getstate()) copy = <classname>.__new__(*args,**kw)._setstate(<to_be_copied>._getstate())
See python doc "pickling" (`__getstate__` and `__setstate__`) for details. See python doc "pickling" (`__getstate__` and `__setstate__`) for details.
""" """
raise NotImplementedError, "To be able to use pickling you need to implement this method" raise NotImplementedError, "To be able to use pickling you need to implement this method"
@ -49,14 +49,14 @@ class Pickleable(object):
class Parentable(object): class Parentable(object):
def __init__(self, direct_parent=None, highest_parent=None, parent_index=None): def __init__(self, direct_parent=None, highest_parent=None, parent_index=None):
super(Parentable,self).__init__() super(Parentable,self).__init__()
self._direct_parent_ = direct_parent self._direct_parent_ = direct_parent
self._parent_index_ = parent_index self._parent_index_ = parent_index
self._highest_parent_ = highest_parent self._highest_parent_ = highest_parent
def has_parent(self): def has_parent(self):
return self._direct_parent_ is not None and self._highest_parent_ is not None return self._direct_parent_ is not None and self._highest_parent_ is not None
class Nameable(Parentable): class Nameable(Parentable):
_name = None _name = None
def __init__(self, name, direct_parent=None, highest_parent=None, parent_index=None): def __init__(self, name, direct_parent=None, highest_parent=None, parent_index=None):
@ -69,10 +69,10 @@ class Nameable(Parentable):
@name.setter @name.setter
def name(self, name): def name(self, name):
from_name = self.name from_name = self.name
self._name = name self._name = name
if self.has_parent(): if self.has_parent():
self._direct_parent_._name_changed(self, from_name) self._direct_parent_._name_changed(self, from_name)
class Constrainable(Nameable): class Constrainable(Nameable):
def __init__(self, name): def __init__(self, name):
super(Constrainable,self).__init__(name) super(Constrainable,self).__init__(name)
@ -84,7 +84,7 @@ class Constrainable(Nameable):
:param transform: the :py:class:`GPy.core.transformations.Transformation` :param transform: the :py:class:`GPy.core.transformations.Transformation`
to constrain the this parameter to. to constrain the this parameter to.
:param warning: print a warning if re-constraining parameters. :param warning: print a warning if re-constraining parameters.
Constrain the parameter to the given Constrain the parameter to the given
:py:class:`GPy.core.transformations.Transformation`. :py:class:`GPy.core.transformations.Transformation`.
""" """
@ -97,37 +97,37 @@ class Constrainable(Nameable):
self._add_constrain(p, transform, warning) self._add_constrain(p, transform, warning)
if update: if update:
self.parameters_changed() self.parameters_changed()
def constrain_positive(self, warning=True): def constrain_positive(self, warning=True, update=True):
""" """
:param warning: print a warning if re-constraining parameters. :param warning: print a warning if re-constraining parameters.
Constrain this parameter to the default positive constraint. Constrain this parameter to the default positive constraint.
""" """
self.constrain(Logexp(), warning) self.constrain(Logexp(), warning=warning, update=update)
def constrain_negative(self, warning=True): def constrain_negative(self, warning=True, update=True):
""" """
:param warning: print a warning if re-constraining parameters. :param warning: print a warning if re-constraining parameters.
Constrain this parameter to the default negative constraint. Constrain this parameter to the default negative constraint.
""" """
self.constrain(NegativeLogexp(), warning) self.constrain(NegativeLogexp(), warning=warning, update=update)
def constrain_bounded(self, lower, upper, warning=True): def constrain_bounded(self, lower, upper, warning=True, update=True):
""" """
:param lower, upper: the limits to bound this parameter to :param lower, upper: the limits to bound this parameter to
:param warning: print a warning if re-constraining parameters. :param warning: print a warning if re-constraining parameters.
Constrain this parameter to lie within the given range. Constrain this parameter to lie within the given range.
""" """
self.constrain(Logistic(lower, upper), warning) self.constrain(Logistic(lower, upper), warning=warning, update=update)
def unconstrain(self, *transforms): def unconstrain(self, *transforms):
""" """
:param transforms: The transformations to unconstrain from. :param transforms: The transformations to unconstrain from.
remove all :py:class:`GPy.core.transformations.Transformation` remove all :py:class:`GPy.core.transformations.Transformation`
transformats of this parameter object. transformats of this parameter object.
""" """
if self.has_parent(): if self.has_parent():
@ -138,20 +138,20 @@ class Constrainable(Nameable):
def unconstrain_positive(self): def unconstrain_positive(self):
""" """
Remove positive constraint of this parameter. Remove positive constraint of this parameter.
""" """
self.unconstrain(Logexp()) self.unconstrain(Logexp())
def unconstrain_negative(self): def unconstrain_negative(self):
""" """
Remove negative constraint of this parameter. Remove negative constraint of this parameter.
""" """
self.unconstrain(NegativeLogexp()) self.unconstrain(NegativeLogexp())
def unconstrain_bounded(self, lower, upper): def unconstrain_bounded(self, lower, upper):
""" """
:param lower, upper: the limits to unbound this parameter from :param lower, upper: the limits to unbound this parameter from
Remove (lower, upper) bounded constrain from this parameter/ Remove (lower, upper) bounded constrain from this parameter/
""" """
self.unconstrain(Logistic(lower, upper)) self.unconstrain(Logistic(lower, upper))

View file

@ -233,7 +233,7 @@ class Parameterized(Constrainable, Pickleable, Observable):
elif not (pname in not_unique): elif not (pname in not_unique):
self.__dict__[pname] = p self.__dict__[pname] = p
self._added_names_.add(pname) self._added_names_.add(pname)
def _connect_highest_parent(self, highest_parent): def _connect_highest_parent(self, highest_parent):
self._highest_parent_ = highest_parent self._highest_parent_ = highest_parent
if not hasattr(self, "_parameters_") or len(self._parameters_) < 1: if not hasattr(self, "_parameters_") or len(self._parameters_) < 1:
@ -242,7 +242,7 @@ class Parameterized(Constrainable, Pickleable, Observable):
for p in self._parameters_: for p in self._parameters_:
p._highest_parent_ = highest_parent p._highest_parent_ = highest_parent
p._connect_highest_parent(highest_parent) p._connect_highest_parent(highest_parent)
#=========================================================================== #===========================================================================
# Pickling operations # Pickling operations
#=========================================================================== #===========================================================================
@ -413,10 +413,10 @@ class Parameterized(Constrainable, Pickleable, Observable):
#=========================================================================== #===========================================================================
# Fixing parameters: # Fixing parameters:
#=========================================================================== #===========================================================================
def _fix(self, param, warning=True): def _fix(self, param, warning=True, update=True):
f = self._add_constrain(param, __fixed__, warning) f = self._add_constrain(param, __fixed__, warning)
self._set_fixed(f) self._set_fixed(f)
def _unfix(self, param): def _unfix(self, param, update=True):
if self._has_fixes(): if self._has_fixes():
f = self._remove_constrain(param, __fixed__) f = self._remove_constrain(param, __fixed__)
self._set_unfixed(f) self._set_unfixed(f)
@ -438,7 +438,8 @@ class Parameterized(Constrainable, Pickleable, Observable):
# if advanced indexing is activated it happens that the array is a copy # if advanced indexing is activated it happens that the array is a copy
# you can retrieve the original param through this method, by passing # you can retrieve the original param through this method, by passing
# the copy here # the copy here
return self._parameters_[param._parent_index_] #return self._parameters_[param._parent_index_]
return param._direct_parent_._parameters_[param._parent_index_]
def hirarchy_name(self): def hirarchy_name(self):
if self.has_parent(): if self.has_parent():
return self._direct_parent_.hirarchy_name() + adjust_name_for_printing(self.name) + "." return self._direct_parent_.hirarchy_name() + adjust_name_for_printing(self.name) + "."
@ -452,7 +453,7 @@ class Parameterized(Constrainable, Pickleable, Observable):
# if removing constraints before adding new is not wanted, just delete the above line! # if removing constraints before adding new is not wanted, just delete the above line!
self.constraints.add(transform, rav_i) self.constraints.add(transform, rav_i)
param = self._get_original(param) param = self._get_original(param)
param._set_params(transform.initialize(param._get_params())) param._set_params(transform.initialize(param._get_params()), update=False)
if warning and any(reconstrained): if warning and any(reconstrained):
# if you want to print the whole params object, which was reconstrained use: # if you want to print the whole params object, which was reconstrained use:
# m = str(param[self._backtranslate_index(param, reconstrained)]) # m = str(param[self._backtranslate_index(param, reconstrained)])