mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-30 14:35:15 +02:00
some speedup for tie
This commit is contained in:
parent
c461e0cfd3
commit
7ca133977b
1 changed files with 54 additions and 20 deletions
|
|
@ -473,7 +473,7 @@ class Tie(Parameterized):
|
|||
self._untie_ = None
|
||||
else:
|
||||
self.label_buf = np.zeros((self._highest_parent_.param_array.size,),dtype=np.uint32)
|
||||
self._traverse_param(lambda x:np.put(self.label_buf,self._highest_parent_._raveled_index_for(x),x.tie.flat), (self._highest_parent_,), [])
|
||||
self._traverse_param(lambda x:np.put(self.label_buf,xrange(self._highest_parent_._offset_for(x),x.size),x.tie.flat), (self._highest_parent_,), [])
|
||||
self.buf_idx = self._highest_parent_._raveled_index_for(self.tied_param)
|
||||
self._untie_ = self.label_buf==0
|
||||
self._untie_[self.buf_idx] = True
|
||||
|
|
@ -577,26 +577,60 @@ class Tie(Parameterized):
|
|||
self.update_model(True)
|
||||
|
||||
def _check_change(self):
|
||||
changed = False
|
||||
changed = [0]
|
||||
if self.tied_param is not None:
|
||||
for i in xrange(self.tied_param.size):
|
||||
b0 = self.label_buf==self.label_buf[self.buf_idx[i]]
|
||||
b = self._highest_parent_.param_array[b0]!=self.tied_param[i]
|
||||
if b.sum()==0:
|
||||
# All the tied parameters are the same
|
||||
continue
|
||||
elif b.sum()==1:
|
||||
# One of the tied parameter is different.
|
||||
# It must be recently changed one.
|
||||
# The rest will be set to its value.
|
||||
val = self._highest_parent_.param_array[b0][b][0]
|
||||
self._highest_parent_.param_array[b0] = val
|
||||
else:
|
||||
# It is most likely that the tie parameter is changed.
|
||||
# Set all the tied parameter to the value of tie parameter.
|
||||
self._highest_parent_.param_array[b0] = self.tied_param[i]
|
||||
changed = True
|
||||
return changed
|
||||
|
||||
from scipy import weave
|
||||
from ...util.misc import param_to_array
|
||||
param_array = self._highest_parent_.param_array
|
||||
tied_param = param_to_array(self.tied_param)
|
||||
tie_l = np.zeros_like(tied_param,dtype=np.int)
|
||||
t_pairs = self.tie_pairs
|
||||
t_pairs_size = self.tie_pairs.size
|
||||
code="""
|
||||
for(int i=0;i<t_pairs_size;i+=2) {
|
||||
int pidx = t_pairs[i];
|
||||
int tidx = t_pairs[i+1];
|
||||
if(param_array[pidx] != tied_param[tidx]) {
|
||||
if(tie_l[tidx]==0) {
|
||||
tie_l[tidx] = pidx;
|
||||
} else if(tie_l[tidx]>0) {
|
||||
tie_l[tidx] = -1;
|
||||
}
|
||||
changed[0] = 1;
|
||||
}
|
||||
}
|
||||
for(int i=0;i<t_pairs_size;i+=2) {
|
||||
int pidx = t_pairs[i];
|
||||
int tidx = t_pairs[i+1];
|
||||
if(tie_l[tidx]>0) {
|
||||
param_array[pidx] = param_array[tie_l[tidx]];
|
||||
tied_param[tidx] = param_array[pidx];
|
||||
} else if(tie_l[tidx]==-1) {
|
||||
param_array[pidx] = tied_param[tidx];
|
||||
}
|
||||
}
|
||||
"""
|
||||
weave.inline(code, arg_names=['param_array','t_pairs','t_pairs_size','tied_param','tie_l','changed'])
|
||||
|
||||
# for i in xrange(self.tied_param.size):
|
||||
# b0 = self.label_buf==self.label_buf[self.buf_idx[i]]
|
||||
# b = self._highest_parent_.param_array[b0]!=self.tied_param[i]
|
||||
# if b.sum()==0:
|
||||
# # All the tied parameters are the same
|
||||
# continue
|
||||
# elif b.sum()==1:
|
||||
# # One of the tied parameter is different.
|
||||
# # It must be recently changed one.
|
||||
# # The rest will be set to its value.
|
||||
# val = self._highest_parent_.param_array[b0][b][0]
|
||||
# self._highest_parent_.param_array[b0] = val
|
||||
# else:
|
||||
# # It is most likely that the tie parameter is changed.
|
||||
# # Set all the tied parameter to the value of tie parameter.
|
||||
# self._highest_parent_.param_array[b0] = self.tied_param[i]
|
||||
# changed = True
|
||||
return False if changed[0]==0 else True
|
||||
|
||||
def _parameters_changed_notification(self, me, which=None):
|
||||
if which is not self:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue