[psi2] implement RBF cpu

This commit is contained in:
Zhenwen Dai 2014-05-21 10:34:51 +01:00
parent 001db6b089
commit a2203179f6
5 changed files with 92 additions and 75 deletions

View file

@ -107,5 +107,3 @@ class SparseGP(GP):
psi2 = kern.psi2(self.Z, Xnew)
var = Kxx - np.sum(np.sum(psi2 * Kmmi_LmiBLmi[None, :, :], 1), 1)
return mu, var

View file

@ -7,7 +7,7 @@ The package for the psi statistics computation
import numpy as np
from . import PSICOMP
from GPy.util.caching import Cache_this,Cacher
from GPy.util.caching import Cache_this
class PSICOMP_SSRBF(PSICOMP):

View file

@ -8,8 +8,7 @@ from ...util.misc import param_to_array
from stationary import Stationary
from GPy.util.caching import Cache_this
from ...core.parameterization import variational
from psi_comp import ssrbf_psi_comp
from psi_comp import ssrbf_psi_gpucomp
from psi_comp import ssrbf_psi_comp,ssrbf_psi_gpucomp,rbf_psi_comp
from ...util.config import *
class RBF(Stationary):
@ -26,6 +25,7 @@ class RBF(Stationary):
super(RBF, self).__init__(input_dim, variance, lengthscale, ARD, active_dims, name, useGPU=useGPU)
self.weave_options = {}
self.group_spike_prob = False
self.psicomp = rbf_psi_comp.PSICOMP_RBF()
def set_for_SpikeAndSlab(self):
if self.useGPU:
@ -50,7 +50,8 @@ class RBF(Stationary):
else:
return self.psicomp.psicomputations(self.variance, self.lengthscale, Z, variational_posterior)[0]
else:
return self.Kdiag(variational_posterior.mean)
# return self.Kdiag(variational_posterior.mean)
return self.psicomp.psicomputations(self.variance, self.lengthscale, Z, variational_posterior)[0]
def psi1(self, Z, variational_posterior):
if isinstance(variational_posterior, variational.SpikeAndSlabPosterior):
@ -59,8 +60,7 @@ class RBF(Stationary):
else:
return self.psicomp.psicomputations(self.variance, self.lengthscale, Z, variational_posterior)[1]
else:
_, _, _, psi1 = self._psi1computations(Z, variational_posterior)
return psi1
return self.psicomp.psicomputations(self.variance, self.lengthscale, Z, variational_posterior)[1]
def psi2(self, Z, variational_posterior):
if isinstance(variational_posterior, variational.SpikeAndSlabPosterior):
@ -69,8 +69,7 @@ class RBF(Stationary):
else:
return self.psicomp.psicomputations(self.variance, self.lengthscale, Z, variational_posterior)[2]
else:
_, _, _, _, psi2 = self._psi2computations(Z, variational_posterior)
return psi2
return self.psicomp.psicomputations(self.variance, self.lengthscale, Z, variational_posterior)[2]
def update_gradients_expectations(self, dL_dpsi0, dL_dpsi1, dL_dpsi2, Z, variational_posterior):
# Spike-and-Slab GPLVM
@ -83,32 +82,37 @@ class RBF(Stationary):
self.lengthscale.gradient = dL_dlengscale
elif isinstance(variational_posterior, variational.NormalPosterior):
l2 = self.lengthscale**2
if l2.size != self.input_dim:
l2 = l2*np.ones(self.input_dim)
dL_dvar, dL_dlengscale, _, _, _ = self.psicomp.psiDerivativecomputations(dL_dpsi0, dL_dpsi1, dL_dpsi2, self.variance, self.lengthscale, Z, variational_posterior)
self.variance.gradient = dL_dvar
self.lengthscale.gradient = dL_dlengscale
#contributions from psi0:
self.variance.gradient = np.sum(dL_dpsi0)
self.lengthscale.gradient = 0.
#from psi1
denom, _, dist_sq, psi1 = self._psi1computations(Z, variational_posterior)
d_length = psi1[:,:,None] * ((dist_sq - 1.)/(self.lengthscale*denom) +1./self.lengthscale)
dpsi1_dlength = d_length * dL_dpsi1[:, :, None]
if self.ARD:
self.lengthscale.gradient += dpsi1_dlength.sum(0).sum(0)
else:
self.lengthscale.gradient += dpsi1_dlength.sum()
self.variance.gradient += np.sum(dL_dpsi1 * psi1) / self.variance
#from psi2
S = variational_posterior.variance
_, Zdist_sq, _, mudist_sq, psi2 = self._psi2computations(Z, variational_posterior)
if not self.ARD:
self.lengthscale.gradient += self._weave_psi2_lengthscale_grads(dL_dpsi2, psi2, Zdist_sq, S, mudist_sq, l2).sum()
else:
self.lengthscale.gradient += self._weave_psi2_lengthscale_grads(dL_dpsi2, psi2, Zdist_sq, S, mudist_sq, l2)
self.variance.gradient += 2.*np.sum(dL_dpsi2 * psi2)/self.variance
# l2 = self.lengthscale**2
# if l2.size != self.input_dim:
# l2 = l2*np.ones(self.input_dim)
# #contributions from psi0:
# self.variance.gradient = np.sum(dL_dpsi0)
# self.lengthscale.gradient = 0.
#
# # #from psi1
# denom, _, dist_sq, psi1 = self._psi1computations(Z, variational_posterior)
# d_length = psi1[:,:,None] * ((dist_sq - 1.)/(self.lengthscale*denom) +1./self.lengthscale)
# dpsi1_dlength = d_length * dL_dpsi1[:, :, None]
# print dpsi1_dlength.sum(0).sum(0)
# if self.ARD:
# self.lengthscale.gradient += dpsi1_dlength.sum(0).sum(0)
# else:
# self.lengthscale.gradient += dpsi1_dlength.sum()
# self.variance.gradient += np.sum(dL_dpsi1 * psi1) / self.variance
# #from psi2
# S = variational_posterior.variance
# _, Zdist_sq, _, mudist_sq, psi2 = self._psi2computations(Z, variational_posterior)
# if not self.ARD:
# self.lengthscale.gradient += self._weave_psi2_lengthscale_grads(dL_dpsi2, psi2, Zdist_sq, S, mudist_sq, l2).sum()
# else:
# self.lengthscale.gradient += self._weave_psi2_lengthscale_grads(dL_dpsi2, psi2, Zdist_sq, S, mudist_sq, l2)
# # print self._weave_psi2_lengthscale_grads(dL_dpsi2, psi2, Zdist_sq, S, mudist_sq, l2)
#
# self.variance.gradient += 2.*np.sum(dL_dpsi2 * psi2)/self.variance
else:
raise ValueError, "unknown distriubtion received for psi-statistics"
@ -123,21 +127,24 @@ class RBF(Stationary):
return dL_dZ
elif isinstance(variational_posterior, variational.NormalPosterior):
l2 = self.lengthscale **2
#psi1
denom, dist, dist_sq, psi1 = self._psi1computations(Z, variational_posterior)
grad = np.einsum('ij,ij,ijk,ijk->jk', dL_dpsi1, psi1, dist, -1./(denom*l2))
#psi2
Zdist, Zdist_sq, mudist, mudist_sq, psi2 = self._psi2computations(Z, variational_posterior)
term1 = Zdist / l2 # M, M, Q
S = variational_posterior.variance
term2 = mudist / (2.*S[:,None,None,:] + l2) # N, M, M, Q
grad += 2.*np.einsum('ijk,ijk,ijkl->kl', dL_dpsi2, psi2, term1[None,:,:,:] + term2)
return grad
_, _, dL_dZ, _, _ = self.psicomp.psiDerivativecomputations(dL_dpsi0, dL_dpsi1, dL_dpsi2, self.variance, self.lengthscale, Z, variational_posterior)
return dL_dZ
#
# l2 = self.lengthscale **2
#
# #psi1
# denom, dist, dist_sq, psi1 = self._psi1computations(Z, variational_posterior)
# grad = np.einsum('ij,ij,ijk,ijk->jk', dL_dpsi1, psi1, dist, -1./(denom*l2))
#
# #psi2
# Zdist, Zdist_sq, mudist, mudist_sq, psi2 = self._psi2computations(Z, variational_posterior)
# term1 = Zdist / l2 # M, M, Q
# S = variational_posterior.variance
# term2 = mudist / (2.*S[:,None,None,:] + l2) # N, M, M, Q
#
# grad += 2.*np.einsum('ijk,ijk,ijkl->kl', dL_dpsi2, psi2, term1[None,:,:,:] + term2)
#
# return grad
else:
raise ValueError, "unknown distriubtion received for psi-statistics"
@ -151,24 +158,27 @@ class RBF(Stationary):
return dL_dmu, dL_dS, dL_dgamma
elif isinstance(variational_posterior, variational.NormalPosterior):
l2 = self.lengthscale **2
#psi1
denom, dist, dist_sq, psi1 = self._psi1computations(Z, variational_posterior)
tmp = psi1[:, :, None] / l2 / denom
grad_mu = np.sum(dL_dpsi1[:, :, None] * tmp * dist, 1)
grad_S = np.sum(dL_dpsi1[:, :, None] * 0.5 * tmp * (dist_sq - 1), 1)
#psi2
_, _, mudist, mudist_sq, psi2 = self._psi2computations(Z, variational_posterior)
S = variational_posterior.variance
tmp = psi2[:, :, :, None] / (2.*S[:,None,None,:] + l2)
grad_mu += -2.*np.einsum('ijk,ijkl,ijkl->il', dL_dpsi2, tmp , mudist)
grad_S += np.einsum('ijk,ijkl,ijkl->il', dL_dpsi2 , tmp , (2.*mudist_sq - 1))
_, _, _, dL_dmu, dL_dS = self.psicomp.psiDerivativecomputations(dL_dpsi0, dL_dpsi1, dL_dpsi2, self.variance, self.lengthscale, Z, variational_posterior)
# l2 = self.lengthscale **2
# #psi1
# denom, dist, dist_sq, psi1 = self._psi1computations(Z, variational_posterior)
# tmp = psi1[:, :, None] / l2 / denom
# grad_mu = np.sum(dL_dpsi1[:, :, None] * tmp * dist, 1)
# grad_S = np.sum(dL_dpsi1[:, :, None] * 0.5 * tmp * (dist_sq - 1), 1)
# #psi2
# _, _, mudist, mudist_sq, psi2 = self._psi2computations(Z, variational_posterior)
# S = variational_posterior.variance
# tmp = psi2[:, :, :, None] / (2.*S[:,None,None,:] + l2)
# grad_mu += -2.*np.einsum('jk,ijkl,ijkl->il', dL_dpsi2, tmp , mudist)
# grad_S += np.einsum('jk,ijkl,ijkl->il', dL_dpsi2 , tmp , (2.*mudist_sq - 1))
return dL_dmu, dL_dS
else:
raise ValueError, "unknown distriubtion received for psi-statistics"
return grad_mu, grad_S
#return grad_mu, grad_S
#---------------------------------------#
# Precomputations #

