mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-10 20:42:39 +02:00
fix the pickle problem for models with psi statistics
This commit is contained in:
parent
abcea70291
commit
f5f90b3a2c
2 changed files with 10 additions and 7 deletions
|
|
@ -9,7 +9,7 @@ import ssrbf_psi_comp
|
||||||
import sslinear_psi_comp
|
import sslinear_psi_comp
|
||||||
import linear_psi_comp
|
import linear_psi_comp
|
||||||
|
|
||||||
class PSICOMP_RBF(object):
|
class PSICOMP_RBF(Pickleable):
|
||||||
|
|
||||||
@Cache_this(limit=2, ignore_args=(0,))
|
@Cache_this(limit=2, ignore_args=(0,))
|
||||||
def psicomputations(self, variance, lengthscale, Z, variational_posterior):
|
def psicomputations(self, variance, lengthscale, Z, variational_posterior):
|
||||||
|
|
@ -29,7 +29,7 @@ class PSICOMP_RBF(object):
|
||||||
else:
|
else:
|
||||||
raise ValueError, "unknown distriubtion received for psi-statistics"
|
raise ValueError, "unknown distriubtion received for psi-statistics"
|
||||||
|
|
||||||
class PSICOMP_Linear(object):
|
class PSICOMP_Linear(Pickleable):
|
||||||
|
|
||||||
@Cache_this(limit=2, ignore_args=(0,))
|
@Cache_this(limit=2, ignore_args=(0,))
|
||||||
def psicomputations(self, variance, Z, variational_posterior):
|
def psicomputations(self, variance, Z, variational_posterior):
|
||||||
|
|
|
||||||
|
|
@ -23,7 +23,7 @@ class SSGPLVM(SparseGP):
|
||||||
:type init: 'PCA'|'random'
|
:type init: 'PCA'|'random'
|
||||||
|
|
||||||
"""
|
"""
|
||||||
def __init__(self, Y, input_dim, X=None, X_variance=None, init='PCA', num_inducing=10,
|
def __init__(self, Y, input_dim, X=None, X_variance=None, Gamma=None, init='PCA', num_inducing=10,
|
||||||
Z=None, kernel=None, inference_method=None, likelihood=None, name='Spike_and_Slab GPLVM', group_spike=False, mpi_comm=None, pi=None, learnPi=True, **kwargs):
|
Z=None, kernel=None, inference_method=None, likelihood=None, name='Spike_and_Slab GPLVM', group_spike=False, mpi_comm=None, pi=None, learnPi=True, **kwargs):
|
||||||
|
|
||||||
self.mpi_comm = mpi_comm
|
self.mpi_comm = mpi_comm
|
||||||
|
|
@ -41,10 +41,13 @@ class SSGPLVM(SparseGP):
|
||||||
if X_variance is None: # The variance of the variational approximation (S)
|
if X_variance is None: # The variance of the variational approximation (S)
|
||||||
X_variance = np.random.uniform(0,.1,X.shape)
|
X_variance = np.random.uniform(0,.1,X.shape)
|
||||||
|
|
||||||
gamma = np.empty_like(X) # The posterior probabilities of the binary variable in the variational approximation
|
if Gamma is None:
|
||||||
gamma[:] = 0.5 + 0.1 * np.random.randn(X.shape[0], input_dim)
|
gamma = np.empty_like(X) # The posterior probabilities of the binary variable in the variational approximation
|
||||||
gamma[gamma>1.-1e-9] = 1.-1e-9
|
gamma[:] = 0.5 + 0.1 * np.random.randn(X.shape[0], input_dim)
|
||||||
gamma[gamma<1e-9] = 1e-9
|
gamma[gamma>1.-1e-9] = 1.-1e-9
|
||||||
|
gamma[gamma<1e-9] = 1e-9
|
||||||
|
else:
|
||||||
|
gamma = Gamma.copy()
|
||||||
|
|
||||||
if Z is None:
|
if Z is None:
|
||||||
Z = np.random.permutation(X.copy())[:num_inducing]
|
Z = np.random.permutation(X.copy())[:num_inducing]
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue