format on save

This commit is contained in:
Martin Bubel 2023-10-10 20:07:05 +02:00
parent e412745861
commit ae5e53a5b9

View file

@ -1,4 +1,4 @@
'''
"""
Copyright (c) 2015, Max Zwiessele
All rights reserved.
@ -26,38 +26,40 @@ SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
'''
"""
import unittest
import GPy, numpy as np
class KLGrad(GPy.core.Model):
def __init__(self, Xvar, kl):
super(KLGrad, self).__init__(name="klgrad")
self.kl = kl
self.link_parameter(Xvar)
self.Xvar = Xvar
self._obj = 0
def parameters_changed(self):
self.Xvar.gradient[:] = 0
self.kl.update_gradients_KL(self.Xvar)
self._obj = self.kl.KL_divergence(self.Xvar)
def objective_function(self):
return self._obj
class Test(unittest.TestCase):
class KLGrad(GPy.core.Model):
def __init__(self, Xvar, kl):
super(KLGrad, self).__init__(name="klgrad")
self.kl = kl
self.link_parameter(Xvar)
self.Xvar = Xvar
self._obj = 0
def parameters_changed(self):
self.Xvar.gradient[:] = 0
self.kl.update_gradients_KL(self.Xvar)
self._obj = self.kl.KL_divergence(self.Xvar)
def objective_function(self):
return self._obj
class Test(unittest.TestCase):
def setUp(self):
np.random.seed(12345)
self.Xvar = GPy.core.parameterization.variational.NormalPosterior(
np.random.uniform(0,1,(10,3)),
np.random.uniform(1e-5,.01, (10,3))
)
np.random.uniform(0, 1, (10, 3)), np.random.uniform(1e-5, 0.01, (10, 3))
)
def testNormal(self):
klgrad = KLGrad(self.Xvar, GPy.core.parameterization.variational.NormalPrior())
np.testing.assert_(klgrad.checkgrad())
if __name__ == "__main__":
#import sys;sys.argv = ['', 'Test.testNormal']
unittest.main()
# import sys;sys.argv = ['', 'Test.testNormal']
unittest.main()