object without args

This commit is contained in:
Max Zwiessele 2014-03-13 11:51:54 +00:00
parent e9c96632ba
commit b95cc90ffb
2 changed files with 70 additions and 70 deletions

View file

@ -94,15 +94,15 @@ class Param(OptimizationHandlable, ObservableArray):
@property @property
def _param_array_(self): def _param_array_(self):
return self return self
@property @property
def gradient(self): def gradient(self):
return self._gradient_array_[self._current_slice_] return self._gradient_array_[self._current_slice_]
@gradient.setter @gradient.setter
def gradient(self, val): def gradient(self, val):
self.gradient[:] = val self.gradient[:] = val
#=========================================================================== #===========================================================================
# Pickling operations # Pickling operations
#=========================================================================== #===========================================================================
@ -135,7 +135,7 @@ class Param(OptimizationHandlable, ObservableArray):
self._parent_index_ = state.pop() self._parent_index_ = state.pop()
self._parent_ = state.pop() self._parent_ = state.pop()
self.name = state.pop() self.name = state.pop()
def copy(self, *args): def copy(self, *args):
constr = self.constraints.copy() constr = self.constraints.copy()
priors = self.priors.copy() priors = self.priors.copy()
@ -151,13 +151,13 @@ class Param(OptimizationHandlable, ObservableArray):
# if trigger_parent: min_priority = None # if trigger_parent: min_priority = None
# else: min_priority = -numpy.inf # else: min_priority = -numpy.inf
# self.notify_observers(None, min_priority) # self.notify_observers(None, min_priority)
# #
# def _get_params(self): # def _get_params(self):
# return self.flat # return self.flat
# #
# def _collect_gradient(self, target): # def _collect_gradient(self, target):
# target += self.gradient.flat # target += self.gradient.flat
# #
# def _set_gradient(self, g): # def _set_gradient(self, g):
# self.gradient = g.reshape(self._realshape_) # self.gradient = g.reshape(self._realshape_)
@ -173,10 +173,10 @@ class Param(OptimizationHandlable, ObservableArray):
try: new_arr._current_slice_ = s; new_arr._original_ = self.base is new_arr.base try: new_arr._current_slice_ = s; new_arr._original_ = self.base is new_arr.base
except AttributeError: pass # returning 0d array or float, double etc except AttributeError: pass # returning 0d array or float, double etc
return new_arr return new_arr
def __setitem__(self, s, val): def __setitem__(self, s, val):
super(Param, self).__setitem__(s, val) super(Param, self).__setitem__(s, val)
#=========================================================================== #===========================================================================
# Index Operations: # Index Operations:
#=========================================================================== #===========================================================================
@ -195,7 +195,7 @@ class Param(OptimizationHandlable, ObservableArray):
a = self._realshape_[i] + a a = self._realshape_[i] + a
internal_offset += a * extended_realshape[i] internal_offset += a * extended_realshape[i]
return internal_offset return internal_offset
def _raveled_index(self, slice_index=None): def _raveled_index(self, slice_index=None):
# return an index array on the raveled array, which is formed by the current_slice # return an index array on the raveled array, which is formed by the current_slice
# of this object # of this object
@ -203,7 +203,7 @@ class Param(OptimizationHandlable, ObservableArray):
ind = self._indices(slice_index) ind = self._indices(slice_index)
if ind.ndim < 2: ind = ind[:, None] if ind.ndim < 2: ind = ind[:, None]
return numpy.asarray(numpy.apply_along_axis(lambda x: numpy.sum(extended_realshape * x), 1, ind), dtype=int) return numpy.asarray(numpy.apply_along_axis(lambda x: numpy.sum(extended_realshape * x), 1, ind), dtype=int)
def _expand_index(self, slice_index=None): def _expand_index(self, slice_index=None):
# this calculates the full indexing arrays from the slicing objects given by get_item for _real..._ attributes # this calculates the full indexing arrays from the slicing objects given by get_item for _real..._ attributes
# it basically translates slices to their respective index arrays and turns negative indices around # it basically translates slices to their respective index arrays and turns negative indices around
@ -245,7 +245,7 @@ class Param(OptimizationHandlable, ObservableArray):
#=========================================================================== #===========================================================================
@property @property
def _description_str(self): def _description_str(self):
if self.size <= 1: if self.size <= 1:
return [str(self.view(numpy.ndarray)[0])] return [str(self.view(numpy.ndarray)[0])]
else: return [str(self.shape)] else: return [str(self.shape)]
def parameter_names(self, add_self=False, adjust_for_printing=False): def parameter_names(self, add_self=False, adjust_for_printing=False):
@ -356,7 +356,7 @@ class ParamConcatenation(object):
self._param_sizes = [p.size for p in self.params] self._param_sizes = [p.size for p in self.params]
startstops = numpy.cumsum([0] + self._param_sizes) startstops = numpy.cumsum([0] + self._param_sizes)
self._param_slices_ = [slice(start, stop) for start,stop in zip(startstops, startstops[1:])] self._param_slices_ = [slice(start, stop) for start,stop in zip(startstops, startstops[1:])]
parents = dict() parents = dict()
for p in self.params: for p in self.params:
if p.has_parent(): if p.has_parent():
@ -396,7 +396,7 @@ class ParamConcatenation(object):
def update_all_params(self): def update_all_params(self):
for par in self.parents: for par in self.parents:
par.notify_observers(-numpy.inf) par.notify_observers(-numpy.inf)
def constrain(self, constraint, warning=True): def constrain(self, constraint, warning=True):
[param.constrain(constraint, trigger_parent=False) for param in self.params] [param.constrain(constraint, trigger_parent=False) for param in self.params]
self.update_all_params() self.update_all_params()

