mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-04 17:22:39 +02:00
observer pattern has a handle to trigger only > min_priority observers
This commit is contained in:
parent
058ab679e7
commit
2771e3f71f
8 changed files with 181 additions and 90 deletions
|
|
@ -60,20 +60,6 @@ class Model(Parameterized):
|
|||
self.priors = state.pop()
|
||||
Parameterized._setstate(self, state)
|
||||
|
||||
def randomize(self):
|
||||
"""
|
||||
Randomize the model.
|
||||
Make this draw from the prior if one exists, else draw from N(0,1)
|
||||
"""
|
||||
# first take care of all parameters (from N(0,1))
|
||||
# x = self._get_params_transformed()
|
||||
x = np.random.randn(self.size_transformed)
|
||||
x = self._untransform_params(x)
|
||||
# now draw from prior where possible
|
||||
[np.put(x, ind, p.rvs(ind.size)) for p, ind in self.priors.iteritems() if not p is None]
|
||||
self._set_params(x)
|
||||
# self._set_params_transformed(self._get_params_transformed()) # makes sure all of the tied parameters get the same init (since there's only one prior object...)
|
||||
|
||||
def optimize_restarts(self, num_restarts=10, robust=False, verbose=True, parallel=False, num_processes=None, **kwargs):
|
||||
"""
|
||||
Perform random restarts of the model, and set the model to the best
|
||||
|
|
|
|||
|
|
@ -66,6 +66,7 @@ class ObservableArray(np.ndarray, Observable):
|
|||
|
||||
def __getslice__(self, start, stop):
|
||||
return self.__getitem__(slice(start, stop))
|
||||
|
||||
def __setslice__(self, start, stop, val):
|
||||
return self.__setitem__(slice(start, stop), val)
|
||||
|
||||
|
|
|
|||
|
|
@ -3,8 +3,8 @@
|
|||
|
||||
import itertools
|
||||
import numpy
|
||||
from parameter_core import Constrainable, Gradcheckable, Indexable, Parentable, adjust_name_for_printing
|
||||
from array_core import ObservableArray, ParamList
|
||||
from parameter_core import OptimizationHandlable, Gradcheckable, adjust_name_for_printing
|
||||
from array_core import ObservableArray
|
||||
|
||||
###### printing
|
||||
__constraints_name__ = "Constraint"
|
||||
|
|
@ -15,7 +15,7 @@ __precision__ = numpy.get_printoptions()['precision'] # numpy printing precision
|
|||
__print_threshold__ = 5
|
||||
######
|
||||
|
||||
class Param(Constrainable, ObservableArray, Gradcheckable):
|
||||
class Param(OptimizationHandlable, ObservableArray, Gradcheckable):
|
||||
"""
|
||||
Parameter object for GPy models.
|
||||
|
||||
|
|
@ -148,8 +148,11 @@ class Param(Constrainable, ObservableArray, Gradcheckable):
|
|||
#===========================================================================
|
||||
# get/set parameters
|
||||
#===========================================================================
|
||||
def _set_params(self, param, update=True):
|
||||
def _set_params(self, param, trigger_parent=True):
|
||||
self.flat = param
|
||||
if trigger_parent: min_priority = None
|
||||
else: min_priority = -numpy.inf
|
||||
self._notify_observers(None, min_priority)
|
||||
|
||||
def _get_params(self):
|
||||
return self.flat
|
||||
|
|
@ -175,9 +178,6 @@ class Param(Constrainable, ObservableArray, Gradcheckable):
|
|||
|
||||
def __setitem__(self, s, val):
|
||||
super(Param, self).__setitem__(s, val)
|
||||
if self.has_parent():
|
||||
self._direct_parent_._notify_parameters_changed()
|
||||
#self._notify_observers()
|
||||
|
||||
#===========================================================================
|
||||
# Index Operations:
|
||||
|
|
@ -205,6 +205,7 @@ class Param(Constrainable, ObservableArray, Gradcheckable):
|
|||
ind = self._indices(slice_index)
|
||||
if ind.ndim < 2: ind = ind[:, None]
|
||||
return numpy.asarray(numpy.apply_along_axis(lambda x: numpy.sum(extended_realshape * x), 1, ind), dtype=int)
|
||||
|
||||
def _expand_index(self, slice_index=None):
|
||||
# this calculates the full indexing arrays from the slicing objects given by get_item for _real..._ attributes
|
||||
# it basically translates slices to their respective index arrays and turns negative indices around
|
||||
|
|
@ -346,6 +347,7 @@ class ParamConcatenation(object):
|
|||
See :py:class:`GPy.core.parameter.Param` for more details on constraining.
|
||||
"""
|
||||
# self.params = params
|
||||
from lists_and_dicts import ParamList
|
||||
self.params = ParamList([])
|
||||
for p in params:
|
||||
for p in p.flattened_parameters:
|
||||
|
|
|
|||
|
|
@ -31,10 +31,24 @@ class Observable(object):
|
|||
for r in to_remove:
|
||||
self._observer_callables_.remove(r)
|
||||
|
||||
def _notify_observers(self, which=None):
|
||||
def _notify_observers(self, which=None, min_priority=None):
|
||||
"""
|
||||
Notifies all observers. Which is the element, which kicked off this
|
||||
notification loop.
|
||||
|
||||
NOTE: notifies only observers with priority p > min_priority!
|
||||
^^^^^^^^^^^^^^^^
|
||||
|
||||
:param which: object, which started this notification loop
|
||||
:param min_priority: only notify observers with priority > min_priority
|
||||
if min_priority is None, notify all observers in order
|
||||
"""
|
||||
if which is None:
|
||||
which = self
|
||||
[callble(which) for _, _, callble in heapq.nlargest(len(self._observer_callables_), self._observer_callables_)]
|
||||
if min_priority is None:
|
||||
[callble(which) for _, _, callble in heapq.nlargest(len(self._observer_callables_), self._observer_callables_)]
|
||||
else:
|
||||
[callble(which) for p, _, callble in heapq.nlargest(len(self._observer_callables_), self._observer_callables_) if p > min_priority]
|
||||
|
||||
class Pickleable(object):
|
||||
def _getstate(self):
|
||||
|
|
@ -210,9 +224,9 @@ class Constrainable(Nameable, Indexable):
|
|||
#===========================================================================
|
||||
# Prior Operations
|
||||
#===========================================================================
|
||||
def set_prior(self, prior, warning=True, update=True):
|
||||
def set_prior(self, prior, warning=True):
|
||||
repriorized = self.unset_priors()
|
||||
self._add_to_index_operations(self.priors, repriorized, prior, warning, update)
|
||||
self._add_to_index_operations(self.priors, repriorized, prior, warning)
|
||||
|
||||
def unset_priors(self, *priors):
|
||||
return self._remove_from_index_operations(self.priors, priors)
|
||||
|
|
@ -238,7 +252,7 @@ class Constrainable(Nameable, Indexable):
|
|||
# Constrain operations -> done
|
||||
#===========================================================================
|
||||
|
||||
def constrain(self, transform, warning=True, update=True):
|
||||
def constrain(self, transform, warning=True):
|
||||
"""
|
||||
:param transform: the :py:class:`GPy.core.transformations.Transformation`
|
||||
to constrain the this parameter to.
|
||||
|
|
@ -248,9 +262,9 @@ class Constrainable(Nameable, Indexable):
|
|||
:py:class:`GPy.core.transformations.Transformation`.
|
||||
"""
|
||||
if isinstance(transform, Transformation):
|
||||
self._set_params(transform.initialize(self._get_params()), update=False)
|
||||
self._set_params(transform.initialize(self._get_params()), trigger_parent=True)
|
||||
reconstrained = self.unconstrain()
|
||||
self._add_to_index_operations(self.constraints, reconstrained, transform, warning, update)
|
||||
self._add_to_index_operations(self.constraints, reconstrained, transform, warning)
|
||||
|
||||
def unconstrain(self, *transforms):
|
||||
"""
|
||||
|
|
@ -261,30 +275,30 @@ class Constrainable(Nameable, Indexable):
|
|||
"""
|
||||
return self._remove_from_index_operations(self.constraints, transforms)
|
||||
|
||||
def constrain_positive(self, warning=True, update=True):
|
||||
def constrain_positive(self, warning=True):
|
||||
"""
|
||||
:param warning: print a warning if re-constraining parameters.
|
||||
|
||||
Constrain this parameter to the default positive constraint.
|
||||
"""
|
||||
self.constrain(Logexp(), warning=warning, update=update)
|
||||
self.constrain(Logexp(), warning=warning)
|
||||
|
||||
def constrain_negative(self, warning=True, update=True):
|
||||
def constrain_negative(self, warning=True):
|
||||
"""
|
||||
:param warning: print a warning if re-constraining parameters.
|
||||
|
||||
Constrain this parameter to the default negative constraint.
|
||||
"""
|
||||
self.constrain(NegativeLogexp(), warning=warning, update=update)
|
||||
self.constrain(NegativeLogexp(), warning=warning)
|
||||
|
||||
def constrain_bounded(self, lower, upper, warning=True, update=True):
|
||||
def constrain_bounded(self, lower, upper, warning=True):
|
||||
"""
|
||||
:param lower, upper: the limits to bound this parameter to
|
||||
:param warning: print a warning if re-constraining parameters.
|
||||
|
||||
Constrain this parameter to lie within the given range.
|
||||
"""
|
||||
self.constrain(Logistic(lower, upper), warning=warning, update=update)
|
||||
self.constrain(Logistic(lower, upper), warning=warning)
|
||||
|
||||
def unconstrain_positive(self):
|
||||
"""
|
||||
|
|
@ -314,12 +328,10 @@ class Constrainable(Nameable, Indexable):
|
|||
for p in self._parameters_:
|
||||
p._parent_changed(parent)
|
||||
|
||||
def _add_to_index_operations(self, which, reconstrained, transform, warning, update):
|
||||
def _add_to_index_operations(self, which, reconstrained, transform, warning):
|
||||
if warning and reconstrained.size > 0:
|
||||
print "WARNING: reconstraining parameters {}".format(self.parameter_names() or self.name)
|
||||
which.add(transform, self._raveled_index())
|
||||
if update:
|
||||
self._notify_observers()
|
||||
|
||||
def _remove_from_index_operations(self, which, transforms):
|
||||
if len(transforms) == 0:
|
||||
|
|
@ -334,8 +346,69 @@ class Constrainable(Nameable, Indexable):
|
|||
|
||||
return removed
|
||||
|
||||
class OptimizationHandlable(Constrainable, Observable):
|
||||
def _get_params_transformed(self):
|
||||
# transformed parameters (apply transformation rules)
|
||||
p = self._get_params()
|
||||
[np.put(p, ind, c.finv(p[ind])) for c, ind in self.constraints.iteritems() if c != __fixed__]
|
||||
if self._has_fixes():
|
||||
return p[self._fixes_]
|
||||
return p
|
||||
|
||||
def _set_params_transformed(self, p):
|
||||
# inverse apply transformations for parameters and set the resulting parameters
|
||||
self._set_params(self._untransform_params(p))
|
||||
|
||||
def _untransform_params(self, p):
|
||||
p = p.copy()
|
||||
if self._has_fixes(): tmp = self._get_params(); tmp[self._fixes_] = p; p = tmp; del tmp
|
||||
[np.put(p, ind, c.f(p[ind])) for c, ind in self.constraints.iteritems() if c != __fixed__]
|
||||
return p
|
||||
|
||||
def _get_params(self):
|
||||
# don't overwrite this anymore!
|
||||
if not self.size:
|
||||
return np.empty(shape=(0,), dtype=np.float64)
|
||||
return np.hstack([x._get_params() for x in self._parameters_ if x.size > 0])
|
||||
|
||||
class Parameterizable(Constrainable, Observable):
|
||||
def _set_params(self, params, trigger_parent=True):
|
||||
# don't overwrite this anymore!
|
||||
raise NotImplementedError, "This needs to be implemented seperately"
|
||||
|
||||
#===========================================================================
|
||||
# Optimization handles:
|
||||
#===========================================================================
|
||||
def _get_param_names(self):
|
||||
n = np.array([p.hirarchy_name() + '[' + str(i) + ']' for p in self.flattened_parameters for i in p._indices()])
|
||||
return n
|
||||
def _get_param_names_transformed(self):
|
||||
n = self._get_param_names()
|
||||
if self._has_fixes():
|
||||
return n[self._fixes_]
|
||||
return n
|
||||
|
||||
#===========================================================================
|
||||
# Randomizeable
|
||||
#===========================================================================
|
||||
def randomize(self):
|
||||
"""
|
||||
Randomize the model.
|
||||
Make this draw from the prior if one exists, else draw from N(0,1)
|
||||
"""
|
||||
import numpy as np
|
||||
# first take care of all parameters (from N(0,1))
|
||||
# x = self._get_params_transformed()
|
||||
x = np.random.randn(self.size_transformed)
|
||||
x = self._untransform_params(x)
|
||||
# now draw from prior where possible
|
||||
[np.put(x, ind, p.rvs(ind.size)) for p, ind in self.priors.iteritems() if not p is None]
|
||||
self._set_params(x)
|
||||
# self._set_params_transformed(self._get_params_transformed()) # makes sure all of the tied parameters get the same init (since there's only one prior object...)
|
||||
|
||||
|
||||
import numpy as np
|
||||
|
||||
class Parameterizable(OptimizationHandlable):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(Parameterizable, self).__init__(*args, **kwargs)
|
||||
from GPy.core.parameterization.lists_and_dicts import ParamList
|
||||
|
|
@ -382,23 +455,21 @@ class Parameterizable(Constrainable, Observable):
|
|||
import itertools
|
||||
[p._collect_gradient(target[s]) for p, s in itertools.izip(self._parameters_, self._param_slices_)]
|
||||
|
||||
def _set_params(self, params, trigger_parent=True):
|
||||
import itertools
|
||||
[p._set_params(params[s], trigger_parent=False) for p, s in itertools.izip(self._parameters_, self._param_slices_)]
|
||||
if trigger_parent: min_priority = None
|
||||
else: min_priority = -np.inf
|
||||
self._notify_observers(None, min_priority)
|
||||
|
||||
def _set_gradient(self, g):
|
||||
import itertools
|
||||
[p._set_gradient(g[s]) for p, s in itertools.izip(self._parameters_, self._param_slices_)]
|
||||
|
||||
def _get_params(self):
|
||||
import numpy as np
|
||||
# don't overwrite this anymore!
|
||||
if not self.size:
|
||||
return np.empty(shape=(0,), dtype=np.float64)
|
||||
return np.hstack([x._get_params() for x in self._parameters_ if x.size > 0])
|
||||
|
||||
def _set_params(self, params, update=True):
|
||||
# don't overwrite this anymore!
|
||||
import itertools
|
||||
[p._set_params(params[s]) for p, s in itertools.izip(self._parameters_, self._param_slices_)]
|
||||
self._notify_parameters_changed()
|
||||
|
||||
|
||||
#===========================================================================
|
||||
# TODO: not working yet
|
||||
#===========================================================================
|
||||
def copy(self):
|
||||
"""Returns a (deep) copy of the current model"""
|
||||
import copy
|
||||
|
|
@ -429,11 +500,6 @@ class Parameterizable(Constrainable, Observable):
|
|||
s.add_parameter(p)
|
||||
|
||||
return s
|
||||
|
||||
def _notify_parameters_changed(self):
|
||||
self.parameters_changed()
|
||||
if self.has_parent():
|
||||
self._direct_parent_._notify_parameters_changed()
|
||||
|
||||
def parameters_changed(self):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -58,6 +58,7 @@ class Parameterized(Parameterizable, Pickleable, Gradcheckable):
|
|||
self._in_init_ = True
|
||||
self._parameters_ = ParamList()
|
||||
self.size = sum(p.size for p in self._parameters_)
|
||||
self.add_observer(self, self._parameters_changed_notification, -100)
|
||||
if not self._has_fixes():
|
||||
self._fixes_ = None
|
||||
self._param_slices_ = []
|
||||
|
|
@ -65,7 +66,7 @@ class Parameterized(Parameterizable, Pickleable, Gradcheckable):
|
|||
del self._in_init_
|
||||
|
||||
def build_pydot(self, G=None):
|
||||
import pydot
|
||||
import pydot # @UnresolvedImport
|
||||
iamroot = False
|
||||
if G is None:
|
||||
G = pydot.Dot(graph_type='digraph')
|
||||
|
|
@ -116,7 +117,7 @@ class Parameterized(Parameterizable, Pickleable, Gradcheckable):
|
|||
self.constraints.update(param.constraints, start)
|
||||
self.priors.update(param.priors, start)
|
||||
self._parameters_.insert(index, param)
|
||||
param.add_observer(self, self._pass_through_notify, -1)
|
||||
param.add_observer(self, self._pass_through_notify_observers, -np.inf)
|
||||
self.size += param.size
|
||||
else:
|
||||
raise RuntimeError, """Parameter exists already added and no copy made"""
|
||||
|
|
@ -173,9 +174,10 @@ class Parameterized(Parameterizable, Pickleable, Gradcheckable):
|
|||
#===========================================================================
|
||||
# notification system
|
||||
#===========================================================================
|
||||
def _pass_through_notify(self, which):
|
||||
def _parameters_changed_notification(self, which):
|
||||
self.parameters_changed()
|
||||
def _pass_through_notify_observers(self, which):
|
||||
self._notify_observers(which)
|
||||
|
||||
#===========================================================================
|
||||
# Pickling operations
|
||||
#===========================================================================
|
||||
|
|
@ -244,32 +246,7 @@ class Parameterized(Parameterizable, Pickleable, Gradcheckable):
|
|||
g[self._offset_for(p) + numpy.array(list(i))] += g[self._raveled_index_for(t)]
|
||||
if self._has_fixes(): return g[self._fixes_]
|
||||
return g
|
||||
#===========================================================================
|
||||
# Optimization handles:
|
||||
#===========================================================================
|
||||
def _get_param_names(self):
|
||||
n = numpy.array([p.hirarchy_name() + '[' + str(i) + ']' for p in self.flattened_parameters for i in p._indices()])
|
||||
return n
|
||||
def _get_param_names_transformed(self):
|
||||
n = self._get_param_names()
|
||||
if self._has_fixes():
|
||||
return n[self._fixes_]
|
||||
return n
|
||||
def _get_params_transformed(self):
|
||||
# transformed parameters (apply transformation rules)
|
||||
p = self._get_params()
|
||||
[numpy.put(p, ind, c.finv(p[ind])) for c, ind in self.constraints.iteritems() if c != __fixed__]
|
||||
if self._has_fixes():
|
||||
return p[self._fixes_]
|
||||
return p
|
||||
def _set_params_transformed(self, p):
|
||||
# inverse apply transformations for parameters and set the resulting parameters
|
||||
self._set_params(self._untransform_params(p))
|
||||
def _untransform_params(self, p):
|
||||
p = p.copy()
|
||||
if self._has_fixes(): tmp = self._get_params(); tmp[self._fixes_] = p; p = tmp; del tmp
|
||||
[numpy.put(p, ind, c.f(p[ind])) for c, ind in self.constraints.iteritems() if c != __fixed__]
|
||||
return p
|
||||
|
||||
#===========================================================================
|
||||
# Indexable Handling
|
||||
#===========================================================================
|
||||
|
|
@ -304,6 +281,7 @@ class Parameterized(Parameterizable, Pickleable, Gradcheckable):
|
|||
this is not in the global view of things!
|
||||
"""
|
||||
return numpy.r_[:self.size]
|
||||
|
||||
#===========================================================================
|
||||
# Fixing parameters:
|
||||
#===========================================================================
|
||||
|
|
@ -311,6 +289,7 @@ class Parameterized(Parameterizable, Pickleable, Gradcheckable):
|
|||
if self._has_fixes():
|
||||
return self._fixes_[self._raveled_index_for(param)]
|
||||
return numpy.ones(self.size, dtype=bool)[self._raveled_index_for(param)]
|
||||
|
||||
#===========================================================================
|
||||
# Convenience for fixed, tied checking of param:
|
||||
#===========================================================================
|
||||
|
|
|
|||
|
|
@ -64,6 +64,36 @@ class Gaussian(Prior):
|
|||
return np.random.randn(n) * self.sigma + self.mu
|
||||
|
||||
|
||||
class Uniform(Prior):
|
||||
domain = _REAL
|
||||
_instances = []
|
||||
def __new__(cls, lower, upper): # Singleton:
|
||||
if cls._instances:
|
||||
cls._instances[:] = [instance for instance in cls._instances if instance()]
|
||||
for instance in cls._instances:
|
||||
if instance().lower == lower and instance().upper == upper:
|
||||
return instance()
|
||||
o = super(Prior, cls).__new__(cls, lower, upper)
|
||||
cls._instances.append(weakref.ref(o))
|
||||
return cls._instances[-1]()
|
||||
|
||||
def __init__(self, lower, upper):
|
||||
self.lower = float(lower)
|
||||
self.upper = float(upper)
|
||||
|
||||
def __str__(self):
|
||||
return "[" + str(np.round(self.lower)) + ', ' + str(np.round(self.upper)) + ']'
|
||||
|
||||
def lnpdf(self, x):
|
||||
region = (x>=self.lower) * (x<=self.upper)
|
||||
return region
|
||||
|
||||
def lnpdf_grad(self, x):
|
||||
return np.zeros(x.shape)
|
||||
|
||||
def rvs(self, n):
|
||||
return np.random.uniform(self.lower, self.upper, size=n)
|
||||
|
||||
class LogGaussian(Prior):
|
||||
"""
|
||||
Implementation of the univariate *log*-Gaussian probability function, coupled with random variables.
|
||||
|
|
|
|||
|
|
@ -6,8 +6,11 @@ import numpy as np
|
|||
from domains import _POSITIVE,_NEGATIVE, _BOUNDED
|
||||
import weakref
|
||||
|
||||
import sys
|
||||
#_lim_val = -np.log(sys.float_info.epsilon)
|
||||
|
||||
_exp_lim_val = np.finfo(np.float64).max
|
||||
_lim_val = np.log(_exp_lim_val)#-np.log(sys.float_info.epsilon)
|
||||
_lim_val = np.log(_exp_lim_val)#
|
||||
|
||||
#===============================================================================
|
||||
# Fixing constants
|
||||
|
|
@ -35,7 +38,6 @@ class Transformation(object):
|
|||
""" produce a sensible initial value for f(x)"""
|
||||
raise NotImplementedError
|
||||
def plot(self, xlabel=r'transformed $\theta$', ylabel=r'$\theta$', axes=None, *args,**kw):
|
||||
import sys
|
||||
assert "matplotlib" in sys.modules, "matplotlib package has not been imported."
|
||||
import matplotlib.pyplot as plt
|
||||
from ...plotting.matplot_dep import base_plots
|
||||
|
|
@ -52,7 +54,7 @@ class Transformation(object):
|
|||
class Logexp(Transformation):
|
||||
domain = _POSITIVE
|
||||
def f(self, x):
|
||||
return np.where(x>_lim_val, x, np.log(1. + np.exp(np.clip(x, -np.inf, _lim_val))))
|
||||
return np.where(x>_lim_val, x, np.log(1. + np.exp(np.clip(x, -_lim_val, _lim_val))))
|
||||
#raises overflow warning: return np.where(x>_lim_val, x, np.log(1. + np.exp(x)))
|
||||
def finv(self, f):
|
||||
return np.where(f>_lim_val, f, np.log(np.exp(f) - 1.))
|
||||
|
|
|
|||
|
|
@ -18,16 +18,26 @@ class ParameterizedTest(Parameterized):
|
|||
params_changed_count = 0
|
||||
def parameters_changed(self):
|
||||
self.params_changed_count += 1
|
||||
def _set_params(self, params, trigger_parent=True):
|
||||
Parameterized._set_params(self, params, trigger_parent=trigger_parent)
|
||||
|
||||
class Test(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.parent = ParamTestParent('test parent')
|
||||
self.par = ParameterizedTest('test model')
|
||||
self.par2 = ParameterizedTest('test model 2')
|
||||
self.p = Param('test parameter', numpy.random.normal(1,2,(10,3)))
|
||||
|
||||
self.par.add_parameter(self.p)
|
||||
self.par.add_parameter(Param('test1', numpy.random.normal(0,1,(1,))))
|
||||
self.par.add_parameter(Param('test2', numpy.random.normal(0,1,(1,))))
|
||||
|
||||
self.par2.add_parameter(Param('par2 test1', numpy.random.normal(0,1,(1,))))
|
||||
self.par2.add_parameter(Param('par2 test2', numpy.random.normal(0,1,(1,))))
|
||||
|
||||
self.parent.add_parameter(self.par)
|
||||
self.parent.add_parameter(self.par2)
|
||||
|
||||
self._observer_triggered = None
|
||||
self._trigger_count = 0
|
||||
|
|
@ -84,7 +94,22 @@ class Test(unittest.TestCase):
|
|||
self.assertEqual(self.par.params_changed_count, 0, 'no params changed yet')
|
||||
self.par._set_params(numpy.ones(self.par.size))
|
||||
self.assertEqual(self.par.params_changed_count, 1, 'now params changed')
|
||||
self.assertEqual(self.par.params_changed_count, self.parent.parent_changed_count, 'parent should be triggered as often as param')
|
||||
self.assertEqual(self.parent.parent_changed_count, self.par.params_changed_count)
|
||||
|
||||
self.parent._set_params(numpy.ones(self.parent.size) * 2)
|
||||
self.assertEqual(self.par.params_changed_count, 2, 'now params changed')
|
||||
self.assertEqual(self.parent.parent_changed_count, self.par.params_changed_count)
|
||||
|
||||
|
||||
def test_priority_notify(self):
|
||||
self.assertEqual(self.par.params_changed_count, 0)
|
||||
self.par._notify_observers(0, None)
|
||||
self.assertEqual(self.par.params_changed_count, 1)
|
||||
self.assertEqual(self.parent.parent_changed_count, self.par.params_changed_count)
|
||||
|
||||
self.par._notify_observers(0, -numpy.inf)
|
||||
self.assertEqual(self.par.params_changed_count, 2)
|
||||
self.assertEqual(self.parent.parent_changed_count, 1)
|
||||
|
||||
def test_priority(self):
|
||||
self.par.add_observer(self, self._trigger, -1)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue