mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-04-27 13:56:23 +02:00
[linear] einsums
This commit is contained in:
parent
22de3ab676
commit
3972b4bd9a
1 changed files with 9 additions and 6 deletions
|
|
@ -51,7 +51,7 @@ class Linear(Kern):
|
|||
self.variances = Param('variances', variances, Logexp())
|
||||
self.add_parameter(self.variances)
|
||||
self.psicomp = PSICOMP_Linear()
|
||||
|
||||
|
||||
@Cache_this(limit=2)
|
||||
def K(self, X, X2=None):
|
||||
if self.ARD:
|
||||
|
|
@ -76,10 +76,12 @@ class Linear(Kern):
|
|||
def update_gradients_full(self, dL_dK, X, X2=None):
|
||||
if self.ARD:
|
||||
if X2 is None:
|
||||
self.variances.gradient = np.array([np.sum(dL_dK * tdot(X[:, i:i + 1])) for i in range(self.input_dim)])
|
||||
#self.variances.gradient = np.array([np.sum(dL_dK * tdot(X[:, i:i + 1])) for i in range(self.input_dim)])
|
||||
self.variances.gradient = np.einsum('ij,iq,jq->q', dL_dK, X, X)
|
||||
else:
|
||||
product = X[:, None, :] * X2[None, :, :]
|
||||
self.variances.gradient = (dL_dK[:, :, None] * product).sum(0).sum(0)
|
||||
#product = X[:, None, :] * X2[None, :, :]
|
||||
#self.variances.gradient = (dL_dK[:, :, None] * product).sum(0).sum(0)
|
||||
self.variances.gradient = np.einsum('ij,iq,jq->q', dL_dK, X, X2)
|
||||
else:
|
||||
self.variances.gradient = np.sum(self._dot_product(X, X2) * dL_dK)
|
||||
|
||||
|
|
@ -93,9 +95,10 @@ class Linear(Kern):
|
|||
|
||||
def gradients_X(self, dL_dK, X, X2=None):
|
||||
if X2 is None:
|
||||
return np.einsum('mq,nm->nq',X*self.variances,dL_dK)+np.einsum('nq,nm->mq',X*self.variances,dL_dK)
|
||||
return np.einsum('jq,q,ij->iq', X, 2*self.variances, dL_dK)
|
||||
else:
|
||||
return (((X2[None,:, :] * self.variances)) * dL_dK[:, :, None]).sum(1)
|
||||
#return (((X2[None,:, :] * self.variances)) * dL_dK[:, :, None]).sum(1)
|
||||
return np.einsum('jq,q,ij->iq', X2, self.variances, dL_dK)
|
||||
|
||||
def gradients_X_diag(self, dL_dKdiag, X):
|
||||
return 2.*self.variances*dL_dKdiag[:,None]*X
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue