Ensure numpy version is used in coregionalize cython test

This commit is contained in:
Jayanth Koushik 2018-02-15 21:26:32 -05:00
parent 64c125573e
commit 08801c5554

View file

@ -636,7 +636,14 @@ class Coregionalize_cython_test(unittest.TestCase):
grads_cython = self.k.gradient.copy()
K_numpy = self.k._K_numpy(self.X)
# Nasty hack to ensure the numpy version is used for update_gradients
# If this test is running, cython is working, so override the cython
# function with the numpy function
_gradient_reduce_cython = self.k._gradient_reduce_cython
self.k._gradient_reduce_cython = self.k._gradient_reduce_numpy
self.k.update_gradients_full(dL_dK, self.X)
# Undo hack
self.k._gradient_reduce_cython = _gradient_reduce_cython
grads_numpy = self.k.gradient.copy()
self.assertTrue(np.allclose(K_numpy, K_cython))
@ -651,7 +658,12 @@ class Coregionalize_cython_test(unittest.TestCase):
K_numpy = self.k._K_numpy(self.X, self.X2)
self.k.gradient = 0.
# Same hack as in test_sym (Line 639)
_gradient_reduce_cython = self.k._gradient_reduce_cython
self.k._gradient_reduce_cython = self.k._gradient_reduce_numpy
self.k.update_gradients_full(dL_dK, self.X, self.X2)
# Undo hack
self.k._gradient_reduce_cython = _gradient_reduce_cython
grads_numpy = self.k.gradient.copy()
self.assertTrue(np.allclose(K_numpy, K_cython))