mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-30 14:35:15 +02:00
some performance tuning for tie
This commit is contained in:
parent
cde51766ad
commit
d8a76b89de
2 changed files with 19 additions and 12 deletions
|
|
@ -86,32 +86,31 @@ class Tie(Parameterized):
|
|||
"""
|
||||
assert self.tied_param is not None
|
||||
read = np.zeros((self.tied_param.size,),dtype=np.uint8)
|
||||
ltoi = np.zeros((self.tied_param.tie.max()+1,),dtype=np.int32)
|
||||
ltoi[self.tied_param.tie] = range(self.tied_param.size)
|
||||
try:
|
||||
from scipy import weave
|
||||
def _sync_val_p(p, tieparam, read):
|
||||
def _sync_val_p(p, tieparam, read, ltoi):
|
||||
if p.tie is not None:
|
||||
# ltoi = self._label_to_idx
|
||||
ltoi = np.zeros((self.tied_param.tie.max()+1,),dtype=np.int32)
|
||||
ltoi[self.tied_param.tie] = range(self.tied_param.size)
|
||||
totie = 1 if toTiedParam else 0
|
||||
p_idx = p._raveled_index()
|
||||
p_tie = p.tie.flatten()
|
||||
p_org = p._original_.flat
|
||||
p_flat = p.flat
|
||||
p_size = p.size
|
||||
code = """
|
||||
for(int i=0;i<p_size;i++) {
|
||||
if(p_tie[i]>0) {
|
||||
if(totie==1 && read[ltoi[p_tie[i]]]==0) {
|
||||
tieparam[ltoi[p_tie[i]]] = p_org[int(p_idx[i])];
|
||||
read[ltoi[p_tie[i]]] = 1;
|
||||
if(totie==1 && read[ltoi[int(p_tie[i])]]==0) {
|
||||
tieparam[ltoi[int(p_tie[i])]] = p_flat[i];
|
||||
read[ltoi[int(p_tie[i])]] = 1+1-1;
|
||||
} else {
|
||||
p_org[int(p_idx[i])] = tieparam[ltoi[p_tie[i]]];
|
||||
p_flat[i] = tieparam[ltoi[int(p_tie[i])]];
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
weave.inline(code, arg_names=['p_org','p_idx','tieparam','p_tie','p_size','ltoi','read','totie'])
|
||||
weave.inline(code, arg_names=['p_flat','tieparam','p_tie','p_size','ltoi','read','totie'])
|
||||
except:
|
||||
raise
|
||||
def _sync_val_p(p, tieparam, read):
|
||||
if p.tie is not None:
|
||||
labels = np.unique(p.tie)
|
||||
|
|
@ -125,7 +124,7 @@ class Tie(Parameterized):
|
|||
val = tieparam[tieparam.tie==l]
|
||||
p[p.tie==l] = val
|
||||
for p in plist:
|
||||
self._traverse_param(_sync_val_p, (p,self.tied_param,read), [])
|
||||
self._traverse_param(_sync_val_p, (p,self.tied_param,read,ltoi), [])
|
||||
|
||||
def _sync_constraints(self, plist, toTiedParam=True):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -85,6 +85,14 @@ class TieTests(unittest.TestCase):
|
|||
self.assertTrue(m.ties.checkConstraintConsistency())
|
||||
self.assertTrue(m.ties.checkTieVector([m.X[:10],m.X[10:20]]))
|
||||
self.assertTrue(m.checkgrad())
|
||||
|
||||
def test_tie_variational_posterior_columns(self):
|
||||
m = GPy.examples.dimensionality_reduction.bgplvm_oil_100(plot=False,optimize=False)
|
||||
m.X.variance[:,:3].tie_vector(m.X.variance[:,3:6])
|
||||
self.assertTrue(m.ties.checkValueConsistency())
|
||||
self.assertTrue(m.ties.checkConstraintConsistency())
|
||||
self.assertTrue(m.ties.checkTieVector([m.X.variance[:,:3],m.X.variance[:,3:6]]))
|
||||
self.assertTrue(m.checkgrad())
|
||||
|
||||
if __name__ == "__main__":
|
||||
print "Running unit tests, please be (very) patient..."
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue