[kernel] fix #218 and #325

This commit is contained in:
mzwiessele 2016-03-10 10:21:17 +00:00
parent af76126ef1
commit 30c6fc90ff
3 changed files with 55 additions and 30 deletions

View file

@ -344,11 +344,14 @@ class KernelTestsMiscellaneous(unittest.TestCase):
N, D = 100, 10
self.X = np.linspace(-np.pi, +np.pi, N)[:,None] * np.random.uniform(-10,10,D)
self.rbf = GPy.kern.RBF(2, active_dims=np.arange(0,4,2))
self.rbf.randomize()
self.linear = GPy.kern.Linear(2, active_dims=(3,9))
self.linear.randomize()
self.matern = GPy.kern.Matern32(3, active_dims=np.array([1,7,9]))
self.matern.randomize()
self.sumkern = self.rbf + self.linear
self.sumkern += self.matern
self.sumkern.randomize()
#self.sumkern.randomize()
def test_which_parts(self):
self.assertTrue(np.allclose(self.sumkern.K(self.X, which_parts=[self.linear, self.matern]), self.linear.K(self.X)+self.matern.K(self.X)))
@ -358,6 +361,21 @@ class KernelTestsMiscellaneous(unittest.TestCase):
def test_active_dims(self):
np.testing.assert_array_equal(self.sumkern.active_dims, [0,1,2,3,7,9])
np.testing.assert_array_equal(self.sumkern._all_dims_active, range(10))
tmp = self.linear+self.rbf
np.testing.assert_array_equal(tmp.active_dims, [0,2,3,9])
np.testing.assert_array_equal(tmp._all_dims_active, range(10))
tmp = self.matern+self.rbf
np.testing.assert_array_equal(tmp.active_dims, [0,1,2,7,9])
np.testing.assert_array_equal(tmp._all_dims_active, range(10))
tmp = self.matern+self.rbf*self.linear
np.testing.assert_array_equal(tmp.active_dims, [0,1,2,3,7,9])
np.testing.assert_array_equal(tmp._all_dims_active, range(10))
tmp = self.matern+self.rbf+self.linear
np.testing.assert_array_equal(tmp.active_dims, [0,1,2,3,7,9])
np.testing.assert_array_equal(tmp._all_dims_active, range(10))
tmp = self.matern*self.rbf*self.linear
np.testing.assert_array_equal(tmp.active_dims, [0,1,2,3,7,9])
np.testing.assert_array_equal(tmp._all_dims_active, range(10))
class KernelTestsNonContinuous(unittest.TestCase):
def setUp(self):