attempt for mpi support for ss_mrd

This commit is contained in:
Zhenwen Dai 2014-10-17 13:09:59 +01:00
parent 43f3bfc385
commit e9d33ddc7e

View file

@ -12,8 +12,8 @@ from ..kern import RBF
class SSMRD(Model): class SSMRD(Model):
def __init__(self, Ylist, input_dim, X=None, X_variance=None, Gammas=None, initx = 'PCA_concat', initz = 'permute', def __init__(self, Ylist, input_dim, X=None, X_variance=None, Gammas=None, initx = 'PCA_concat', initz = 'permute',
num_inducing=10, Zs=None, kernel=None, inference_method=None, likelihoods=None, num_inducing=10, Zs=None, kernel=None, inference_methods=None, likelihoods=None,
pi=0.5, name='ss_mrd', Ynames=None): pi=0.5, name='ss_mrd', Ynames=None, mpi_comm=None):
super(SSMRD, self).__init__(name) super(SSMRD, self).__init__(name)
# initialize X for individual models # initialize X for individual models
@ -25,11 +25,13 @@ class SSMRD(Model):
Zs = [None]* len(Ylist) Zs = [None]* len(Ylist)
if likelihoods is None: if likelihoods is None:
likelihoods = [None]* len(Ylist) likelihoods = [None]* len(Ylist)
if inference_methods is None:
inference_methods = [None]* len(Ylist)
self.var_priors = [VarPrior_SSMRD(nModels=len(Ylist),pi=pi,learnPi=False, group_spike=True) for i in xrange(len(Ylist))] self.var_priors = [VarPrior_SSMRD(nModels=len(Ylist),pi=pi,learnPi=False, group_spike=True) for i in xrange(len(Ylist))]
self.models = [SSGPLVM(y, input_dim, X=X, X_variance=X_variance, Gamma=Gammas[i], num_inducing=num_inducing,Z=Zs[i], learnPi=False, group_spike=True, self.models = [SSGPLVM(y, input_dim, X=X, X_variance=X_variance, Gamma=Gammas[i], num_inducing=num_inducing,Z=Zs[i], learnPi=False, group_spike=True,
kernel=kernel.copy(),inference_method=inference_method,likelihood=likelihoods[i], variational_prior=self.var_priors[i], kernel=kernel.copy(),inference_method=inference_methods[i],likelihood=likelihoods[i], variational_prior=self.var_priors[i],
name='model_'+str(i)) for i,y in enumerate(Ylist)] name='model_'+str(i), mpi_comm=mpi_comm) for i,y in enumerate(Ylist)]
self.link_parameters(*(self.models)) self.link_parameters(*(self.models))
self.models[0].X.mean.tie_vector(*[m.X.mean for m in self.models[1:]]) self.models[0].X.mean.tie_vector(*[m.X.mean for m in self.models[1:]])