format on save

This commit is contained in:
Martin Bubel 2023-10-10 20:07:05 +02:00
parent e412745861
commit ae5e53a5b9

View file

@ -1,4 +1,4 @@
''' """
Copyright (c) 2015, Max Zwiessele Copyright (c) 2015, Max Zwiessele
All rights reserved. All rights reserved.
@ -26,38 +26,40 @@ SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
''' """
import unittest import unittest
import GPy, numpy as np import GPy, numpy as np
class KLGrad(GPy.core.Model):
def __init__(self, Xvar, kl):
super(KLGrad, self).__init__(name="klgrad")
self.kl = kl
self.link_parameter(Xvar)
self.Xvar = Xvar
self._obj = 0
def parameters_changed(self):
self.Xvar.gradient[:] = 0
self.kl.update_gradients_KL(self.Xvar)
self._obj = self.kl.KL_divergence(self.Xvar)
def objective_function(self):
return self._obj
class Test(unittest.TestCase):
class KLGrad(GPy.core.Model):
def __init__(self, Xvar, kl):
super(KLGrad, self).__init__(name="klgrad")
self.kl = kl
self.link_parameter(Xvar)
self.Xvar = Xvar
self._obj = 0
def parameters_changed(self):
self.Xvar.gradient[:] = 0
self.kl.update_gradients_KL(self.Xvar)
self._obj = self.kl.KL_divergence(self.Xvar)
def objective_function(self):
return self._obj
class Test(unittest.TestCase):
def setUp(self): def setUp(self):
np.random.seed(12345) np.random.seed(12345)
self.Xvar = GPy.core.parameterization.variational.NormalPosterior( self.Xvar = GPy.core.parameterization.variational.NormalPosterior(
np.random.uniform(0,1,(10,3)), np.random.uniform(0, 1, (10, 3)), np.random.uniform(1e-5, 0.01, (10, 3))
np.random.uniform(1e-5,.01, (10,3)) )
)
def testNormal(self): def testNormal(self):
klgrad = KLGrad(self.Xvar, GPy.core.parameterization.variational.NormalPrior()) klgrad = KLGrad(self.Xvar, GPy.core.parameterization.variational.NormalPrior())
np.testing.assert_(klgrad.checkgrad()) np.testing.assert_(klgrad.checkgrad())
if __name__ == "__main__": if __name__ == "__main__":
#import sys;sys.argv = ['', 'Test.testNormal'] # import sys;sys.argv = ['', 'Test.testNormal']
unittest.main() unittest.main()