View file

@ -1,7 +1,7 @@
# Copyright (c) 2012, GPy authors (see AUTHORS.txt). # Copyright (c) 2012, GPy authors (see AUTHORS.txt).
# Licensed under the BSD 3-clause license (see LICENSE.txt) # Licensed under the BSD 3-clause license (see LICENSE.txt)
""" """
Core module for parameterization. Core module for parameterization.
This module implements all parameterization techniques, split up in modular bits. This module implements all parameterization techniques, split up in modular bits.
HierarchyError: HierarchyError:
@ -41,7 +41,7 @@ class Observable(object):
""" """
_updated = True _updated = True
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super(Observable, self).__init__(*args, **kwargs) super(Observable, self).__init__()
self._observer_callables_ = [] self._observer_callables_ = []
def add_observer(self, observer, callble, priority=0): def add_observer(self, observer, callble, priority=0):
@ -61,7 +61,7 @@ class Observable(object):
def notify_observers(self, which=None, min_priority=None): def notify_observers(self, which=None, min_priority=None):
""" """
Notifies all observers. Which is the element, which kicked off this Notifies all observers. Which is the element, which kicked off this
notification loop. notification loop.
NOTE: notifies only observers with priority p > min_priority! NOTE: notifies only observers with priority p > min_priority!
@ -91,11 +91,11 @@ class Observable(object):
class Pickleable(object): class Pickleable(object):
""" """
Make an object pickleable (See python doc 'pickling'). Make an object pickleable (See python doc 'pickling').
This class allows for pickling support by Memento pattern. This class allows for pickling support by Memento pattern.
_getstate returns a memento of the class, which gets pickled. _getstate returns a memento of the class, which gets pickled.
_setstate(<memento>) (re-)sets the state of the class to the memento _setstate(<memento>) (re-)sets the state of the class to the memento
""" """
#=========================================================================== #===========================================================================
# Pickling operations # Pickling operations
@ -112,14 +112,14 @@ class Pickleable(object):
with open(f, 'w') as f: with open(f, 'w') as f:
cPickle.dump(self, f, protocol) cPickle.dump(self, f, protocol)
else: else:
cPickle.dump(self, f, protocol) cPickle.dump(self, f, protocol)
def __getstate__(self): def __getstate__(self):
if self._has_get_set_state(): if self._has_get_set_state():
return self._getstate() return self._getstate()
return self.__dict__ return self.__dict__
def __setstate__(self, state): def __setstate__(self, state):
if self._has_get_set_state(): if self._has_get_set_state():
self._setstate(state) self._setstate(state)
# TODO: maybe parameters_changed() here? # TODO: maybe parameters_changed() here?
return return
self.__dict__ = state self.__dict__ = state
@ -160,7 +160,7 @@ class Parentable(object):
_parent_ = None _parent_ = None
_parent_index_ = None _parent_index_ = None
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super(Parentable, self).__init__(*args, **kwargs) super(Parentable, self).__init__()
def has_parent(self): def has_parent(self):
""" """
@ -201,18 +201,18 @@ class Gradcheckable(Parentable):
Adds the functionality for an object to be gradcheckable. Adds the functionality for an object to be gradcheckable.
It is just a thin wrapper of a call to the highest parent for now. It is just a thin wrapper of a call to the highest parent for now.
TODO: Can be done better, by only changing parameters of the current parameter handle, TODO: Can be done better, by only changing parameters of the current parameter handle,
such that object hierarchy only has to change for those. such that object hierarchy only has to change for those.
""" """
def __init__(self, *a, **kw): def __init__(self, *a, **kw):
super(Gradcheckable, self).__init__(*a, **kw) super(Gradcheckable, self).__init__(*a, **kw)
def checkgrad(self, verbose=0, step=1e-6, tolerance=1e-3, _debug=False): def checkgrad(self, verbose=0, step=1e-6, tolerance=1e-3, _debug=False):
""" """
Check the gradient of this parameter with respect to the highest parent's Check the gradient of this parameter with respect to the highest parent's
objective function. objective function.
This is a three point estimate of the gradient, wiggling at the parameters This is a three point estimate of the gradient, wiggling at the parameters
with a stepsize step. with a stepsize step.
The check passes if either the ratio or the difference between numerical and The check passes if either the ratio or the difference between numerical and
analytical gradient is smaller then tolerance. analytical gradient is smaller then tolerance.
:param bool verbose: whether each parameter shall be checked individually. :param bool verbose: whether each parameter shall be checked individually.
@ -275,22 +275,22 @@ class Indexable(object):
The raveled index of an object is the index for its parameters in a flattened int array. The raveled index of an object is the index for its parameters in a flattened int array.
""" """
def __init__(self, *a, **kw): def __init__(self, *a, **kw):
super(Indexable, self).__init__(*a, **kw) super(Indexable, self).__init__()
def _raveled_index(self): def _raveled_index(self):
""" """
Flattened array of ints, specifying the index of this object. Flattened array of ints, specifying the index of this object.
This has to account for shaped parameters! This has to account for shaped parameters!
""" """
raise NotImplementedError, "Need to be able to get the raveled Index" raise NotImplementedError, "Need to be able to get the raveled Index"
def _internal_offset(self): def _internal_offset(self):
""" """
The offset for this parameter inside its parent. The offset for this parameter inside its parent.
This has to account for shaped parameters! This has to account for shaped parameters!
""" """
return 0 return 0
def _offset_for(self, param): def _offset_for(self, param):
""" """
Return the offset of the param inside this parameterized object. Return the offset of the param inside this parameterized object.
@ -298,15 +298,15 @@ class Indexable(object):
basically just sums up the parameter sizes which come before param. basically just sums up the parameter sizes which come before param.
""" """
raise NotImplementedError, "shouldnt happen, offset required from non parameterization object?" raise NotImplementedError, "shouldnt happen, offset required from non parameterization object?"
def _raveled_index_for(self, param): def _raveled_index_for(self, param):
""" """
get the raveled index for a param get the raveled index for a param
that is an int array, containing the indexes for the flattened that is an int array, containing the indexes for the flattened
param inside this parameterized logic. param inside this parameterized logic.
""" """
raise NotImplementedError, "shouldnt happen, raveld index transformation required from non parameterization object?" raise NotImplementedError, "shouldnt happen, raveld index transformation required from non parameterization object?"
class Constrainable(Nameable, Indexable): class Constrainable(Nameable, Indexable):
""" """
@ -315,7 +315,7 @@ class Constrainable(Nameable, Indexable):
Adding a constraint to a Parameter means to tell the highest parent that Adding a constraint to a Parameter means to tell the highest parent that
the constraint was added and making sure that all parameters covered the constraint was added and making sure that all parameters covered
by this object are indeed conforming to the constraint. by this object are indeed conforming to the constraint.
:func:`constrain()` and :func:`unconstrain()` are main methods here :func:`constrain()` and :func:`unconstrain()` are main methods here
""" """
def __init__(self, name, default_constraint=None, *a, **kw): def __init__(self, name, default_constraint=None, *a, **kw):
@ -326,7 +326,7 @@ class Constrainable(Nameable, Indexable):
self.priors = ParameterIndexOperations() self.priors = ParameterIndexOperations()
if self._default_constraint_ is not None: if self._default_constraint_ is not None:
self.constrain(self._default_constraint_) self.constrain(self._default_constraint_)
def _disconnect_parent(self, constr=None, *args, **kw): def _disconnect_parent(self, constr=None, *args, **kw):
""" """
From Parentable: From Parentable:
@ -340,7 +340,7 @@ class Constrainable(Nameable, Indexable):
self._parent_index_ = None self._parent_index_ = None
self._connect_fixes() self._connect_fixes()
self._notify_parent_change() self._notify_parent_change()
#=========================================================================== #===========================================================================
# Fixing Parameters: # Fixing Parameters:
#=========================================================================== #===========================================================================
@ -356,20 +356,20 @@ class Constrainable(Nameable, Indexable):
rav_i = self._highest_parent_._raveled_index_for(self) rav_i = self._highest_parent_._raveled_index_for(self)
self._highest_parent_._set_fixed(rav_i) self._highest_parent_._set_fixed(rav_i)
fix = constrain_fixed fix = constrain_fixed
def unconstrain_fixed(self): def unconstrain_fixed(self):
""" """
This parameter will no longer be fixed. This parameter will no longer be fixed.
""" """
unconstrained = self.unconstrain(__fixed__) unconstrained = self.unconstrain(__fixed__)
self._highest_parent_._set_unfixed(unconstrained) self._highest_parent_._set_unfixed(unconstrained)
unfix = unconstrain_fixed unfix = unconstrain_fixed
def _set_fixed(self, index): def _set_fixed(self, index):
if not self._has_fixes(): self._fixes_ = np.ones(self.size, dtype=bool) if not self._has_fixes(): self._fixes_ = np.ones(self.size, dtype=bool)
self._fixes_[index] = FIXED self._fixes_[index] = FIXED
if np.all(self._fixes_): self._fixes_ = None # ==UNFIXED if np.all(self._fixes_): self._fixes_ = None # ==UNFIXED
def _set_unfixed(self, index): def _set_unfixed(self, index):
if not self._has_fixes(): self._fixes_ = np.ones(self.size, dtype=bool) if not self._has_fixes(): self._fixes_ = np.ones(self.size, dtype=bool)
# rav_i = self._raveled_index_for(param)[index] # rav_i = self._raveled_index_for(param)[index]
@ -383,7 +383,7 @@ class Constrainable(Nameable, Indexable):
self._fixes_[fixed_indices] = FIXED self._fixes_[fixed_indices] = FIXED
else: else:
self._fixes_ = None self._fixes_ = None
def _has_fixes(self): def _has_fixes(self):
return hasattr(self, "_fixes_") and self._fixes_ is not None return hasattr(self, "_fixes_") and self._fixes_ is not None
@ -398,21 +398,21 @@ class Constrainable(Nameable, Indexable):
""" """
repriorized = self.unset_priors() repriorized = self.unset_priors()
self._add_to_index_operations(self.priors, repriorized, prior, warning) self._add_to_index_operations(self.priors, repriorized, prior, warning)
def unset_priors(self, *priors): def unset_priors(self, *priors):
""" """
Un-set all priors given from this parameter handle. Un-set all priors given from this parameter handle.
""" """
return self._remove_from_index_operations(self.priors, priors) return self._remove_from_index_operations(self.priors, priors)
def log_prior(self): def log_prior(self):
"""evaluate the prior""" """evaluate the prior"""
if self.priors.size > 0: if self.priors.size > 0:
x = self._get_params() x = self._get_params()
return reduce(lambda a, b: a + b, [p.lnpdf(x[ind]).sum() for p, ind in self.priors.iteritems()], 0) return reduce(lambda a, b: a + b, [p.lnpdf(x[ind]).sum() for p, ind in self.priors.iteritems()], 0)
return 0. return 0.
def _log_prior_gradients(self): def _log_prior_gradients(self):
"""evaluate the gradients of the priors""" """evaluate the gradients of the priors"""
if self.priors.size > 0: if self.priors.size > 0:
@ -421,7 +421,7 @@ class Constrainable(Nameable, Indexable):
[np.put(ret, ind, p.lnpdf_grad(x[ind])) for p, ind in self.priors.iteritems()] [np.put(ret, ind, p.lnpdf_grad(x[ind])) for p, ind in self.priors.iteritems()]
return ret return ret
return 0. return 0.
#=========================================================================== #===========================================================================
# Constrain operations -> done # Constrain operations -> done
#=========================================================================== #===========================================================================
@ -448,7 +448,7 @@ class Constrainable(Nameable, Indexable):
transformats of this parameter object. transformats of this parameter object.
""" """
return self._remove_from_index_operations(self.constraints, transforms) return self._remove_from_index_operations(self.constraints, transforms)
def constrain_positive(self, warning=True, trigger_parent=True): def constrain_positive(self, warning=True, trigger_parent=True):
""" """
:param warning: print a warning if re-constraining parameters. :param warning: print a warning if re-constraining parameters.
@ -493,7 +493,7 @@ class Constrainable(Nameable, Indexable):
Remove (lower, upper) bounded constrain from this parameter/ Remove (lower, upper) bounded constrain from this parameter/
""" """
self.unconstrain(Logistic(lower, upper)) self.unconstrain(Logistic(lower, upper))
def _parent_changed(self, parent): def _parent_changed(self, parent):
""" """
From Parentable: From Parentable:
@ -522,7 +522,7 @@ class Constrainable(Nameable, Indexable):
def _remove_from_index_operations(self, which, what): def _remove_from_index_operations(self, which, what):
""" """
Helper preventing copy code. Helper preventing copy code.
Remove given what (transform prior etc) from which param index ops. Remove given what (transform prior etc) from which param index ops.
""" """
if len(what) == 0: if len(what) == 0:
transforms = which.properties() transforms = which.properties()
@ -532,7 +532,7 @@ class Constrainable(Nameable, Indexable):
removed = np.union1d(removed, unconstrained) removed = np.union1d(removed, unconstrained)
if t is __fixed__: if t is __fixed__:
self._highest_parent_._set_unfixed(unconstrained) self._highest_parent_._set_unfixed(unconstrained)
return removed return removed
class OptimizationHandlable(Constrainable, Observable): class OptimizationHandlable(Constrainable, Observable):
@ -543,13 +543,13 @@ class OptimizationHandlable(Constrainable, Observable):
""" """
def __init__(self, name, default_constraint=None, *a, **kw): def __init__(self, name, default_constraint=None, *a, **kw):
super(OptimizationHandlable, self).__init__(name, default_constraint=default_constraint, *a, **kw) super(OptimizationHandlable, self).__init__(name, default_constraint=default_constraint, *a, **kw)
def transform(self): def transform(self):
[np.put(self._param_array_, ind, c.finv(self._param_array_[ind])) for c, ind in self.constraints.iteritems() if c != __fixed__] [np.put(self._param_array_, ind, c.finv(self._param_array_[ind])) for c, ind in self.constraints.iteritems() if c != __fixed__]
def untransform(self): def untransform(self):
[np.put(self._param_array_, ind, c.f(self._param_array_[ind])) for c, ind in self.constraints.iteritems() if c != __fixed__] [np.put(self._param_array_, ind, c.f(self._param_array_[ind])) for c, ind in self.constraints.iteritems() if c != __fixed__]
def _get_params_transformed(self): def _get_params_transformed(self):
# transformed parameters (apply transformation rules) # transformed parameters (apply transformation rules)
p = self._param_array_.copy() p = self._param_array_.copy()
@ -565,23 +565,23 @@ class OptimizationHandlable(Constrainable, Observable):
else: self._param_array_[:] = p else: self._param_array_[:] = p
self.untransform() self.untransform()
self._trigger_params_changed() self._trigger_params_changed()
def _trigger_params_changed(self, trigger_parent=True): def _trigger_params_changed(self, trigger_parent=True):
[p._trigger_params_changed(trigger_parent=False) for p in self._parameters_] [p._trigger_params_changed(trigger_parent=False) for p in self._parameters_]
if trigger_parent: min_priority = None if trigger_parent: min_priority = None
else: min_priority = -np.inf else: min_priority = -np.inf
self.notify_observers(None, min_priority) self.notify_observers(None, min_priority)
def _size_transformed(self): def _size_transformed(self):
return self.size - self.constraints[__fixed__].size return self.size - self.constraints[__fixed__].size
# #
# def _untransform_params(self, p): # def _untransform_params(self, p):
# # inverse apply transformations for parameters # # inverse apply transformations for parameters
# #p = p.copy() # #p = p.copy()
# if self._has_fixes(): tmp = self._get_params(); tmp[self._fixes_] = p; p = tmp; del tmp # if self._has_fixes(): tmp = self._get_params(); tmp[self._fixes_] = p; p = tmp; del tmp
# [np.put(p, ind, c.f(p[ind])) for c, ind in self.constraints.iteritems() if c != __fixed__] # [np.put(p, ind, c.f(p[ind])) for c, ind in self.constraints.iteritems() if c != __fixed__]
# return p # return p
# #
# def _get_params(self): # def _get_params(self):
# """ # """
# get all parameters # get all parameters
@ -592,7 +592,7 @@ class OptimizationHandlable(Constrainable, Observable):
# return p # return p
# [np.put(p, ind, par._get_params()) for ind, par in itertools.izip(self._param)] # [np.put(p, ind, par._get_params()) for ind, par in itertools.izip(self._param)]
# return p # return p
# def _set_params(self, params, trigger_parent=True): # def _set_params(self, params, trigger_parent=True):
# self._param_array_.flat = params # self._param_array_.flat = params
# if trigger_parent: min_priority = None # if trigger_parent: min_priority = None
@ -600,14 +600,14 @@ class OptimizationHandlable(Constrainable, Observable):
# self.notify_observers(None, min_priority) # self.notify_observers(None, min_priority)
# don't overwrite this anymore! # don't overwrite this anymore!
#raise NotImplementedError, "Abstract superclass: This needs to be implemented in Param and Parameterizable" #raise NotImplementedError, "Abstract superclass: This needs to be implemented in Param and Parameterizable"
#=========================================================================== #===========================================================================
# Optimization handles: # Optimization handles:
#=========================================================================== #===========================================================================
def _get_param_names(self): def _get_param_names(self):
n = np.array([p.hierarchy_name() + '[' + str(i) + ']' for p in self.flattened_parameters for i in p._indices()]) n = np.array([p.hierarchy_name() + '[' + str(i) + ']' for p in self.flattened_parameters for i in p._indices()])
return n return n
def _get_param_names_transformed(self): def _get_param_names_transformed(self):
n = self._get_param_names() n = self._get_param_names()
if self._has_fixes(): if self._has_fixes():
@ -621,7 +621,7 @@ class OptimizationHandlable(Constrainable, Observable):
""" """
Randomize the model. Randomize the model.
Make this draw from the prior if one exists, else draw from given random generator Make this draw from the prior if one exists, else draw from given random generator
:param rand_gen: numpy random number generator which takes args and kwargs :param rand_gen: numpy random number generator which takes args and kwargs
:param flaot loc: loc parameter for random number generator :param flaot loc: loc parameter for random number generator
:param float scale: scale parameter for random number generator :param float scale: scale parameter for random number generator
@ -663,7 +663,7 @@ class Parameterizable(OptimizationHandlable):
def parameter_names(self, add_self=False, adjust_for_printing=False, recursive=True): def parameter_names(self, add_self=False, adjust_for_printing=False, recursive=True):
""" """
Get the names of all parameters of this model. Get the names of all parameters of this model.
:param bool add_self: whether to add the own name in front of names :param bool add_self: whether to add the own name in front of names
:param bool adjust_for_printing: whether to call `adjust_name_for_printing` on names :param bool adjust_for_printing: whether to call `adjust_name_for_printing` on names
@ -712,7 +712,7 @@ class Parameterizable(OptimizationHandlable):
#========================================================================= #=========================================================================
@property @property
def gradient(self): def gradient(self):
return self._gradient_array_ return self._gradient_array_
@gradient.setter @gradient.setter
def gradient(self, val): def gradient(self, val):
@ -821,8 +821,8 @@ class Parameterizable(OptimizationHandlable):
# connect parameterlist to this parameterized object # connect parameterlist to this parameterized object
# This just sets up the right connection for the params objects # This just sets up the right connection for the params objects
# to be used as parameters # to be used as parameters
# it also sets the constraints for each parameter to the constraints # it also sets the constraints for each parameter to the constraints
# of their respective parents # of their respective parents
if not hasattr(self, "_parameters_") or len(self._parameters_) < 1: if not hasattr(self, "_parameters_") or len(self._parameters_) < 1:
# no parameters for this class # no parameters for this class
return return
@ -837,7 +837,7 @@ class Parameterizable(OptimizationHandlable):
pslice = slice(old_size, old_size+p.size) pslice = slice(old_size, old_size+p.size)
# first connect all children # first connect all children
p._propagate_param_grad(self._param_array_[pslice], self._gradient_array_[pslice]) p._propagate_param_grad(self._param_array_[pslice], self._gradient_array_[pslice])
# then connect children to self # then connect children to self
self._param_array_[pslice] = p._param_array_.ravel()#, requirements=['C', 'W']).ravel(order='C') self._param_array_[pslice] = p._param_array_.ravel()#, requirements=['C', 'W']).ravel(order='C')
self._gradient_array_[pslice] = p._gradient_array_.ravel()#, requirements=['C', 'W']).ravel(order='C') self._gradient_array_[pslice] = p._gradient_array_.ravel()#, requirements=['C', 'W']).ravel(order='C')
@ -879,7 +879,7 @@ class Parameterizable(OptimizationHandlable):
dc[k] = copy.deepcopy(v) dc[k] = copy.deepcopy(v)
if k == '_parameters_': if k == '_parameters_':
params = [p.copy() for p in v] params = [p.copy() for p in v]
dc['_parent_'] = None dc['_parent_'] = None
dc['_parent_index_'] = None dc['_parent_index_'] = None
dc['_observer_callables_'] = [] dc['_observer_callables_'] = []
@ -890,12 +890,12 @@ class Parameterizable(OptimizationHandlable):
s = self.__new__(self.__class__) s = self.__new__(self.__class__)
s.__dict__ = dc s.__dict__ = dc
for p in params: for p in params:
s.add_parameter(p, _ignore_added_names=True) s.add_parameter(p, _ignore_added_names=True)
return s return s
#=========================================================================== #===========================================================================
# From being parentable, we have to define the parent_change notification # From being parentable, we have to define the parent_change notification
#=========================================================================== #===========================================================================