mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-30 14:35:15 +02:00
bug fix: tie framework
This commit is contained in:
parent
0cef82db29
commit
7d5b4f2769
1 changed files with 24 additions and 23 deletions
|
|
@ -477,30 +477,31 @@ class Tie(Parameterized):
|
|||
self.buf_idx = self._highest_parent_._raveled_index_for(self.tied_param)
|
||||
self._untie_ = self.label_buf==0
|
||||
self._untie_[self.buf_idx] = True
|
||||
self.tie_pairs = np.empty(((self.label_buf>0).sum()-self.tied_param.size,2),dtype=np.uint32)
|
||||
try:
|
||||
from scipy import weave
|
||||
self._label_to_idx = np.zeros((self.tied_param.tie.max()+1,),dtype=np.int32)
|
||||
self._label_to_idx[self.tied_param.tie] = range(self.tied_param.size)
|
||||
ltoi = self._label_to_idx
|
||||
t_start=int(self.buf_idx[0])
|
||||
t_end=int(self.buf_idx[-1])
|
||||
label_buf = self.label_buf
|
||||
buf_size = self.label_buf.size
|
||||
t_pairs = self.tie_pairs
|
||||
code = """
|
||||
int j=0;
|
||||
for(int i=0;i<buf_size;i++) {
|
||||
if(label_buf[i]>0 && !(i>=t_start && i<=t_end)) {
|
||||
t_pairs[j*2] = i;
|
||||
t_pairs[j*2+1] = ltoi[label_buf[i]];
|
||||
j++;
|
||||
self.tie_pairs = np.empty(((np.logical_not(self._untie_)).sum(),2),dtype=np.uint32)
|
||||
if self.tie_pairs.size>0:
|
||||
try:
|
||||
from scipy import weave
|
||||
self._label_to_idx = np.zeros((self.tied_param.tie.max()+1,),dtype=np.int32)
|
||||
self._label_to_idx[self.tied_param.tie] = range(self.tied_param.size)
|
||||
ltoi = self._label_to_idx
|
||||
t_start=int(self.buf_idx[0])
|
||||
t_end=int(self.buf_idx[-1])
|
||||
label_buf = self.label_buf
|
||||
buf_size = self.label_buf.size
|
||||
t_pairs = self.tie_pairs
|
||||
code = """
|
||||
int j=0;
|
||||
for(int i=0;i<buf_size;i++) {
|
||||
if(label_buf[i]>0 && !(i>=t_start && i<=t_end)) {
|
||||
t_pairs[j*2] = i;
|
||||
t_pairs[j*2+1] = ltoi[label_buf[i]];
|
||||
j++;
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
weave.inline(code, arg_names=['ltoi','t_start','t_end','label_buf','buf_size','t_pairs'])
|
||||
except:
|
||||
pass
|
||||
"""
|
||||
weave.inline(code, arg_names=['ltoi','t_start','t_end','label_buf','buf_size','t_pairs'])
|
||||
except:
|
||||
pass
|
||||
assert(np.all(self.tied_param.tie>0))
|
||||
|
||||
def _keepParamList(self,plist):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue