This commit is contained in:
Zhenwen Dai 2014-10-24 09:04:19 +01:00
parent 8856144890
commit bc3dab4c4d
2 changed files with 9 additions and 2 deletions

View file

@ -57,7 +57,8 @@ class Inference_X(Model):
if isinstance(model, SSGPLVM):
X = variational.SpikeAndSlabPosterior(param_to_array(model.X.mean[idx]), param_to_array(model.X.variance[idx]), param_to_array(model.X.gamma[idx]))
if model.group_spike:
[X.gamma[:,i].tie_together() for i in xrange(X.gamma.shape[1])] # Tie columns together
#[X.gamma[:,i].tie_together() for i in xrange(X.gamma.shape[1])] # Tie columns together
X.gamma.fix()
else:
X = variational.NormalPosterior(param_to_array(model.X.mean[idx]), param_to_array(model.X.variance[idx]))

View file

@ -30,7 +30,13 @@ class RBF(Stationary):
if flag:
if isinstance(self.psicomp,PSICOMP_RBF):
self.useGPU = True
self.psicomp = PSICOMP_RBF_GPU()
from ...models import SSGPLVM,SSMRD
if isinstance(self._highest_parent_, SSGPLVM) or isinstance(self._highest_parent_, SSGPLVM):
from psi_comp.ssrbf_psi_gpucomp import PSICOMP_SSRBF_GPU
self.psicomp = PSICOMP_SSRBF_GPU()
else:
self.psicomp = PSICOMP_RBF_GPU()
else:
if isinstance(self.psicomp,PSICOMP_RBF_GPU):
self.psicomp = PSICOMP_RBF()