mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-30 14:35:15 +02:00
improved implementation of tie
This commit is contained in:
parent
09589fb50f
commit
a9a86ac323
1 changed files with 20 additions and 6 deletions
|
|
@ -57,6 +57,9 @@ class Tie(Parameterized):
|
|||
|
||||
"""
|
||||
def __init__(self, name='tie'):
|
||||
# whether it has just propagated tied parameter values during optimization
|
||||
# If ture, it does not need to check consistency
|
||||
self._PROPAGATE_VAL_ = False
|
||||
super(Tie, self).__init__(name)
|
||||
self.tied_param = None
|
||||
# The buffer keeps track of tie status
|
||||
|
|
@ -189,23 +192,33 @@ class Tie(Parameterized):
|
|||
b0 = self.label_buf==self.label_buf[self.buf_idx[i]]
|
||||
b = self._highest_parent_.param_array[b0]!=self.tied_param[i]
|
||||
if b.sum()==0:
|
||||
print 'XXX'
|
||||
# All the tied parameters are the same
|
||||
continue
|
||||
elif b.sum()==1:
|
||||
print '!!!'
|
||||
# One of the tied parameter is different.
|
||||
# It must be recently changed one.
|
||||
# The rest will be set to its value.
|
||||
val = self._highest_parent_.param_array[b0][b][0]
|
||||
self._highest_parent_.param_array[b0] = val
|
||||
else:
|
||||
print '@@@'
|
||||
# It is most likely that the tie parameter is changed.
|
||||
# Set all the tied parameter to the value of tie parameter.
|
||||
self._highest_parent_.param_array[b0] = self.tied_param[i]
|
||||
changed = True
|
||||
return changed
|
||||
|
||||
def _parameters_changed_notification(self, me, which=None):
|
||||
if which is not self:
|
||||
self._optimizer_copy_transformed = False # tells the optimizer array to update on next request
|
||||
self.parameters_changed()
|
||||
|
||||
def parameters_changed(self):
|
||||
#ensure all out parameters have the correct value, as specified by our mapping
|
||||
changed = self._check_change()
|
||||
if changed:
|
||||
self._highest_parent_._trigger_params_changed()
|
||||
if self._PROPAGATE_VAL_:
|
||||
self._PROPAGATE_VAL_ = False
|
||||
else:
|
||||
if self._check_change():
|
||||
self._highest_parent_._trigger_params_changed()
|
||||
self.collate_gradient()
|
||||
|
||||
def collate_gradient(self):
|
||||
|
|
@ -218,6 +231,7 @@ class Tie(Parameterized):
|
|||
if self.tied_param is not None:
|
||||
for i in xrange(self.tied_param.size):
|
||||
self._highest_parent_.param_array[self.label_buf==self.label_buf[self.buf_idx[i]]] = self.tied_param[i]
|
||||
self._PROPAGATE_VAL_ = True
|
||||
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue