[pickling] pickle error

This commit is contained in:
Max Zwiessele 2015-09-04 16:23:30 +01:00
parent 55d714f78c
commit 8d61fe632b
2 changed files with 5 additions and 19 deletions

View file

@ -20,6 +20,8 @@ from GPy.examples.dimensionality_reduction import mrd_simulation
from GPy.core.parameterization.variational import NormalPosterior from GPy.core.parameterization.variational import NormalPosterior
from GPy.models.gp_regression import GPRegression from GPy.models.gp_regression import GPRegression
from functools import reduce from functools import reduce
from GPy.util.caching import Cacher
from pickle import PicklingError
def toy_model(): def toy_model():
X = np.linspace(0,1,50)[:, None] X = np.linspace(0,1,50)[:, None]
@ -205,23 +207,6 @@ class Test(ListDictTestCase):
def _callback(self, what, which): def _callback(self, what, which):
what.count += 1 what.count += 1
@unittest.skip
def test_add_observer(self):
par = toy_model()
par.name = "original"
par.count = 0
par.add_observer(self, self._callback, 1)
pcopy = GPRegression(par.X.copy(), par.Y.copy(), kernel=par.kern.copy())
self.assertNotIn(par.observers[0], pcopy.observers)
pcopy = par.copy()
pcopy.name = "copy"
self.assertTrue(par.checkgrad())
self.assertTrue(pcopy.checkgrad())
self.assertTrue(pcopy.kern.checkgrad())
import ipdb;ipdb.set_trace()
self.assertIn(par.observers[0], pcopy.observers)
self.assertEqual(par.count, 3)
self.assertEqual(pcopy.count, 6) # 3 of each call to checkgrad
if __name__ == "__main__": if __name__ == "__main__":
#import sys;sys.argv = ['', 'Test.test_parameter_index_operations'] #import sys;sys.argv = ['', 'Test.test_parameter_index_operations']

View file

@ -3,6 +3,7 @@
from ..core.parameterization.observable import Observable from ..core.parameterization.observable import Observable
import collections, weakref import collections, weakref
from functools import reduce from functools import reduce
from pickle import PickleError
class Cacher(object): class Cacher(object):
def __init__(self, operation, limit=5, ignore_args=(), force_kwargs=()): def __init__(self, operation, limit=5, ignore_args=(), force_kwargs=()):
@ -149,10 +150,10 @@ class Cacher(object):
return Cacher(self.operation, self.limit, self.ignore_args, self.force_kwargs) return Cacher(self.operation, self.limit, self.ignore_args, self.force_kwargs)
def __getstate__(self, memo=None): def __getstate__(self, memo=None):
raise NotImplementedError("Trying to pickle Cacher object with function {}, pickling functions not possible.".format(str(self.operation))) raise PickleError("Trying to pickle Cacher object with function {}, pickling functions not possible.".format(str(self.operation)))
def __setstate__(self, memo=None): def __setstate__(self, memo=None):
raise NotImplementedError("Trying to pickle Cacher object with function {}, pickling functions not possible.".format(str(self.operation))) raise PickleError("Trying to pickle Cacher object with function {}, pickling functions not possible.".format(str(self.operation)))
@property @property
def __name__(self): def __name__(self):