restructure tie implementation and add some tests

This commit is contained in:
Zhenwen Dai 2014-09-19 18:49:00 +01:00
parent 6c226a129d
commit 077232c26f
3 changed files with 187 additions and 14 deletions

View file

@ -49,7 +49,7 @@ class Param(Parameterizable, ObsAr):
obj._realshape_ = obj.shape
obj._realsize_ = obj.size
obj._realndim_ = obj.ndim
obj._original_ = True
obj._original_ = obj
return obj
def __init__(self, name, input_array, default_constraint=None, *a, **kw):
@ -139,7 +139,7 @@ class Param(Parameterizable, ObsAr):
new_arr._current_slice_ = s
new_arr._gradient_array_ = self.gradient[s]
new_arr._tie_ = self.tie[s]
new_arr._original_ = self.base is new_arr.base
new_arr._original_ = self #self.base is new_arr.base
except AttributeError: pass # returning 0d array or float, double etc
return new_arr

View file

@ -70,6 +70,76 @@ class Tie(Parameterized):
self.buf_idx = None
self._untie_ = None
def getTiedParamList(self):
if self.tied_param is None:
return []
labels = np.unique(self.label_buf)
labels = labels[labels>0]
return [np.where(self.label_buf==l)[0] for l in labels]
def _sync_val(self, plist, toTiedParam=True):
"""
Ensure the consistency of the values of tied parameters.
if toTieParam is true, the values of tied parameters will be synchronized among themselves and to the *TiedParam*,
otherwise all the tied parameters will be synchronized according to the values in *TiedParam*.
"""
assert self.tied_param is not None
read = np.zeros((self.tied_param.size,),dtype=np.uint8)
def _sync_val_p(p, tieparam, read):
if p.tie is not None:
labels = np.unique(p.tie)
labels = labels[labels>0]
for l in labels:
if toTiedParam and read[tieparam.tie==l] == 0:
val = p[p.tie==l][0]
tieparam[tieparam.tie==l] = val
read[tieparam.tie==l] = 1
else:
val = tieparam[tieparam.tie==l]
p[p.tie==l] = val
for p in plist:
self._traverse_param(_sync_val_p, (p,self.tied_param,read), [])
def _sync_constraints(self, plist, toTiedParam=True):
"""
Ensure the consistency of the constraints of tied parameters.
if toTieParam is true, the constraints of tied parameters will be synchronized among themselves and to the *TiedParam*,
otherwise all the tied parameters will be synchronized according to the constraints in *TiedParam*.
"""
assert self.tied_param is not None
read = np.zeros((self.tied_param.size,),dtype=np.uint8)
cons = [c[0] if len(c)>0 else None for c in self.tied_param.constraints.properties_for(range(self.tied_param.size))]
def _sync_constraints_p(p, tieparam, read, cons):
if p.tie is not None:
labels = np.unique(p.tie)
labels = labels[labels>0]
for l in labels:
idx = np.where(tieparam.tie==l)[0][0]
conslist = p.constraints.properties_to_index_dict(np.where(p.tie.flat==l)[0])
if toTiedParam and read[tieparam.tie==l] == 0:
if len(conslist)==0:
if cons[idx] is not None:
tieparam[idx:idx+1].unconstrain()
else:
c = conslist.keys()[0]
if len(conslist)>1 or len(conslist[c])!= (p.tie==l).sum():
p[p.tie==l].constrain(c)
if c != cons[idx]:
tieparam[idx:idx+1].constrain(c)
read[tieparam.tie==l] = 1
else:
if cons[idx] is None:
if len(conslist)>0:
p[p.tie==l].unconstrain()
else:
if len(conslist)!=1 or conslist.keys()[0]!=cons[idx] or len(conslist[cons[idx]])!= (p.tie==l).sum():
print cons[idx]
p[p.tie==l].constrain(cons[idx])
for p in plist:
self._traverse_param(_sync_constraints_p, (p,self.tied_param,read, cons), [])
@staticmethod
def recoverTies(p):
"""Recover the Tie object from the param objects"""
@ -316,19 +386,24 @@ class Tie(Parameterized):
self._untie_ = self.label_buf==0
self._untie_[self.buf_idx] = True
assert(np.all(self.tied_param.tie>0))
def _keepParamList(self,plist):
return [(p._original_, p._current_slice_) for p in plist]
def _updateParamList(self, p_split):
return [p_org[p_slice] for p_org,p_slice in p_split]
def tie_together(self,plist):
"""tie a list of parameters"""
"""tie a list of parameters"""
self.update_model(False)
labels = self._get_labels(plist)
val = self._sync_val_group(plist)
if labels[0]==0 and labels.size==1:
# None of parameters in plist has been tied before.
p_split = self._keepParamList(plist)
tie_labels,_ = self._expand_tie_param(1)
plist = self._updateParamList(p_split)
self._set_labels(plist, tie_labels)
tie_con = self._sync_constraint_group(plist)
if tie_con is not None:
self.tied_param[self.tied_param.tie==tie_labels[0]].constrain(tie_con)
toTiedParam = True
else:
# Some of parameters has been tied already.
# Merge the tie param
@ -336,13 +411,38 @@ class Tie(Parameterized):
if tie_labels.size>1:
self._merge_tie_labels(tie_labels)
self._set_labels(plist, [tie_labels[0]])
tie_p = self.tied_param[self.tied_param.tie==tie_labels[0]]
tie_con = tie_p.constraints.properties()[0] if tie_p.constraints.size>0 else None
self._sync_constraint_group(plist, True, tie_con)
toTiedParam = False
self._sync_val(plist,toTiedParam)
self._sync_constraints(plist, toTiedParam)
self._update_label_buf()
self.tied_param[self.tied_param.tie==tie_labels[0]] = val
self.update_model(True)
# def tie_together(self,plist):
# """tie a list of parameters"""
# self.update_model(False)
# labels = self._get_labels(plist)
# val = self._sync_val_group(plist)
# if labels[0]==0 and labels.size==1:
# # None of parameters in plist has been tied before.
# tie_labels,_ = self._expand_tie_param(1)
# self._set_labels(plist, tie_labels)
# tie_con = self._sync_constraint_group(plist)
# if tie_con is not None:
# self.tied_param[self.tied_param.tie==tie_labels[0]].constrain(tie_con)
# else:
# # Some of parameters has been tied already.
# # Merge the tie param
# tie_labels = labels[labels>0]
# if tie_labels.size>1:
# self._merge_tie_labels(tie_labels)
# self._set_labels(plist, [tie_labels[0]])
# tie_p = self.tied_param[self.tied_param.tie==tie_labels[0]]
# tie_con = tie_p.constraints.properties()[0] if tie_p.constraints.size>0 else None
# self._sync_constraint_group(plist, True, tie_con)
# self._update_label_buf()
# self.tied_param[self.tied_param.tie==tie_labels[0]] = val
# self.update_model(True)
def tie_vector(self, p1, p2):
"""tie a pair of vectors"""
self.update_model(False)
@ -415,6 +515,38 @@ class Tie(Parameterized):
for i in xrange(self.tied_param.size):
self._highest_parent_.param_array[self.label_buf==self.label_buf[self.buf_idx[i]]] = self.tied_param[i]
self._PROPAGATE_VAL_ = True
def checkValueConsistence(self):
return not self._check_change()
def checkConstraintConsistence(self):
if self.tied_param is not None:
tlist = self.getTiedParamList()
for l in tlist:
for _,ind in self._highest_parent_.constraints.iteritems():
f = np.in1d(l,ind)
if not np.all(f) and np.any(f):
return False
return True
def checkTieTogether(self, plist):
idx = []
for p in plist:
idx.extend(self._highest_parent_._raveled_index_for(p))
labels = np.unique(self.label_buf[idx])
if len(labels)==1 and labels[0]>0:
return True
else:
return False
def checkTieVector(self, plist):
p1 = plist[0]
idx1 = self._highest_parent_._raveled_index_for(p1)
if np.any(self.label_buf[idx1]==0):
return False
for p2 in plist[1:]:
idx2 = self._highest_parent_._raveled_index_for(p2)
if np.any(self.label_buf[idx2]==0) or np.any(self.label_buf[idx1]!=self.label_buf[idx2]):
return False
return True

