diff --git a/GPy/__init__.py b/GPy/__init__.py index 32f0c1c4..d044b2c0 100644 --- a/GPy/__init__.py +++ b/GPy/__init__.py @@ -40,18 +40,28 @@ def load(file_or_path): :param file_name: path/to/file.pickle """ + # This is the pickling pain when changing _src -> src try: - import cPickle as pickle - if isinstance(file_or_path, basestring): - with open(file_or_path, 'rb') as f: - m = pickle.load(f) - else: - m = pickle.load(file_or_path) - except: - import pickle - if isinstance(file_or_path, str): - with open(file_or_path, 'rb') as f: - m = pickle.load(f) - else: - m = pickle.load(file_or_path) + try: + import cPickle as pickle + if isinstance(file_or_path, basestring): + with open(file_or_path, 'rb') as f: + m = pickle.load(f) + else: + m = pickle.load(file_or_path) + except: + import pickle + if isinstance(file_or_path, str): + with open(file_or_path, 'rb') as f: + m = pickle.load(f) + else: + m = pickle.load(file_or_path) + except ImportError: + import sys + import inspect + sys.modules['GPy.kern._src'] = kern.src + for name, module in inspect.getmembers(kern.src): + if not name.startswith('_'): + sys.modules['GPy.kern._src.{}'.format(name)] = module + m = load(file_or_path) return m diff --git a/GPy/testing/pickle_tests.py b/GPy/testing/pickle_tests.py index 59818de7..575496e1 100644 --- a/GPy/testing/pickle_tests.py +++ b/GPy/testing/pickle_tests.py @@ -42,6 +42,12 @@ class ListDictTestCase(unittest.TestCase): np.testing.assert_array_equal(a1, a2) class Test(ListDictTestCase): + def test_load_pickle(self): + import os + m = GPy.load(os.path.join(os.path.abspath(os.path.split(__file__)[0]), 'pickle_test.pickle')) + self.assertTrue(m.checkgrad()) + self.assertEqual(m.log_likelihood(), -4.7351019830022087) + def test_parameter_index_operations(self): pio = ParameterIndexOperations(dict(test1=np.array([4,3,1,6,4]), test2=np.r_[2:130])) piov = ParameterIndexOperationsView(pio, 20, 250)