fix: pickle and deep copy classification models with EP

- Removed unecessary code preventing pickling & copying
- Wrote a new test to check that pickling and copying works
This commit is contained in:
Olivier Lamarre 2025-01-03 23:33:00 -05:00
parent 9a31886085
commit ad9c507cde
2 changed files with 38 additions and 17 deletions

View file

@ -229,24 +229,17 @@ class EPBase(object):
v_diff = np.mean(np.square(ga_approx.v-self.ga_approx_old.v))
return ((tau_diff < self.epsilon) and (v_diff < self.epsilon))
def __setstate__(self, state):
super(EPBase, self).__setstate__(state[0])
self.epsilon, self.eta, self.delta = state[1]
self.reset()
def __getstate__(self):
return [super(EPBase, self).__getstate__() , [self.epsilon, self.eta, self.delta]]
def _save_to_input_dict(self):
input_dict = super(EPBase, self)._save_to_input_dict()
input_dict["epsilon"]=self.epsilon
input_dict["eta"]=self.eta
input_dict["delta"]=self.delta
input_dict["always_reset"]=self.always_reset
input_dict["max_iters"]=self.max_iters
input_dict["ep_mode"]=self.ep_mode
input_dict["parallel_updates"]=self.parallel_updates
input_dict["loading"]=True
input_dict = {
"epsilon": self.epsilon,
"eta": self.eta,
"delta": self.delta,
"always_reset": self.always_reset,
"max_iters": self.max_iters,
"ep_mode": self.ep_mode,
"parallel_updates": self.parallel_updates,
"loading": True
}
return input_dict
class EP(EPBase, ExactGaussianInference):

View file

@ -5,6 +5,8 @@
The test cases for various inference algorithms
"""
import copy
import pickle
import numpy as np
import GPy
@ -146,6 +148,32 @@ class TestInferenceGPEP:
< 1e6
)
def test_pickle_copy_EP(self):
"""Pickling and deep-copying a classification model employing EP"""
# Dummy binary classification dataset
X = np.array([0, 1, 2, 3]).reshape(-1, 1)
Y = np.array([0, 0, 1, 1]).reshape(-1, 1)
# Some classification model
inf = GPy.inference.latent_function_inference.expectation_propagation.EP(
max_iters=30, delta=0.5
)
m = GPy.core.GP(
X=X,
Y=Y,
kernel=GPy.kern.RBF(input_dim=1, variance=1.0, lengthscale=1.0),
inference_method = inf,
likelihood=GPy.likelihoods.Bernoulli(),
mean_function=None
)
m.optimize()
m_pickled = pickle.dumps(m)
assert pickle.loads(m_pickled) is not None
assert copy.deepcopy(m) is not None
# NOTE: adding a test like above for parameterized likelihood- the above test is
# only for probit likelihood which does not have any tunable hyperparameter which is why
# the term in dictionary of gradients: dL_dthetaL will always be zero. So here we repeat tests for