mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-06-08 15:05:15 +02:00
[pickling] have the pickling test against a model, which is now being shipped with the distro
This commit is contained in:
parent
29921e1c69
commit
b2b88ae8b8
2 changed files with 29 additions and 13 deletions
|
|
@ -40,18 +40,28 @@ def load(file_or_path):
|
|||
|
||||
:param file_name: path/to/file.pickle
|
||||
"""
|
||||
# This is the pickling pain when changing _src -> src
|
||||
try:
|
||||
import cPickle as pickle
|
||||
if isinstance(file_or_path, basestring):
|
||||
with open(file_or_path, 'rb') as f:
|
||||
m = pickle.load(f)
|
||||
else:
|
||||
m = pickle.load(file_or_path)
|
||||
except:
|
||||
import pickle
|
||||
if isinstance(file_or_path, str):
|
||||
with open(file_or_path, 'rb') as f:
|
||||
m = pickle.load(f)
|
||||
else:
|
||||
m = pickle.load(file_or_path)
|
||||
try:
|
||||
import cPickle as pickle
|
||||
if isinstance(file_or_path, basestring):
|
||||
with open(file_or_path, 'rb') as f:
|
||||
m = pickle.load(f)
|
||||
else:
|
||||
m = pickle.load(file_or_path)
|
||||
except:
|
||||
import pickle
|
||||
if isinstance(file_or_path, str):
|
||||
with open(file_or_path, 'rb') as f:
|
||||
m = pickle.load(f)
|
||||
else:
|
||||
m = pickle.load(file_or_path)
|
||||
except ImportError:
|
||||
import sys
|
||||
import inspect
|
||||
sys.modules['GPy.kern._src'] = kern.src
|
||||
for name, module in inspect.getmembers(kern.src):
|
||||
if not name.startswith('_'):
|
||||
sys.modules['GPy.kern._src.{}'.format(name)] = module
|
||||
m = load(file_or_path)
|
||||
return m
|
||||
|
|
|
|||
|
|
@ -42,6 +42,12 @@ class ListDictTestCase(unittest.TestCase):
|
|||
np.testing.assert_array_equal(a1, a2)
|
||||
|
||||
class Test(ListDictTestCase):
|
||||
def test_load_pickle(self):
|
||||
import os
|
||||
m = GPy.load(os.path.join(os.path.abspath(os.path.split(__file__)[0]), 'pickle_test.pickle'))
|
||||
self.assertTrue(m.checkgrad())
|
||||
self.assertEqual(m.log_likelihood(), -4.7351019830022087)
|
||||
|
||||
def test_parameter_index_operations(self):
|
||||
pio = ParameterIndexOperations(dict(test1=np.array([4,3,1,6,4]), test2=np.r_[2:130]))
|
||||
piov = ParameterIndexOperationsView(pio, 20, 250)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue