mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-10 20:42:39 +02:00
RBF for SSGPLVM gpu implemented
This commit is contained in:
parent
3f36a245d1
commit
9c6bfae0b9
4 changed files with 435 additions and 503 deletions
|
|
@ -2,17 +2,14 @@
|
|||
# Licensed under the BSD 3-clause license (see LICENSE.txt)
|
||||
|
||||
import numpy as np
|
||||
import itertools
|
||||
from matplotlib import pyplot
|
||||
|
||||
from ..core.sparse_gp import SparseGP
|
||||
from .. import kern
|
||||
from ..likelihoods import Gaussian
|
||||
from ..inference.optimization import SCG
|
||||
from ..util import linalg
|
||||
from ..core.parameterization.variational import SpikeAndSlabPrior, SpikeAndSlabPosterior
|
||||
from ..inference.latent_function_inference.var_dtc_parallel import update_gradients, VarDTC_minibatch
|
||||
from ..inference.latent_function_inference.var_dtc_gpu import VarDTC_GPU
|
||||
from ..kern._src.psi_comp.ssrbf_psi_gpucomp import PSICOMP_SSRBF_GPU
|
||||
|
||||
class SSGPLVM(SparseGP):
|
||||
"""
|
||||
|
|
@ -62,6 +59,8 @@ class SSGPLVM(SparseGP):
|
|||
|
||||
if kernel is None:
|
||||
kernel = kern.RBF(input_dim, lengthscale=fracs, ARD=True) # + kern.white(input_dim)
|
||||
if kernel.useGPU:
|
||||
kernel.psicomp = PSICOMP_SSRBF_GPU()
|
||||
|
||||
if inference_method is None:
|
||||
inference_method = VarDTC_minibatch(mpi_comm=mpi_comm)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue