migrate variational_tests to pytest

This commit is contained in:
Martin Bubel 2023-10-10 20:07:46 +02:00
parent ae5e53a5b9
commit 5bab8ca976

View file

@ -27,7 +27,6 @@ CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
""" """
import unittest
import GPy, numpy as np import GPy, numpy as np
@ -48,18 +47,14 @@ class KLGrad(GPy.core.Model):
return self._obj return self._obj
class Test(unittest.TestCase): class TestVariational:
def setUp(self): def setup(self):
np.random.seed(12345) np.random.seed(12345)
self.Xvar = GPy.core.parameterization.variational.NormalPosterior( self.Xvar = GPy.core.parameterization.variational.NormalPosterior(
np.random.uniform(0, 1, (10, 3)), np.random.uniform(1e-5, 0.01, (10, 3)) np.random.uniform(0, 1, (10, 3)), np.random.uniform(1e-5, 0.01, (10, 3))
) )
def testNormal(self): def test_normal(self):
self.setup()
klgrad = KLGrad(self.Xvar, GPy.core.parameterization.variational.NormalPrior()) klgrad = KLGrad(self.Xvar, GPy.core.parameterization.variational.NormalPrior())
np.testing.assert_(klgrad.checkgrad()) np.testing.assert_(klgrad.checkgrad())
if __name__ == "__main__":
# import sys;sys.argv = ['', 'Test.testNormal']
unittest.main()