From 5bab8ca97667e0ebbc416f7f87ad75fc1edbcf27 Mon Sep 17 00:00:00 2001 From: Martin Bubel Date: Tue, 10 Oct 2023 20:07:46 +0200 Subject: [PATCH] migrate variational_tests to pytest --- GPy/testing/variational_tests.py | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/GPy/testing/variational_tests.py b/GPy/testing/variational_tests.py index cd266f4d..33197d03 100644 --- a/GPy/testing/variational_tests.py +++ b/GPy/testing/variational_tests.py @@ -27,7 +27,6 @@ CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. """ -import unittest import GPy, numpy as np @@ -48,18 +47,14 @@ class KLGrad(GPy.core.Model): return self._obj -class Test(unittest.TestCase): - def setUp(self): +class TestVariational: + def setup(self): np.random.seed(12345) self.Xvar = GPy.core.parameterization.variational.NormalPosterior( np.random.uniform(0, 1, (10, 3)), np.random.uniform(1e-5, 0.01, (10, 3)) ) - def testNormal(self): + def test_normal(self): + self.setup() klgrad = KLGrad(self.Xvar, GPy.core.parameterization.variational.NormalPrior()) np.testing.assert_(klgrad.checkgrad()) - - -if __name__ == "__main__": - # import sys;sys.argv = ['', 'Test.testNormal'] - unittest.main()