mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-10 12:32:40 +02:00
[GPU] psi1 after debug
This commit is contained in:
parent
e4d19120cd
commit
bc59cb8b22
5 changed files with 157 additions and 40 deletions
|
|
@ -11,6 +11,9 @@ from ..likelihoods import Gaussian
|
|||
from ..inference.optimization import SCG
|
||||
from ..util import linalg
|
||||
from ..core.parameterization.variational import SpikeAndSlabPrior, SpikeAndSlabPosterior
|
||||
from ..inference.latent_function_inference.var_dtc_parallel import update_gradients
|
||||
from ..inference.latent_function_inference.var_dtc_gpu import VarDTC_GPU
|
||||
|
||||
|
||||
class SSGPLVM(SparseGP):
|
||||
"""
|
||||
|
|
@ -64,8 +67,16 @@ class SSGPLVM(SparseGP):
|
|||
SparseGP.__init__(self, X, Y, Z, kernel, likelihood, inference_method, name, **kwargs)
|
||||
self.add_parameter(self.X, index=0)
|
||||
self.add_parameter(self.variational_prior)
|
||||
|
||||
def set_X_gradients(self, X, X_grad):
|
||||
"""Set the gradients of the posterior distribution of X in its specific form."""
|
||||
X.mean.gradient, X.variance.gradient, X.binary_prob.gradient = X_grad
|
||||
|
||||
def parameters_changed(self):
|
||||
if isinstance(self.inference_method, VarDTC_GPU):
|
||||
update_gradients(self)
|
||||
return
|
||||
|
||||
super(SSGPLVM, self).parameters_changed()
|
||||
self._log_marginal_likelihood -= self.variational_prior.KL_divergence(self.X)
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue