mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-04 01:02:39 +02:00
Merge pull request #1108 from olamarre/bugfix/ep-pickling
fix: pickle and deep copy classification models with EP
This commit is contained in:
commit
acdd03d3ed
3 changed files with 40 additions and 17 deletions
|
|
@ -1,6 +1,8 @@
|
||||||
# Changelog
|
# Changelog
|
||||||
|
|
||||||
## Unreleased
|
## Unreleased
|
||||||
|
* fix pickle and deep copy for classification models inheriting from EP #1108 [olamarre]
|
||||||
|
|
||||||
* update prior `__new__` methods #1098 [MartinBubel]
|
* update prior `__new__` methods #1098 [MartinBubel]
|
||||||
|
|
||||||
* fix invalid escape sequence #1011 [janmayer]
|
* fix invalid escape sequence #1011 [janmayer]
|
||||||
|
|
|
||||||
|
|
@ -229,24 +229,17 @@ class EPBase(object):
|
||||||
v_diff = np.mean(np.square(ga_approx.v-self.ga_approx_old.v))
|
v_diff = np.mean(np.square(ga_approx.v-self.ga_approx_old.v))
|
||||||
return ((tau_diff < self.epsilon) and (v_diff < self.epsilon))
|
return ((tau_diff < self.epsilon) and (v_diff < self.epsilon))
|
||||||
|
|
||||||
def __setstate__(self, state):
|
|
||||||
super(EPBase, self).__setstate__(state[0])
|
|
||||||
self.epsilon, self.eta, self.delta = state[1]
|
|
||||||
self.reset()
|
|
||||||
|
|
||||||
def __getstate__(self):
|
|
||||||
return [super(EPBase, self).__getstate__() , [self.epsilon, self.eta, self.delta]]
|
|
||||||
|
|
||||||
def _save_to_input_dict(self):
|
def _save_to_input_dict(self):
|
||||||
input_dict = super(EPBase, self)._save_to_input_dict()
|
input_dict = {
|
||||||
input_dict["epsilon"]=self.epsilon
|
"epsilon": self.epsilon,
|
||||||
input_dict["eta"]=self.eta
|
"eta": self.eta,
|
||||||
input_dict["delta"]=self.delta
|
"delta": self.delta,
|
||||||
input_dict["always_reset"]=self.always_reset
|
"always_reset": self.always_reset,
|
||||||
input_dict["max_iters"]=self.max_iters
|
"max_iters": self.max_iters,
|
||||||
input_dict["ep_mode"]=self.ep_mode
|
"ep_mode": self.ep_mode,
|
||||||
input_dict["parallel_updates"]=self.parallel_updates
|
"parallel_updates": self.parallel_updates,
|
||||||
input_dict["loading"]=True
|
"loading": True
|
||||||
|
}
|
||||||
return input_dict
|
return input_dict
|
||||||
|
|
||||||
class EP(EPBase, ExactGaussianInference):
|
class EP(EPBase, ExactGaussianInference):
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,8 @@
|
||||||
The test cases for various inference algorithms
|
The test cases for various inference algorithms
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import copy
|
||||||
|
import pickle
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import GPy
|
import GPy
|
||||||
|
|
||||||
|
|
@ -146,6 +148,32 @@ class TestInferenceGPEP:
|
||||||
< 1e6
|
< 1e6
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_pickle_copy_EP(self):
|
||||||
|
"""Pickling and deep-copying a classification model employing EP"""
|
||||||
|
|
||||||
|
# Dummy binary classification dataset
|
||||||
|
X = np.array([0, 1, 2, 3]).reshape(-1, 1)
|
||||||
|
Y = np.array([0, 0, 1, 1]).reshape(-1, 1)
|
||||||
|
|
||||||
|
# Some classification model
|
||||||
|
inf = GPy.inference.latent_function_inference.expectation_propagation.EP(
|
||||||
|
max_iters=30, delta=0.5
|
||||||
|
)
|
||||||
|
m = GPy.core.GP(
|
||||||
|
X=X,
|
||||||
|
Y=Y,
|
||||||
|
kernel=GPy.kern.RBF(input_dim=1, variance=1.0, lengthscale=1.0),
|
||||||
|
inference_method = inf,
|
||||||
|
likelihood=GPy.likelihoods.Bernoulli(),
|
||||||
|
mean_function=None
|
||||||
|
)
|
||||||
|
m.optimize()
|
||||||
|
|
||||||
|
m_pickled = pickle.dumps(m)
|
||||||
|
assert pickle.loads(m_pickled) is not None
|
||||||
|
|
||||||
|
assert copy.deepcopy(m) is not None
|
||||||
|
|
||||||
# NOTE: adding a test like above for parameterized likelihood- the above test is
|
# NOTE: adding a test like above for parameterized likelihood- the above test is
|
||||||
# only for probit likelihood which does not have any tunable hyperparameter which is why
|
# only for probit likelihood which does not have any tunable hyperparameter which is why
|
||||||
# the term in dictionary of gradients: dL_dthetaL will always be zero. So here we repeat tests for
|
# the term in dictionary of gradients: dL_dthetaL will always be zero. So here we repeat tests for
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue