mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-06-11 15:15:15 +02:00
Priors added
This commit is contained in:
parent
922e72af94
commit
b1c98c2c3d
8 changed files with 206 additions and 356 deletions
|
|
@ -66,9 +66,6 @@ class Parameterized(Constrainable, Pickleable, Observable, Gradcheckable):
|
|||
self._added_names_ = set()
|
||||
del self._in_init_
|
||||
|
||||
def _has_fixes(self):
|
||||
return hasattr(self, "_fixes_") and self._fixes_ is not None
|
||||
|
||||
def add_parameter(self, param, index=None):
|
||||
"""
|
||||
:param parameters: the parameters to add
|
||||
|
|
@ -88,11 +85,14 @@ class Parameterized(Constrainable, Pickleable, Observable, Gradcheckable):
|
|||
# make sure the size is set
|
||||
if index is None:
|
||||
self.constraints.update(param.constraints, self.size)
|
||||
self.priors.update(param.priors, self.size)
|
||||
self._parameters_.append(param)
|
||||
else:
|
||||
start = sum(p.size for p in self._parameters_[:index])
|
||||
self.constraints.shift(start, param.size)
|
||||
self.priors.shift(start, param.size)
|
||||
self.constraints.update(param.constraints, start)
|
||||
self.priors.update(param.priors, start)
|
||||
self._parameters_.insert(index, param)
|
||||
self.size += param.size
|
||||
else:
|
||||
|
|
@ -210,6 +210,7 @@ class Parameterized(Constrainable, Pickleable, Observable, Gradcheckable):
|
|||
"""
|
||||
return [
|
||||
self._fixes_,
|
||||
self.priors,
|
||||
self.constraints,
|
||||
self._parameters_,
|
||||
self._name,
|
||||
|
|
@ -220,9 +221,10 @@ class Parameterized(Constrainable, Pickleable, Observable, Gradcheckable):
|
|||
self._added_names_ = state.pop()
|
||||
self._name = state.pop()
|
||||
self._parameters_ = state.pop()
|
||||
self._connect_parameters()
|
||||
self.constraints = state.pop()
|
||||
self.priors = state.pop()
|
||||
self._fixes_ = state.pop()
|
||||
self._connect_parameters()
|
||||
self.parameters_changed()
|
||||
#===========================================================================
|
||||
# Gradient control
|
||||
|
|
@ -248,16 +250,6 @@ class Parameterized(Constrainable, Pickleable, Observable, Gradcheckable):
|
|||
if self._has_fixes():
|
||||
return n[self._fixes_]
|
||||
return n
|
||||
def _get_params(self):
|
||||
# don't overwrite this anymore!
|
||||
if not self.size:
|
||||
return np.empty(shape=(0,), dtype=np.float64)
|
||||
return numpy.hstack([x._get_params() for x in self._parameters_ if x.size > 0])
|
||||
|
||||
def _set_params(self, params, update=True):
|
||||
# don't overwrite this anymore!
|
||||
[p._set_params(params[s], update=update) for p, s in itertools.izip(self._parameters_, self._param_slices_)]
|
||||
self.parameters_changed()
|
||||
def _get_params_transformed(self):
|
||||
# transformed parameters (apply transformation rules)
|
||||
p = self._get_params()
|
||||
|
|
@ -348,36 +340,6 @@ class Parameterized(Constrainable, Pickleable, Observable, Gradcheckable):
|
|||
return self._direct_parent_.hirarchy_name() + adjust_name_for_printing(self.name) + "."
|
||||
return ''
|
||||
#===========================================================================
|
||||
# Constraint Handling:
|
||||
#===========================================================================
|
||||
#===========================================================================
|
||||
# def _add_constrain(self, param, transform, warning=True):
|
||||
# rav_i = self._raveled_index_for(param)
|
||||
# reconstrained = self._remove_constrain(param, index=rav_i) # remove constraints before
|
||||
# # if removing constraints before adding new is not wanted, just delete the above line!
|
||||
# self.constraints.add(transform, rav_i)
|
||||
# param = self._get_original(param)
|
||||
# if not (transform == __fixed__):
|
||||
# param._set_params(transform.initialize(param._get_params()), update=False)
|
||||
# if warning and any(reconstrained):
|
||||
# # if you want to print the whole params object, which was reconstrained use:
|
||||
# # m = str(param[self._backtranslate_index(param, reconstrained)])
|
||||
# print "Warning: re-constraining parameters:\n{}".format(param._short())
|
||||
# return rav_i
|
||||
# def _remove_constrain(self, param, *transforms, **kwargs):
|
||||
# if not transforms:
|
||||
# transforms = self.constraints.properties()
|
||||
# removed_indices = numpy.array([]).astype(int)
|
||||
# if "index" in kwargs: index = kwargs['index']
|
||||
# else: index = self._raveled_index_for(param)
|
||||
# for constr in transforms:
|
||||
# removed = self.constraints.remove(constr, index)
|
||||
# if constr is __fixed__:
|
||||
# self._set_unfixed(removed)
|
||||
# removed_indices = numpy.union1d(removed_indices, removed)
|
||||
# return removed_indices
|
||||
#===========================================================================
|
||||
#===========================================================================
|
||||
# Get/set parameters:
|
||||
#===========================================================================
|
||||
def grep_param_names(self, regexp):
|
||||
|
|
@ -434,9 +396,6 @@ class Parameterized(Constrainable, Pickleable, Observable, Gradcheckable):
|
|||
return self._direct_parent_.hirarchy_name() + adjust_name_for_printing(self.name)
|
||||
else:
|
||||
return adjust_name_for_printing(self.name)
|
||||
#parameter_names = property(parameter_names, doc="Names for all parameters handled by this parameterization object -- will add hirarchy name entries for printing")
|
||||
def _collect_gradient(self, target):
|
||||
[p._collect_gradient(target[s]) for p, s in itertools.izip(self._parameters_, self._param_slices_)]
|
||||
@property
|
||||
def flattened_parameters(self):
|
||||
return [xi for x in self._parameters_ for xi in x.flattened_parameters]
|
||||
|
|
@ -455,6 +414,9 @@ class Parameterized(Constrainable, Pickleable, Observable, Gradcheckable):
|
|||
def _constraints_str(self):
|
||||
return [cs for p in self._parameters_ for cs in p._constraints_str]
|
||||
@property
|
||||
def _priors_str(self):
|
||||
return [cs for p in self._parameters_ for cs in p._priors_str]
|
||||
@property
|
||||
def _description_str(self):
|
||||
return [xi for x in self._parameters_ for xi in x._description_str]
|
||||
@property
|
||||
|
|
@ -463,20 +425,23 @@ class Parameterized(Constrainable, Pickleable, Observable, Gradcheckable):
|
|||
def __str__(self, header=True):
|
||||
|
||||
name = adjust_name_for_printing(self.name) + "."
|
||||
constrs = self._constraints_str; ts = self._ties_str
|
||||
constrs = self._constraints_str;
|
||||
ts = self._ties_str
|
||||
prirs = self._priors_str
|
||||
desc = self._description_str; names = self.parameter_names()
|
||||
nl = max([len(str(x)) for x in names + [name]])
|
||||
sl = max([len(str(x)) for x in desc + ["Value"]])
|
||||
cl = max([len(str(x)) if x else 0 for x in constrs + ["Constraint"]])
|
||||
tl = max([len(str(x)) if x else 0 for x in ts + ["Tied to"]])
|
||||
format_spec = " \033[1m{{name:<{0}s}}\033[0;0m | {{desc:^{1}s}} | {{const:^{2}s}} | {{t:^{3}s}}".format(nl, sl, cl, tl)
|
||||
pl = max([len(str(x)) if x else 0 for x in prirs + ["Prior"]])
|
||||
format_spec = " \033[1m{{name:<{0}s}}\033[0;0m | {{desc:^{1}s}} | {{const:^{2}s}} | {{pri:^{3}s}} | {{t:^{4}s}}".format(nl, sl, cl, pl, tl)
|
||||
to_print = []
|
||||
for n, d, c, t in itertools.izip(names, desc, constrs, ts):
|
||||
to_print.append(format_spec.format(name=n, desc=d, const=c, t=t))
|
||||
for n, d, c, t, p in itertools.izip(names, desc, constrs, ts, prirs):
|
||||
to_print.append(format_spec.format(name=n, desc=d, const=c, t=t, pri=p))
|
||||
# to_print = [format_spec.format(p=p, const=c, t=t) if isinstance(p, Param) else p.__str__(header=False) for p, c, t in itertools.izip(self._parameters_, constrs, ts)]
|
||||
sep = '-' * (nl + sl + cl + tl + 8 * 2 + 3)
|
||||
sep = '-' * (nl + sl + cl + + pl + tl + 8 * 2 + 3)
|
||||
if header:
|
||||
header = " {{0:<{0}s}} | {{1:^{1}s}} | {{2:^{2}s}} | {{3:^{3}s}}".format(nl, sl, cl, tl).format(name, "Value", "Constraint", "Tied to")
|
||||
header = " {{0:<{0}s}} | {{1:^{1}s}} | {{2:^{2}s}} | {{3:^{3}s}} | {{4:^{4}s}}".format(nl, sl, cl, pl, tl).format(name, "Value", "Constraint", "Prior", "Tied to")
|
||||
# header += '\n' + sep
|
||||
to_print.insert(0, header)
|
||||
return '\n'.format(sep).join(to_print)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue