bug fix for tie framework. SSMRD ready!

This commit is contained in:
Zhenwen Dai 2014-10-16 19:13:32 +01:00
parent ec64f653b5
commit 43f3bfc385
3 changed files with 14 additions and 7 deletions

View file

@ -593,8 +593,7 @@ class Indexable(Nameable, Observable):
def tie_vector(self, *plist):
"""Tie a vector of parameters to other vectors of parameters"""
for p in plist:
self._highest_parent_.ties.tie_vector(self, p)
self._highest_parent_.ties.tie_vector((self,)+plist)
self._highest_parent_._trigger_params_changed()

View file

@ -218,10 +218,12 @@ class Tie(Parameterized):
def _get_labels_vector(self, p1,p2):
label1 = []
self._traverse_param(lambda x: x.tie.flat, (p1,), label1)
for p in p1:
self._traverse_param(lambda x: x.tie.flat, (p,), label1)
label1 = np.hstack(label1)
label2 = []
self._traverse_param(lambda x: x.tie.flat, (p2,), label2)
for p in p2:
self._traverse_param(lambda x: x.tie.flat, (p,), label2)
label2 = np.hstack(label2)
expandlist = np.where(label1+label2==0)[0]
labellist =label1.copy()
@ -371,14 +373,20 @@ class Tie(Parameterized):
self.update_model(True)
def tie_vector(self, plist):
assert len(plist)>=2
p_splits = [self._keepParamList([p]) for p in plist]
for p_split2 in p_splits[1:]:
p_split1 = p_splits[0]
p1 = self._updateParamList(p_split1)
p2 = self._updateParamList(p_split2)
self._tie_vector(p1, p2)
def _tie_vector(self, p1, p2):
"""tie a pair of vectors"""
self.update_model(False)
expandlist,removelist,labellist = self._get_labels_vector(p1, p2)
p_split1 = self._keepParamList([p1])
p_split2 = self._keepParamList([p2])
p_split1 = self._keepParamList(p1)
p_split2 = self._keepParamList(p2)
if len(expandlist)>0:
tie_labels,idxlist = self._expand_tie_param(len(expandlist))
labellist[expandlist] = tie_labels

View file

@ -27,7 +27,7 @@ class SSMRD(Model):
likelihoods = [None]* len(Ylist)
self.var_priors = [VarPrior_SSMRD(nModels=len(Ylist),pi=pi,learnPi=False, group_spike=True) for i in xrange(len(Ylist))]
self.models = [SSGPLVM(y, input_dim, X=X, X_variance=X_variance, Gamma=Gammas[i], num_inducing=num_inducing,Z=Zs[i],learnPi=False, group_spike=True,
self.models = [SSGPLVM(y, input_dim, X=X, X_variance=X_variance, Gamma=Gammas[i], num_inducing=num_inducing,Z=Zs[i], learnPi=False, group_spike=True,
kernel=kernel.copy(),inference_method=inference_method,likelihood=likelihoods[i], variational_prior=self.var_priors[i],
name='model_'+str(i)) for i,y in enumerate(Ylist)]
self.link_parameters(*(self.models))