View file

@ -8,7 +8,7 @@ from ..likelihoods import Gaussian
from ..inference.optimization import SCG
from ..util import linalg
from ..core.parameterization.variational import NormalPosterior, NormalPrior, VariationalPosterior
from ..inference.latent_function_inference.var_dtc_parallel import update_gradients
from ..inference.latent_function_inference.var_dtc_parallel import update_gradients, VarDTC_minibatch
from ..inference.latent_function_inference.var_dtc_gpu import VarDTC_GPU
class BayesianGPLVM(SparseGP):
@ -67,7 +67,7 @@ class BayesianGPLVM(SparseGP):
X.mean.gradient, X.variance.gradient = X_grad
def parameters_changed(self):
if isinstance(self.inference_method, VarDTC_GPU):
if isinstance(self.inference_method, VarDTC_GPU) or isinstance(self.inference_method, VarDTC_minibatch):
update_gradients(self)
return

View file

@ -89,7 +89,7 @@ class vector_show(matplotlib_show):
class lvm(matplotlib_show):
def __init__(self, vals, model, data_visualize, latent_axes=None, sense_axes=None, latent_index=[0,1]):
def __init__(self, vals, model, data_visualize, latent_axes=None, sense_axes=None, latent_index=[0,1], disable_drag=False):
"""Visualize a latent variable model
:param model: the latent variable model to visualize.
@ -108,12 +108,14 @@ class lvm(matplotlib_show):
if isinstance(latent_axes,mpl.axes.Axes):
self.cid = latent_axes.figure.canvas.mpl_connect('button_press_event', self.on_click)
self.cid = latent_axes.figure.canvas.mpl_connect('motion_notify_event', self.on_move)
if not disable_drag:
self.cid = latent_axes.figure.canvas.mpl_connect('motion_notify_event', self.on_move)
self.cid = latent_axes.figure.canvas.mpl_connect('axes_leave_event', self.on_leave)
self.cid = latent_axes.figure.canvas.mpl_connect('axes_enter_event', self.on_enter)
else:
self.cid = latent_axes[0].figure.canvas.mpl_connect('button_press_event', self.on_click)
self.cid = latent_axes[0].figure.canvas.mpl_connect('motion_notify_event', self.on_move)
if not disable_drag:
self.cid = latent_axes[0].figure.canvas.mpl_connect('motion_notify_event', self.on_move)
self.cid = latent_axes[0].figure.canvas.mpl_connect('axes_leave_event', self.on_leave)
self.cid = latent_axes[0].figure.canvas.mpl_connect('axes_enter_event', self.on_enter)
@ -125,6 +127,7 @@ class lvm(matplotlib_show):
self.move_on = False
self.latent_index = latent_index
self.latent_dim = model.input_dim
self.disable_drag = disable_drag
# The red cross which shows current latent point.
self.latent_values = vals
@ -149,8 +152,13 @@ class lvm(matplotlib_show):
def on_click(self, event):
print 'click!'
if event.inaxes!=self.latent_axes: return
self.move_on = not self.move_on
self.called = True
if self.disable_drag:
self.move_on = True
self.called = True
self.on_move(event)
else:
self.move_on = not self.move_on
self.called = True
def on_move(self, event):
if event.inaxes!=self.latent_axes: return
@ -400,7 +408,7 @@ class mocap_data_show(matplotlib_show):
def __init__(self, vals, axes=None, connect=None):
if axes==None:
fig = plt.figure()
axes = fig.add_subplot(111, projection='3d',aspect='equal')
axes = fig.add_subplot(111, projection='3d', aspect='equal')
matplotlib_show.__init__(self, vals, axes)
self.connect = connect
@ -438,6 +446,7 @@ class mocap_data_show(matplotlib_show):
self.process_values()
self.initialize_axes_modify()
self.draw_vertices()
self.initialize_axes()
self.finalize_axes_modify()
self.draw_edges()
self.axes.figure.canvas.draw()
@ -460,10 +469,10 @@ class mocap_data_show(matplotlib_show):
self.axes.set_xlim(self.x_lim)
self.axes.set_ylim(self.y_lim)
self.axes.set_zlim(self.z_lim)
self.axes.auto_scale_xyz([-1., 1.], [-1., 1.], [-1.5, 1.5])
self.axes.auto_scale_xyz([-1., 1.], [-1., 1.], [-1., 1.])
#self.axes.set_aspect('equal')
self.axes.autoscale(enable=False)
# self.axes.set_aspect('equal')
# self.axes.autoscale(enable=False)
def finalize_axes_modify(self):
self.axes.set_xlim(self.x_lim)