This commit is contained in:
Zhenwen Dai 2014-10-24 09:11:03 +01:00
parent bc3dab4c4d
commit db0a94e1a3
2 changed files with 15 additions and 1 deletions

View file

@ -26,7 +26,11 @@ class Inference_X(Model):
self.likelihood = model.likelihood.copy() self.likelihood = model.likelihood.copy()
self.kern = model.kern.copy() self.kern = model.kern.copy()
if model.kern.useGPU: if model.kern.useGPU:
self.kern.GPU(True) from ...models import SSGPLVM
if isinstance(model, SSGPLVM):
self.kern.GPU_SSRBF(True)
else:
self.kern.GPU(True)
from copy import deepcopy from copy import deepcopy
self.posterior = deepcopy(model.posterior) self.posterior = deepcopy(model.posterior)
self.variational_prior = model.variational_prior.copy() self.variational_prior = model.variational_prior.copy()

View file

@ -40,6 +40,16 @@ class RBF(Stationary):
else: else:
if isinstance(self.psicomp,PSICOMP_RBF_GPU): if isinstance(self.psicomp,PSICOMP_RBF_GPU):
self.psicomp = PSICOMP_RBF() self.psicomp = PSICOMP_RBF()
def GPU_SSRBF(self,flag=True):
if flag:
if isinstance(self.psicomp,PSICOMP_RBF):
self.useGPU = True
from psi_comp.ssrbf_psi_gpucomp import PSICOMP_SSRBF_GPU
self.psicomp = PSICOMP_SSRBF_GPU()
else:
if isinstance(self.psicomp,PSICOMP_RBF_GPU):
self.psicomp = PSICOMP_RBF()
def K_of_r(self, r): def K_of_r(self, r):
return self.variance * np.exp(-0.5 * r**2) return self.variance * np.exp(-0.5 * r**2)