fix sparse_gplvm

This commit is contained in:
Zhenwen Dai 2014-07-02 14:21:27 +01:00
parent 1c165db845
commit f16db00438
2 changed files with 8 additions and 45 deletions

View file

@ -11,7 +11,7 @@ from GPy.models.gplvm import GPLVM
# from ..core import model
# from ..util.linalg import pdinv, PCA
class SparseGPLVM(SparseGPRegression, GPLVM):
class SparseGPLVM(SparseGPRegression):
"""
Sparse Gaussian Process Latent Variable Model
@ -23,40 +23,12 @@ class SparseGPLVM(SparseGPRegression, GPLVM):
:type init: 'PCA'|'random'
"""
def __init__(self, Y, input_dim, kernel=None, init='PCA', num_inducing=10):
X = self.initialise_latent(init, input_dim, Y)
def __init__(self, Y, input_dim, X=None, kernel=None, init='PCA', num_inducing=10):
if X is None:
from ..util.initialization import initialize_latent
X, fracs = initialize_latent(init, input_dim, Y)
SparseGPRegression.__init__(self, X, Y, kernel=kernel, num_inducing=num_inducing)
self.ensure_default_constraints()
def _get_param_names(self):
return (sum([['X_%i_%i' % (n, q) for q in range(self.input_dim)] for n in range(self.num_data)], [])
+ SparseGPRegression._get_param_names(self))
def _get_params(self):
return np.hstack((self.X.flatten(), SparseGPRegression._get_params(self)))
def _set_params(self, x):
self.X = x[:self.X.size].reshape(self.num_data, self.input_dim).copy()
SparseGPRegression._set_params(self, x[self.X.size:])
def log_likelihood(self):
return SparseGPRegression.log_likelihood(self)
def dL_dX(self):
dL_dX = self.kern.dKdiag_dX(self.dL_dpsi0, self.X)
dL_dX += self.kern.gradients_X(self.dL_dpsi1, self.X, self.Z)
return dL_dX
def _log_likelihood_gradients(self):
return np.hstack((self.dL_dX().flatten(), SparseGPRegression._log_likelihood_gradients(self)))
def plot(self):
GPLVM.plot(self)
# passing Z without a small amout of jitter will induce the white kernel where we don;t want it!
mu, var, upper, lower = SparseGPRegression.predict(self, self.Z + np.random.randn(*self.Z.shape) * 0.0001)
pb.plot(mu[:, 0] , mu[:, 1], 'ko')
def plot_latent(self, *args, **kwargs):
input_1, input_2 = GPLVM.plot_latent(*args, **kwargs)
pb.plot(m.Z[:, input_1], m.Z[:, input_2], '^w')
def parameters_changed(self):
super(SparseGPLVM, self).parameters_changed()
self.X.gradient = self.kern.gradients_X(self.grad_dict['dL_dKnm'], self.X, self.Z)