mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-14 14:32:37 +02:00
psi_stat tests renamed
This commit is contained in:
parent
b7de50b5b3
commit
1f5a7d0053
3 changed files with 33 additions and 21 deletions
|
|
@ -7,6 +7,6 @@ import unittest
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
def deepTest(reason):
|
def deepTest(reason):
|
||||||
if 'deep' in sys.argv:
|
if 'deep' in reason:
|
||||||
return lambda x:x
|
return lambda x:x
|
||||||
return unittest.skip("Not deep scanning, enable deepscan by adding 'deep' argument")
|
return unittest.skip("Not deep scanning, enable deepscan by adding 'deep' argument")
|
||||||
|
|
|
||||||
|
|
@ -6,10 +6,9 @@ Created on 26 Apr 2013
|
||||||
import unittest
|
import unittest
|
||||||
import GPy
|
import GPy
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import sys
|
from GPy import testing
|
||||||
from .. import testing
|
|
||||||
|
|
||||||
__test__ = True
|
__test__ = False
|
||||||
np.random.seed(0)
|
np.random.seed(0)
|
||||||
|
|
||||||
def ard(p):
|
def ard(p):
|
||||||
|
|
@ -20,7 +19,7 @@ def ard(p):
|
||||||
pass
|
pass
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
@testing.deepTest
|
@testing.deepTest(__test__)
|
||||||
class Test(unittest.TestCase):
|
class Test(unittest.TestCase):
|
||||||
D = 9
|
D = 9
|
||||||
M = 4
|
M = 4
|
||||||
|
|
@ -29,13 +28,22 @@ class Test(unittest.TestCase):
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.kerns = (
|
self.kerns = (
|
||||||
GPy.kern.rbf(self.D), GPy.kern.rbf(self.D, ARD=True),
|
# (GPy.kern.rbf(self.D, ARD=True) +
|
||||||
GPy.kern.linear(self.D, ARD=False), GPy.kern.linear(self.D, ARD=True),
|
# GPy.kern.linear(self.D, ARD=True) +
|
||||||
GPy.kern.linear(self.D) + GPy.kern.bias(self.D),
|
# GPy.kern.bias(self.D) +
|
||||||
GPy.kern.rbf(self.D) + GPy.kern.bias(self.D),
|
# GPy.kern.white(self.D)),
|
||||||
GPy.kern.linear(self.D) + GPy.kern.bias(self.D) + GPy.kern.white(self.D),
|
(GPy.kern.rbf(self.D, np.random.rand(), np.random.rand(self.D), ARD=True) +
|
||||||
GPy.kern.rbf(self.D) + GPy.kern.bias(self.D) + GPy.kern.white(self.D),
|
GPy.kern.rbf(self.D, np.random.rand(), np.random.rand(self.D), ARD=True) +
|
||||||
GPy.kern.bias(self.D), GPy.kern.white(self.D),
|
GPy.kern.linear(self.D, np.random.rand(self.D), ARD=True) +
|
||||||
|
GPy.kern.bias(self.D) +
|
||||||
|
GPy.kern.white(self.D)),
|
||||||
|
# GPy.kern.rbf(self.D), GPy.kern.rbf(self.D, ARD=True),
|
||||||
|
# GPy.kern.linear(self.D, ARD=False), GPy.kern.linear(self.D, ARD=True),
|
||||||
|
# GPy.kern.linear(self.D) + GPy.kern.bias(self.D),
|
||||||
|
# GPy.kern.rbf(self.D) + GPy.kern.bias(self.D),
|
||||||
|
# GPy.kern.linear(self.D) + GPy.kern.bias(self.D) + GPy.kern.white(self.D),
|
||||||
|
# GPy.kern.rbf(self.D) + GPy.kern.bias(self.D) + GPy.kern.white(self.D),
|
||||||
|
# GPy.kern.bias(self.D), GPy.kern.white(self.D),
|
||||||
)
|
)
|
||||||
self.q_x_mean = np.random.randn(self.D)
|
self.q_x_mean = np.random.randn(self.D)
|
||||||
self.q_x_variance = np.exp(np.random.randn(self.D))
|
self.q_x_variance = np.exp(np.random.randn(self.D))
|
||||||
|
|
@ -64,8 +72,9 @@ class Test(unittest.TestCase):
|
||||||
K_ /= self.Nsamples / Nsamples
|
K_ /= self.Nsamples / Nsamples
|
||||||
msg = "psi1: " + "+".join([p.name + ard(p) for p in kern.parts])
|
msg = "psi1: " + "+".join([p.name + ard(p) for p in kern.parts])
|
||||||
try:
|
try:
|
||||||
# pylab.figure(msg)
|
import pylab
|
||||||
# pylab.plot(diffs)
|
pylab.figure(msg)
|
||||||
|
pylab.plot(diffs)
|
||||||
self.assertTrue(np.allclose(psi1.squeeze(), K_,
|
self.assertTrue(np.allclose(psi1.squeeze(), K_,
|
||||||
rtol=1e-1, atol=.1),
|
rtol=1e-1, atol=.1),
|
||||||
msg=msg + ": not matching")
|
msg=msg + ": not matching")
|
||||||
|
|
@ -90,8 +99,9 @@ class Test(unittest.TestCase):
|
||||||
K_ /= self.Nsamples / Nsamples
|
K_ /= self.Nsamples / Nsamples
|
||||||
msg = "psi2: {}".format("+".join([p.name + ard(p) for p in kern.parts]))
|
msg = "psi2: {}".format("+".join([p.name + ard(p) for p in kern.parts]))
|
||||||
try:
|
try:
|
||||||
# pylab.figure(msg)
|
import pylab
|
||||||
# pylab.plot(diffs)
|
pylab.figure(msg)
|
||||||
|
pylab.plot(diffs)
|
||||||
self.assertTrue(np.allclose(psi2.squeeze(), K_,
|
self.assertTrue(np.allclose(psi2.squeeze(), K_,
|
||||||
rtol=1e-1, atol=.1),
|
rtol=1e-1, atol=.1),
|
||||||
msg=msg + ": not matching")
|
msg=msg + ": not matching")
|
||||||
|
|
@ -104,9 +114,11 @@ class Test(unittest.TestCase):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import sys;sys.argv = ['',
|
import sys
|
||||||
'Test.test_psi0',
|
__test__ = 'deep' in sys.argv
|
||||||
'Test.test_psi1',
|
sys.argv = ['',
|
||||||
'Test.test_psi2',
|
'Test.test_psi0',
|
||||||
]
|
'Test.test_psi1',
|
||||||
|
'Test.test_psi2',
|
||||||
|
]
|
||||||
unittest.main()
|
unittest.main()
|
||||||
Loading…
Add table
Add a link
Reference in a new issue