improved implementation of tie

This commit is contained in:
Zhenwen Dai 2014-09-01 16:13:11 +01:00
parent 09589fb50f
commit a9a86ac323

View file

@ -57,6 +57,9 @@ class Tie(Parameterized):
"""
def __init__(self, name='tie'):
# whether it has just propagated tied parameter values during optimization
# If ture, it does not need to check consistency
self._PROPAGATE_VAL_ = False
super(Tie, self).__init__(name)
self.tied_param = None
# The buffer keeps track of tie status
@ -189,23 +192,33 @@ class Tie(Parameterized):
b0 = self.label_buf==self.label_buf[self.buf_idx[i]]
b = self._highest_parent_.param_array[b0]!=self.tied_param[i]
if b.sum()==0:
print 'XXX'
# All the tied parameters are the same
continue
elif b.sum()==1:
print '!!!'
# One of the tied parameter is different.
# It must be recently changed one.
# The rest will be set to its value.
val = self._highest_parent_.param_array[b0][b][0]
self._highest_parent_.param_array[b0] = val
else:
print '@@@'
# It is most likely that the tie parameter is changed.
# Set all the tied parameter to the value of tie parameter.
self._highest_parent_.param_array[b0] = self.tied_param[i]
changed = True
return changed
def _parameters_changed_notification(self, me, which=None):
if which is not self:
self._optimizer_copy_transformed = False # tells the optimizer array to update on next request
self.parameters_changed()
def parameters_changed(self):
#ensure all out parameters have the correct value, as specified by our mapping
changed = self._check_change()
if changed:
self._highest_parent_._trigger_params_changed()
if self._PROPAGATE_VAL_:
self._PROPAGATE_VAL_ = False
else:
if self._check_change():
self._highest_parent_._trigger_params_changed()
self.collate_gradient()
def collate_gradient(self):
@ -218,6 +231,7 @@ class Tie(Parameterized):
if self.tied_param is not None:
for i in xrange(self.tied_param.size):
self._highest_parent_.param_array[self.label_buf==self.label_buf[self.buf_idx[i]]] = self.tied_param[i]
self._PROPAGATE_VAL_ = True