From db0a94e1a3c094374b611cd6f516f31a3c9205c0 Mon Sep 17 00:00:00 2001 From: Zhenwen Dai Date: Fri, 24 Oct 2014 09:11:03 +0100 Subject: [PATCH] bug fix --- GPy/inference/latent_function_inference/inference_X.py | 6 +++++- GPy/kern/_src/rbf.py | 10 ++++++++++ 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/GPy/inference/latent_function_inference/inference_X.py b/GPy/inference/latent_function_inference/inference_X.py index 14b9b521..fa9b49ab 100644 --- a/GPy/inference/latent_function_inference/inference_X.py +++ b/GPy/inference/latent_function_inference/inference_X.py @@ -26,7 +26,11 @@ class Inference_X(Model): self.likelihood = model.likelihood.copy() self.kern = model.kern.copy() if model.kern.useGPU: - self.kern.GPU(True) + from ...models import SSGPLVM + if isinstance(model, SSGPLVM): + self.kern.GPU_SSRBF(True) + else: + self.kern.GPU(True) from copy import deepcopy self.posterior = deepcopy(model.posterior) self.variational_prior = model.variational_prior.copy() diff --git a/GPy/kern/_src/rbf.py b/GPy/kern/_src/rbf.py index 20976d7c..a0fd8312 100644 --- a/GPy/kern/_src/rbf.py +++ b/GPy/kern/_src/rbf.py @@ -40,6 +40,16 @@ class RBF(Stationary): else: if isinstance(self.psicomp,PSICOMP_RBF_GPU): self.psicomp = PSICOMP_RBF() + + def GPU_SSRBF(self,flag=True): + if flag: + if isinstance(self.psicomp,PSICOMP_RBF): + self.useGPU = True + from psi_comp.ssrbf_psi_gpucomp import PSICOMP_SSRBF_GPU + self.psicomp = PSICOMP_SSRBF_GPU() + else: + if isinstance(self.psicomp,PSICOMP_RBF_GPU): + self.psicomp = PSICOMP_RBF() def K_of_r(self, r): return self.variance * np.exp(-0.5 * r**2)