Merge pull request #1108 from olamarre/bugfix/ep-pickling

fix: pickle and deep copy classification models with EP
This commit is contained in:
Martin Bubel 2025-01-15 19:22:47 +01:00 committed by GitHub
commit acdd03d3ed
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 40 additions and 17 deletions

View file

@ -1,6 +1,8 @@
# Changelog # Changelog
## Unreleased ## Unreleased
* fix pickle and deep copy for classification models inheriting from EP #1108 [olamarre]
* update prior `__new__` methods #1098 [MartinBubel] * update prior `__new__` methods #1098 [MartinBubel]
* fix invalid escape sequence #1011 [janmayer] * fix invalid escape sequence #1011 [janmayer]

View file

@ -229,24 +229,17 @@ class EPBase(object):
v_diff = np.mean(np.square(ga_approx.v-self.ga_approx_old.v)) v_diff = np.mean(np.square(ga_approx.v-self.ga_approx_old.v))
return ((tau_diff < self.epsilon) and (v_diff < self.epsilon)) return ((tau_diff < self.epsilon) and (v_diff < self.epsilon))
def __setstate__(self, state):
super(EPBase, self).__setstate__(state[0])
self.epsilon, self.eta, self.delta = state[1]
self.reset()
def __getstate__(self):
return [super(EPBase, self).__getstate__() , [self.epsilon, self.eta, self.delta]]
def _save_to_input_dict(self): def _save_to_input_dict(self):
input_dict = super(EPBase, self)._save_to_input_dict() input_dict = {
input_dict["epsilon"]=self.epsilon "epsilon": self.epsilon,
input_dict["eta"]=self.eta "eta": self.eta,
input_dict["delta"]=self.delta "delta": self.delta,
input_dict["always_reset"]=self.always_reset "always_reset": self.always_reset,
input_dict["max_iters"]=self.max_iters "max_iters": self.max_iters,
input_dict["ep_mode"]=self.ep_mode "ep_mode": self.ep_mode,
input_dict["parallel_updates"]=self.parallel_updates "parallel_updates": self.parallel_updates,
input_dict["loading"]=True "loading": True
}
return input_dict return input_dict
class EP(EPBase, ExactGaussianInference): class EP(EPBase, ExactGaussianInference):

View file

@ -5,6 +5,8 @@
The test cases for various inference algorithms The test cases for various inference algorithms
""" """
import copy
import pickle
import numpy as np import numpy as np
import GPy import GPy
@ -146,6 +148,32 @@ class TestInferenceGPEP:
< 1e6 < 1e6
) )
def test_pickle_copy_EP(self):
"""Pickling and deep-copying a classification model employing EP"""
# Dummy binary classification dataset
X = np.array([0, 1, 2, 3]).reshape(-1, 1)
Y = np.array([0, 0, 1, 1]).reshape(-1, 1)
# Some classification model
inf = GPy.inference.latent_function_inference.expectation_propagation.EP(
max_iters=30, delta=0.5
)
m = GPy.core.GP(
X=X,
Y=Y,
kernel=GPy.kern.RBF(input_dim=1, variance=1.0, lengthscale=1.0),
inference_method = inf,
likelihood=GPy.likelihoods.Bernoulli(),
mean_function=None
)
m.optimize()
m_pickled = pickle.dumps(m)
assert pickle.loads(m_pickled) is not None
assert copy.deepcopy(m) is not None
# NOTE: adding a test like above for parameterized likelihood- the above test is # NOTE: adding a test like above for parameterized likelihood- the above test is
# only for probit likelihood which does not have any tunable hyperparameter which is why # only for probit likelihood which does not have any tunable hyperparameter which is why
# the term in dictionary of gradients: dL_dthetaL will always be zero. So here we repeat tests for # the term in dictionary of gradients: dL_dthetaL will always be zero. So here we repeat tests for