automatic slicing

This commit is contained in:
Max Zwiessele 2014-03-11 16:24:09 +00:00
parent e078bb47e1
commit 01f5d789c5
3 changed files with 72 additions and 144 deletions

View file

@ -56,28 +56,28 @@ class RBF(Stationary):
if isinstance(variational_posterior, variational.SpikeAndSlabPosterior):
_, _dpsi1_dvariance, _, _, _, _, _dpsi1_dlengthscale = ssrbf_psi_comp._psi1computations(self.variance, self.lengthscale, Z, variational_posterior.mean, variational_posterior.variance, variational_posterior.binary_prob)
_, _dpsi2_dvariance, _, _, _, _, _dpsi2_dlengthscale = ssrbf_psi_comp._psi2computations(self.variance, self.lengthscale, Z, variational_posterior.mean, variational_posterior.variance, variational_posterior.binary_prob)
#contributions from psi0:
self.variance.gradient = np.sum(dL_dpsi0)
#from psi1
self.variance.gradient += np.sum(dL_dpsi1 * _dpsi1_dvariance)
if self.ARD:
self.lengthscale.gradient = (dL_dpsi1[:,:,None]*_dpsi1_dlengthscale).reshape(-1,self.input_dim).sum(axis=0)
else:
self.lengthscale.gradient = (dL_dpsi1[:,:,None]*_dpsi1_dlengthscale).sum()
#from psi2
self.variance.gradient += (dL_dpsi2 * _dpsi2_dvariance).sum()
if self.ARD:
self.lengthscale.gradient += (dL_dpsi2[:,:,:,None] * _dpsi2_dlengthscale).reshape(-1,self.input_dim).sum(axis=0)
else:
self.lengthscale.gradient += (dL_dpsi2[:,:,:,None] * _dpsi2_dlengthscale).sum()
elif isinstance(variational_posterior, variational.NormalPosterior):
l2 = self.lengthscale **2
l2 = self.lengthscale**2
if l2.size != self.input_dim:
l2 = l2*np.ones(self.input_dim)
#contributions from psi0:
self.variance.gradient = np.sum(dL_dpsi0)
@ -92,11 +92,9 @@ class RBF(Stationary):
else:
self.lengthscale.gradient += dpsi1_dlength.sum()
self.variance.gradient += np.sum(dL_dpsi1 * psi1) / self.variance
#from psi2
S = variational_posterior.variance
_, Zdist_sq, _, mudist_sq, psi2 = self._psi2computations(Z, variational_posterior)
if not self.ARD:
self.lengthscale.gradient += self._weave_psi2_lengthscale_grads(dL_dpsi2, psi2, Zdist_sq, S, mudist_sq, l2).sum()
else:
@ -112,17 +110,16 @@ class RBF(Stationary):
if isinstance(variational_posterior, variational.SpikeAndSlabPosterior):
_, _, _, _, _, _dpsi1_dZ, _ = ssrbf_psi_comp._psi1computations(self.variance, self.lengthscale, Z, variational_posterior.mean, variational_posterior.variance, variational_posterior.binary_prob)
_, _, _, _, _, _dpsi2_dZ, _ = ssrbf_psi_comp._psi2computations(self.variance, self.lengthscale, Z, variational_posterior.mean, variational_posterior.variance, variational_posterior.binary_prob)
#psi1
grad = (dL_dpsi1[:, :, None] * _dpsi1_dZ).sum(axis=0)
#psi2
grad += (dL_dpsi2[:, :, :, None] * _dpsi2_dZ).sum(axis=0).sum(axis=1)
return grad
elif isinstance(variational_posterior, variational.NormalPosterior):
l2 = self.lengthscale **2
#psi1
@ -145,10 +142,10 @@ class RBF(Stationary):
# Spike-and-Slab GPLVM
if isinstance(variational_posterior, variational.SpikeAndSlabPosterior):
ndata = variational_posterior.mean.shape[0]
_, _, _dpsi1_dgamma, _dpsi1_dmu, _dpsi1_dS, _, _ = ssrbf_psi_comp._psi1computations(self.variance, self.lengthscale, Z, variational_posterior.mean, variational_posterior.variance, variational_posterior.binary_prob)
_, _, _dpsi2_dgamma, _dpsi2_dmu, _dpsi2_dS, _, _ = ssrbf_psi_comp._psi2computations(self.variance, self.lengthscale, Z, variational_posterior.mean, variational_posterior.variance, variational_posterior.binary_prob)
#psi1
grad_mu = (dL_dpsi1[:, :, None] * _dpsi1_dmu).sum(axis=1)
grad_S = (dL_dpsi1[:, :, None] * _dpsi1_dS).sum(axis=1)
@ -157,11 +154,11 @@ class RBF(Stationary):
grad_mu += (dL_dpsi2[:, :, :, None] * _dpsi2_dmu).reshape(ndata,-1,self.input_dim).sum(axis=1)
grad_S += (dL_dpsi2[:, :, :, None] * _dpsi2_dS).reshape(ndata,-1,self.input_dim).sum(axis=1)
grad_gamma += (dL_dpsi2[:,:,:, None] * _dpsi2_dgamma).reshape(ndata,-1,self.input_dim).sum(axis=1)
return grad_mu, grad_S, grad_gamma
elif isinstance(variational_posterior, variational.NormalPosterior):
l2 = self.lengthscale **2
#psi1
denom, dist, dist_sq, psi1 = self._psi1computations(Z, variational_posterior)