mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-06-05 14:55:15 +02:00
mirgrate inference_tests to pytest
This commit is contained in:
parent
0b92d3a57c
commit
03fcf7311d
1 changed files with 22 additions and 24 deletions
|
|
@ -12,8 +12,8 @@ import GPy
|
||||||
# np.seterr(invalid='raise')
|
# np.seterr(invalid='raise')
|
||||||
|
|
||||||
|
|
||||||
class InferenceXTestCase(unittest.TestCase):
|
class TestInferenceXCase:
|
||||||
def genData(self):
|
def get_data(self):
|
||||||
np.random.seed(1111)
|
np.random.seed(1111)
|
||||||
Ylist = GPy.examples.dimensionality_reduction._simulate_matern(
|
Ylist = GPy.examples.dimensionality_reduction._simulate_matern(
|
||||||
5, 1, 1, 10, 3, False
|
5, 1, 1, 10, 3, False
|
||||||
|
|
@ -21,7 +21,7 @@ class InferenceXTestCase(unittest.TestCase):
|
||||||
return Ylist[0]
|
return Ylist[0]
|
||||||
|
|
||||||
def test_inferenceX_BGPLVM_Linear(self):
|
def test_inferenceX_BGPLVM_Linear(self):
|
||||||
Ys = self.genData()
|
Ys = self.get_data()
|
||||||
m = GPy.models.BayesianGPLVM(Ys, 3, kernel=GPy.kern.Linear(3, ARD=True))
|
m = GPy.models.BayesianGPLVM(Ys, 3, kernel=GPy.kern.Linear(3, ARD=True))
|
||||||
m.optimize()
|
m.optimize()
|
||||||
x, mi = m.infer_newX(m.Y, optimize=True)
|
x, mi = m.infer_newX(m.Y, optimize=True)
|
||||||
|
|
@ -29,34 +29,34 @@ class InferenceXTestCase(unittest.TestCase):
|
||||||
np.testing.assert_array_almost_equal(m.X.variance, mi.X.variance, decimal=2)
|
np.testing.assert_array_almost_equal(m.X.variance, mi.X.variance, decimal=2)
|
||||||
|
|
||||||
def test_inferenceX_BGPLVM_RBF(self):
|
def test_inferenceX_BGPLVM_RBF(self):
|
||||||
Ys = self.genData()
|
Ys = self.get_data()
|
||||||
m = GPy.models.BayesianGPLVM(Ys, 3, kernel=GPy.kern.RBF(3, ARD=True))
|
m = GPy.models.BayesianGPLVM(Ys, 3, kernel=GPy.kern.RBF(3, ARD=True))
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
with warnings.catch_warnings():
|
with warnings.catch_warnings():
|
||||||
warnings.simplefilter("ignore")
|
warnings.simplefilter("ignore")
|
||||||
m.optimize()
|
m.optimize()
|
||||||
x, mi = m.infer_newX(m.Y, optimize=True)
|
_x, mi = m.infer_newX(m.Y, optimize=True)
|
||||||
np.testing.assert_array_almost_equal(m.X.mean, mi.X.mean, decimal=2)
|
np.testing.assert_array_almost_equal(m.X.mean, mi.X.mean, decimal=2)
|
||||||
np.testing.assert_array_almost_equal(m.X.variance, mi.X.variance, decimal=2)
|
np.testing.assert_array_almost_equal(m.X.variance, mi.X.variance, decimal=2)
|
||||||
|
|
||||||
def test_inferenceX_GPLVM_Linear(self):
|
def test_inferenceX_GPLVM_Linear(self):
|
||||||
Ys = self.genData()
|
Ys = self.get_data()
|
||||||
m = GPy.models.GPLVM(Ys, 3, kernel=GPy.kern.Linear(3, ARD=True))
|
m = GPy.models.GPLVM(Ys, 3, kernel=GPy.kern.Linear(3, ARD=True))
|
||||||
m.optimize()
|
m.optimize()
|
||||||
x, mi = m.infer_newX(m.Y, optimize=True)
|
_x, mi = m.infer_newX(m.Y, optimize=True)
|
||||||
np.testing.assert_array_almost_equal(m.X, mi.X, decimal=2)
|
np.testing.assert_array_almost_equal(m.X, mi.X, decimal=2)
|
||||||
|
|
||||||
def test_inferenceX_GPLVM_RBF(self):
|
def test_inferenceX_GPLVM_RBF(self):
|
||||||
Ys = self.genData()
|
Ys = self.get_data()
|
||||||
m = GPy.models.GPLVM(Ys, 3, kernel=GPy.kern.RBF(3, ARD=True))
|
m = GPy.models.GPLVM(Ys, 3, kernel=GPy.kern.RBF(3, ARD=True))
|
||||||
m.optimize()
|
m.optimize()
|
||||||
x, mi = m.infer_newX(m.Y, optimize=True)
|
_x, mi = m.infer_newX(m.Y, optimize=True)
|
||||||
np.testing.assert_array_almost_equal(m.X, mi.X, decimal=2)
|
np.testing.assert_array_almost_equal(m.X, mi.X, decimal=2)
|
||||||
|
|
||||||
|
|
||||||
class InferenceGPEP(unittest.TestCase):
|
class TestInferenceGPEP:
|
||||||
def genData(self):
|
def get_data(self):
|
||||||
np.random.seed(1)
|
np.random.seed(1)
|
||||||
k = GPy.kern.RBF(1, variance=7.0, lengthscale=0.2)
|
k = GPy.kern.RBF(1, variance=7.0, lengthscale=0.2)
|
||||||
X = np.random.rand(200, 1)
|
X = np.random.rand(200, 1)
|
||||||
|
|
@ -64,11 +64,11 @@ class InferenceGPEP(unittest.TestCase):
|
||||||
np.zeros(200), k.K(X) + 1e-5 * np.eye(X.shape[0])
|
np.zeros(200), k.K(X) + 1e-5 * np.eye(X.shape[0])
|
||||||
)
|
)
|
||||||
lik = GPy.likelihoods.Bernoulli()
|
lik = GPy.likelihoods.Bernoulli()
|
||||||
p = lik.gp_link.transf(f) # squash the latent function
|
_p = lik.gp_link.transf(f) # squash the latent function
|
||||||
Y = lik.samples(f).reshape(-1, 1)
|
Y = lik.samples(f).reshape(-1, 1)
|
||||||
return X, Y
|
return X, Y
|
||||||
|
|
||||||
def genNoisyData(self):
|
def get_noisy_data(self):
|
||||||
np.random.seed(1)
|
np.random.seed(1)
|
||||||
X = np.random.rand(100, 1)
|
X = np.random.rand(100, 1)
|
||||||
self.real_std = 0.1
|
self.real_std = 0.1
|
||||||
|
|
@ -83,7 +83,7 @@ class InferenceGPEP(unittest.TestCase):
|
||||||
def test_inference_EP(self):
|
def test_inference_EP(self):
|
||||||
from paramz import ObsAr
|
from paramz import ObsAr
|
||||||
|
|
||||||
X, Y = self.genData()
|
X, Y = self.get_data()
|
||||||
lik = GPy.likelihoods.Bernoulli()
|
lik = GPy.likelihoods.Bernoulli()
|
||||||
k = GPy.kern.RBF(1, variance=7.0, lengthscale=0.2)
|
k = GPy.kern.RBF(1, variance=7.0, lengthscale=0.2)
|
||||||
inf = GPy.inference.latent_function_inference.expectation_propagation.EP(
|
inf = GPy.inference.latent_function_inference.expectation_propagation.EP(
|
||||||
|
|
@ -158,7 +158,7 @@ class InferenceGPEP(unittest.TestCase):
|
||||||
def test_inference_EP_non_classification(self):
|
def test_inference_EP_non_classification(self):
|
||||||
from paramz import ObsAr
|
from paramz import ObsAr
|
||||||
|
|
||||||
X, Y, Y_extra_noisy = self.genNoisyData()
|
X, _Y, Y_extra_noisy = self.get_noisy_data()
|
||||||
deg_freedom = 5.0
|
deg_freedom = 5.0
|
||||||
init_noise_var = 0.08
|
init_noise_var = 0.08
|
||||||
lik_studentT = GPy.likelihoods.StudentT(
|
lik_studentT = GPy.likelihoods.StudentT(
|
||||||
|
|
@ -234,7 +234,7 @@ class InferenceGPEP(unittest.TestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class VarDtcTest(unittest.TestCase):
|
class TestVarDtc:
|
||||||
def test_var_dtc_inference_with_mean(self):
|
def test_var_dtc_inference_with_mean(self):
|
||||||
"""Check dL_dm in var_dtc is calculated correctly"""
|
"""Check dL_dm in var_dtc is calculated correctly"""
|
||||||
np.random.seed(1)
|
np.random.seed(1)
|
||||||
|
|
@ -243,10 +243,10 @@ class VarDtcTest(unittest.TestCase):
|
||||||
m = GPy.models.SparseGPRegression(
|
m = GPy.models.SparseGPRegression(
|
||||||
x, y, mean_function=GPy.mappings.Linear(input_dim=1, output_dim=1)
|
x, y, mean_function=GPy.mappings.Linear(input_dim=1, output_dim=1)
|
||||||
)
|
)
|
||||||
self.assertTrue(m.checkgrad())
|
assert m.checkgrad()
|
||||||
|
|
||||||
|
|
||||||
class HMCSamplerTest(unittest.TestCase):
|
class TestHMCSampler:
|
||||||
def test_sampling(self):
|
def test_sampling(self):
|
||||||
np.random.seed(1)
|
np.random.seed(1)
|
||||||
x = np.linspace(0.0, 2 * np.pi, 100)[:, None]
|
x = np.linspace(0.0, 2 * np.pi, 100)[:, None]
|
||||||
|
|
@ -258,10 +258,11 @@ class HMCSamplerTest(unittest.TestCase):
|
||||||
m.likelihood.variance.set_prior(GPy.priors.Gamma.from_EV(1.0, 10.0))
|
m.likelihood.variance.set_prior(GPy.priors.Gamma.from_EV(1.0, 10.0))
|
||||||
|
|
||||||
hmc = GPy.inference.mcmc.HMC(m, stepsize=1e-2)
|
hmc = GPy.inference.mcmc.HMC(m, stepsize=1e-2)
|
||||||
s = hmc.sample(num_samples=3)
|
_s = hmc.sample(num_samples=3)
|
||||||
|
# TODO: seems like there is no test here?
|
||||||
|
|
||||||
|
|
||||||
class MCMCSamplerTest(unittest.TestCase):
|
class TestMCMCSampler:
|
||||||
def test_sampling(self):
|
def test_sampling(self):
|
||||||
np.random.seed(1)
|
np.random.seed(1)
|
||||||
x = np.linspace(0.0, 2 * np.pi, 100)[:, None]
|
x = np.linspace(0.0, 2 * np.pi, 100)[:, None]
|
||||||
|
|
@ -274,7 +275,4 @@ class MCMCSamplerTest(unittest.TestCase):
|
||||||
|
|
||||||
mcmc = GPy.inference.mcmc.Metropolis_Hastings(m)
|
mcmc = GPy.inference.mcmc.Metropolis_Hastings(m)
|
||||||
mcmc.sample(Ntotal=100, Nburn=10)
|
mcmc.sample(Ntotal=100, Nburn=10)
|
||||||
|
# TODO: seems like there is no test here?
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
unittest.main()
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue