diff --git a/GPy/core/parameterization/array_core.py b/GPy/core/parameterization/array_core.py index a120f004..780367c8 100644 --- a/GPy/core/parameterization/array_core.py +++ b/GPy/core/parameterization/array_core.py @@ -1,7 +1,7 @@ # Copyright (c) 2012, GPy authors (see AUTHORS.txt). # Licensed under the BSD 3-clause license (see LICENSE.txt) -__updated__ = '2014-03-17' +__updated__ = '2014-03-21' import numpy as np from parameter_core import Observable @@ -55,7 +55,7 @@ class ObsAr(np.ndarray, Observable): def __setitem__(self, s, val): if self._s_not_empty(s): super(ObsAr, self).__setitem__(s, val) - self.notify_observers(self[s]) + self.notify_observers() def __getslice__(self, start, stop): return self.__getitem__(slice(start, stop)) diff --git a/GPy/kern/_src/kernel_slice_operations.py b/GPy/kern/_src/kernel_slice_operations.py index c355ccad..9beb40ab 100644 --- a/GPy/kern/_src/kernel_slice_operations.py +++ b/GPy/kern/_src/kernel_slice_operations.py @@ -126,7 +126,7 @@ def _slice_wrapper(kern, operation, diag=False, derivative=False, psi_stat=False kern._sliced_X -= 1 return ret x_slice_wrapper._operation = operation - x_slice_wrapper.__name__ = ("slicer("+operation.__name__ + x_slice_wrapper.__name__ = ("slicer("+str(operation) +(","+str(bool(diag)) if diag else'') +(','+str(bool(derivative)) if derivative else '') +')') diff --git a/GPy/testing/observable_tests.py b/GPy/testing/observable_tests.py index f8be4a48..6d463a91 100644 --- a/GPy/testing/observable_tests.py +++ b/GPy/testing/observable_tests.py @@ -46,7 +46,7 @@ class Test(unittest.TestCase): self._second = None def _trigger(self, which): - self._observer_triggered = float(which) + self._observer_triggered = which self._trigger_count += 1 if self._first is not None: self._second = self._trigger @@ -65,28 +65,28 @@ class Test(unittest.TestCase): self.assertEqual(self.par.params_changed_count, self.parent.parent_changed_count, 'parent should be triggered as often as param') self.p[0,1] = 3 # trigger observers - self.assertEqual(self._observer_triggered, 3, 'observer should have triggered') + self.assertIs(self._observer_triggered, self.p, 'observer should have triggered') self.assertEqual(self._trigger_count, 1, 'observer should have triggered once') self.assertEqual(self.par.params_changed_count, 1, 'params changed once') self.assertEqual(self.par.params_changed_count, self.parent.parent_changed_count, 'parent should be triggered as often as param') self.par.remove_observer(self) - self.p[2,1] = 4 - self.assertEqual(self._observer_triggered, 3, 'observer should not have triggered') + self.p[0,1] = 4 + self.assertIs(self._observer_triggered, self.p, 'observer should not have triggered') self.assertEqual(self._trigger_count, 1, 'observer should have triggered once') self.assertEqual(self.par.params_changed_count, 2, 'params changed second') self.assertEqual(self.par.params_changed_count, self.parent.parent_changed_count, 'parent should be triggered as often as param') self.par.add_observer(self, self._trigger, -1) - self.p[2,1] = 4 - self.assertEqual(self._observer_triggered, 4, 'observer should have triggered') + self.p[0,1] = 4 + self.assertIs(self._observer_triggered, self.p, 'observer should have triggered') self.assertEqual(self._trigger_count, 2, 'observer should have triggered once') self.assertEqual(self.par.params_changed_count, 3, 'params changed second') self.assertEqual(self.par.params_changed_count, self.parent.parent_changed_count, 'parent should be triggered as often as param') self.par.remove_observer(self, self._trigger) self.p[0,1] = 3 - self.assertEqual(self._observer_triggered, 4, 'observer should not have triggered') + self.assertIs(self._observer_triggered, self.p, 'observer should not have triggered') self.assertEqual(self._trigger_count, 2, 'observer should have triggered once') self.assertEqual(self.par.params_changed_count, 4, 'params changed second') self.assertEqual(self.par.params_changed_count, self.parent.parent_changed_count, 'parent should be triggered as often as param') diff --git a/GPy/util/caching.py b/GPy/util/caching.py index ea09292a..0b8039d6 100644 --- a/GPy/util/caching.py +++ b/GPy/util/caching.py @@ -76,9 +76,8 @@ class Cacher(object): [a.add_observer(self, self.on_cache_changed) for a in observable_args] return self.cached_outputs[-1]#return except: - raise - finally: self.reset() + raise def on_cache_changed(self, arg): """ @@ -98,6 +97,32 @@ class Cacher(object): self.cached_outputs = [] self.inputs_changed = [] + @property + def __name__(self): + return self.operation.__name__ + +from functools import wraps, partial + +class Cacher_wrap(object): + def __init__(self, f, limit, ignore_args, force_kwargs): + self.limit = limit + self.ignore_args = ignore_args + self.force_kwargs = force_kwargs + self.f = f + def __get__(self, obj, objtype=None): + return partial(self, obj) + def __call__(self, *args, **kwargs): + obj = args[0] + try: + caches = obj.__cachers + except AttributeError: + caches = obj.__cachers = {} + try: + cacher = caches[self.f] + except KeyError: + cacher = caches[self.f] = Cacher(self.f, self.limit, self.ignore_args, self.force_kwargs) + return cacher(*args, **kwargs) + class Cache_this(object): """ A decorator which can be applied to bound methods in order to cache them @@ -106,12 +131,5 @@ class Cache_this(object): self.limit = limit self.ignore_args = ignore_args self.force_args = force_kwargs - self.c = None def __call__(self, f): - def f_wrap(*args, **kw): - if self.c is None: - self.c = Cacher(f, self.limit, ignore_args=self.ignore_args, force_kwargs=self.force_args) - return self.c(*args, **kw) - f_wrap._cacher = self - f_wrap.__doc__ = "**cached**" + (f.__doc__ or "") - return f_wrap + return Cacher_wrap(f, self.limit, self.ignore_args, self.force_args)