This commit is contained in:
Max Zwiessele 2014-03-14 11:31:16 +00:00
parent 26e6b29071
commit 7143180842
3 changed files with 49 additions and 18 deletions

View file

@ -17,24 +17,33 @@ class Test(unittest.TestCase):
self.param_index.add(one, [3])
self.param_index.add(two, [0,5])
self.param_index.add(three, [2,4,7])
self.view = ParameterIndexOperationsView(self.param_index, 2, 6)
def test_clear(self):
self.param_index.clear()
self.assertDictEqual(self.param_index._properties, {})
def test_remove(self):
self.param_index.remove(three, np.r_[3:10])
self.assertListEqual(self.param_index[three].tolist(), [2])
self.param_index.remove(one, [1])
self.assertListEqual(self.param_index[one].tolist(), [3])
self.assertListEqual(self.param_index[one].tolist(), [3])
self.assertListEqual(self.param_index.remove('not in there', []).tolist(), [])
self.param_index.remove(one, [9])
self.assertListEqual(self.param_index[one].tolist(), [3])
self.assertListEqual(self.param_index.remove('not in there', [2,3,4]).tolist(), [])
def test_shift_left(self):
self.param_index.shift_left(1, 2)
self.view.shift_left(0, 2)
self.assertListEqual(self.param_index[three].tolist(), [2,5])
self.assertListEqual(self.param_index[two].tolist(), [0,3])
self.assertListEqual(self.param_index[one].tolist(), [1])
self.assertListEqual(self.param_index[one].tolist(), [])
def test_shift_right(self):
self.param_index.shift_right(5, 2)
self.view.shift_right(3, 2)
self.assertListEqual(self.param_index[three].tolist(), [2,4,9])
self.assertListEqual(self.param_index[two].tolist(), [0,7])
self.assertListEqual(self.param_index[one].tolist(), [3])
self.assertListEqual(self.param_index[one].tolist(), [3])
def test_index_view(self):
#=======================================================================
@ -44,17 +53,17 @@ class Test(unittest.TestCase):
# three three three
# view: [0 1 2 3 4 5 ]
#=======================================================================
view = ParameterIndexOperationsView(self.param_index, 2, 6)
self.assertSetEqual(set(view.properties()), set([one, two, three]))
for v,p in zip(view.properties_for(np.r_[:6]), self.param_index.properties_for(np.r_[2:2+6])):
self.view = ParameterIndexOperationsView(self.param_index, 2, 6)
self.assertSetEqual(set(self.view.properties()), set([one, two, three]))
for v,p in zip(self.view.properties_for(np.r_[:6]), self.param_index.properties_for(np.r_[2:2+6])):
self.assertEqual(v, p)
self.assertSetEqual(set(view[two]), set([3]))
self.assertSetEqual(set(self.view[two]), set([3]))
self.assertSetEqual(set(self.param_index[two]), set([0, 5]))
view.add(two, np.array([0]))
self.assertSetEqual(set(view[two]), set([0,3]))
self.view.add(two, np.array([0]))
self.assertSetEqual(set(self.view[two]), set([0,3]))
self.assertSetEqual(set(self.param_index[two]), set([0, 2, 5]))
view.clear()
for v,p in zip(view.properties_for(np.r_[:6]), self.param_index.properties_for(np.r_[2:2+6])):
self.view.clear()
for v,p in zip(self.view.properties_for(np.r_[:6]), self.param_index.properties_for(np.r_[2:2+6])):
self.assertEqual(v, p)
self.assertEqual(v, [])
param_index = ParameterIndexOperations()
@ -62,11 +71,17 @@ class Test(unittest.TestCase):
param_index.add(two, [0,5])
param_index.add(three, [2,4,7])
view2 = ParameterIndexOperationsView(param_index, 2, 6)
view.update(view2)
self.view.update(view2)
for [i,v],[i2,v2] in zip(sorted(param_index.items()), sorted(self.param_index.items())):
self.assertEqual(i, i2)
self.assertTrue(np.all(v == v2))
def test_misc(self):
for k,v in self.param_index.copy()._properties.iteritems():
self.assertListEqual(self.param_index[k].tolist(), v.tolist())
self.assertEqual(self.param_index.size, 6)
self.assertEqual(self.view.size, 5)
if __name__ == "__main__":
#import sys;sys.argv = ['', 'Test.test_index_view']
unittest.main()

View file

@ -311,12 +311,12 @@ class KernelTestsNonContinuous(unittest.TestCase):
self.X_block[N:N+N1, D:D+D] = self.X2
self.X_block[0:N, -1] = 1
self.X_block[N:N+1, -1] = 2
self.X_block = self.X_block[self.X_block.argsort(-1)[:, -1], :]
self.X_block = self.X_block[self.X_block.argsort(0)[:, -1], :]
def test_IndependentOutputs(self):
k = GPy.kern.RBF(self.D)
kern = GPy.kern.IndependentOutputs(k, -1)
self.assertTrue(check_kernel_gradient_functions(kern, X=self.X, X2=self.X2, verbose=verbose))
self.assertTrue(check_kernel_gradient_functions(kern, X=self.X_block, X2=self.X_block, verbose=verbose))
if __name__ == "__main__":
print "Running unit tests, please be (very) patient..."

View file

@ -7,8 +7,24 @@ import unittest
import GPy
import numpy as np
from GPy.core.parameterization.parameter_core import HierarchyError
from GPy.core.parameterization.array_core import ObservableArray
class Test(unittest.TestCase):
class ArrayCoreTest(unittest.TestCase):
def setUp(self):
self.X = np.random.normal(1,1, size=(100,10))
self.obsX = ObservableArray(self.X)
def test_init(self):
X = ObservableArray(self.X)
X2 = ObservableArray(X)
self.assertIs(X, X2, "no new Observable array, when Observable is given")
def test_slice(self):
t1 = self.X[2:78]
t2 = self.obsX[2:78]
self.assertListEqual(t1.tolist(), t2.tolist(), "Slicing should be the exact same, as in ndarray")
class ParameterizedTest(unittest.TestCase):
def setUp(self):
self.rbf = GPy.kern.RBF(1)