diff --git a/GPy/core/parameterization/ties_and_remappings.py b/GPy/core/parameterization/ties_and_remappings.py index 8cec3251..26240c19 100644 --- a/GPy/core/parameterization/ties_and_remappings.py +++ b/GPy/core/parameterization/ties_and_remappings.py @@ -86,32 +86,31 @@ class Tie(Parameterized): """ assert self.tied_param is not None read = np.zeros((self.tied_param.size,),dtype=np.uint8) + ltoi = np.zeros((self.tied_param.tie.max()+1,),dtype=np.int32) + ltoi[self.tied_param.tie] = range(self.tied_param.size) try: from scipy import weave - def _sync_val_p(p, tieparam, read): + def _sync_val_p(p, tieparam, read, ltoi): if p.tie is not None: -# ltoi = self._label_to_idx - ltoi = np.zeros((self.tied_param.tie.max()+1,),dtype=np.int32) - ltoi[self.tied_param.tie] = range(self.tied_param.size) totie = 1 if toTiedParam else 0 - p_idx = p._raveled_index() p_tie = p.tie.flatten() - p_org = p._original_.flat + p_flat = p.flat p_size = p.size code = """ for(int i=0;i0) { - if(totie==1 && read[ltoi[p_tie[i]]]==0) { - tieparam[ltoi[p_tie[i]]] = p_org[int(p_idx[i])]; - read[ltoi[p_tie[i]]] = 1; + if(totie==1 && read[ltoi[int(p_tie[i])]]==0) { + tieparam[ltoi[int(p_tie[i])]] = p_flat[i]; + read[ltoi[int(p_tie[i])]] = 1+1-1; } else { - p_org[int(p_idx[i])] = tieparam[ltoi[p_tie[i]]]; + p_flat[i] = tieparam[ltoi[int(p_tie[i])]]; } } } """ - weave.inline(code, arg_names=['p_org','p_idx','tieparam','p_tie','p_size','ltoi','read','totie']) + weave.inline(code, arg_names=['p_flat','tieparam','p_tie','p_size','ltoi','read','totie']) except: + raise def _sync_val_p(p, tieparam, read): if p.tie is not None: labels = np.unique(p.tie) @@ -125,7 +124,7 @@ class Tie(Parameterized): val = tieparam[tieparam.tie==l] p[p.tie==l] = val for p in plist: - self._traverse_param(_sync_val_p, (p,self.tied_param,read), []) + self._traverse_param(_sync_val_p, (p,self.tied_param,read,ltoi), []) def _sync_constraints(self, plist, toTiedParam=True): """ diff --git a/GPy/testing/tie_tests.py b/GPy/testing/tie_tests.py index ad3ee746..6ac84b40 100644 --- a/GPy/testing/tie_tests.py +++ b/GPy/testing/tie_tests.py @@ -85,6 +85,14 @@ class TieTests(unittest.TestCase): self.assertTrue(m.ties.checkConstraintConsistency()) self.assertTrue(m.ties.checkTieVector([m.X[:10],m.X[10:20]])) self.assertTrue(m.checkgrad()) + + def test_tie_variational_posterior_columns(self): + m = GPy.examples.dimensionality_reduction.bgplvm_oil_100(plot=False,optimize=False) + m.X.variance[:,:3].tie_vector(m.X.variance[:,3:6]) + self.assertTrue(m.ties.checkValueConsistency()) + self.assertTrue(m.ties.checkConstraintConsistency()) + self.assertTrue(m.ties.checkTieVector([m.X.variance[:,:3],m.X.variance[:,3:6]])) + self.assertTrue(m.checkgrad()) if __name__ == "__main__": print "Running unit tests, please be (very) patient..."