diff --git a/GPy/kern/rbf.py b/GPy/kern/rbf.py index d473a5cc..b774db1c 100644 --- a/GPy/kern/rbf.py +++ b/GPy/kern/rbf.py @@ -217,8 +217,8 @@ class rbf(kernpart): #psi2 self._psi2_denom = 2.*S[:,None,None,:]/self.lengthscale2+1. # N,M,M,Q - self._psi2_mudist = mu[:,None,None,:]-self._psi2_Zhat #N,M,M,Q - self._psi2_mudist_sq, self._psi2_exponent, _ = self.weave_stuff() + self._psi2_mudist, self._psi2_mudist_sq, self._psi2_exponent, _ = self.weave_psi2(mu,self._psi2_Zhat) + #self._psi2_mudist = mu[:,None,None,:]-self._psi2_Zhat #N,M,M,Q #self._psi2_mudist_sq = np.square(self._psi2_mudist)/(self.lengthscale2*self._psi2_denom) #self._psi2_exponent = np.sum(-self._psi2_Zdist_sq/4. -self._psi2_mudist_sq -0.5*np.log(self._psi2_denom),-1) #N,M,M self._psi2 = np.square(self.variance)*np.exp(self._psi2_exponent) # N,M,M @@ -226,14 +226,18 @@ class rbf(kernpart): #store matrices for caching self._Z, self._mu, self._S = Z, mu,S - def weave_psi2(self): + def weave_psi2(self,mu,Zhat): weave_options = {'extra_compile_args': ['-O3']} - N,M,M,Q = self._psi2_mudist.shape - mudist = self._psi2_mudist - psi2_Zdist_sq = self._psi2_Zdist_sq + + N,Q = mu.shape + M = Zhat.shape[0] + + mudist = np.empty((N,M,M,Q)) mudist_sq = np.empty((N,M,M,Q)) psi2_exponent = np.zeros((N,M,M)) psi2 = np.empty((N,M,M)) + + psi2_Zdist_sq = self._psi2_Zdist_sq half_log_psi2_denom = 0.5*np.log(self._psi2_denom).squeeze() variance_sq = float(np.square(self.variance)) if self.ARD: @@ -247,14 +251,23 @@ class rbf(kernpart): for (int m=0; m