mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-30 14:35:15 +02:00
restructure tie implementation and add some tests
This commit is contained in:
parent
6c226a129d
commit
077232c26f
3 changed files with 187 additions and 14 deletions
|
|
@ -49,7 +49,7 @@ class Param(Parameterizable, ObsAr):
|
||||||
obj._realshape_ = obj.shape
|
obj._realshape_ = obj.shape
|
||||||
obj._realsize_ = obj.size
|
obj._realsize_ = obj.size
|
||||||
obj._realndim_ = obj.ndim
|
obj._realndim_ = obj.ndim
|
||||||
obj._original_ = True
|
obj._original_ = obj
|
||||||
return obj
|
return obj
|
||||||
|
|
||||||
def __init__(self, name, input_array, default_constraint=None, *a, **kw):
|
def __init__(self, name, input_array, default_constraint=None, *a, **kw):
|
||||||
|
|
@ -139,7 +139,7 @@ class Param(Parameterizable, ObsAr):
|
||||||
new_arr._current_slice_ = s
|
new_arr._current_slice_ = s
|
||||||
new_arr._gradient_array_ = self.gradient[s]
|
new_arr._gradient_array_ = self.gradient[s]
|
||||||
new_arr._tie_ = self.tie[s]
|
new_arr._tie_ = self.tie[s]
|
||||||
new_arr._original_ = self.base is new_arr.base
|
new_arr._original_ = self #self.base is new_arr.base
|
||||||
except AttributeError: pass # returning 0d array or float, double etc
|
except AttributeError: pass # returning 0d array or float, double etc
|
||||||
return new_arr
|
return new_arr
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -70,6 +70,76 @@ class Tie(Parameterized):
|
||||||
self.buf_idx = None
|
self.buf_idx = None
|
||||||
self._untie_ = None
|
self._untie_ = None
|
||||||
|
|
||||||
|
def getTiedParamList(self):
|
||||||
|
if self.tied_param is None:
|
||||||
|
return []
|
||||||
|
labels = np.unique(self.label_buf)
|
||||||
|
labels = labels[labels>0]
|
||||||
|
return [np.where(self.label_buf==l)[0] for l in labels]
|
||||||
|
|
||||||
|
def _sync_val(self, plist, toTiedParam=True):
|
||||||
|
"""
|
||||||
|
Ensure the consistency of the values of tied parameters.
|
||||||
|
if toTieParam is true, the values of tied parameters will be synchronized among themselves and to the *TiedParam*,
|
||||||
|
otherwise all the tied parameters will be synchronized according to the values in *TiedParam*.
|
||||||
|
"""
|
||||||
|
assert self.tied_param is not None
|
||||||
|
read = np.zeros((self.tied_param.size,),dtype=np.uint8)
|
||||||
|
def _sync_val_p(p, tieparam, read):
|
||||||
|
if p.tie is not None:
|
||||||
|
labels = np.unique(p.tie)
|
||||||
|
labels = labels[labels>0]
|
||||||
|
for l in labels:
|
||||||
|
if toTiedParam and read[tieparam.tie==l] == 0:
|
||||||
|
val = p[p.tie==l][0]
|
||||||
|
tieparam[tieparam.tie==l] = val
|
||||||
|
read[tieparam.tie==l] = 1
|
||||||
|
else:
|
||||||
|
val = tieparam[tieparam.tie==l]
|
||||||
|
p[p.tie==l] = val
|
||||||
|
|
||||||
|
for p in plist:
|
||||||
|
self._traverse_param(_sync_val_p, (p,self.tied_param,read), [])
|
||||||
|
|
||||||
|
def _sync_constraints(self, plist, toTiedParam=True):
|
||||||
|
"""
|
||||||
|
Ensure the consistency of the constraints of tied parameters.
|
||||||
|
if toTieParam is true, the constraints of tied parameters will be synchronized among themselves and to the *TiedParam*,
|
||||||
|
otherwise all the tied parameters will be synchronized according to the constraints in *TiedParam*.
|
||||||
|
"""
|
||||||
|
assert self.tied_param is not None
|
||||||
|
read = np.zeros((self.tied_param.size,),dtype=np.uint8)
|
||||||
|
cons = [c[0] if len(c)>0 else None for c in self.tied_param.constraints.properties_for(range(self.tied_param.size))]
|
||||||
|
def _sync_constraints_p(p, tieparam, read, cons):
|
||||||
|
if p.tie is not None:
|
||||||
|
labels = np.unique(p.tie)
|
||||||
|
labels = labels[labels>0]
|
||||||
|
for l in labels:
|
||||||
|
idx = np.where(tieparam.tie==l)[0][0]
|
||||||
|
conslist = p.constraints.properties_to_index_dict(np.where(p.tie.flat==l)[0])
|
||||||
|
if toTiedParam and read[tieparam.tie==l] == 0:
|
||||||
|
if len(conslist)==0:
|
||||||
|
if cons[idx] is not None:
|
||||||
|
tieparam[idx:idx+1].unconstrain()
|
||||||
|
else:
|
||||||
|
c = conslist.keys()[0]
|
||||||
|
if len(conslist)>1 or len(conslist[c])!= (p.tie==l).sum():
|
||||||
|
p[p.tie==l].constrain(c)
|
||||||
|
if c != cons[idx]:
|
||||||
|
tieparam[idx:idx+1].constrain(c)
|
||||||
|
read[tieparam.tie==l] = 1
|
||||||
|
else:
|
||||||
|
if cons[idx] is None:
|
||||||
|
if len(conslist)>0:
|
||||||
|
p[p.tie==l].unconstrain()
|
||||||
|
else:
|
||||||
|
if len(conslist)!=1 or conslist.keys()[0]!=cons[idx] or len(conslist[cons[idx]])!= (p.tie==l).sum():
|
||||||
|
print cons[idx]
|
||||||
|
p[p.tie==l].constrain(cons[idx])
|
||||||
|
for p in plist:
|
||||||
|
self._traverse_param(_sync_constraints_p, (p,self.tied_param,read, cons), [])
|
||||||
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def recoverTies(p):
|
def recoverTies(p):
|
||||||
"""Recover the Tie object from the param objects"""
|
"""Recover the Tie object from the param objects"""
|
||||||
|
|
@ -317,18 +387,23 @@ class Tie(Parameterized):
|
||||||
self._untie_[self.buf_idx] = True
|
self._untie_[self.buf_idx] = True
|
||||||
assert(np.all(self.tied_param.tie>0))
|
assert(np.all(self.tied_param.tie>0))
|
||||||
|
|
||||||
|
def _keepParamList(self,plist):
|
||||||
|
return [(p._original_, p._current_slice_) for p in plist]
|
||||||
|
|
||||||
|
def _updateParamList(self, p_split):
|
||||||
|
return [p_org[p_slice] for p_org,p_slice in p_split]
|
||||||
|
|
||||||
def tie_together(self,plist):
|
def tie_together(self,plist):
|
||||||
"""tie a list of parameters"""
|
"""tie a list of parameters"""
|
||||||
self.update_model(False)
|
self.update_model(False)
|
||||||
labels = self._get_labels(plist)
|
labels = self._get_labels(plist)
|
||||||
val = self._sync_val_group(plist)
|
|
||||||
if labels[0]==0 and labels.size==1:
|
if labels[0]==0 and labels.size==1:
|
||||||
# None of parameters in plist has been tied before.
|
# None of parameters in plist has been tied before.
|
||||||
|
p_split = self._keepParamList(plist)
|
||||||
tie_labels,_ = self._expand_tie_param(1)
|
tie_labels,_ = self._expand_tie_param(1)
|
||||||
|
plist = self._updateParamList(p_split)
|
||||||
self._set_labels(plist, tie_labels)
|
self._set_labels(plist, tie_labels)
|
||||||
tie_con = self._sync_constraint_group(plist)
|
toTiedParam = True
|
||||||
if tie_con is not None:
|
|
||||||
self.tied_param[self.tied_param.tie==tie_labels[0]].constrain(tie_con)
|
|
||||||
else:
|
else:
|
||||||
# Some of parameters has been tied already.
|
# Some of parameters has been tied already.
|
||||||
# Merge the tie param
|
# Merge the tie param
|
||||||
|
|
@ -336,13 +411,38 @@ class Tie(Parameterized):
|
||||||
if tie_labels.size>1:
|
if tie_labels.size>1:
|
||||||
self._merge_tie_labels(tie_labels)
|
self._merge_tie_labels(tie_labels)
|
||||||
self._set_labels(plist, [tie_labels[0]])
|
self._set_labels(plist, [tie_labels[0]])
|
||||||
tie_p = self.tied_param[self.tied_param.tie==tie_labels[0]]
|
toTiedParam = False
|
||||||
tie_con = tie_p.constraints.properties()[0] if tie_p.constraints.size>0 else None
|
self._sync_val(plist,toTiedParam)
|
||||||
self._sync_constraint_group(plist, True, tie_con)
|
self._sync_constraints(plist, toTiedParam)
|
||||||
self._update_label_buf()
|
self._update_label_buf()
|
||||||
self.tied_param[self.tied_param.tie==tie_labels[0]] = val
|
|
||||||
self.update_model(True)
|
self.update_model(True)
|
||||||
|
|
||||||
|
# def tie_together(self,plist):
|
||||||
|
# """tie a list of parameters"""
|
||||||
|
# self.update_model(False)
|
||||||
|
# labels = self._get_labels(plist)
|
||||||
|
# val = self._sync_val_group(plist)
|
||||||
|
# if labels[0]==0 and labels.size==1:
|
||||||
|
# # None of parameters in plist has been tied before.
|
||||||
|
# tie_labels,_ = self._expand_tie_param(1)
|
||||||
|
# self._set_labels(plist, tie_labels)
|
||||||
|
# tie_con = self._sync_constraint_group(plist)
|
||||||
|
# if tie_con is not None:
|
||||||
|
# self.tied_param[self.tied_param.tie==tie_labels[0]].constrain(tie_con)
|
||||||
|
# else:
|
||||||
|
# # Some of parameters has been tied already.
|
||||||
|
# # Merge the tie param
|
||||||
|
# tie_labels = labels[labels>0]
|
||||||
|
# if tie_labels.size>1:
|
||||||
|
# self._merge_tie_labels(tie_labels)
|
||||||
|
# self._set_labels(plist, [tie_labels[0]])
|
||||||
|
# tie_p = self.tied_param[self.tied_param.tie==tie_labels[0]]
|
||||||
|
# tie_con = tie_p.constraints.properties()[0] if tie_p.constraints.size>0 else None
|
||||||
|
# self._sync_constraint_group(plist, True, tie_con)
|
||||||
|
# self._update_label_buf()
|
||||||
|
# self.tied_param[self.tied_param.tie==tie_labels[0]] = val
|
||||||
|
# self.update_model(True)
|
||||||
|
|
||||||
def tie_vector(self, p1, p2):
|
def tie_vector(self, p1, p2):
|
||||||
"""tie a pair of vectors"""
|
"""tie a pair of vectors"""
|
||||||
self.update_model(False)
|
self.update_model(False)
|
||||||
|
|
@ -416,5 +516,37 @@ class Tie(Parameterized):
|
||||||
self._highest_parent_.param_array[self.label_buf==self.label_buf[self.buf_idx[i]]] = self.tied_param[i]
|
self._highest_parent_.param_array[self.label_buf==self.label_buf[self.buf_idx[i]]] = self.tied_param[i]
|
||||||
self._PROPAGATE_VAL_ = True
|
self._PROPAGATE_VAL_ = True
|
||||||
|
|
||||||
|
def checkValueConsistence(self):
|
||||||
|
return not self._check_change()
|
||||||
|
|
||||||
|
def checkConstraintConsistence(self):
|
||||||
|
if self.tied_param is not None:
|
||||||
|
tlist = self.getTiedParamList()
|
||||||
|
for l in tlist:
|
||||||
|
for _,ind in self._highest_parent_.constraints.iteritems():
|
||||||
|
f = np.in1d(l,ind)
|
||||||
|
if not np.all(f) and np.any(f):
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
def checkTieTogether(self, plist):
|
||||||
|
idx = []
|
||||||
|
for p in plist:
|
||||||
|
idx.extend(self._highest_parent_._raveled_index_for(p))
|
||||||
|
labels = np.unique(self.label_buf[idx])
|
||||||
|
if len(labels)==1 and labels[0]>0:
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def checkTieVector(self, plist):
|
||||||
|
p1 = plist[0]
|
||||||
|
idx1 = self._highest_parent_._raveled_index_for(p1)
|
||||||
|
if np.any(self.label_buf[idx1]==0):
|
||||||
|
return False
|
||||||
|
for p2 in plist[1:]:
|
||||||
|
idx2 = self._highest_parent_._raveled_index_for(p2)
|
||||||
|
if np.any(self.label_buf[idx2]==0) or np.any(self.label_buf[idx1]!=self.label_buf[idx2]):
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
|
||||||
41
GPy/testing/tie_tests.py
Normal file
41
GPy/testing/tie_tests.py
Normal file
|
|
@ -0,0 +1,41 @@
|
||||||
|
# Copyright (c) 2012, GPy authors (see AUTHORS.txt).
|
||||||
|
# Licensed under the BSD 3-clause license (see LICENSE.txt)
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
import numpy as np
|
||||||
|
import GPy
|
||||||
|
|
||||||
|
class TieTests(unittest.TestCase):
|
||||||
|
def test_tie_together(self):
|
||||||
|
m = GPy.examples.regression.sparse_GP_regression_1D(optimize=False, plot=False, checkgrad=False)
|
||||||
|
m.Z.constrain_positive(warning=False)
|
||||||
|
m.Z.tie_together()
|
||||||
|
self.assertTrue(m.ties.checkValueConsistence())
|
||||||
|
self.assertTrue(m.ties.checkValueConsistence())
|
||||||
|
self.assertTrue(m.ties.checkTieTogether([m.Z]))
|
||||||
|
self.assertTrue(m.checkgrad())
|
||||||
|
|
||||||
|
def test_tie_together_two(self):
|
||||||
|
m = GPy.examples.regression.sparse_GP_regression_1D(optimize=False, plot=False, checkgrad=False)
|
||||||
|
m.Z.constrain_positive(warning=False)
|
||||||
|
m.Z[:2].tie_together()
|
||||||
|
m.Z[2:4].tie_together()
|
||||||
|
self.assertTrue(m.ties.checkTieTogether([m.Z[:2]]))
|
||||||
|
self.assertTrue(m.ties.checkTieTogether([m.Z[2:4]]))
|
||||||
|
self.assertTrue(m.ties.checkValueConsistence())
|
||||||
|
self.assertTrue(m.ties.checkValueConsistence())
|
||||||
|
self.assertTrue(m.checkgrad())
|
||||||
|
|
||||||
|
def test_tie_together_merge(self):
|
||||||
|
m = GPy.examples.regression.sparse_GP_regression_1D(optimize=False, plot=False, checkgrad=False)
|
||||||
|
m.Z.constrain_positive(warning=False)
|
||||||
|
m.Z[:2].tie_together()
|
||||||
|
m.Z[1:3].tie_together()
|
||||||
|
self.assertTrue(m.ties.checkTieTogether([m.Z[:3]]))
|
||||||
|
self.assertTrue(m.ties.checkValueConsistence())
|
||||||
|
self.assertTrue(m.ties.checkValueConsistence())
|
||||||
|
self.assertTrue(m.checkgrad())
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
print "Running unit tests, please be (very) patient..."
|
||||||
|
unittest.main()
|
||||||
Loading…
Add table
Add a link
Reference in a new issue