diff --git a/GPy/testing/pickle_tests.py b/GPy/testing/pickle_tests.py index fd1bf93c..64357b39 100644 --- a/GPy/testing/pickle_tests.py +++ b/GPy/testing/pickle_tests.py @@ -20,6 +20,8 @@ from GPy.examples.dimensionality_reduction import mrd_simulation from GPy.core.parameterization.variational import NormalPosterior from GPy.models.gp_regression import GPRegression from functools import reduce +from GPy.util.caching import Cacher +from pickle import PicklingError def toy_model(): X = np.linspace(0,1,50)[:, None] @@ -205,23 +207,6 @@ class Test(ListDictTestCase): def _callback(self, what, which): what.count += 1 - @unittest.skip - def test_add_observer(self): - par = toy_model() - par.name = "original" - par.count = 0 - par.add_observer(self, self._callback, 1) - pcopy = GPRegression(par.X.copy(), par.Y.copy(), kernel=par.kern.copy()) - self.assertNotIn(par.observers[0], pcopy.observers) - pcopy = par.copy() - pcopy.name = "copy" - self.assertTrue(par.checkgrad()) - self.assertTrue(pcopy.checkgrad()) - self.assertTrue(pcopy.kern.checkgrad()) - import ipdb;ipdb.set_trace() - self.assertIn(par.observers[0], pcopy.observers) - self.assertEqual(par.count, 3) - self.assertEqual(pcopy.count, 6) # 3 of each call to checkgrad if __name__ == "__main__": #import sys;sys.argv = ['', 'Test.test_parameter_index_operations'] diff --git a/GPy/util/caching.py b/GPy/util/caching.py index 196ce343..cfc7e243 100644 --- a/GPy/util/caching.py +++ b/GPy/util/caching.py @@ -3,6 +3,7 @@ from ..core.parameterization.observable import Observable import collections, weakref from functools import reduce +from pickle import PickleError class Cacher(object): def __init__(self, operation, limit=5, ignore_args=(), force_kwargs=()): @@ -149,10 +150,10 @@ class Cacher(object): return Cacher(self.operation, self.limit, self.ignore_args, self.force_kwargs) def __getstate__(self, memo=None): - raise NotImplementedError("Trying to pickle Cacher object with function {}, pickling functions not possible.".format(str(self.operation))) + raise PickleError("Trying to pickle Cacher object with function {}, pickling functions not possible.".format(str(self.operation))) def __setstate__(self, memo=None): - raise NotImplementedError("Trying to pickle Cacher object with function {}, pickling functions not possible.".format(str(self.operation))) + raise PickleError("Trying to pickle Cacher object with function {}, pickling functions not possible.".format(str(self.operation))) @property def __name__(self):