mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-12 21:42:39 +02:00
logic edits for copy
This commit is contained in:
parent
441a9f524d
commit
a98334e009
5 changed files with 144 additions and 138 deletions
|
|
@ -247,7 +247,8 @@ class Param(OptimizationHandlable, ObservableArray, Gradcheckable):
|
|||
#===========================================================================
|
||||
@property
|
||||
def _description_str(self):
|
||||
if self.size <= 1: return ["%f" % self]
|
||||
if self.size <= 1:
|
||||
return [str(numpy.take(self, 0))]
|
||||
else: return [str(self.shape)]
|
||||
def parameter_names(self, add_self=False, adjust_for_printing=False):
|
||||
if adjust_for_printing:
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@
|
|||
# Licensed under the BSD 3-clause license (see LICENSE.txt)
|
||||
|
||||
from transformations import Transformation, Logexp, NegativeLogexp, Logistic, __fixed__, FIXED, UNFIXED
|
||||
import heapq
|
||||
import numpy as np
|
||||
|
||||
__updated__ = '2013-12-16'
|
||||
|
||||
|
|
@ -22,7 +22,7 @@ class Observable(object):
|
|||
self._observer_callables_ = []
|
||||
|
||||
def add_observer(self, observer, callble, priority=0):
|
||||
heapq.heappush(self._observer_callables_, (priority, observer, callble))
|
||||
self._insert_sorted(priority, observer, callble)
|
||||
|
||||
def remove_observer(self, observer, callble=None):
|
||||
to_remove = []
|
||||
|
|
@ -51,10 +51,21 @@ class Observable(object):
|
|||
if which is None:
|
||||
which = self
|
||||
if min_priority is None:
|
||||
[callble(which) for _, _, callble in heapq.nlargest(len(self._observer_callables_), self._observer_callables_)]
|
||||
[callble(which) for _, _, callble in self._observer_callables_]
|
||||
else:
|
||||
[callble(which) for p, _, callble in heapq.nlargest(len(self._observer_callables_), self._observer_callables_) if p > min_priority]
|
||||
for p, _, callble in self._observer_callables_:
|
||||
if p <= min_priority:
|
||||
break
|
||||
callble(which)
|
||||
|
||||
def _insert_sorted(self, p, o, c):
|
||||
ins = 0
|
||||
for pr, _, _ in self._observer_callables_:
|
||||
if p > pr:
|
||||
break
|
||||
ins += 1
|
||||
self._observer_callables_.insert(ins, (p, o, c))
|
||||
|
||||
class Pickleable(object):
|
||||
def _getstate(self):
|
||||
"""
|
||||
|
|
@ -202,20 +213,17 @@ class Constrainable(Nameable, Indexable):
|
|||
unfix = unconstrain_fixed
|
||||
|
||||
def _set_fixed(self, index):
|
||||
import numpy as np
|
||||
if not self._has_fixes(): self._fixes_ = np.ones(self.size, dtype=bool)
|
||||
self._fixes_[index] = FIXED
|
||||
if np.all(self._fixes_): self._fixes_ = None # ==UNFIXED
|
||||
|
||||
def _set_unfixed(self, index):
|
||||
import numpy as np
|
||||
if not self._has_fixes(): self._fixes_ = np.ones(self.size, dtype=bool)
|
||||
# rav_i = self._raveled_index_for(param)[index]
|
||||
self._fixes_[index] = UNFIXED
|
||||
if np.all(self._fixes_): self._fixes_ = None # ==UNFIXED
|
||||
|
||||
def _connect_fixes(self):
|
||||
import numpy as np
|
||||
fixed_indices = self.constraints[__fixed__]
|
||||
if fixed_indices.size > 0:
|
||||
self._fixes_ = np.ones(self.size, dtype=bool) * UNFIXED
|
||||
|
|
@ -245,7 +253,6 @@ class Constrainable(Nameable, Indexable):
|
|||
|
||||
def _log_prior_gradients(self):
|
||||
"""evaluate the gradients of the priors"""
|
||||
import numpy as np
|
||||
if self.priors.size > 0:
|
||||
x = self._get_params()
|
||||
ret = np.zeros(x.size)
|
||||
|
|
@ -342,7 +349,6 @@ class Constrainable(Nameable, Indexable):
|
|||
def _remove_from_index_operations(self, which, transforms):
|
||||
if len(transforms) == 0:
|
||||
transforms = which.properties()
|
||||
import numpy as np
|
||||
removed = np.empty((0,), dtype=int)
|
||||
for t in transforms:
|
||||
unconstrained = which.remove(t, self._raveled_index())
|
||||
|
|
@ -404,7 +410,6 @@ class OptimizationHandlable(Constrainable, Observable):
|
|||
Randomize the model.
|
||||
Make this draw from the prior if one exists, else draw from N(0,1)
|
||||
"""
|
||||
import numpy as np
|
||||
# first take care of all parameters (from N(0,1))
|
||||
# x = self._get_params_transformed()
|
||||
x = np.random.randn(self._size_transformed())
|
||||
|
|
@ -413,9 +418,6 @@ class OptimizationHandlable(Constrainable, Observable):
|
|||
[np.put(x, ind, p.rvs(ind.size)) for p, ind in self.priors.iteritems() if not p is None]
|
||||
self._set_params(x)
|
||||
# self._set_params_transformed(self._get_params_transformed()) # makes sure all of the tied parameters get the same init (since there's only one prior object...)
|
||||
|
||||
|
||||
import numpy as np
|
||||
|
||||
class Parameterizable(OptimizationHandlable):
|
||||
def __init__(self, *args, **kwargs):
|
||||
|
|
@ -474,7 +476,113 @@ class Parameterizable(OptimizationHandlable):
|
|||
def _set_gradient(self, g):
|
||||
import itertools
|
||||
[p._set_gradient(g[s]) for p, s in itertools.izip(self._parameters_, self._param_slices_)]
|
||||
|
||||
def add_parameter(self, param, index=None):
|
||||
"""
|
||||
:param parameters: the parameters to add
|
||||
:type parameters: list of or one :py:class:`GPy.core.param.Param`
|
||||
:param [index]: index of where to put parameters
|
||||
|
||||
|
||||
Add all parameters to this param class, you can insert parameters
|
||||
at any given index using the :func:`list.insert` syntax
|
||||
"""
|
||||
# if param.has_parent():
|
||||
# raise AttributeError, "parameter {} already in another model, create new object (or copy) for adding".format(param._short())
|
||||
if param in self._parameters_ and index is not None:
|
||||
self.remove_parameter(param)
|
||||
self.add_parameter(param, index)
|
||||
elif param not in self._parameters_:
|
||||
if param.has_parent():
|
||||
parent = param._direct_parent_
|
||||
while parent is not None:
|
||||
if parent is self:
|
||||
raise HierarchyError, "You cannot add a parameter twice into the hirarchy"
|
||||
parent = parent._direct_parent_
|
||||
param._direct_parent_.remove_parameter(param)
|
||||
# make sure the size is set
|
||||
if index is None:
|
||||
self.constraints.update(param.constraints, self.size)
|
||||
self.priors.update(param.priors, self.size)
|
||||
self._parameters_.append(param)
|
||||
else:
|
||||
start = sum(p.size for p in self._parameters_[:index])
|
||||
self.constraints.shift_right(start, param.size)
|
||||
self.priors.shift_right(start, param.size)
|
||||
self.constraints.update(param.constraints, start)
|
||||
self.priors.update(param.priors, start)
|
||||
self._parameters_.insert(index, param)
|
||||
|
||||
param.add_observer(self, self._pass_through_notify_observers, -np.inf)
|
||||
|
||||
self.size += param.size
|
||||
|
||||
self._connect_parameters()
|
||||
self._notify_parent_change()
|
||||
self._connect_fixes()
|
||||
else:
|
||||
raise RuntimeError, """Parameter exists already added and no copy made"""
|
||||
|
||||
|
||||
def add_parameters(self, *parameters):
|
||||
"""
|
||||
convenience method for adding several
|
||||
parameters without gradient specification
|
||||
"""
|
||||
[self.add_parameter(p) for p in parameters]
|
||||
|
||||
def remove_parameter(self, param):
|
||||
"""
|
||||
:param param: param object to remove from being a parameter of this parameterized object.
|
||||
"""
|
||||
if not param in self._parameters_:
|
||||
raise RuntimeError, "Parameter {} does not belong to this object, remove parameters directly from their respective parents".format(param._short())
|
||||
|
||||
start = sum([p.size for p in self._parameters_[:param._parent_index_]])
|
||||
self._remove_parameter_name(param)
|
||||
self.size -= param.size
|
||||
del self._parameters_[param._parent_index_]
|
||||
|
||||
param._disconnect_parent()
|
||||
param.remove_observer(self, self._pass_through_notify_observers)
|
||||
self.constraints.shift_left(start, param.size)
|
||||
|
||||
self._connect_fixes()
|
||||
self._connect_parameters()
|
||||
self._notify_parent_change()
|
||||
|
||||
parent = self._direct_parent_
|
||||
while parent is not None:
|
||||
parent._connect_fixes()
|
||||
parent._connect_parameters()
|
||||
parent._notify_parent_change()
|
||||
parent = parent._direct_parent_
|
||||
|
||||
def _connect_parameters(self):
|
||||
# connect parameterlist to this parameterized object
|
||||
# This just sets up the right connection for the params objects
|
||||
# to be used as parameters
|
||||
# it also sets the constraints for each parameter to the constraints
|
||||
# of their respective parents
|
||||
if not hasattr(self, "_parameters_") or len(self._parameters_) < 1:
|
||||
# no parameters for this class
|
||||
return
|
||||
sizes = [0]
|
||||
self._param_slices_ = []
|
||||
for i, p in enumerate(self._parameters_):
|
||||
p._direct_parent_ = self
|
||||
p._parent_index_ = i
|
||||
sizes.append(p.size + sizes[-1])
|
||||
self._param_slices_.append(slice(sizes[-2], sizes[-1]))
|
||||
self._add_parameter_name(p)
|
||||
|
||||
#===========================================================================
|
||||
# notification system
|
||||
#===========================================================================
|
||||
def _parameters_changed_notification(self, which):
|
||||
self.parameters_changed()
|
||||
def _pass_through_notify_observers(self, which):
|
||||
self._notify_observers(which)
|
||||
|
||||
#===========================================================================
|
||||
# TODO: not working yet
|
||||
|
|
@ -487,16 +595,17 @@ class Parameterizable(OptimizationHandlable):
|
|||
|
||||
dc = dict()
|
||||
for k, v in self.__dict__.iteritems():
|
||||
if k not in ['_direct_parent_', '_parameters_', '_parent_index_'] + self.parameter_names():
|
||||
if k not in ['_direct_parent_', '_parameters_', '_parent_index_', '_observer_callables_'] + self.parameter_names():
|
||||
if isinstance(v, (Constrainable, ParameterIndexOperations, ParameterIndexOperationsView)):
|
||||
dc[k] = v.copy()
|
||||
else:
|
||||
dc[k] = copy.deepcopy(v)
|
||||
if k == '_parameters_':
|
||||
params = [p.copy() for p in v]
|
||||
|
||||
|
||||
dc['_direct_parent_'] = None
|
||||
dc['_parent_index_'] = None
|
||||
dc['_observer_callables_'] = []
|
||||
dc['_parameters_'] = ArrayList()
|
||||
dc['constraints'].clear()
|
||||
dc['priors'].clear()
|
||||
|
|
@ -506,6 +615,7 @@ class Parameterizable(OptimizationHandlable):
|
|||
s.__dict__ = dc
|
||||
|
||||
for p in params:
|
||||
import ipdb;ipdb.set_trace()
|
||||
s.add_parameter(p)
|
||||
|
||||
return s
|
||||
|
|
|
|||
|
|
@ -88,114 +88,6 @@ class Parameterized(Parameterizable, Pickleable, Gradcheckable):
|
|||
return G
|
||||
return node
|
||||
|
||||
|
||||
def add_parameter(self, param, index=None):
|
||||
"""
|
||||
:param parameters: the parameters to add
|
||||
:type parameters: list of or one :py:class:`GPy.core.param.Param`
|
||||
:param [index]: index of where to put parameters
|
||||
|
||||
|
||||
Add all parameters to this param class, you can insert parameters
|
||||
at any given index using the :func:`list.insert` syntax
|
||||
"""
|
||||
# if param.has_parent():
|
||||
# raise AttributeError, "parameter {} already in another model, create new object (or copy) for adding".format(param._short())
|
||||
if param in self._parameters_ and index is not None:
|
||||
self.remove_parameter(param)
|
||||
self.add_parameter(param, index)
|
||||
elif param not in self._parameters_:
|
||||
if param.has_parent():
|
||||
parent = param._direct_parent_
|
||||
while parent is not None:
|
||||
if parent is self:
|
||||
from GPy.core.parameterization.parameter_core import HierarchyError
|
||||
raise HierarchyError, "You cannot add a parameter twice into the hirarchy"
|
||||
parent = parent._direct_parent_
|
||||
param._direct_parent_.remove_parameter(param)
|
||||
# make sure the size is set
|
||||
if index is None:
|
||||
self.constraints.update(param.constraints, self.size)
|
||||
self.priors.update(param.priors, self.size)
|
||||
self._parameters_.append(param)
|
||||
else:
|
||||
start = sum(p.size for p in self._parameters_[:index])
|
||||
self.constraints.shift_right(start, param.size)
|
||||
self.priors.shift_right(start, param.size)
|
||||
self.constraints.update(param.constraints, start)
|
||||
self.priors.update(param.priors, start)
|
||||
self._parameters_.insert(index, param)
|
||||
|
||||
param.add_observer(self, self._pass_through_notify_observers, -np.inf)
|
||||
|
||||
self.size += param.size
|
||||
|
||||
self._connect_parameters()
|
||||
self._notify_parent_change()
|
||||
self._connect_fixes()
|
||||
else:
|
||||
raise RuntimeError, """Parameter exists already added and no copy made"""
|
||||
|
||||
|
||||
def add_parameters(self, *parameters):
|
||||
"""
|
||||
convenience method for adding several
|
||||
parameters without gradient specification
|
||||
"""
|
||||
[self.add_parameter(p) for p in parameters]
|
||||
|
||||
def remove_parameter(self, param):
|
||||
"""
|
||||
:param param: param object to remove from being a parameter of this parameterized object.
|
||||
"""
|
||||
if not param in self._parameters_:
|
||||
raise RuntimeError, "Parameter {} does not belong to this object, remove parameters directly from their respective parents".format(param._short())
|
||||
|
||||
start = sum([p.size for p in self._parameters_[:param._parent_index_]])
|
||||
self._remove_parameter_name(param)
|
||||
self.size -= param.size
|
||||
del self._parameters_[param._parent_index_]
|
||||
|
||||
param._disconnect_parent()
|
||||
param.remove_observer(self, self._pass_through_notify_observers)
|
||||
self.constraints.shift_left(start, param.size)
|
||||
|
||||
self._connect_fixes()
|
||||
self._connect_parameters()
|
||||
self._notify_parent_change()
|
||||
|
||||
parent = self._direct_parent_
|
||||
while parent is not None:
|
||||
parent._connect_fixes()
|
||||
parent._connect_parameters()
|
||||
parent._notify_parent_change()
|
||||
parent = parent._direct_parent_
|
||||
|
||||
def _connect_parameters(self):
|
||||
# connect parameterlist to this parameterized object
|
||||
# This just sets up the right connection for the params objects
|
||||
# to be used as parameters
|
||||
# it also sets the constraints for each parameter to the constraints
|
||||
# of their respective parents
|
||||
if not hasattr(self, "_parameters_") or len(self._parameters_) < 1:
|
||||
# no parameters for this class
|
||||
return
|
||||
sizes = [0]
|
||||
self._param_slices_ = []
|
||||
for i, p in enumerate(self._parameters_):
|
||||
p._direct_parent_ = self
|
||||
p._parent_index_ = i
|
||||
sizes.append(p.size + sizes[-1])
|
||||
self._param_slices_.append(slice(sizes[-2], sizes[-1]))
|
||||
self._add_parameter_name(p)
|
||||
|
||||
#===========================================================================
|
||||
# notification system
|
||||
#===========================================================================
|
||||
def _parameters_changed_notification(self, which):
|
||||
self.parameters_changed()
|
||||
def _pass_through_notify_observers(self, which):
|
||||
self._notify_observers(which)
|
||||
#===========================================================================
|
||||
# Pickling operations
|
||||
#===========================================================================
|
||||
|
|
@ -212,6 +104,11 @@ class Parameterized(Parameterizable, Pickleable, Gradcheckable):
|
|||
else:
|
||||
cPickle.dump(self, f, protocol)
|
||||
|
||||
def copy(self):
|
||||
c = super(Parameterized, self).copy()
|
||||
c.add_observer(c, c._parameters_changed_notification, -100)
|
||||
return c
|
||||
|
||||
def __getstate__(self):
|
||||
if self._has_get_set_state():
|
||||
return self._getstate()
|
||||
|
|
@ -332,9 +229,13 @@ class Parameterized(Parameterizable, Pickleable, Gradcheckable):
|
|||
return ParamConcatenation(paramlist)
|
||||
|
||||
def __setitem__(self, name, value, paramlist=None):
|
||||
try: param = self.__getitem__(name, paramlist)
|
||||
except AttributeError as a: raise a
|
||||
param[:] = value
|
||||
if isinstance(name, slice):
|
||||
self[''][name] = value
|
||||
else:
|
||||
try: param = self.__getitem__(name, paramlist)
|
||||
except AttributeError as a: raise a
|
||||
param[:] = value
|
||||
|
||||
def __setattr__(self, name, val):
|
||||
# override the default behaviour, if setting a param, so broadcasting can by used
|
||||
if hasattr(self, '_parameters_'):
|
||||
|
|
@ -379,7 +280,7 @@ class Parameterized(Parameterizable, Pickleable, Gradcheckable):
|
|||
cl = max([len(str(x)) if x else 0 for x in constrs + ["Constraint"]])
|
||||
tl = max([len(str(x)) if x else 0 for x in ts + ["Tied to"]])
|
||||
pl = max([len(str(x)) if x else 0 for x in prirs + ["Prior"]])
|
||||
format_spec = " \033[1m{{name:<{0}s}}\033[0;0m | {{desc:^{1}s}} | {{const:^{2}s}} | {{pri:^{3}s}} | {{t:^{4}s}}".format(nl, sl, cl, pl, tl)
|
||||
format_spec = " \033[1m{{name:<{0}s}}\033[0;0m | {{desc:>{1}s}} | {{const:^{2}s}} | {{pri:^{3}s}} | {{t:^{4}s}}".format(nl, sl, cl, pl, tl)
|
||||
to_print = []
|
||||
for n, d, c, t, p in itertools.izip(names, desc, constrs, ts, prirs):
|
||||
to_print.append(format_spec.format(name=n, desc=d, const=c, t=t, pri=p))
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue