mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-06-05 14:55:15 +02:00
parameters now work efficiently, tieing is iwth observer pattern
This commit is contained in:
parent
6be05de791
commit
d5afb3b797
2 changed files with 49 additions and 41 deletions
|
|
@ -12,6 +12,7 @@ __constraints_name__ = "Constraint"
|
|||
__index_name__ = "Index"
|
||||
__tie_name__ = "Tied to"
|
||||
__precision__ = numpy.get_printoptions()['precision'] # numpy printing precision used, sublassing numpy ndarray after all
|
||||
__print_threshold__ = 5
|
||||
######
|
||||
|
||||
class Param(numpy.ndarray):
|
||||
|
|
@ -228,7 +229,7 @@ class Param(numpy.ndarray):
|
|||
|
||||
self._parent_._get_original(self)._tied_to_ += [param]
|
||||
param._add_tie_listener(self)
|
||||
self._parent_._set_fixed(param)
|
||||
self._parent_._set_fixed(self)
|
||||
# self._parent_._add_tie(self, param)
|
||||
|
||||
def untie(self, *ties):
|
||||
|
|
@ -244,18 +245,22 @@ class Param(numpy.ndarray):
|
|||
# self._parent_._remove_tie(self, *params)
|
||||
def _fire_changed(self):
|
||||
for tied, ind in self._tied_to_me_.iteritems():
|
||||
tied._on_change(self[list(ind)])
|
||||
tied._on_change(self.base, list(ind))
|
||||
def _add_tie_listener(self, tied_to_me):
|
||||
self._tied_to_me_[tied_to_me] |= set(self._raveled_index())
|
||||
def _remove_tie_listener(self, to_remove):
|
||||
for t in self._tied_to_me_.keys():
|
||||
if t._parent_index_ == self._parent_index_:
|
||||
self._tied_to_me_[t] &= set(t._raveled_index())
|
||||
def _on_change(self, val):
|
||||
if self._original_: # this happens when indexing created a copy of the array
|
||||
self[:] = val
|
||||
else:
|
||||
self._parent_._get_original(self)[self._current_slice_] = val
|
||||
if len(self._tied_to_me_[t]) == 0:
|
||||
del self._tied_to_me_[t]
|
||||
def _on_change(self, val, ind):
|
||||
if not numpy.all(self==val[ind]):
|
||||
if self._original_:
|
||||
self[:] = val[ind]
|
||||
else: # this happens when indexing created a copy of the array
|
||||
self._parent_._get_original(self)[self._current_slice_] = val[ind]
|
||||
self._fire_changed()
|
||||
#===========================================================================
|
||||
# Prior Operations
|
||||
#===========================================================================
|
||||
|
|
@ -281,7 +286,7 @@ class Param(numpy.ndarray):
|
|||
def __getitem__(self, s, *args, **kwargs):
|
||||
if not isinstance(s, tuple):
|
||||
s = (s,)
|
||||
if not reduce(lambda a,b: a or numpy.any(b is Ellipsis), s, False):
|
||||
if not reduce(lambda a,b: a or numpy.any(b is Ellipsis), s, False) and len(s) <= self.ndim:
|
||||
s += (Ellipsis,)
|
||||
new_arr = numpy.ndarray.__getitem__(self, s, *args, **kwargs)
|
||||
try: new_arr._current_slice_ = s; new_arr._original_ = self.base is new_arr.base
|
||||
|
|
@ -337,7 +342,7 @@ class Param(numpy.ndarray):
|
|||
a = b+a
|
||||
return numpy.r_[a]
|
||||
return numpy.r_[:b]
|
||||
return itertools.imap(f, itertools.izip_longest(slice_index[:self._realndim_], self._realshape_, fillvalue=slice(None)))
|
||||
return itertools.imap(f, itertools.izip_longest(slice_index[:self._realndim_], self._realshape_, fillvalue=slice(self.size)))
|
||||
#===========================================================================
|
||||
# Printing -> done
|
||||
#===========================================================================
|
||||
|
|
@ -356,12 +361,15 @@ class Param(numpy.ndarray):
|
|||
def __repr__(self, *args, **kwargs):
|
||||
return "\033[1m{x:s}\033[0;0m:\n".format(x=self.name)+super(Param, self).__repr__(*args,**kwargs)
|
||||
def _ties_for(self, rav_index):
|
||||
ties = [[]] * numpy.size(rav_index)
|
||||
for tied_to in self._tied_to_:
|
||||
ties = numpy.empty(shape=(len(self._tied_to_), numpy.size(rav_index)), dtype=Param)
|
||||
for i, tied_to in enumerate(self._tied_to_):
|
||||
for t in tied_to._tied_to_me_.iterkeys():
|
||||
if t._parent_index_ == self._parent_index_:
|
||||
[ties.__setitem__(i, ties[i] + [tied_to]) for i in t._raveled_index()]
|
||||
return ties
|
||||
matches = numpy.where(rav_index[:,None] == t._raveled_index()[None, :])
|
||||
tt_rav_index = tied_to._raveled_index()
|
||||
ties[i, matches[0]] = numpy.take(tt_rav_index, matches[1], mode='wrap')
|
||||
#[ties.__setitem__(i, ties[i] + [tied_to]) for i in t._raveled_index()]
|
||||
return map(lambda a: sum(a,[]), zip(*[[[tie.flatten()] if tx!=None else [] for tx in t] for t,tie in zip(ties,self._tied_to_)]))
|
||||
def _constraints_for(self, rav_index):
|
||||
return self._parent_._constraints_for(self, rav_index)
|
||||
def _indices(self, slice_index=None):
|
||||
|
|
@ -372,10 +380,11 @@ class Param(numpy.ndarray):
|
|||
clean_curr_slice = [s for s in slice_index if numpy.any(s != Ellipsis)]
|
||||
if (all(isinstance(n, (numpy.ndarray, list, tuple)) for n in clean_curr_slice)
|
||||
and len(set(map(len,clean_curr_slice))) <= 1):
|
||||
return numpy.fromiter(itertools.izip(*self._expand_index()),
|
||||
dtype=[('',int)]*self._realndim_,count=self.size).view((int, self._realndim_))
|
||||
return numpy.fromiter(itertools.product(*self._expand_index(slice_index)),
|
||||
dtype=[('',int)]*self._realndim_,count=self.size).view((int, self._realndim_))
|
||||
return numpy.fromiter(itertools.izip(*clean_curr_slice),
|
||||
dtype=[('',int)]*self._realndim_,count=len(clean_curr_slice[0])).view((int, self._realndim_))
|
||||
expanded_index = list(self._expand_index(slice_index))
|
||||
return numpy.fromiter(itertools.product(*expanded_index),
|
||||
dtype=[('',int)]*self._realndim_,count=reduce(lambda a,b: a*b.size,expanded_index,1)).view((int, self._realndim_))
|
||||
def _max_len_names(self, gen, header):
|
||||
return reduce(lambda a, b:max(a, len(b)), gen, len(header))
|
||||
def _max_len_values(self):
|
||||
|
|
@ -391,16 +400,20 @@ class Param(numpy.ndarray):
|
|||
else: indstr = ','.join(map(str,ind))
|
||||
return self.name+'['+indstr+']'
|
||||
def __str__(self, constr_matrix=None, indices=None, ties=None, lc=None, lx=None, li=None, lt=None):
|
||||
if indices is None: indices = self._indices()
|
||||
ravi = self._raveled_index()
|
||||
filter_ = self._current_slice_
|
||||
vals = self.flat
|
||||
if indices is None: indices = self._indices(filter_)
|
||||
ravi = self._raveled_index(filter_)
|
||||
if constr_matrix is None: constr_matrix = self._constraints_for(ravi)
|
||||
if ties is None: ties = self._ties_for(ravi)
|
||||
ties = [' '.join(map(lambda x: x._short(), t)) for t in ties]
|
||||
if lc is None: lc = self._max_len_names(constr_matrix, __constraints_name__)
|
||||
if lx is None: lx = self._max_len_values()
|
||||
if li is None: li = self._max_len_index(self._indices())
|
||||
if lt is None: lt = self._max_len_names([t._short() for ti in ties for t in ti], __tie_name__)
|
||||
if li is None: li = self._max_len_index(indices)
|
||||
if lt is None: lt = self._max_len_names(ties, __tie_name__)
|
||||
header = " {i:^{2}s} | \033[1m{x:^{1}s}\033[0;0m | {c:^{0}s} | {t:^{3}s}".format(lc,lx,li,lt, x=self.name, c=__constraints_name__, i=__index_name__, t=__tie_name__) # nice header for printing
|
||||
return "\n".join([header]+[" {i!s:^{3}s} | {x: >{1}.{2}G} | {c:^{0}s} | {t:^{4}} ".format(lc,lx,__precision__,li,lt, x=x, c=" ".join(map(str,c)), t=" ".join([tie._short() for tie in t]), i=i) for i,x,c,t in itertools.izip(indices,self.flat,constr_matrix,ties)]) # return all the constraints with right indices
|
||||
if not ties: ties = itertools.cycle([''])
|
||||
return "\n".join([header]+[" {i!s:^{3}s} | {x: >{1}.{2}G} | {c:^{0}s} | {t:^{4}s} ".format(lc,lx,__precision__,li,lt, x=x, c=" ".join(map(str,c)), t=(t or ''), i=i) for i,x,c,t in itertools.izip(indices,vals,constr_matrix,ties)]) # return all the constraints with right indices
|
||||
#except: return super(Param, self).__str__()
|
||||
|
||||
class ParamConcatenation(object):
|
||||
|
|
@ -489,7 +502,7 @@ class ParamConcatenation(object):
|
|||
if __name__ == '__main__':
|
||||
from GPy.core.parameterized import Parameterized
|
||||
#X = numpy.random.randn(2,3,1,5,2,4,3)
|
||||
X = numpy.random.randn(2,3)
|
||||
X = numpy.random.randn(100,20)
|
||||
print "random done"
|
||||
p = Param("q_mean", X, None)
|
||||
p1 = Param("q_variance", numpy.random.rand(*p.shape), None)
|
||||
|
|
@ -503,8 +516,10 @@ if __name__ == '__main__':
|
|||
print "constraining variance"
|
||||
m[".*variance"].constrain_positive()
|
||||
print "constraining rbf"
|
||||
m.rbf.constrain_positive()
|
||||
m.q_variance[:,:2].tie_to(m.rbf_l)
|
||||
m.rbf_l.constrain_positive()
|
||||
m.q_variance[1,[0,5,11,19,2]].tie_to(m.rbf_v)
|
||||
m.rbf_v.tie_to(m.rbf_l[0])
|
||||
m.rbf_l[0].tie_to(m.rbf_l[1])
|
||||
#m.q_v.tie_to(m.rbf_v)
|
||||
# m.rbf_l.tie_to(m.rbf_va)
|
||||
# pt = numpy.array(params._get_params_transformed())
|
||||
|
|
|
|||
|
|
@ -19,19 +19,8 @@ __fixed__ = "fixed"
|
|||
|
||||
#===============================================================================
|
||||
# constants
|
||||
class _F_(object):
|
||||
def __bool__(self):
|
||||
return False
|
||||
def __str__(self):
|
||||
return "FIXED"
|
||||
FIXED = _F_()
|
||||
class _U_(object):
|
||||
def __bool__(self):
|
||||
return True
|
||||
def __str__(self):
|
||||
return "UNFIXED"
|
||||
UNFIXED = _U_()
|
||||
del _U_, _F_
|
||||
FIXED = False
|
||||
UNFIXED = True
|
||||
#===============================================================================
|
||||
|
||||
class Parameterized(object):
|
||||
|
|
@ -268,13 +257,17 @@ class Parameterized(object):
|
|||
try:
|
||||
param_or_index = self._raveled_index_for(param_or_index)
|
||||
except AttributeError:
|
||||
self._ties_fixes_[param_or_index] = FIXED
|
||||
pass
|
||||
self._ties_fixes_[param_or_index] = FIXED
|
||||
if numpy.all(self._ties_fixes_): self._ties_fixes_ = None # ==UNFIXED
|
||||
def _set_unfixed(self, param_or_index):
|
||||
if self._ties_fixes_ is None: self._ties_fixes_ = numpy.ones(self._parameter_size_, dtype=bool)
|
||||
try:
|
||||
param_or_index = self._raveled_index_for(param_or_index)
|
||||
except AttributeError:
|
||||
self._ties_fixes_[param_or_index] = UNFIXED
|
||||
pass
|
||||
self._ties_fixes_[param_or_index] = UNFIXED
|
||||
if numpy.all(self._ties_fixes_): self._ties_fixes_ = None # ==UNFIXED
|
||||
# def _add_tie(self, param, tied_to):
|
||||
# # tie param to tie_to, if the values match (with broadcasting)
|
||||
# self._remove_tie(param) # delete if multiple ties should be allowed
|
||||
|
|
@ -425,7 +418,7 @@ class Parameterized(object):
|
|||
return [x._desc for x in self._parameters_]
|
||||
@property
|
||||
def _ts(self):
|
||||
return [' '.join([t._short() for t in self._ties_iter(x)]) for x in self._parameters_]
|
||||
return [' '.join([t._short() for t in x._tied_to_]) for x in self._parameters_]
|
||||
def __str__(self, header=True):
|
||||
nl = max([len(str(x)) for x in self.parameter_names + ["Name"]])
|
||||
sl = max([len(str(x)) for x in self._descs + ["Value"]])
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue