mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-27 14:25:16 +02:00
observer pattern has a handle to trigger only > min_priority observers
This commit is contained in:
parent
058ab679e7
commit
2771e3f71f
8 changed files with 181 additions and 90 deletions
|
|
@ -60,20 +60,6 @@ class Model(Parameterized):
|
||||||
self.priors = state.pop()
|
self.priors = state.pop()
|
||||||
Parameterized._setstate(self, state)
|
Parameterized._setstate(self, state)
|
||||||
|
|
||||||
def randomize(self):
|
|
||||||
"""
|
|
||||||
Randomize the model.
|
|
||||||
Make this draw from the prior if one exists, else draw from N(0,1)
|
|
||||||
"""
|
|
||||||
# first take care of all parameters (from N(0,1))
|
|
||||||
# x = self._get_params_transformed()
|
|
||||||
x = np.random.randn(self.size_transformed)
|
|
||||||
x = self._untransform_params(x)
|
|
||||||
# now draw from prior where possible
|
|
||||||
[np.put(x, ind, p.rvs(ind.size)) for p, ind in self.priors.iteritems() if not p is None]
|
|
||||||
self._set_params(x)
|
|
||||||
# self._set_params_transformed(self._get_params_transformed()) # makes sure all of the tied parameters get the same init (since there's only one prior object...)
|
|
||||||
|
|
||||||
def optimize_restarts(self, num_restarts=10, robust=False, verbose=True, parallel=False, num_processes=None, **kwargs):
|
def optimize_restarts(self, num_restarts=10, robust=False, verbose=True, parallel=False, num_processes=None, **kwargs):
|
||||||
"""
|
"""
|
||||||
Perform random restarts of the model, and set the model to the best
|
Perform random restarts of the model, and set the model to the best
|
||||||
|
|
|
||||||
|
|
@ -66,6 +66,7 @@ class ObservableArray(np.ndarray, Observable):
|
||||||
|
|
||||||
def __getslice__(self, start, stop):
|
def __getslice__(self, start, stop):
|
||||||
return self.__getitem__(slice(start, stop))
|
return self.__getitem__(slice(start, stop))
|
||||||
|
|
||||||
def __setslice__(self, start, stop, val):
|
def __setslice__(self, start, stop, val):
|
||||||
return self.__setitem__(slice(start, stop), val)
|
return self.__setitem__(slice(start, stop), val)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -3,8 +3,8 @@
|
||||||
|
|
||||||
import itertools
|
import itertools
|
||||||
import numpy
|
import numpy
|
||||||
from parameter_core import Constrainable, Gradcheckable, Indexable, Parentable, adjust_name_for_printing
|
from parameter_core import OptimizationHandlable, Gradcheckable, adjust_name_for_printing
|
||||||
from array_core import ObservableArray, ParamList
|
from array_core import ObservableArray
|
||||||
|
|
||||||
###### printing
|
###### printing
|
||||||
__constraints_name__ = "Constraint"
|
__constraints_name__ = "Constraint"
|
||||||
|
|
@ -15,7 +15,7 @@ __precision__ = numpy.get_printoptions()['precision'] # numpy printing precision
|
||||||
__print_threshold__ = 5
|
__print_threshold__ = 5
|
||||||
######
|
######
|
||||||
|
|
||||||
class Param(Constrainable, ObservableArray, Gradcheckable):
|
class Param(OptimizationHandlable, ObservableArray, Gradcheckable):
|
||||||
"""
|
"""
|
||||||
Parameter object for GPy models.
|
Parameter object for GPy models.
|
||||||
|
|
||||||
|
|
@ -148,8 +148,11 @@ class Param(Constrainable, ObservableArray, Gradcheckable):
|
||||||
#===========================================================================
|
#===========================================================================
|
||||||
# get/set parameters
|
# get/set parameters
|
||||||
#===========================================================================
|
#===========================================================================
|
||||||
def _set_params(self, param, update=True):
|
def _set_params(self, param, trigger_parent=True):
|
||||||
self.flat = param
|
self.flat = param
|
||||||
|
if trigger_parent: min_priority = None
|
||||||
|
else: min_priority = -numpy.inf
|
||||||
|
self._notify_observers(None, min_priority)
|
||||||
|
|
||||||
def _get_params(self):
|
def _get_params(self):
|
||||||
return self.flat
|
return self.flat
|
||||||
|
|
@ -175,9 +178,6 @@ class Param(Constrainable, ObservableArray, Gradcheckable):
|
||||||
|
|
||||||
def __setitem__(self, s, val):
|
def __setitem__(self, s, val):
|
||||||
super(Param, self).__setitem__(s, val)
|
super(Param, self).__setitem__(s, val)
|
||||||
if self.has_parent():
|
|
||||||
self._direct_parent_._notify_parameters_changed()
|
|
||||||
#self._notify_observers()
|
|
||||||
|
|
||||||
#===========================================================================
|
#===========================================================================
|
||||||
# Index Operations:
|
# Index Operations:
|
||||||
|
|
@ -205,6 +205,7 @@ class Param(Constrainable, ObservableArray, Gradcheckable):
|
||||||
ind = self._indices(slice_index)
|
ind = self._indices(slice_index)
|
||||||
if ind.ndim < 2: ind = ind[:, None]
|
if ind.ndim < 2: ind = ind[:, None]
|
||||||
return numpy.asarray(numpy.apply_along_axis(lambda x: numpy.sum(extended_realshape * x), 1, ind), dtype=int)
|
return numpy.asarray(numpy.apply_along_axis(lambda x: numpy.sum(extended_realshape * x), 1, ind), dtype=int)
|
||||||
|
|
||||||
def _expand_index(self, slice_index=None):
|
def _expand_index(self, slice_index=None):
|
||||||
# this calculates the full indexing arrays from the slicing objects given by get_item for _real..._ attributes
|
# this calculates the full indexing arrays from the slicing objects given by get_item for _real..._ attributes
|
||||||
# it basically translates slices to their respective index arrays and turns negative indices around
|
# it basically translates slices to their respective index arrays and turns negative indices around
|
||||||
|
|
@ -346,6 +347,7 @@ class ParamConcatenation(object):
|
||||||
See :py:class:`GPy.core.parameter.Param` for more details on constraining.
|
See :py:class:`GPy.core.parameter.Param` for more details on constraining.
|
||||||
"""
|
"""
|
||||||
# self.params = params
|
# self.params = params
|
||||||
|
from lists_and_dicts import ParamList
|
||||||
self.params = ParamList([])
|
self.params = ParamList([])
|
||||||
for p in params:
|
for p in params:
|
||||||
for p in p.flattened_parameters:
|
for p in p.flattened_parameters:
|
||||||
|
|
|
||||||
|
|
@ -31,10 +31,24 @@ class Observable(object):
|
||||||
for r in to_remove:
|
for r in to_remove:
|
||||||
self._observer_callables_.remove(r)
|
self._observer_callables_.remove(r)
|
||||||
|
|
||||||
def _notify_observers(self, which=None):
|
def _notify_observers(self, which=None, min_priority=None):
|
||||||
|
"""
|
||||||
|
Notifies all observers. Which is the element, which kicked off this
|
||||||
|
notification loop.
|
||||||
|
|
||||||
|
NOTE: notifies only observers with priority p > min_priority!
|
||||||
|
^^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
|
:param which: object, which started this notification loop
|
||||||
|
:param min_priority: only notify observers with priority > min_priority
|
||||||
|
if min_priority is None, notify all observers in order
|
||||||
|
"""
|
||||||
if which is None:
|
if which is None:
|
||||||
which = self
|
which = self
|
||||||
[callble(which) for _, _, callble in heapq.nlargest(len(self._observer_callables_), self._observer_callables_)]
|
if min_priority is None:
|
||||||
|
[callble(which) for _, _, callble in heapq.nlargest(len(self._observer_callables_), self._observer_callables_)]
|
||||||
|
else:
|
||||||
|
[callble(which) for p, _, callble in heapq.nlargest(len(self._observer_callables_), self._observer_callables_) if p > min_priority]
|
||||||
|
|
||||||
class Pickleable(object):
|
class Pickleable(object):
|
||||||
def _getstate(self):
|
def _getstate(self):
|
||||||
|
|
@ -210,9 +224,9 @@ class Constrainable(Nameable, Indexable):
|
||||||
#===========================================================================
|
#===========================================================================
|
||||||
# Prior Operations
|
# Prior Operations
|
||||||
#===========================================================================
|
#===========================================================================
|
||||||
def set_prior(self, prior, warning=True, update=True):
|
def set_prior(self, prior, warning=True):
|
||||||
repriorized = self.unset_priors()
|
repriorized = self.unset_priors()
|
||||||
self._add_to_index_operations(self.priors, repriorized, prior, warning, update)
|
self._add_to_index_operations(self.priors, repriorized, prior, warning)
|
||||||
|
|
||||||
def unset_priors(self, *priors):
|
def unset_priors(self, *priors):
|
||||||
return self._remove_from_index_operations(self.priors, priors)
|
return self._remove_from_index_operations(self.priors, priors)
|
||||||
|
|
@ -238,7 +252,7 @@ class Constrainable(Nameable, Indexable):
|
||||||
# Constrain operations -> done
|
# Constrain operations -> done
|
||||||
#===========================================================================
|
#===========================================================================
|
||||||
|
|
||||||
def constrain(self, transform, warning=True, update=True):
|
def constrain(self, transform, warning=True):
|
||||||
"""
|
"""
|
||||||
:param transform: the :py:class:`GPy.core.transformations.Transformation`
|
:param transform: the :py:class:`GPy.core.transformations.Transformation`
|
||||||
to constrain the this parameter to.
|
to constrain the this parameter to.
|
||||||
|
|
@ -248,9 +262,9 @@ class Constrainable(Nameable, Indexable):
|
||||||
:py:class:`GPy.core.transformations.Transformation`.
|
:py:class:`GPy.core.transformations.Transformation`.
|
||||||
"""
|
"""
|
||||||
if isinstance(transform, Transformation):
|
if isinstance(transform, Transformation):
|
||||||
self._set_params(transform.initialize(self._get_params()), update=False)
|
self._set_params(transform.initialize(self._get_params()), trigger_parent=True)
|
||||||
reconstrained = self.unconstrain()
|
reconstrained = self.unconstrain()
|
||||||
self._add_to_index_operations(self.constraints, reconstrained, transform, warning, update)
|
self._add_to_index_operations(self.constraints, reconstrained, transform, warning)
|
||||||
|
|
||||||
def unconstrain(self, *transforms):
|
def unconstrain(self, *transforms):
|
||||||
"""
|
"""
|
||||||
|
|
@ -261,30 +275,30 @@ class Constrainable(Nameable, Indexable):
|
||||||
"""
|
"""
|
||||||
return self._remove_from_index_operations(self.constraints, transforms)
|
return self._remove_from_index_operations(self.constraints, transforms)
|
||||||
|
|
||||||
def constrain_positive(self, warning=True, update=True):
|
def constrain_positive(self, warning=True):
|
||||||
"""
|
"""
|
||||||
:param warning: print a warning if re-constraining parameters.
|
:param warning: print a warning if re-constraining parameters.
|
||||||
|
|
||||||
Constrain this parameter to the default positive constraint.
|
Constrain this parameter to the default positive constraint.
|
||||||
"""
|
"""
|
||||||
self.constrain(Logexp(), warning=warning, update=update)
|
self.constrain(Logexp(), warning=warning)
|
||||||
|
|
||||||
def constrain_negative(self, warning=True, update=True):
|
def constrain_negative(self, warning=True):
|
||||||
"""
|
"""
|
||||||
:param warning: print a warning if re-constraining parameters.
|
:param warning: print a warning if re-constraining parameters.
|
||||||
|
|
||||||
Constrain this parameter to the default negative constraint.
|
Constrain this parameter to the default negative constraint.
|
||||||
"""
|
"""
|
||||||
self.constrain(NegativeLogexp(), warning=warning, update=update)
|
self.constrain(NegativeLogexp(), warning=warning)
|
||||||
|
|
||||||
def constrain_bounded(self, lower, upper, warning=True, update=True):
|
def constrain_bounded(self, lower, upper, warning=True):
|
||||||
"""
|
"""
|
||||||
:param lower, upper: the limits to bound this parameter to
|
:param lower, upper: the limits to bound this parameter to
|
||||||
:param warning: print a warning if re-constraining parameters.
|
:param warning: print a warning if re-constraining parameters.
|
||||||
|
|
||||||
Constrain this parameter to lie within the given range.
|
Constrain this parameter to lie within the given range.
|
||||||
"""
|
"""
|
||||||
self.constrain(Logistic(lower, upper), warning=warning, update=update)
|
self.constrain(Logistic(lower, upper), warning=warning)
|
||||||
|
|
||||||
def unconstrain_positive(self):
|
def unconstrain_positive(self):
|
||||||
"""
|
"""
|
||||||
|
|
@ -314,12 +328,10 @@ class Constrainable(Nameable, Indexable):
|
||||||
for p in self._parameters_:
|
for p in self._parameters_:
|
||||||
p._parent_changed(parent)
|
p._parent_changed(parent)
|
||||||
|
|
||||||
def _add_to_index_operations(self, which, reconstrained, transform, warning, update):
|
def _add_to_index_operations(self, which, reconstrained, transform, warning):
|
||||||
if warning and reconstrained.size > 0:
|
if warning and reconstrained.size > 0:
|
||||||
print "WARNING: reconstraining parameters {}".format(self.parameter_names() or self.name)
|
print "WARNING: reconstraining parameters {}".format(self.parameter_names() or self.name)
|
||||||
which.add(transform, self._raveled_index())
|
which.add(transform, self._raveled_index())
|
||||||
if update:
|
|
||||||
self._notify_observers()
|
|
||||||
|
|
||||||
def _remove_from_index_operations(self, which, transforms):
|
def _remove_from_index_operations(self, which, transforms):
|
||||||
if len(transforms) == 0:
|
if len(transforms) == 0:
|
||||||
|
|
@ -334,8 +346,69 @@ class Constrainable(Nameable, Indexable):
|
||||||
|
|
||||||
return removed
|
return removed
|
||||||
|
|
||||||
|
class OptimizationHandlable(Constrainable, Observable):
|
||||||
|
def _get_params_transformed(self):
|
||||||
|
# transformed parameters (apply transformation rules)
|
||||||
|
p = self._get_params()
|
||||||
|
[np.put(p, ind, c.finv(p[ind])) for c, ind in self.constraints.iteritems() if c != __fixed__]
|
||||||
|
if self._has_fixes():
|
||||||
|
return p[self._fixes_]
|
||||||
|
return p
|
||||||
|
|
||||||
|
def _set_params_transformed(self, p):
|
||||||
|
# inverse apply transformations for parameters and set the resulting parameters
|
||||||
|
self._set_params(self._untransform_params(p))
|
||||||
|
|
||||||
|
def _untransform_params(self, p):
|
||||||
|
p = p.copy()
|
||||||
|
if self._has_fixes(): tmp = self._get_params(); tmp[self._fixes_] = p; p = tmp; del tmp
|
||||||
|
[np.put(p, ind, c.f(p[ind])) for c, ind in self.constraints.iteritems() if c != __fixed__]
|
||||||
|
return p
|
||||||
|
|
||||||
|
def _get_params(self):
|
||||||
|
# don't overwrite this anymore!
|
||||||
|
if not self.size:
|
||||||
|
return np.empty(shape=(0,), dtype=np.float64)
|
||||||
|
return np.hstack([x._get_params() for x in self._parameters_ if x.size > 0])
|
||||||
|
|
||||||
class Parameterizable(Constrainable, Observable):
|
def _set_params(self, params, trigger_parent=True):
|
||||||
|
# don't overwrite this anymore!
|
||||||
|
raise NotImplementedError, "This needs to be implemented seperately"
|
||||||
|
|
||||||
|
#===========================================================================
|
||||||
|
# Optimization handles:
|
||||||
|
#===========================================================================
|
||||||
|
def _get_param_names(self):
|
||||||
|
n = np.array([p.hirarchy_name() + '[' + str(i) + ']' for p in self.flattened_parameters for i in p._indices()])
|
||||||
|
return n
|
||||||
|
def _get_param_names_transformed(self):
|
||||||
|
n = self._get_param_names()
|
||||||
|
if self._has_fixes():
|
||||||
|
return n[self._fixes_]
|
||||||
|
return n
|
||||||
|
|
||||||
|
#===========================================================================
|
||||||
|
# Randomizeable
|
||||||
|
#===========================================================================
|
||||||
|
def randomize(self):
|
||||||
|
"""
|
||||||
|
Randomize the model.
|
||||||
|
Make this draw from the prior if one exists, else draw from N(0,1)
|
||||||
|
"""
|
||||||
|
import numpy as np
|
||||||
|
# first take care of all parameters (from N(0,1))
|
||||||
|
# x = self._get_params_transformed()
|
||||||
|
x = np.random.randn(self.size_transformed)
|
||||||
|
x = self._untransform_params(x)
|
||||||
|
# now draw from prior where possible
|
||||||
|
[np.put(x, ind, p.rvs(ind.size)) for p, ind in self.priors.iteritems() if not p is None]
|
||||||
|
self._set_params(x)
|
||||||
|
# self._set_params_transformed(self._get_params_transformed()) # makes sure all of the tied parameters get the same init (since there's only one prior object...)
|
||||||
|
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
class Parameterizable(OptimizationHandlable):
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super(Parameterizable, self).__init__(*args, **kwargs)
|
super(Parameterizable, self).__init__(*args, **kwargs)
|
||||||
from GPy.core.parameterization.lists_and_dicts import ParamList
|
from GPy.core.parameterization.lists_and_dicts import ParamList
|
||||||
|
|
@ -382,23 +455,21 @@ class Parameterizable(Constrainable, Observable):
|
||||||
import itertools
|
import itertools
|
||||||
[p._collect_gradient(target[s]) for p, s in itertools.izip(self._parameters_, self._param_slices_)]
|
[p._collect_gradient(target[s]) for p, s in itertools.izip(self._parameters_, self._param_slices_)]
|
||||||
|
|
||||||
|
def _set_params(self, params, trigger_parent=True):
|
||||||
|
import itertools
|
||||||
|
[p._set_params(params[s], trigger_parent=False) for p, s in itertools.izip(self._parameters_, self._param_slices_)]
|
||||||
|
if trigger_parent: min_priority = None
|
||||||
|
else: min_priority = -np.inf
|
||||||
|
self._notify_observers(None, min_priority)
|
||||||
|
|
||||||
def _set_gradient(self, g):
|
def _set_gradient(self, g):
|
||||||
import itertools
|
import itertools
|
||||||
[p._set_gradient(g[s]) for p, s in itertools.izip(self._parameters_, self._param_slices_)]
|
[p._set_gradient(g[s]) for p, s in itertools.izip(self._parameters_, self._param_slices_)]
|
||||||
|
|
||||||
def _get_params(self):
|
|
||||||
import numpy as np
|
#===========================================================================
|
||||||
# don't overwrite this anymore!
|
# TODO: not working yet
|
||||||
if not self.size:
|
#===========================================================================
|
||||||
return np.empty(shape=(0,), dtype=np.float64)
|
|
||||||
return np.hstack([x._get_params() for x in self._parameters_ if x.size > 0])
|
|
||||||
|
|
||||||
def _set_params(self, params, update=True):
|
|
||||||
# don't overwrite this anymore!
|
|
||||||
import itertools
|
|
||||||
[p._set_params(params[s]) for p, s in itertools.izip(self._parameters_, self._param_slices_)]
|
|
||||||
self._notify_parameters_changed()
|
|
||||||
|
|
||||||
def copy(self):
|
def copy(self):
|
||||||
"""Returns a (deep) copy of the current model"""
|
"""Returns a (deep) copy of the current model"""
|
||||||
import copy
|
import copy
|
||||||
|
|
@ -429,11 +500,6 @@ class Parameterizable(Constrainable, Observable):
|
||||||
s.add_parameter(p)
|
s.add_parameter(p)
|
||||||
|
|
||||||
return s
|
return s
|
||||||
|
|
||||||
def _notify_parameters_changed(self):
|
|
||||||
self.parameters_changed()
|
|
||||||
if self.has_parent():
|
|
||||||
self._direct_parent_._notify_parameters_changed()
|
|
||||||
|
|
||||||
def parameters_changed(self):
|
def parameters_changed(self):
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -58,6 +58,7 @@ class Parameterized(Parameterizable, Pickleable, Gradcheckable):
|
||||||
self._in_init_ = True
|
self._in_init_ = True
|
||||||
self._parameters_ = ParamList()
|
self._parameters_ = ParamList()
|
||||||
self.size = sum(p.size for p in self._parameters_)
|
self.size = sum(p.size for p in self._parameters_)
|
||||||
|
self.add_observer(self, self._parameters_changed_notification, -100)
|
||||||
if not self._has_fixes():
|
if not self._has_fixes():
|
||||||
self._fixes_ = None
|
self._fixes_ = None
|
||||||
self._param_slices_ = []
|
self._param_slices_ = []
|
||||||
|
|
@ -65,7 +66,7 @@ class Parameterized(Parameterizable, Pickleable, Gradcheckable):
|
||||||
del self._in_init_
|
del self._in_init_
|
||||||
|
|
||||||
def build_pydot(self, G=None):
|
def build_pydot(self, G=None):
|
||||||
import pydot
|
import pydot # @UnresolvedImport
|
||||||
iamroot = False
|
iamroot = False
|
||||||
if G is None:
|
if G is None:
|
||||||
G = pydot.Dot(graph_type='digraph')
|
G = pydot.Dot(graph_type='digraph')
|
||||||
|
|
@ -116,7 +117,7 @@ class Parameterized(Parameterizable, Pickleable, Gradcheckable):
|
||||||
self.constraints.update(param.constraints, start)
|
self.constraints.update(param.constraints, start)
|
||||||
self.priors.update(param.priors, start)
|
self.priors.update(param.priors, start)
|
||||||
self._parameters_.insert(index, param)
|
self._parameters_.insert(index, param)
|
||||||
param.add_observer(self, self._pass_through_notify, -1)
|
param.add_observer(self, self._pass_through_notify_observers, -np.inf)
|
||||||
self.size += param.size
|
self.size += param.size
|
||||||
else:
|
else:
|
||||||
raise RuntimeError, """Parameter exists already added and no copy made"""
|
raise RuntimeError, """Parameter exists already added and no copy made"""
|
||||||
|
|
@ -173,9 +174,10 @@ class Parameterized(Parameterizable, Pickleable, Gradcheckable):
|
||||||
#===========================================================================
|
#===========================================================================
|
||||||
# notification system
|
# notification system
|
||||||
#===========================================================================
|
#===========================================================================
|
||||||
def _pass_through_notify(self, which):
|
def _parameters_changed_notification(self, which):
|
||||||
|
self.parameters_changed()
|
||||||
|
def _pass_through_notify_observers(self, which):
|
||||||
self._notify_observers(which)
|
self._notify_observers(which)
|
||||||
|
|
||||||
#===========================================================================
|
#===========================================================================
|
||||||
# Pickling operations
|
# Pickling operations
|
||||||
#===========================================================================
|
#===========================================================================
|
||||||
|
|
@ -244,32 +246,7 @@ class Parameterized(Parameterizable, Pickleable, Gradcheckable):
|
||||||
g[self._offset_for(p) + numpy.array(list(i))] += g[self._raveled_index_for(t)]
|
g[self._offset_for(p) + numpy.array(list(i))] += g[self._raveled_index_for(t)]
|
||||||
if self._has_fixes(): return g[self._fixes_]
|
if self._has_fixes(): return g[self._fixes_]
|
||||||
return g
|
return g
|
||||||
#===========================================================================
|
|
||||||
# Optimization handles:
|
|
||||||
#===========================================================================
|
|
||||||
def _get_param_names(self):
|
|
||||||
n = numpy.array([p.hirarchy_name() + '[' + str(i) + ']' for p in self.flattened_parameters for i in p._indices()])
|
|
||||||
return n
|
|
||||||
def _get_param_names_transformed(self):
|
|
||||||
n = self._get_param_names()
|
|
||||||
if self._has_fixes():
|
|
||||||
return n[self._fixes_]
|
|
||||||
return n
|
|
||||||
def _get_params_transformed(self):
|
|
||||||
# transformed parameters (apply transformation rules)
|
|
||||||
p = self._get_params()
|
|
||||||
[numpy.put(p, ind, c.finv(p[ind])) for c, ind in self.constraints.iteritems() if c != __fixed__]
|
|
||||||
if self._has_fixes():
|
|
||||||
return p[self._fixes_]
|
|
||||||
return p
|
|
||||||
def _set_params_transformed(self, p):
|
|
||||||
# inverse apply transformations for parameters and set the resulting parameters
|
|
||||||
self._set_params(self._untransform_params(p))
|
|
||||||
def _untransform_params(self, p):
|
|
||||||
p = p.copy()
|
|
||||||
if self._has_fixes(): tmp = self._get_params(); tmp[self._fixes_] = p; p = tmp; del tmp
|
|
||||||
[numpy.put(p, ind, c.f(p[ind])) for c, ind in self.constraints.iteritems() if c != __fixed__]
|
|
||||||
return p
|
|
||||||
#===========================================================================
|
#===========================================================================
|
||||||
# Indexable Handling
|
# Indexable Handling
|
||||||
#===========================================================================
|
#===========================================================================
|
||||||
|
|
@ -304,6 +281,7 @@ class Parameterized(Parameterizable, Pickleable, Gradcheckable):
|
||||||
this is not in the global view of things!
|
this is not in the global view of things!
|
||||||
"""
|
"""
|
||||||
return numpy.r_[:self.size]
|
return numpy.r_[:self.size]
|
||||||
|
|
||||||
#===========================================================================
|
#===========================================================================
|
||||||
# Fixing parameters:
|
# Fixing parameters:
|
||||||
#===========================================================================
|
#===========================================================================
|
||||||
|
|
@ -311,6 +289,7 @@ class Parameterized(Parameterizable, Pickleable, Gradcheckable):
|
||||||
if self._has_fixes():
|
if self._has_fixes():
|
||||||
return self._fixes_[self._raveled_index_for(param)]
|
return self._fixes_[self._raveled_index_for(param)]
|
||||||
return numpy.ones(self.size, dtype=bool)[self._raveled_index_for(param)]
|
return numpy.ones(self.size, dtype=bool)[self._raveled_index_for(param)]
|
||||||
|
|
||||||
#===========================================================================
|
#===========================================================================
|
||||||
# Convenience for fixed, tied checking of param:
|
# Convenience for fixed, tied checking of param:
|
||||||
#===========================================================================
|
#===========================================================================
|
||||||
|
|
|
||||||
|
|
@ -64,6 +64,36 @@ class Gaussian(Prior):
|
||||||
return np.random.randn(n) * self.sigma + self.mu
|
return np.random.randn(n) * self.sigma + self.mu
|
||||||
|
|
||||||
|
|
||||||
|
class Uniform(Prior):
|
||||||
|
domain = _REAL
|
||||||
|
_instances = []
|
||||||
|
def __new__(cls, lower, upper): # Singleton:
|
||||||
|
if cls._instances:
|
||||||
|
cls._instances[:] = [instance for instance in cls._instances if instance()]
|
||||||
|
for instance in cls._instances:
|
||||||
|
if instance().lower == lower and instance().upper == upper:
|
||||||
|
return instance()
|
||||||
|
o = super(Prior, cls).__new__(cls, lower, upper)
|
||||||
|
cls._instances.append(weakref.ref(o))
|
||||||
|
return cls._instances[-1]()
|
||||||
|
|
||||||
|
def __init__(self, lower, upper):
|
||||||
|
self.lower = float(lower)
|
||||||
|
self.upper = float(upper)
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return "[" + str(np.round(self.lower)) + ', ' + str(np.round(self.upper)) + ']'
|
||||||
|
|
||||||
|
def lnpdf(self, x):
|
||||||
|
region = (x>=self.lower) * (x<=self.upper)
|
||||||
|
return region
|
||||||
|
|
||||||
|
def lnpdf_grad(self, x):
|
||||||
|
return np.zeros(x.shape)
|
||||||
|
|
||||||
|
def rvs(self, n):
|
||||||
|
return np.random.uniform(self.lower, self.upper, size=n)
|
||||||
|
|
||||||
class LogGaussian(Prior):
|
class LogGaussian(Prior):
|
||||||
"""
|
"""
|
||||||
Implementation of the univariate *log*-Gaussian probability function, coupled with random variables.
|
Implementation of the univariate *log*-Gaussian probability function, coupled with random variables.
|
||||||
|
|
|
||||||
|
|
@ -6,8 +6,11 @@ import numpy as np
|
||||||
from domains import _POSITIVE,_NEGATIVE, _BOUNDED
|
from domains import _POSITIVE,_NEGATIVE, _BOUNDED
|
||||||
import weakref
|
import weakref
|
||||||
|
|
||||||
|
import sys
|
||||||
|
#_lim_val = -np.log(sys.float_info.epsilon)
|
||||||
|
|
||||||
_exp_lim_val = np.finfo(np.float64).max
|
_exp_lim_val = np.finfo(np.float64).max
|
||||||
_lim_val = np.log(_exp_lim_val)#-np.log(sys.float_info.epsilon)
|
_lim_val = np.log(_exp_lim_val)#
|
||||||
|
|
||||||
#===============================================================================
|
#===============================================================================
|
||||||
# Fixing constants
|
# Fixing constants
|
||||||
|
|
@ -35,7 +38,6 @@ class Transformation(object):
|
||||||
""" produce a sensible initial value for f(x)"""
|
""" produce a sensible initial value for f(x)"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
def plot(self, xlabel=r'transformed $\theta$', ylabel=r'$\theta$', axes=None, *args,**kw):
|
def plot(self, xlabel=r'transformed $\theta$', ylabel=r'$\theta$', axes=None, *args,**kw):
|
||||||
import sys
|
|
||||||
assert "matplotlib" in sys.modules, "matplotlib package has not been imported."
|
assert "matplotlib" in sys.modules, "matplotlib package has not been imported."
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
from ...plotting.matplot_dep import base_plots
|
from ...plotting.matplot_dep import base_plots
|
||||||
|
|
@ -52,7 +54,7 @@ class Transformation(object):
|
||||||
class Logexp(Transformation):
|
class Logexp(Transformation):
|
||||||
domain = _POSITIVE
|
domain = _POSITIVE
|
||||||
def f(self, x):
|
def f(self, x):
|
||||||
return np.where(x>_lim_val, x, np.log(1. + np.exp(np.clip(x, -np.inf, _lim_val))))
|
return np.where(x>_lim_val, x, np.log(1. + np.exp(np.clip(x, -_lim_val, _lim_val))))
|
||||||
#raises overflow warning: return np.where(x>_lim_val, x, np.log(1. + np.exp(x)))
|
#raises overflow warning: return np.where(x>_lim_val, x, np.log(1. + np.exp(x)))
|
||||||
def finv(self, f):
|
def finv(self, f):
|
||||||
return np.where(f>_lim_val, f, np.log(np.exp(f) - 1.))
|
return np.where(f>_lim_val, f, np.log(np.exp(f) - 1.))
|
||||||
|
|
|
||||||
|
|
@ -18,16 +18,26 @@ class ParameterizedTest(Parameterized):
|
||||||
params_changed_count = 0
|
params_changed_count = 0
|
||||||
def parameters_changed(self):
|
def parameters_changed(self):
|
||||||
self.params_changed_count += 1
|
self.params_changed_count += 1
|
||||||
|
def _set_params(self, params, trigger_parent=True):
|
||||||
|
Parameterized._set_params(self, params, trigger_parent=trigger_parent)
|
||||||
|
|
||||||
class Test(unittest.TestCase):
|
class Test(unittest.TestCase):
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.parent = ParamTestParent('test parent')
|
self.parent = ParamTestParent('test parent')
|
||||||
self.par = ParameterizedTest('test model')
|
self.par = ParameterizedTest('test model')
|
||||||
|
self.par2 = ParameterizedTest('test model 2')
|
||||||
self.p = Param('test parameter', numpy.random.normal(1,2,(10,3)))
|
self.p = Param('test parameter', numpy.random.normal(1,2,(10,3)))
|
||||||
|
|
||||||
self.par.add_parameter(self.p)
|
self.par.add_parameter(self.p)
|
||||||
|
self.par.add_parameter(Param('test1', numpy.random.normal(0,1,(1,))))
|
||||||
|
self.par.add_parameter(Param('test2', numpy.random.normal(0,1,(1,))))
|
||||||
|
|
||||||
|
self.par2.add_parameter(Param('par2 test1', numpy.random.normal(0,1,(1,))))
|
||||||
|
self.par2.add_parameter(Param('par2 test2', numpy.random.normal(0,1,(1,))))
|
||||||
|
|
||||||
self.parent.add_parameter(self.par)
|
self.parent.add_parameter(self.par)
|
||||||
|
self.parent.add_parameter(self.par2)
|
||||||
|
|
||||||
self._observer_triggered = None
|
self._observer_triggered = None
|
||||||
self._trigger_count = 0
|
self._trigger_count = 0
|
||||||
|
|
@ -84,7 +94,22 @@ class Test(unittest.TestCase):
|
||||||
self.assertEqual(self.par.params_changed_count, 0, 'no params changed yet')
|
self.assertEqual(self.par.params_changed_count, 0, 'no params changed yet')
|
||||||
self.par._set_params(numpy.ones(self.par.size))
|
self.par._set_params(numpy.ones(self.par.size))
|
||||||
self.assertEqual(self.par.params_changed_count, 1, 'now params changed')
|
self.assertEqual(self.par.params_changed_count, 1, 'now params changed')
|
||||||
self.assertEqual(self.par.params_changed_count, self.parent.parent_changed_count, 'parent should be triggered as often as param')
|
self.assertEqual(self.parent.parent_changed_count, self.par.params_changed_count)
|
||||||
|
|
||||||
|
self.parent._set_params(numpy.ones(self.parent.size) * 2)
|
||||||
|
self.assertEqual(self.par.params_changed_count, 2, 'now params changed')
|
||||||
|
self.assertEqual(self.parent.parent_changed_count, self.par.params_changed_count)
|
||||||
|
|
||||||
|
|
||||||
|
def test_priority_notify(self):
|
||||||
|
self.assertEqual(self.par.params_changed_count, 0)
|
||||||
|
self.par._notify_observers(0, None)
|
||||||
|
self.assertEqual(self.par.params_changed_count, 1)
|
||||||
|
self.assertEqual(self.parent.parent_changed_count, self.par.params_changed_count)
|
||||||
|
|
||||||
|
self.par._notify_observers(0, -numpy.inf)
|
||||||
|
self.assertEqual(self.par.params_changed_count, 2)
|
||||||
|
self.assertEqual(self.parent.parent_changed_count, 1)
|
||||||
|
|
||||||
def test_priority(self):
|
def test_priority(self):
|
||||||
self.par.add_observer(self, self._trigger, -1)
|
self.par.add_observer(self, self._trigger, -1)
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue