add sample_W to SSGPLVM

This commit is contained in:
Zhenwen Dai 2015-11-03 15:09:59 +00:00
parent 2bf1be5500
commit 92248b62e6

View file

@ -190,4 +190,38 @@ class SSGPLVM(SparseGP_MPI):
if self.kern.ARD:
return self.kern.input_sensitivity()
else:
return self.variational_prior.pi
return self.variational_prior.pi
def sample_W(self, nSamples, raw_samples=False):
"""
Sample the loading matrix if the kernel is linear.
"""
assert isinstance(self.kern, kern.Linear)
from ..util.linalg import pdinv
N, D = self.Y.shape
Q = self.X.shape[1]
noise_var = self.likelihood.variance.values
# Draw samples for X
Xs = np.random.randn(*((nSamples,)+self.X.shape))*np.sqrt(self.X.variance.values)+self.X.mean.values
b = np.random.rand(*((nSamples,)+self.X.shape))
Xs[b>self.X.gamma.values] = 0
invcov = (Xs[:,:,:,None]*Xs[:,:,None,:]).sum(1)/noise_var+np.eye(Q)
cov = np.array([pdinv(invcov[s_idx])[0] for s_idx in xrange(invcov.shape[0])])
Ws = np.empty((nSamples, Q, D))
tmp = (np.transpose(Xs, (0,2,1)).reshape(nSamples*Q,N).dot(self.Y)).reshape(nSamples,Q,D)
mean = (cov[:,:,:,None]*tmp[:,None,:,:]).sum(2)/noise_var
zeros = np.zeros((Q,))
for s_idx in xrange(Xs.shape[0]):
Ws[s_idx] = (np.random.multivariate_normal(mean=zeros,cov=cov[s_idx],size=(D,))).T+mean[s_idx]
if raw_samples:
return Ws
else:
return Ws.mean(0), Ws.std(0)