mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-04-24 20:36:23 +02:00
fix: pickle and deep copy classification models with EP
- Removed unecessary code preventing pickling & copying - Wrote a new test to check that pickling and copying works
This commit is contained in:
parent
9a31886085
commit
ad9c507cde
2 changed files with 38 additions and 17 deletions
|
|
@ -229,24 +229,17 @@ class EPBase(object):
|
|||
v_diff = np.mean(np.square(ga_approx.v-self.ga_approx_old.v))
|
||||
return ((tau_diff < self.epsilon) and (v_diff < self.epsilon))
|
||||
|
||||
def __setstate__(self, state):
|
||||
super(EPBase, self).__setstate__(state[0])
|
||||
self.epsilon, self.eta, self.delta = state[1]
|
||||
self.reset()
|
||||
|
||||
def __getstate__(self):
|
||||
return [super(EPBase, self).__getstate__() , [self.epsilon, self.eta, self.delta]]
|
||||
|
||||
def _save_to_input_dict(self):
|
||||
input_dict = super(EPBase, self)._save_to_input_dict()
|
||||
input_dict["epsilon"]=self.epsilon
|
||||
input_dict["eta"]=self.eta
|
||||
input_dict["delta"]=self.delta
|
||||
input_dict["always_reset"]=self.always_reset
|
||||
input_dict["max_iters"]=self.max_iters
|
||||
input_dict["ep_mode"]=self.ep_mode
|
||||
input_dict["parallel_updates"]=self.parallel_updates
|
||||
input_dict["loading"]=True
|
||||
input_dict = {
|
||||
"epsilon": self.epsilon,
|
||||
"eta": self.eta,
|
||||
"delta": self.delta,
|
||||
"always_reset": self.always_reset,
|
||||
"max_iters": self.max_iters,
|
||||
"ep_mode": self.ep_mode,
|
||||
"parallel_updates": self.parallel_updates,
|
||||
"loading": True
|
||||
}
|
||||
return input_dict
|
||||
|
||||
class EP(EPBase, ExactGaussianInference):
|
||||
|
|
|
|||
|
|
@ -5,6 +5,8 @@
|
|||
The test cases for various inference algorithms
|
||||
"""
|
||||
|
||||
import copy
|
||||
import pickle
|
||||
import numpy as np
|
||||
import GPy
|
||||
|
||||
|
|
@ -146,6 +148,32 @@ class TestInferenceGPEP:
|
|||
< 1e6
|
||||
)
|
||||
|
||||
def test_pickle_copy_EP(self):
|
||||
"""Pickling and deep-copying a classification model employing EP"""
|
||||
|
||||
# Dummy binary classification dataset
|
||||
X = np.array([0, 1, 2, 3]).reshape(-1, 1)
|
||||
Y = np.array([0, 0, 1, 1]).reshape(-1, 1)
|
||||
|
||||
# Some classification model
|
||||
inf = GPy.inference.latent_function_inference.expectation_propagation.EP(
|
||||
max_iters=30, delta=0.5
|
||||
)
|
||||
m = GPy.core.GP(
|
||||
X=X,
|
||||
Y=Y,
|
||||
kernel=GPy.kern.RBF(input_dim=1, variance=1.0, lengthscale=1.0),
|
||||
inference_method = inf,
|
||||
likelihood=GPy.likelihoods.Bernoulli(),
|
||||
mean_function=None
|
||||
)
|
||||
m.optimize()
|
||||
|
||||
m_pickled = pickle.dumps(m)
|
||||
assert pickle.loads(m_pickled) is not None
|
||||
|
||||
assert copy.deepcopy(m) is not None
|
||||
|
||||
# NOTE: adding a test like above for parameterized likelihood- the above test is
|
||||
# only for probit likelihood which does not have any tunable hyperparameter which is why
|
||||
# the term in dictionary of gradients: dL_dthetaL will always be zero. So here we repeat tests for
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue