mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-30 14:35:15 +02:00
some performance tuning for tie
This commit is contained in:
parent
cde51766ad
commit
d8a76b89de
2 changed files with 19 additions and 12 deletions
|
|
@ -86,32 +86,31 @@ class Tie(Parameterized):
|
||||||
"""
|
"""
|
||||||
assert self.tied_param is not None
|
assert self.tied_param is not None
|
||||||
read = np.zeros((self.tied_param.size,),dtype=np.uint8)
|
read = np.zeros((self.tied_param.size,),dtype=np.uint8)
|
||||||
|
ltoi = np.zeros((self.tied_param.tie.max()+1,),dtype=np.int32)
|
||||||
|
ltoi[self.tied_param.tie] = range(self.tied_param.size)
|
||||||
try:
|
try:
|
||||||
from scipy import weave
|
from scipy import weave
|
||||||
def _sync_val_p(p, tieparam, read):
|
def _sync_val_p(p, tieparam, read, ltoi):
|
||||||
if p.tie is not None:
|
if p.tie is not None:
|
||||||
# ltoi = self._label_to_idx
|
|
||||||
ltoi = np.zeros((self.tied_param.tie.max()+1,),dtype=np.int32)
|
|
||||||
ltoi[self.tied_param.tie] = range(self.tied_param.size)
|
|
||||||
totie = 1 if toTiedParam else 0
|
totie = 1 if toTiedParam else 0
|
||||||
p_idx = p._raveled_index()
|
|
||||||
p_tie = p.tie.flatten()
|
p_tie = p.tie.flatten()
|
||||||
p_org = p._original_.flat
|
p_flat = p.flat
|
||||||
p_size = p.size
|
p_size = p.size
|
||||||
code = """
|
code = """
|
||||||
for(int i=0;i<p_size;i++) {
|
for(int i=0;i<p_size;i++) {
|
||||||
if(p_tie[i]>0) {
|
if(p_tie[i]>0) {
|
||||||
if(totie==1 && read[ltoi[p_tie[i]]]==0) {
|
if(totie==1 && read[ltoi[int(p_tie[i])]]==0) {
|
||||||
tieparam[ltoi[p_tie[i]]] = p_org[int(p_idx[i])];
|
tieparam[ltoi[int(p_tie[i])]] = p_flat[i];
|
||||||
read[ltoi[p_tie[i]]] = 1;
|
read[ltoi[int(p_tie[i])]] = 1+1-1;
|
||||||
} else {
|
} else {
|
||||||
p_org[int(p_idx[i])] = tieparam[ltoi[p_tie[i]]];
|
p_flat[i] = tieparam[ltoi[int(p_tie[i])]];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
weave.inline(code, arg_names=['p_org','p_idx','tieparam','p_tie','p_size','ltoi','read','totie'])
|
weave.inline(code, arg_names=['p_flat','tieparam','p_tie','p_size','ltoi','read','totie'])
|
||||||
except:
|
except:
|
||||||
|
raise
|
||||||
def _sync_val_p(p, tieparam, read):
|
def _sync_val_p(p, tieparam, read):
|
||||||
if p.tie is not None:
|
if p.tie is not None:
|
||||||
labels = np.unique(p.tie)
|
labels = np.unique(p.tie)
|
||||||
|
|
@ -125,7 +124,7 @@ class Tie(Parameterized):
|
||||||
val = tieparam[tieparam.tie==l]
|
val = tieparam[tieparam.tie==l]
|
||||||
p[p.tie==l] = val
|
p[p.tie==l] = val
|
||||||
for p in plist:
|
for p in plist:
|
||||||
self._traverse_param(_sync_val_p, (p,self.tied_param,read), [])
|
self._traverse_param(_sync_val_p, (p,self.tied_param,read,ltoi), [])
|
||||||
|
|
||||||
def _sync_constraints(self, plist, toTiedParam=True):
|
def _sync_constraints(self, plist, toTiedParam=True):
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -85,6 +85,14 @@ class TieTests(unittest.TestCase):
|
||||||
self.assertTrue(m.ties.checkConstraintConsistency())
|
self.assertTrue(m.ties.checkConstraintConsistency())
|
||||||
self.assertTrue(m.ties.checkTieVector([m.X[:10],m.X[10:20]]))
|
self.assertTrue(m.ties.checkTieVector([m.X[:10],m.X[10:20]]))
|
||||||
self.assertTrue(m.checkgrad())
|
self.assertTrue(m.checkgrad())
|
||||||
|
|
||||||
|
def test_tie_variational_posterior_columns(self):
|
||||||
|
m = GPy.examples.dimensionality_reduction.bgplvm_oil_100(plot=False,optimize=False)
|
||||||
|
m.X.variance[:,:3].tie_vector(m.X.variance[:,3:6])
|
||||||
|
self.assertTrue(m.ties.checkValueConsistency())
|
||||||
|
self.assertTrue(m.ties.checkConstraintConsistency())
|
||||||
|
self.assertTrue(m.ties.checkTieVector([m.X.variance[:,:3],m.X.variance[:,3:6]]))
|
||||||
|
self.assertTrue(m.checkgrad())
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
print "Running unit tests, please be (very) patient..."
|
print "Running unit tests, please be (very) patient..."
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue