mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-07 11:02:38 +02:00
first attempt
This commit is contained in:
parent
98f632e92e
commit
1d2cbfe44a
2 changed files with 28 additions and 2 deletions
|
|
@ -5,6 +5,7 @@ import numpy as np
|
|||
from kern import CombinationKernel
|
||||
from ...util.caching import Cache_this
|
||||
import itertools
|
||||
import operator
|
||||
|
||||
class Prod(CombinationKernel):
|
||||
"""
|
||||
|
|
@ -42,9 +43,19 @@ class Prod(CombinationKernel):
|
|||
return reduce(np.multiply, (p.Kdiag(X) for p in which_parts))
|
||||
|
||||
def update_gradients_full(self, dL_dK, X, X2=None):
|
||||
np.seterr(invalid='raise')
|
||||
k = self.K(X,X2)*dL_dK
|
||||
for p in self.parts:
|
||||
p.update_gradients_full(k/p.K(X,X2),X,X2)
|
||||
try:
|
||||
for p in self.parts:
|
||||
p.update_gradients_full(k/p.K(X,X2),X,X2)
|
||||
except FloatingPointError:
|
||||
np.seterr(invalid='warn')
|
||||
print "Gradient warning: falling back to slow version due to zero-valued kernel"
|
||||
for combination in itertools.combinations(self.parts, len(self.parts) - 1):
|
||||
prod = reduce(operator.mul, [p.K(X, X2) for p in combination])
|
||||
to_update = list(set(self.parts) - set(combination))[0]
|
||||
to_update.update_gradients_full(dL_dK * prod, X, X2)
|
||||
|
||||
|
||||
def update_gradients_diag(self, dL_dKdiag, X):
|
||||
k = self.Kdiag(X)*dL_dKdiag
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue