mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-06-11 15:15:15 +02:00
migrate gp_tests to pytest
This commit is contained in:
parent
9263f572eb
commit
d0c65eaa28
1 changed files with 12 additions and 9 deletions
|
|
@ -3,13 +3,13 @@ Created on 4 Sep 2015
|
|||
|
||||
@author: maxz
|
||||
"""
|
||||
import unittest
|
||||
import numpy as np, GPy
|
||||
import numpy as np
|
||||
import GPy
|
||||
from GPy.core.parameterization.variational import NormalPosterior
|
||||
|
||||
|
||||
class Test(unittest.TestCase):
|
||||
def setUp(self):
|
||||
class TestGP:
|
||||
def setup(self):
|
||||
np.random.seed(12345)
|
||||
self.N = 20
|
||||
self.N_new = 50
|
||||
|
|
@ -19,6 +19,8 @@ class Test(unittest.TestCase):
|
|||
self.X_new = np.random.uniform(-3.0, 3.0, (self.N_new, 1))
|
||||
|
||||
def test_setxy_bgplvm(self):
|
||||
self.setup()
|
||||
|
||||
k = GPy.kern.RBF(1)
|
||||
m = GPy.models.BayesianGPLVM(self.Y, 1, kernel=k)
|
||||
mu, var = m.predict(m.X)
|
||||
|
|
@ -36,6 +38,8 @@ class Test(unittest.TestCase):
|
|||
np.testing.assert_allclose(var, var2)
|
||||
|
||||
def test_setxy_gplvm(self):
|
||||
self.setup()
|
||||
|
||||
k = GPy.kern.RBF(1)
|
||||
m = GPy.models.GPLVM(self.Y, 1, kernel=k)
|
||||
mu, var = m.predict(m.X)
|
||||
|
|
@ -53,6 +57,8 @@ class Test(unittest.TestCase):
|
|||
np.testing.assert_allclose(var, var2)
|
||||
|
||||
def test_setxy_gp(self):
|
||||
self.setup()
|
||||
|
||||
k = GPy.kern.RBF(1)
|
||||
m = GPy.models.GPRegression(self.X, self.Y, kernel=k)
|
||||
mu, var = m.predict(m.X)
|
||||
|
|
@ -72,6 +78,8 @@ class Test(unittest.TestCase):
|
|||
from GPy.core.parameterization.param import Param
|
||||
from GPy.core.mapping import Mapping
|
||||
|
||||
self.setup()
|
||||
|
||||
class Parabola(Mapping):
|
||||
def __init__(self, variance, degree=2, name="parabola"):
|
||||
super(Parabola, self).__init__(1, 1, name)
|
||||
|
|
@ -111,8 +119,3 @@ class Test(unittest.TestCase):
|
|||
m.randomize()
|
||||
assert m.checkgrad()
|
||||
_ = m.predict(m.X)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# import sys;sys.argv = ['', 'Test.testName']
|
||||
unittest.main()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue