diff --git a/GPy/kern/_src/psi_comp/linear_psi_comp.py b/GPy/kern/_src/psi_comp/linear_psi_comp.py index f3262d3b..87ceaf66 100644 --- a/GPy/kern/_src/psi_comp/linear_psi_comp.py +++ b/GPy/kern/_src/psi_comp/linear_psi_comp.py @@ -21,11 +21,12 @@ def psicomputations(variance, Z, variational_posterior, return_psi2_n=False): S = variational_posterior.variance psi0 = (variance*(np.square(mu)+S)).sum(axis=1) - psi1 = np.dot(mu,(variance*Z).T) - if not return_psi2_n: - psi2 = np.dot(S.sum(axis=0)*np.square(variance)*Z,Z.T)+ tdot(psi1.T) + Zv = variance * Z + psi1 = np.dot(mu,Zv.T) + if return_psi2_n: + psi2 = psi1[:,:,None] * psi1[:,None,:] + np.dot(S[:,None,:] * Zv[None,:,:], Zv.T) else: - raise NotImplementedError + psi2 = np.dot(S.sum(axis=0) * Zv, Zv.T) + tdot(psi1.T) return psi0, psi1, psi2