resolve all the test failures!

This commit is contained in:
Zhenwen Dai 2014-09-19 15:46:11 +01:00
parent 3653892d19
commit 6c226a129d
5 changed files with 57 additions and 25 deletions

View file

@ -224,7 +224,7 @@ class Param(Parameterizable, ObsAr):
def _ties_str(self):
return ['']
def _ties_for(self, ravi):
return [['N/A' if self.tie[i]==0 else str(self.tie[i])] for i in xrange(ravi.size)]
return [['N/A' if self.tie.flat[i]==0 else str(self.tie[i])] for i in xrange(ravi.size)]
def __repr__(self, *args, **kwargs):
name = "\033[1m{x:s}\033[0;0m:\n".format(
x=self.hierarchy_name())

View file

@ -565,7 +565,8 @@ class Indexable(Nameable, Observable):
if self.has_parent():
fixes = np.ones(self.size).astype(bool)
fixes[self.constraints[__fixed__]] = FIXED
fixes[np.logical_not(self._highest_parent_.ties._untie_[self._highest_parent_._raveled_index_for(self)])] = FIXED
if self._has_ties():
fixes[np.logical_not(self._highest_parent_.ties._untie_[self._highest_parent_._raveled_index_for(self)])] = FIXED
return fixes
else:
hf = self._has_fixes()
@ -775,10 +776,7 @@ class OptimizationHandlable(Indexable):
self.param_array.flat[f] = p
[np.put(self.param_array, ind[f[ind]], c.f(self.param_array.flat[ind[f[ind]]]))
for c, ind in self.constraints.iteritems() if c != __fixed__]
<<<<<<< HEAD
self._highest_parent_.ties.propagate_val()
=======
>>>>>>> 48fb60489160de6fb0e84f6559b85b07dd16e274
self._optimizer_copy_transformed = False
self._trigger_params_changed()

View file

@ -80,18 +80,13 @@ class Parameterized(Parameterizable):
if not self._has_fixes():
self._fixes_ = None
self._param_slices_ = []
#self._connect_parameters()
<<<<<<< HEAD
self.add_parameters(*parameters)
self.link_parameters(*parameters)
from ties_and_remappings import Tie
if not isinstance(self,Tie):
self.ties = Tie()
self.add_parameter(self.ties, -1)
self.link_parameter(self.ties, -1)
self.add_observer(self.ties, self.ties._parameters_changed_notification, priority=-500)
=======
self.link_parameters(*parameters)
>>>>>>> 48fb60489160de6fb0e84f6559b85b07dd16e274
def build_pydot(self, G=None):
import pydot # @UnresolvedImport
@ -230,7 +225,7 @@ class Parameterized(Parameterizable):
def add_parameter(self, *args, **kwargs):
raise DeprecationWarning, "add_parameter was renamed to link_parameter to avoid confusion of setting variables"
def remove_parameter(self, *args, **kwargs):
raise DeprecationWarning, "remove_parameter was renamed to link_parameter to avoid confusion of setting variables"
raise DeprecationWarning, "remove_parameter was renamed to unlink_parameter to avoid confusion of setting variables"
def _connect_parameters(self, ignore_added_names=False):
# connect parameterlist to this parameterized object
@ -245,6 +240,10 @@ class Parameterized(Parameterizable):
self._param_array_ = np.empty(self.size, dtype=np.float64)
if self.gradient.size != self.size:
self._gradient_array_ = np.empty(self.size, dtype=np.float64)
if not self.has_parent() and not hasattr(self, 'ties'):
from ties_and_remappings import Tie
Tie.recoverTies(self)
old_size = 0
self._param_slices_ = []

View file

@ -69,26 +69,45 @@ class Tie(Parameterized):
self.label_buf = None
self.buf_idx = None
self._untie_ = None
@staticmethod
def recoverTies(p):
"""Recover the Tie object from the param objects"""
if not p.has_parent():
p.ties = Tie()
p.link_parameter(p.ties, -1)
p.add_observer(p.ties, p.ties._parameters_changed_notification, priority=-500)
p.update_model(False)
labels = p.ties._get_labels([p])
labels = labels[labels>0]
if len(labels)>0:
p._expand_tie_param(len(labels))
vals = p.ties._get_sync_val(p, labels)
p.tied_param[:] = vals
p.tied_param.tie[:] = labels
p._update_label_buf()
p.update_model(True)
def mergeTies(self, p):
"""Merge the tie tree with another tie tree"""
assert hasattr(p,'ties') and isinstance(p.ties,Tie), str(type(p))
self.update_model(False)
#self.update_model(False)
if p.ties.tied_param is not None:
tie_labels,_ = self._expand_tie_param(p.ties.tied_param.size)
self.tied_param[-p.ties.tied_param.size:] = p.ties.tied_param
pairs = zip(self.tied_param.tie,tie_labels)
self._replace_labels(p, pairs)
p.remove_observer(p.ties)
p.remove_parameter(p.ties)
p.unlink_parameter(p.ties)
del p.ties
self._update_label_buf()
self.update_model(True)
#self.update_model(True)
def splitTies(self, p):
"""Split the tie subtree with the tie tree"""
"""Split the tie subtree from the original tie tree"""
p.ties = Tie()
p.add_parameter(p.ties, -1)
p.link_parameter(p.ties, -1)
p.add_observer(p.ties, p.ties._parameters_changed_notification, priority=-500)
if self.tied_param is not None:
self.update_model(False)
@ -101,7 +120,23 @@ class Tie(Parameterized):
p.tied_param.tie[:] = self.tied_param.tie[idx]
self._remove_unnecessary_ties()
self._update_label_buf()
p._update_label_buf()
self.update_model(True)
def _get_sync_val(self, p, labels):
vals = np.empty((labels.size,))
read = np.zeros((labels.size,),dtype=np.uint8)
def _get_sync_v(p, labels, vals, read):
for i in xrange(labels.size):
if read[i]==1:
p[p.tie==labels[i]] = vals[i]
elif np.any(p.tie==labels[i]):
vals[i] = p[p.tie==labels[i]][0]
p[p.tie==labels[i]][0] = vals[i]
read[i] = 1
self._traverse_param(_get_sync_v, (p,labels,vals,read), [])
return vals
def _sync_val_group(self, plist):
val = np.hstack([p.param_array.flat for p in plist]).mean()
@ -227,27 +262,27 @@ class Tie(Parameterized):
old_size = self.tied_param.size
labellist = np.array(range(start_label,start_label+num),dtype=np.int)
idxlist = np.array(range(old_size,old_size+num),dtype=np.int)
self.remove_parameter(self.tied_param)
self.unlink_parameter(self.tied_param)
self.tied_param = Param('tied',new_buf)
self.tied_param.tie[:old_size] = old_tie_
self.tied_param.tie[old_size:] = labellist
self.add_parameter(self.tied_param)
self.link_parameter(self.tied_param)
return labellist, idxlist
def _remove_tie_param(self, labels):
"""Remove the tie param corresponding to *labels*"""
if len(labels) == self.tied_param.size:
self.remove_parameter(self.tied_param)
self.unlink_parameter(self.tied_param)
self.tied_param = None
else:
new_buf = np.empty((self.tied_param.size-len(labels),))
idx = np.logical_not(np.in1d(self.tied_param.tie,labels))
new_buf[:] = self.tied_param[idx]
old_tie_ = self.tied_param.tie.copy()
self.remove_parameter(self.tied_param)
self.unlink_parameter(self.tied_param)
self.tied_param = Param('tied',new_buf)
self.tied_param.tie[:] = old_tie_[idx]
self.add_parameter(self.tied_param)
self.link_parameter(self.tied_param)
def _merge_tie_labels(self, labels):
"""Merge all the labels in the list to the first one"""

View file

@ -225,7 +225,7 @@ class CombinationKernel(Kern):
@property
def parts(self):
return self.parameters
return [p for p in self.parameters if isinstance(p,Kern)]
def get_input_dim_active_dims(self, kernels, extra_dims = None):
#active_dims = reduce(np.union1d, (np.r_[x.active_dims] for x in kernels), np.array([], dtype=int))