41
GPy/testing/tie_tests.py Normal file
View file

@ -0,0 +1,41 @@
# Copyright (c) 2012, GPy authors (see AUTHORS.txt).
# Licensed under the BSD 3-clause license (see LICENSE.txt)
import unittest
import numpy as np
import GPy
class TieTests(unittest.TestCase):
def test_tie_together(self):
m = GPy.examples.regression.sparse_GP_regression_1D(optimize=False, plot=False, checkgrad=False)
m.Z.constrain_positive(warning=False)
m.Z.tie_together()
self.assertTrue(m.ties.checkValueConsistence())
self.assertTrue(m.ties.checkValueConsistence())
self.assertTrue(m.ties.checkTieTogether([m.Z]))
self.assertTrue(m.checkgrad())
def test_tie_together_two(self):
m = GPy.examples.regression.sparse_GP_regression_1D(optimize=False, plot=False, checkgrad=False)
m.Z.constrain_positive(warning=False)
m.Z[:2].tie_together()
m.Z[2:4].tie_together()
self.assertTrue(m.ties.checkTieTogether([m.Z[:2]]))
self.assertTrue(m.ties.checkTieTogether([m.Z[2:4]]))
self.assertTrue(m.ties.checkValueConsistence())
self.assertTrue(m.ties.checkValueConsistence())
self.assertTrue(m.checkgrad())
def test_tie_together_merge(self):
m = GPy.examples.regression.sparse_GP_regression_1D(optimize=False, plot=False, checkgrad=False)
m.Z.constrain_positive(warning=False)
m.Z[:2].tie_together()
m.Z[1:3].tie_together()
self.assertTrue(m.ties.checkTieTogether([m.Z[:3]]))
self.assertTrue(m.ties.checkValueConsistence())
self.assertTrue(m.ties.checkValueConsistence())
self.assertTrue(m.checkgrad())
if __name__ == "__main__":
print "Running unit tests, please be (very) patient..."
unittest.main()