test + code change in gradients_X

This commit is contained in:
Daniel Beck 2015-02-09 10:02:26 +11:00
parent fc8705104b
commit 98c743d157
2 changed files with 23 additions and 14 deletions

View file

@ -14,8 +14,9 @@ def numpy_invalid_op_as_exception(func):
"""
def func_wrapper(*args, **kwargs):
np.seterr(invalid='raise')
func(*args, **kwargs)
result = func(*args, **kwargs)
np.seterr(invalid='warn')
return result
return func_wrapper
@ -66,17 +67,23 @@ class Prod(CombinationKernel):
to_update = list(set(self.parts) - set(combination))[0]
to_update.update_gradients_full(dL_dK * prod, X, X2)
def update_gradients_diag(self, dL_dKdiag, X):
k = self.Kdiag(X)*dL_dKdiag
for p in self.parts:
p.update_gradients_diag(k/p.Kdiag(X),X)
@numpy_invalid_op_as_exception
def gradients_X(self, dL_dK, X, X2=None):
target = np.zeros(X.shape)
k = self.K(X,X2)*dL_dK
for p in self.parts:
target += p.gradients_X(k/p.K(X,X2),X,X2)
try:
for p in self.parts:
target += p.gradients_X(k/p.K(X,X2),X,X2)
except FloatingPointError:
for combination in itertools.combinations(self.parts, len(self.parts) - 1):
prod = reduce(np.multiply, [p.K(X, X2) for p in combination])
to_update = list(set(self.parts) - set(combination))[0]
target += to_update.gradients_X(dL_dK * prod, X, X2)
return target
def gradients_X_diag(self, dL_dKdiag, X):