diff --git a/GPy/testing/variational_tests.py b/GPy/testing/variational_tests.py index 89053b81..cd266f4d 100644 --- a/GPy/testing/variational_tests.py +++ b/GPy/testing/variational_tests.py @@ -1,4 +1,4 @@ -''' +""" Copyright (c) 2015, Max Zwiessele All rights reserved. @@ -26,38 +26,40 @@ SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -''' +""" import unittest import GPy, numpy as np -class KLGrad(GPy.core.Model): - def __init__(self, Xvar, kl): - super(KLGrad, self).__init__(name="klgrad") - self.kl = kl - self.link_parameter(Xvar) - self.Xvar = Xvar - self._obj = 0 - def parameters_changed(self): - self.Xvar.gradient[:] = 0 - self.kl.update_gradients_KL(self.Xvar) - self._obj = self.kl.KL_divergence(self.Xvar) - def objective_function(self): - return self._obj - -class Test(unittest.TestCase): +class KLGrad(GPy.core.Model): + def __init__(self, Xvar, kl): + super(KLGrad, self).__init__(name="klgrad") + self.kl = kl + self.link_parameter(Xvar) + self.Xvar = Xvar + self._obj = 0 + + def parameters_changed(self): + self.Xvar.gradient[:] = 0 + self.kl.update_gradients_KL(self.Xvar) + self._obj = self.kl.KL_divergence(self.Xvar) + + def objective_function(self): + return self._obj + + +class Test(unittest.TestCase): def setUp(self): np.random.seed(12345) self.Xvar = GPy.core.parameterization.variational.NormalPosterior( - np.random.uniform(0,1,(10,3)), - np.random.uniform(1e-5,.01, (10,3)) - ) - + np.random.uniform(0, 1, (10, 3)), np.random.uniform(1e-5, 0.01, (10, 3)) + ) def testNormal(self): klgrad = KLGrad(self.Xvar, GPy.core.parameterization.variational.NormalPrior()) np.testing.assert_(klgrad.checkgrad()) + if __name__ == "__main__": - #import sys;sys.argv = ['', 'Test.testNormal'] - unittest.main() \ No newline at end of file + # import sys;sys.argv = ['', 'Test.testNormal'] + unittest.main()