psi_stat tests renamed

This commit is contained in:
Max Zwiessele 2013-05-22 12:39:49 +01:00
parent b7de50b5b3
commit 1f5a7d0053
3 changed files with 33 additions and 21 deletions

View file

@ -7,6 +7,6 @@ import unittest
import sys
def deepTest(reason):
if 'deep' in sys.argv:
if 'deep' in reason:
return lambda x:x
return unittest.skip("Not deep scanning, enable deepscan by adding 'deep' argument")

View file

@ -6,10 +6,9 @@ Created on 26 Apr 2013
import unittest
import GPy
import numpy as np
import sys
from .. import testing
from GPy import testing
__test__ = True
__test__ = False
np.random.seed(0)
def ard(p):
@ -20,7 +19,7 @@ def ard(p):
pass
return ""
@testing.deepTest
@testing.deepTest(__test__)
class Test(unittest.TestCase):
D = 9
M = 4
@ -29,13 +28,22 @@ class Test(unittest.TestCase):
def setUp(self):
self.kerns = (
GPy.kern.rbf(self.D), GPy.kern.rbf(self.D, ARD=True),
GPy.kern.linear(self.D, ARD=False), GPy.kern.linear(self.D, ARD=True),
GPy.kern.linear(self.D) + GPy.kern.bias(self.D),
GPy.kern.rbf(self.D) + GPy.kern.bias(self.D),
GPy.kern.linear(self.D) + GPy.kern.bias(self.D) + GPy.kern.white(self.D),
GPy.kern.rbf(self.D) + GPy.kern.bias(self.D) + GPy.kern.white(self.D),
GPy.kern.bias(self.D), GPy.kern.white(self.D),
# (GPy.kern.rbf(self.D, ARD=True) +
# GPy.kern.linear(self.D, ARD=True) +
# GPy.kern.bias(self.D) +
# GPy.kern.white(self.D)),
(GPy.kern.rbf(self.D, np.random.rand(), np.random.rand(self.D), ARD=True) +
GPy.kern.rbf(self.D, np.random.rand(), np.random.rand(self.D), ARD=True) +
GPy.kern.linear(self.D, np.random.rand(self.D), ARD=True) +
GPy.kern.bias(self.D) +
GPy.kern.white(self.D)),
# GPy.kern.rbf(self.D), GPy.kern.rbf(self.D, ARD=True),
# GPy.kern.linear(self.D, ARD=False), GPy.kern.linear(self.D, ARD=True),
# GPy.kern.linear(self.D) + GPy.kern.bias(self.D),
# GPy.kern.rbf(self.D) + GPy.kern.bias(self.D),
# GPy.kern.linear(self.D) + GPy.kern.bias(self.D) + GPy.kern.white(self.D),
# GPy.kern.rbf(self.D) + GPy.kern.bias(self.D) + GPy.kern.white(self.D),
# GPy.kern.bias(self.D), GPy.kern.white(self.D),
)
self.q_x_mean = np.random.randn(self.D)
self.q_x_variance = np.exp(np.random.randn(self.D))
@ -64,8 +72,9 @@ class Test(unittest.TestCase):
K_ /= self.Nsamples / Nsamples
msg = "psi1: " + "+".join([p.name + ard(p) for p in kern.parts])
try:
# pylab.figure(msg)
# pylab.plot(diffs)
import pylab
pylab.figure(msg)
pylab.plot(diffs)
self.assertTrue(np.allclose(psi1.squeeze(), K_,
rtol=1e-1, atol=.1),
msg=msg + ": not matching")
@ -90,8 +99,9 @@ class Test(unittest.TestCase):
K_ /= self.Nsamples / Nsamples
msg = "psi2: {}".format("+".join([p.name + ard(p) for p in kern.parts]))
try:
# pylab.figure(msg)
# pylab.plot(diffs)
import pylab
pylab.figure(msg)
pylab.plot(diffs)
self.assertTrue(np.allclose(psi2.squeeze(), K_,
rtol=1e-1, atol=.1),
msg=msg + ": not matching")
@ -104,9 +114,11 @@ class Test(unittest.TestCase):
pass
if __name__ == "__main__":
import sys;sys.argv = ['',
'Test.test_psi0',
'Test.test_psi1',
'Test.test_psi2',
]
import sys
__test__ = 'deep' in sys.argv
sys.argv = ['',
'Test.test_psi0',
'Test.test_psi1',
'Test.test_psi2',
]
unittest.main()