mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-30 14:35:15 +02:00
implement splitTies
This commit is contained in:
parent
a506f4a605
commit
11bb1e4c4c
2 changed files with 44 additions and 0 deletions
|
|
@ -217,6 +217,13 @@ class Parameterized(Parameterizable):
|
||||||
self._highest_parent_._connect_fixes()
|
self._highest_parent_._connect_fixes()
|
||||||
self._highest_parent_._notify_parent_change()
|
self._highest_parent_._notify_parent_change()
|
||||||
|
|
||||||
|
if isinstance(param,Parameterized):
|
||||||
|
from ties_and_remappings import Tie
|
||||||
|
if not isinstance(param,Tie):
|
||||||
|
self._highest_parent_.ties.splitTies(param)
|
||||||
|
else:
|
||||||
|
self._highest_parent_.ties._update_label_buf()
|
||||||
|
|
||||||
def _connect_parameters(self, ignore_added_names=False):
|
def _connect_parameters(self, ignore_added_names=False):
|
||||||
# connect parameterlist to this parameterized object
|
# connect parameterlist to this parameterized object
|
||||||
# This just sets up the right connection for the params objects
|
# This just sets up the right connection for the params objects
|
||||||
|
|
|
||||||
|
|
@ -85,6 +85,24 @@ class Tie(Parameterized):
|
||||||
self._update_label_buf()
|
self._update_label_buf()
|
||||||
self.updates = True
|
self.updates = True
|
||||||
|
|
||||||
|
def splitTies(self, p):
|
||||||
|
"""Split the tie subtree with the tie tree"""
|
||||||
|
p.ties = Tie()
|
||||||
|
p.add_parameter(p.ties, -1)
|
||||||
|
p.add_observer(p.ties, p.ties._parameters_changed_notification, priority=-500)
|
||||||
|
if self.tied_param is not None:
|
||||||
|
self.updates = False
|
||||||
|
labels = self._get_labels([p])
|
||||||
|
labels = labels[labels>0]
|
||||||
|
if len(labels)>0:
|
||||||
|
p._expand_tie_param(len(labels))
|
||||||
|
idx = np.in1d(self.tied_param.tie,labels)
|
||||||
|
p.tied_param[:] = self.tied_param[idx]
|
||||||
|
p.tied_param.tie[:] = self.tied_param.tie[idx]
|
||||||
|
self._remove_unnecessary_ties()
|
||||||
|
self._update_label_buf()
|
||||||
|
self.updates = True
|
||||||
|
|
||||||
def _sync_val_group(self, plist):
|
def _sync_val_group(self, plist):
|
||||||
val = np.asarray([p.param_array for p in plist]).mean()
|
val = np.asarray([p.param_array for p in plist]).mean()
|
||||||
def _set_val(p):
|
def _set_val(p):
|
||||||
|
|
@ -170,6 +188,14 @@ class Tie(Parameterized):
|
||||||
self._remove_tie_param(labels[1:])
|
self._remove_tie_param(labels[1:])
|
||||||
self._replace_labels(self._highest_parent_, [(l,labels[0]) for l in labels[1:]])
|
self._replace_labels(self._highest_parent_, [(l,labels[0]) for l in labels[1:]])
|
||||||
|
|
||||||
|
def _remove_unnecessary_ties(self):
|
||||||
|
"""Remove the unnecessary ties"""
|
||||||
|
if self.tied_param is not None:
|
||||||
|
labels = [l for l in self.tied_param.tie if (self.label_buf==l).sum()<=2]
|
||||||
|
if len(labels)>0:
|
||||||
|
self._remove_tie_param(labels)
|
||||||
|
self._replace_labels(self._highest_parent_, zip(labels,[0]*len(labels)))
|
||||||
|
|
||||||
def _update_label_buf(self):
|
def _update_label_buf(self):
|
||||||
if self.tied_param is None:
|
if self.tied_param is None:
|
||||||
self.label_buf = None
|
self.label_buf = None
|
||||||
|
|
@ -181,6 +207,7 @@ class Tie(Parameterized):
|
||||||
self.buf_idx = self._highest_parent_._raveled_index_for(self.tied_param)
|
self.buf_idx = self._highest_parent_._raveled_index_for(self.tied_param)
|
||||||
self._untie_ = self.label_buf==0
|
self._untie_ = self.label_buf==0
|
||||||
self._untie_[self.buf_idx] = True
|
self._untie_[self.buf_idx] = True
|
||||||
|
assert(np.all(self.tied_param.tie>0))
|
||||||
|
|
||||||
def tie_together(self,plist):
|
def tie_together(self,plist):
|
||||||
"""tie a list of parameters"""
|
"""tie a list of parameters"""
|
||||||
|
|
@ -202,6 +229,16 @@ class Tie(Parameterized):
|
||||||
self.tied_param[self.tied_param.tie==tie_labels[0]] = val
|
self.tied_param[self.tied_param.tie==tie_labels[0]] = val
|
||||||
self.updates = True
|
self.updates = True
|
||||||
|
|
||||||
|
def untie(self,plist):
|
||||||
|
"""Untie a list of parameters"""
|
||||||
|
self.updates = False
|
||||||
|
self._set_labels(plist,[0])
|
||||||
|
self._update_label_buf()
|
||||||
|
self._remove_unnecessary_ties()
|
||||||
|
self._update_label_buf()
|
||||||
|
self.updates = True
|
||||||
|
|
||||||
|
|
||||||
def _check_change(self):
|
def _check_change(self):
|
||||||
changed = False
|
changed = False
|
||||||
if self.tied_param is not None:
|
if self.tied_param is not None:
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue