mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-09 12:02:38 +02:00
logic edits for copy
This commit is contained in:
parent
441a9f524d
commit
a98334e009
5 changed files with 144 additions and 138 deletions
|
|
@ -88,114 +88,6 @@ class Parameterized(Parameterizable, Pickleable, Gradcheckable):
|
|||
return G
|
||||
return node
|
||||
|
||||
|
||||
def add_parameter(self, param, index=None):
|
||||
"""
|
||||
:param parameters: the parameters to add
|
||||
:type parameters: list of or one :py:class:`GPy.core.param.Param`
|
||||
:param [index]: index of where to put parameters
|
||||
|
||||
|
||||
Add all parameters to this param class, you can insert parameters
|
||||
at any given index using the :func:`list.insert` syntax
|
||||
"""
|
||||
# if param.has_parent():
|
||||
# raise AttributeError, "parameter {} already in another model, create new object (or copy) for adding".format(param._short())
|
||||
if param in self._parameters_ and index is not None:
|
||||
self.remove_parameter(param)
|
||||
self.add_parameter(param, index)
|
||||
elif param not in self._parameters_:
|
||||
if param.has_parent():
|
||||
parent = param._direct_parent_
|
||||
while parent is not None:
|
||||
if parent is self:
|
||||
from GPy.core.parameterization.parameter_core import HierarchyError
|
||||
raise HierarchyError, "You cannot add a parameter twice into the hirarchy"
|
||||
parent = parent._direct_parent_
|
||||
param._direct_parent_.remove_parameter(param)
|
||||
# make sure the size is set
|
||||
if index is None:
|
||||
self.constraints.update(param.constraints, self.size)
|
||||
self.priors.update(param.priors, self.size)
|
||||
self._parameters_.append(param)
|
||||
else:
|
||||
start = sum(p.size for p in self._parameters_[:index])
|
||||
self.constraints.shift_right(start, param.size)
|
||||
self.priors.shift_right(start, param.size)
|
||||
self.constraints.update(param.constraints, start)
|
||||
self.priors.update(param.priors, start)
|
||||
self._parameters_.insert(index, param)
|
||||
|
||||
param.add_observer(self, self._pass_through_notify_observers, -np.inf)
|
||||
|
||||
self.size += param.size
|
||||
|
||||
self._connect_parameters()
|
||||
self._notify_parent_change()
|
||||
self._connect_fixes()
|
||||
else:
|
||||
raise RuntimeError, """Parameter exists already added and no copy made"""
|
||||
|
||||
|
||||
def add_parameters(self, *parameters):
|
||||
"""
|
||||
convenience method for adding several
|
||||
parameters without gradient specification
|
||||
"""
|
||||
[self.add_parameter(p) for p in parameters]
|
||||
|
||||
def remove_parameter(self, param):
|
||||
"""
|
||||
:param param: param object to remove from being a parameter of this parameterized object.
|
||||
"""
|
||||
if not param in self._parameters_:
|
||||
raise RuntimeError, "Parameter {} does not belong to this object, remove parameters directly from their respective parents".format(param._short())
|
||||
|
||||
start = sum([p.size for p in self._parameters_[:param._parent_index_]])
|
||||
self._remove_parameter_name(param)
|
||||
self.size -= param.size
|
||||
del self._parameters_[param._parent_index_]
|
||||
|
||||
param._disconnect_parent()
|
||||
param.remove_observer(self, self._pass_through_notify_observers)
|
||||
self.constraints.shift_left(start, param.size)
|
||||
|
||||
self._connect_fixes()
|
||||
self._connect_parameters()
|
||||
self._notify_parent_change()
|
||||
|
||||
parent = self._direct_parent_
|
||||
while parent is not None:
|
||||
parent._connect_fixes()
|
||||
parent._connect_parameters()
|
||||
parent._notify_parent_change()
|
||||
parent = parent._direct_parent_
|
||||
|
||||
def _connect_parameters(self):
|
||||
# connect parameterlist to this parameterized object
|
||||
# This just sets up the right connection for the params objects
|
||||
# to be used as parameters
|
||||
# it also sets the constraints for each parameter to the constraints
|
||||
# of their respective parents
|
||||
if not hasattr(self, "_parameters_") or len(self._parameters_) < 1:
|
||||
# no parameters for this class
|
||||
return
|
||||
sizes = [0]
|
||||
self._param_slices_ = []
|
||||
for i, p in enumerate(self._parameters_):
|
||||
p._direct_parent_ = self
|
||||
p._parent_index_ = i
|
||||
sizes.append(p.size + sizes[-1])
|
||||
self._param_slices_.append(slice(sizes[-2], sizes[-1]))
|
||||
self._add_parameter_name(p)
|
||||
|
||||
#===========================================================================
|
||||
# notification system
|
||||
#===========================================================================
|
||||
def _parameters_changed_notification(self, which):
|
||||
self.parameters_changed()
|
||||
def _pass_through_notify_observers(self, which):
|
||||
self._notify_observers(which)
|
||||
#===========================================================================
|
||||
# Pickling operations
|
||||
#===========================================================================
|
||||
|
|
@ -212,6 +104,11 @@ class Parameterized(Parameterizable, Pickleable, Gradcheckable):
|
|||
else:
|
||||
cPickle.dump(self, f, protocol)
|
||||
|
||||
def copy(self):
|
||||
c = super(Parameterized, self).copy()
|
||||
c.add_observer(c, c._parameters_changed_notification, -100)
|
||||
return c
|
||||
|
||||
def __getstate__(self):
|
||||
if self._has_get_set_state():
|
||||
return self._getstate()
|
||||
|
|
@ -332,9 +229,13 @@ class Parameterized(Parameterizable, Pickleable, Gradcheckable):
|
|||
return ParamConcatenation(paramlist)
|
||||
|
||||
def __setitem__(self, name, value, paramlist=None):
|
||||
try: param = self.__getitem__(name, paramlist)
|
||||
except AttributeError as a: raise a
|
||||
param[:] = value
|
||||
if isinstance(name, slice):
|
||||
self[''][name] = value
|
||||
else:
|
||||
try: param = self.__getitem__(name, paramlist)
|
||||
except AttributeError as a: raise a
|
||||
param[:] = value
|
||||
|
||||
def __setattr__(self, name, val):
|
||||
# override the default behaviour, if setting a param, so broadcasting can by used
|
||||
if hasattr(self, '_parameters_'):
|
||||
|
|
@ -379,7 +280,7 @@ class Parameterized(Parameterizable, Pickleable, Gradcheckable):
|
|||
cl = max([len(str(x)) if x else 0 for x in constrs + ["Constraint"]])
|
||||
tl = max([len(str(x)) if x else 0 for x in ts + ["Tied to"]])
|
||||
pl = max([len(str(x)) if x else 0 for x in prirs + ["Prior"]])
|
||||
format_spec = " \033[1m{{name:<{0}s}}\033[0;0m | {{desc:^{1}s}} | {{const:^{2}s}} | {{pri:^{3}s}} | {{t:^{4}s}}".format(nl, sl, cl, pl, tl)
|
||||
format_spec = " \033[1m{{name:<{0}s}}\033[0;0m | {{desc:>{1}s}} | {{const:^{2}s}} | {{pri:^{3}s}} | {{t:^{4}s}}".format(nl, sl, cl, pl, tl)
|
||||
to_print = []
|
||||
for n, d, c, t, p in itertools.izip(names, desc, constrs, ts, prirs):
|
||||
to_print.append(format_spec.format(name=n, desc=d, const=c, t=t, pri=p))
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue