diff --git a/GPy/testing/index_operations_tests.py b/GPy/testing/index_operations_tests.py index 49cc844a..738f92b4 100644 --- a/GPy/testing/index_operations_tests.py +++ b/GPy/testing/index_operations_tests.py @@ -14,9 +14,9 @@ class Test(unittest.TestCase): def setUp(self): self.param_index = ParameterIndexOperations() - self.param_index.add(one, [3]) + self.param_index.add(one, [3,9]) self.param_index.add(two, [0,5]) - self.param_index.add(three, [2,4,7]) + self.param_index.add(three, [2,4,7,10]) self.view = ParameterIndexOperationsView(self.param_index, 2, 6) def test_clear(self): @@ -24,37 +24,61 @@ class Test(unittest.TestCase): self.assertDictEqual(self.param_index._properties, {}) def test_remove(self): - removed = self.param_index.remove(three, np.r_[3:10]) - self.assertListEqual(removed.tolist(), [4, 7]) + removed = self.param_index.remove(three, np.r_[3:13]) + self.assertListEqual(removed.tolist(), [4,7,10]) self.assertListEqual(self.param_index[three].tolist(), [2]) removed = self.param_index.remove(one, [1]) self.assertListEqual(removed.tolist(), []) - self.assertListEqual(self.param_index[one].tolist(), [3]) + self.assertListEqual(self.param_index[one].tolist(), [3,9]) self.assertListEqual(self.param_index.remove('not in there', []).tolist(), []) removed = self.param_index.remove(one, [9]) + self.assertListEqual(removed.tolist(), [9]) self.assertListEqual(self.param_index[one].tolist(), [3]) self.assertListEqual(self.param_index.remove('not in there', [2,3,4]).tolist(), []) - self.assertListEqual(self.view.remove('not in there', [2,3,4]).tolist(), []) def test_shift_left(self): self.view.shift_left(0, 2) - self.assertListEqual(self.param_index[three].tolist(), [2,5]) + self.assertListEqual(self.param_index[three].tolist(), [2,5,8]) + self.assertListEqual(self.param_index[two].tolist(), [0,3]) + self.assertListEqual(self.param_index[one].tolist(), [7]) + #======================================================================= + # 0 1 2 3 4 5 6 7 8 9 10 + # one + # two two + # three three three + # view: [0 1 2 3 4 5 ] + #======================================================================= + self.assertListEqual(self.view[three].tolist(), [0,3]) + self.assertListEqual(self.view[two].tolist(), [1]) + self.assertListEqual(self.view[one].tolist(), [5]) + self.param_index.shift_left(7, 1) + #======================================================================= + # 0 1 2 3 4 5 6 7 8 9 10 + # + # two two + # three three three + # view: [0 1 2 3 4 5 ] + #======================================================================= + self.assertListEqual(self.param_index[three].tolist(), [2,5,7]) self.assertListEqual(self.param_index[two].tolist(), [0,3]) self.assertListEqual(self.param_index[one].tolist(), []) + self.assertListEqual(self.view[three].tolist(), [0,3,5]) + self.assertListEqual(self.view[two].tolist(), [1]) + self.assertListEqual(self.view[one].tolist(), []) def test_shift_right(self): self.view.shift_right(3, 2) - self.assertListEqual(self.param_index[three].tolist(), [2,4,9]) + self.assertListEqual(self.param_index[three].tolist(), [2,4,9,12]) self.assertListEqual(self.param_index[two].tolist(), [0,7]) - self.assertListEqual(self.param_index[one].tolist(), [3]) + self.assertListEqual(self.param_index[one].tolist(), [3,11]) def test_index_view(self): #======================================================================= - # 0 1 2 3 4 5 6 7 8 9 - # one + # 0 1 2 3 4 5 6 7 8 9 10 + # one one # two two - # three three three + # three three three three # view: [0 1 2 3 4 5 ] #======================================================================= self.view = ParameterIndexOperationsView(self.param_index, 2, 6) @@ -71,26 +95,37 @@ class Test(unittest.TestCase): self.assertEqual(v, p) self.assertEqual(v, []) param_index = ParameterIndexOperations() - param_index.add(one, [3]) + param_index.add(one, [3,9]) param_index.add(two, [0,5]) - param_index.add(three, [2,4,7]) - view2 = ParameterIndexOperationsView(param_index, 2, 6) + param_index.add(three, [2,4,7,10]) + view2 = ParameterIndexOperationsView(param_index, 2, 8) self.view.update(view2) for [i,v],[i2,v2] in zip(sorted(param_index.items()), sorted(self.param_index.items())): self.assertEqual(i, i2) - self.assertTrue(np.all(v == v2)) + np.testing.assert_equal(v, v2) + + def test_view_of_view(self): + #======================================================================= + # 0 1 2 3 4 5 6 7 8 9 10 + # one one + # two two + # three three three three + # view: [0 1 2 3 4 5 ] + # view2: [0 1 2 3 4 5 ] + #======================================================================= + view2 = ParameterIndexOperationsView(self.view, 2, 6) + view2.shift_right(0, 2) def test_indexview_remove(self): removed = self.view.remove(two, [3]) - self.assertListEqual(removed.tolist(), [3]) + self.assertListEqual(removed.tolist(), [3]) removed = self.view.remove(three, np.r_[:5]) - self.assertListEqual(removed.tolist(), [0, 2]) - + self.assertListEqual(removed.tolist(), [0, 2]) def test_misc(self): for k,v in self.param_index.copy()._properties.iteritems(): self.assertListEqual(self.param_index[k].tolist(), v.tolist()) - self.assertEqual(self.param_index.size, 6) + self.assertEqual(self.param_index.size, 8) self.assertEqual(self.view.size, 5) def test_print(self):