some performance tuning for tie

This commit is contained in:
Zhenwen Dai 2014-10-30 23:47:41 +00:00
parent cde51766ad
commit d8a76b89de
2 changed files with 19 additions and 12 deletions

View file

@ -86,32 +86,31 @@ class Tie(Parameterized):
""" """
assert self.tied_param is not None assert self.tied_param is not None
read = np.zeros((self.tied_param.size,),dtype=np.uint8) read = np.zeros((self.tied_param.size,),dtype=np.uint8)
ltoi = np.zeros((self.tied_param.tie.max()+1,),dtype=np.int32)
ltoi[self.tied_param.tie] = range(self.tied_param.size)
try: try:
from scipy import weave from scipy import weave
def _sync_val_p(p, tieparam, read): def _sync_val_p(p, tieparam, read, ltoi):
if p.tie is not None: if p.tie is not None:
# ltoi = self._label_to_idx
ltoi = np.zeros((self.tied_param.tie.max()+1,),dtype=np.int32)
ltoi[self.tied_param.tie] = range(self.tied_param.size)
totie = 1 if toTiedParam else 0 totie = 1 if toTiedParam else 0
p_idx = p._raveled_index()
p_tie = p.tie.flatten() p_tie = p.tie.flatten()
p_org = p._original_.flat p_flat = p.flat
p_size = p.size p_size = p.size
code = """ code = """
for(int i=0;i<p_size;i++) { for(int i=0;i<p_size;i++) {
if(p_tie[i]>0) { if(p_tie[i]>0) {
if(totie==1 && read[ltoi[p_tie[i]]]==0) { if(totie==1 && read[ltoi[int(p_tie[i])]]==0) {
tieparam[ltoi[p_tie[i]]] = p_org[int(p_idx[i])]; tieparam[ltoi[int(p_tie[i])]] = p_flat[i];
read[ltoi[p_tie[i]]] = 1; read[ltoi[int(p_tie[i])]] = 1+1-1;
} else { } else {
p_org[int(p_idx[i])] = tieparam[ltoi[p_tie[i]]]; p_flat[i] = tieparam[ltoi[int(p_tie[i])]];
} }
} }
} }
""" """
weave.inline(code, arg_names=['p_org','p_idx','tieparam','p_tie','p_size','ltoi','read','totie']) weave.inline(code, arg_names=['p_flat','tieparam','p_tie','p_size','ltoi','read','totie'])
except: except:
raise
def _sync_val_p(p, tieparam, read): def _sync_val_p(p, tieparam, read):
if p.tie is not None: if p.tie is not None:
labels = np.unique(p.tie) labels = np.unique(p.tie)
@ -125,7 +124,7 @@ class Tie(Parameterized):
val = tieparam[tieparam.tie==l] val = tieparam[tieparam.tie==l]
p[p.tie==l] = val p[p.tie==l] = val
for p in plist: for p in plist:
self._traverse_param(_sync_val_p, (p,self.tied_param,read), []) self._traverse_param(_sync_val_p, (p,self.tied_param,read,ltoi), [])
def _sync_constraints(self, plist, toTiedParam=True): def _sync_constraints(self, plist, toTiedParam=True):
""" """

View file

@ -86,6 +86,14 @@ class TieTests(unittest.TestCase):
self.assertTrue(m.ties.checkTieVector([m.X[:10],m.X[10:20]])) self.assertTrue(m.ties.checkTieVector([m.X[:10],m.X[10:20]]))
self.assertTrue(m.checkgrad()) self.assertTrue(m.checkgrad())
def test_tie_variational_posterior_columns(self):
m = GPy.examples.dimensionality_reduction.bgplvm_oil_100(plot=False,optimize=False)
m.X.variance[:,:3].tie_vector(m.X.variance[:,3:6])
self.assertTrue(m.ties.checkValueConsistency())
self.assertTrue(m.ties.checkConstraintConsistency())
self.assertTrue(m.ties.checkTieVector([m.X.variance[:,:3],m.X.variance[:,3:6]]))
self.assertTrue(m.checkgrad())
if __name__ == "__main__": if __name__ == "__main__":
print "Running unit tests, please be (very) patient..." print "Running unit tests, please be (very) patient..."
unittest.main() unittest.main()