first attempt

This commit is contained in:
Daniel Beck 2015-02-06 19:39:46 +11:00
parent 98f632e92e
commit 1d2cbfe44a
2 changed files with 28 additions and 2 deletions

View file

@ -5,6 +5,7 @@ import numpy as np
from kern import CombinationKernel from kern import CombinationKernel
from ...util.caching import Cache_this from ...util.caching import Cache_this
import itertools import itertools
import operator
class Prod(CombinationKernel): class Prod(CombinationKernel):
""" """
@ -42,9 +43,19 @@ class Prod(CombinationKernel):
return reduce(np.multiply, (p.Kdiag(X) for p in which_parts)) return reduce(np.multiply, (p.Kdiag(X) for p in which_parts))
def update_gradients_full(self, dL_dK, X, X2=None): def update_gradients_full(self, dL_dK, X, X2=None):
np.seterr(invalid='raise')
k = self.K(X,X2)*dL_dK k = self.K(X,X2)*dL_dK
for p in self.parts: try:
p.update_gradients_full(k/p.K(X,X2),X,X2) for p in self.parts:
p.update_gradients_full(k/p.K(X,X2),X,X2)
except FloatingPointError:
np.seterr(invalid='warn')
print "Gradient warning: falling back to slow version due to zero-valued kernel"
for combination in itertools.combinations(self.parts, len(self.parts) - 1):
prod = reduce(operator.mul, [p.K(X, X2) for p in combination])
to_update = list(set(self.parts) - set(combination))[0]
to_update.update_gradients_full(dL_dK * prod, X, X2)
def update_gradients_diag(self, dL_dKdiag, X): def update_gradients_diag(self, dL_dKdiag, X):
k = self.Kdiag(X)*dL_dKdiag k = self.Kdiag(X)*dL_dKdiag

View file

@ -400,12 +400,27 @@ class Coregionalize_weave_test(unittest.TestCase):
#reset the weave state for any other tests #reset the weave state for any other tests
GPy.util.config.config.set('weave', 'working', 'False') GPy.util.config.config.set('weave', 'working', 'False')
class KernelTestsProductWithZeroValues(unittest.TestCase):
def test_zero_valued_kernel(self):
X = np.array([[0,1],[1,0]])
Y = np.array([[1],[10]])
lin = GPy.kern.Linear(2)
bias = GPy.kern.Bias(2)
k = lin * bias
#k = lin
m = GPy.models.GPRegression(X, Y, kernel=k)
#m['mul.bias.variance'].constrain_fixed(0)
m.optimize(messages=False)
if __name__ == "__main__": if __name__ == "__main__":
print "Running unit tests, please be (very) patient..." print "Running unit tests, please be (very) patient..."
unittest.main() unittest.main()
#suite = unittest.TestLoader().loadTestsFromTestCase(KernelTestsProductWithZeroValues)
#unittest.TextTestRunner().run(suite)
# np.random.seed(0) # np.random.seed(0)
# N0 = 3 # N0 = 3
# N1 = 9 # N1 = 9