mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-09 12:02:38 +02:00
some simplification of the psi2_statistics in rbf
This commit is contained in:
parent
a03b0f1179
commit
a42ec441e4
1 changed files with 6 additions and 6 deletions
|
|
@ -153,7 +153,7 @@ class rbf(kernpart):
|
|||
"""Shape N,M,M,Ntheta"""
|
||||
self._psi_computations(Z,mu,S)
|
||||
d_var = 2.*self._psi2/self.variance
|
||||
d_length = self._psi2[:,:,:,None]*(0.5*self._psi2_Zdist_sq*self._psi2_denom + 2.*self._psi2_mudist_sq + 2.*S[:,None,None,:]/self.lengthscale2)/(self.lengthscale*self._psi2_denom)
|
||||
d_length = 2.*self._psi2[:,:,:,None]*(self._psi2_Zdist_sq*self._psi2_denom + self._psi2_mudist_sq + S[:,None,None,:]/self.lengthscale2)/(self.lengthscale*self._psi2_denom)
|
||||
|
||||
target[0] += np.sum(dL_dpsi2*d_var)
|
||||
dpsi2_dlength = d_length*dL_dpsi2[:,:,:,None]
|
||||
|
|
@ -164,7 +164,7 @@ class rbf(kernpart):
|
|||
|
||||
def dpsi2_dZ(self,dL_dpsi2,Z,mu,S,target):
|
||||
self._psi_computations(Z,mu,S)
|
||||
term1 = 0.5*self._psi2_Zdist/self.lengthscale2 # M, M, Q
|
||||
term1 = self._psi2_Zdist/self.lengthscale2 # M, M, Q
|
||||
term2 = self._psi2_mudist/self._psi2_denom/self.lengthscale2 # N, M, M, Q
|
||||
dZ = self._psi2[:,:,:,None] * (term1[None] + term2)
|
||||
target += (dL_dpsi2[:,:,:,None]*dZ).sum(0).sum(0)
|
||||
|
|
@ -200,8 +200,8 @@ class rbf(kernpart):
|
|||
if not np.all(Z==self._Z):
|
||||
#Z has changed, compute Z specific stuff
|
||||
self._psi2_Zhat = 0.5*(Z[:,None,:] +Z[None,:,:]) # M,M,Q
|
||||
self._psi2_Zdist = Z[:,None,:]-Z[None,:,:] # M,M,Q
|
||||
self._psi2_Zdist_sq = np.square(self._psi2_Zdist)/self.lengthscale2 # M,M,Q
|
||||
self._psi2_Zdist = 0.5*(Z[:,None,:]-Z[None,:,:]) # M,M,Q
|
||||
self._psi2_Zdist_sq = np.square(self._psi2_Zdist/self.lengthscale) # M,M,Q
|
||||
self._Z = Z
|
||||
|
||||
if not (np.all(Z==self._Z) and np.all(mu==self._mu) and np.all(S==self._S)):
|
||||
|
|
@ -219,7 +219,7 @@ class rbf(kernpart):
|
|||
self._psi2_mudist, self._psi2_mudist_sq, self._psi2_exponent, _ = self.weave_psi2(mu,self._psi2_Zhat)
|
||||
#self._psi2_mudist = mu[:,None,None,:]-self._psi2_Zhat #N,M,M,Q
|
||||
#self._psi2_mudist_sq = np.square(self._psi2_mudist)/(self.lengthscale2*self._psi2_denom)
|
||||
#self._psi2_exponent = np.sum(-self._psi2_Zdist_sq/4. -self._psi2_mudist_sq -0.5*np.log(self._psi2_denom),-1) #N,M,M
|
||||
#self._psi2_exponent = np.sum(-self._psi2_Zdist_sq -self._psi2_mudist_sq -0.5*np.log(self._psi2_denom),-1) #N,M,M
|
||||
self._psi2 = np.square(self.variance)*np.exp(self._psi2_exponent) # N,M,M
|
||||
|
||||
#store matrices for caching
|
||||
|
|
@ -266,7 +266,7 @@ class rbf(kernpart):
|
|||
mudist_sq(n,mm,m,q) = tmp;
|
||||
|
||||
//now psi2_exponent
|
||||
tmp = -psi2_Zdist_sq(m,mm,q)/4.0 - tmp - half_log_psi2_denom(n,q);
|
||||
tmp = -psi2_Zdist_sq(m,mm,q) - tmp - half_log_psi2_denom(n,q);
|
||||
psi2_exponent(n,mm,m) += tmp;
|
||||
if (m !=mm){
|
||||
psi2_exponent(n,m,mm) += tmp;
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue