mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-06-11 15:15:15 +02:00
bias now looks in shape
This commit is contained in:
parent
2da256fa93
commit
ea5d19bb4e
1 changed files with 31 additions and 51 deletions
|
|
@ -2,80 +2,60 @@
|
|||
# Licensed under the BSD 3-clause license (see LICENSE.txt)
|
||||
|
||||
|
||||
from kernpart import Kernpart
|
||||
from kern import Kern
|
||||
from ...core.parameterization import Param
|
||||
from ...core.parameterization.transformations import Logexp
|
||||
|
||||
class Bias(Kernpart):
|
||||
class Bias(Kern):
|
||||
def __init__(self,input_dim,variance=1.,name=None):
|
||||
"""
|
||||
:param input_dim: the number of input dimensions
|
||||
:type input_dim: int
|
||||
:param variance: the variance of the kernel
|
||||
:type variance: float
|
||||
"""
|
||||
super(Bias, self).__init__(input_dim, name)
|
||||
from ...core.parameterization.transformations import Logexp
|
||||
self.variance = Param("variance", variance, Logexp())
|
||||
self.add_parameter(self.variance)
|
||||
|
||||
def K(self,X,X2,target):
|
||||
target += self.variance
|
||||
def K(self, X, X2=None):
|
||||
shape = (X.shape[0], X.shape[0] if X2 is None else X2.shape[0])
|
||||
ret = np.empty(shape, dtype=np.float64)
|
||||
ret[:] = self.variance
|
||||
return ret
|
||||
|
||||
def Kdiag(self,X,target):
|
||||
target += self.variance
|
||||
def Kdiag(self,X):
|
||||
ret = np.empty((X.shape[0],), dtype=np.float64)
|
||||
ret[:] = self.variance
|
||||
return ret
|
||||
|
||||
#def dK_dtheta(self,dL_dKdiag,X,X2,target):
|
||||
#target += dL_dKdiag.sum()
|
||||
def update_gradients_full(self, dL_dK, X):
|
||||
def update_gradients_full(self, dL_dK, X, X2=None):
|
||||
self.variance.gradient = dL_dK.sum()
|
||||
|
||||
def dKdiag_dtheta(self,dL_dKdiag,X,target):
|
||||
target += dL_dKdiag.sum()
|
||||
def update_gradients_diag(self, dL_dKdiag, X):
|
||||
self.variance.gradient = dL_dK.sum()
|
||||
|
||||
def gradients_X(self, dL_dK,X, X2, target):
|
||||
pass
|
||||
return np.zeros(X.shape)
|
||||
|
||||
def dKdiag_dX(self,dL_dKdiag,X,target):
|
||||
pass
|
||||
def gradients_X_diag(self,dL_dKdiag,X,target):
|
||||
return np.zeros(X.shape)
|
||||
|
||||
|
||||
#---------------------------------------#
|
||||
# PSI statistics #
|
||||
#---------------------------------------#
|
||||
|
||||
def psi0(self, Z, mu, S, target):
|
||||
target += self.variance
|
||||
def psi0(self, Z, mu, S):
|
||||
return self.Kdiag(mu)
|
||||
|
||||
def psi1(self, Z, mu, S, target):
|
||||
self._psi1 = self.variance
|
||||
target += self._psi1
|
||||
|
||||
return self.K(mu, S)
|
||||
|
||||
def psi2(self, Z, mu, S, target):
|
||||
target += self.variance**2
|
||||
ret = np.empty((mu.shape[0], Z.shape[0], Z.shape[0]), dtype=np.float64)
|
||||
ret[:] = self.variance**2
|
||||
return ret
|
||||
|
||||
def dpsi0_dtheta(self, dL_dpsi0, Z, mu, S, target):
|
||||
target += dL_dpsi0.sum()
|
||||
def update_gradients_variational(self, dL_dKmm, dL_dpsi0, dL_dpsi1, dL_dpsi2, mu, S, Z):
|
||||
self.variance.gradient = dL_dKmm.sum() + dL_dpsi0.sum() + dL_dpsi1.sum() + 2.*self.variance*dL_dpsi2.sum()
|
||||
|
||||
def dpsi1_dtheta(self, dL_dpsi1, Z, mu, S, target):
|
||||
target += dL_dpsi1.sum()
|
||||
def gradients_Z_variational(self, dL_dKmm, dL_dpsi0, dL_dpsi1, dL_dpsi2, mu, S, Z):
|
||||
return np.zeros(Z.shape)
|
||||
|
||||
def dpsi2_dtheta(self, dL_dpsi2, Z, mu, S, target):
|
||||
target += 2.*self.variance*dL_dpsi2.sum()
|
||||
|
||||
def dpsi0_dZ(self, dL_dpsi0, Z, mu, S, target):
|
||||
pass
|
||||
|
||||
def dpsi0_dmuS(self, dL_dpsi0, Z, mu, S, target_mu, target_S):
|
||||
pass
|
||||
|
||||
def dpsi1_dZ(self, dL_dpsi1, Z, mu, S, target):
|
||||
pass
|
||||
|
||||
def dpsi1_dmuS(self, dL_dpsi1, Z, mu, S, target_mu, target_S):
|
||||
pass
|
||||
|
||||
def dpsi2_dZ(self, dL_dpsi2, Z, mu, S, target):
|
||||
pass
|
||||
|
||||
def dpsi2_dmuS(self, dL_dpsi2, Z, mu, S, target_mu, target_S):
|
||||
pass
|
||||
def gradients_muS_variational(self, dL_dKmm, dL_dpsi0, dL_dpsi1, dL_dpsi2, mu, S, Z):
|
||||
return np.zeros(mu.shape), np.zeros(S.shape)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue