Priors added

This commit is contained in:
Max Zwiessele 2014-02-14 13:18:48 +00:00
parent 922e72af94
commit b1c98c2c3d
8 changed files with 206 additions and 356 deletions

View file

@ -66,9 +66,6 @@ class Parameterized(Constrainable, Pickleable, Observable, Gradcheckable):
self._added_names_ = set()
del self._in_init_
def _has_fixes(self):
return hasattr(self, "_fixes_") and self._fixes_ is not None
def add_parameter(self, param, index=None):
"""
:param parameters: the parameters to add
@ -88,11 +85,14 @@ class Parameterized(Constrainable, Pickleable, Observable, Gradcheckable):
# make sure the size is set
if index is None:
self.constraints.update(param.constraints, self.size)
self.priors.update(param.priors, self.size)
self._parameters_.append(param)
else:
start = sum(p.size for p in self._parameters_[:index])
self.constraints.shift(start, param.size)
self.priors.shift(start, param.size)
self.constraints.update(param.constraints, start)
self.priors.update(param.priors, start)
self._parameters_.insert(index, param)
self.size += param.size
else:
@ -210,6 +210,7 @@ class Parameterized(Constrainable, Pickleable, Observable, Gradcheckable):
"""
return [
self._fixes_,
self.priors,
self.constraints,
self._parameters_,
self._name,
@ -220,9 +221,10 @@ class Parameterized(Constrainable, Pickleable, Observable, Gradcheckable):
self._added_names_ = state.pop()
self._name = state.pop()
self._parameters_ = state.pop()
self._connect_parameters()
self.constraints = state.pop()
self.priors = state.pop()
self._fixes_ = state.pop()
self._connect_parameters()
self.parameters_changed()
#===========================================================================
# Gradient control
@ -248,16 +250,6 @@ class Parameterized(Constrainable, Pickleable, Observable, Gradcheckable):
if self._has_fixes():
return n[self._fixes_]
return n
def _get_params(self):
# don't overwrite this anymore!
if not self.size:
return np.empty(shape=(0,), dtype=np.float64)
return numpy.hstack([x._get_params() for x in self._parameters_ if x.size > 0])
def _set_params(self, params, update=True):
# don't overwrite this anymore!
[p._set_params(params[s], update=update) for p, s in itertools.izip(self._parameters_, self._param_slices_)]
self.parameters_changed()
def _get_params_transformed(self):
# transformed parameters (apply transformation rules)
p = self._get_params()
@ -348,36 +340,6 @@ class Parameterized(Constrainable, Pickleable, Observable, Gradcheckable):
return self._direct_parent_.hirarchy_name() + adjust_name_for_printing(self.name) + "."
return ''
#===========================================================================
# Constraint Handling:
#===========================================================================
#===========================================================================
# def _add_constrain(self, param, transform, warning=True):
# rav_i = self._raveled_index_for(param)
# reconstrained = self._remove_constrain(param, index=rav_i) # remove constraints before
# # if removing constraints before adding new is not wanted, just delete the above line!
# self.constraints.add(transform, rav_i)
# param = self._get_original(param)
# if not (transform == __fixed__):
# param._set_params(transform.initialize(param._get_params()), update=False)
# if warning and any(reconstrained):
# # if you want to print the whole params object, which was reconstrained use:
# # m = str(param[self._backtranslate_index(param, reconstrained)])
# print "Warning: re-constraining parameters:\n{}".format(param._short())
# return rav_i
# def _remove_constrain(self, param, *transforms, **kwargs):
# if not transforms:
# transforms = self.constraints.properties()
# removed_indices = numpy.array([]).astype(int)
# if "index" in kwargs: index = kwargs['index']
# else: index = self._raveled_index_for(param)
# for constr in transforms:
# removed = self.constraints.remove(constr, index)
# if constr is __fixed__:
# self._set_unfixed(removed)
# removed_indices = numpy.union1d(removed_indices, removed)
# return removed_indices
#===========================================================================
#===========================================================================
# Get/set parameters:
#===========================================================================
def grep_param_names(self, regexp):
@ -434,9 +396,6 @@ class Parameterized(Constrainable, Pickleable, Observable, Gradcheckable):
return self._direct_parent_.hirarchy_name() + adjust_name_for_printing(self.name)
else:
return adjust_name_for_printing(self.name)
#parameter_names = property(parameter_names, doc="Names for all parameters handled by this parameterization object -- will add hirarchy name entries for printing")
def _collect_gradient(self, target):
[p._collect_gradient(target[s]) for p, s in itertools.izip(self._parameters_, self._param_slices_)]
@property
def flattened_parameters(self):
return [xi for x in self._parameters_ for xi in x.flattened_parameters]
@ -455,6 +414,9 @@ class Parameterized(Constrainable, Pickleable, Observable, Gradcheckable):
def _constraints_str(self):
return [cs for p in self._parameters_ for cs in p._constraints_str]
@property
def _priors_str(self):
return [cs for p in self._parameters_ for cs in p._priors_str]
@property
def _description_str(self):
return [xi for x in self._parameters_ for xi in x._description_str]
@property
@ -463,20 +425,23 @@ class Parameterized(Constrainable, Pickleable, Observable, Gradcheckable):
def __str__(self, header=True):
name = adjust_name_for_printing(self.name) + "."
constrs = self._constraints_str; ts = self._ties_str
constrs = self._constraints_str;
ts = self._ties_str
prirs = self._priors_str
desc = self._description_str; names = self.parameter_names()
nl = max([len(str(x)) for x in names + [name]])
sl = max([len(str(x)) for x in desc + ["Value"]])
cl = max([len(str(x)) if x else 0 for x in constrs + ["Constraint"]])
tl = max([len(str(x)) if x else 0 for x in ts + ["Tied to"]])
format_spec = " \033[1m{{name:<{0}s}}\033[0;0m | {{desc:^{1}s}} | {{const:^{2}s}} | {{t:^{3}s}}".format(nl, sl, cl, tl)
pl = max([len(str(x)) if x else 0 for x in prirs + ["Prior"]])
format_spec = " \033[1m{{name:<{0}s}}\033[0;0m | {{desc:^{1}s}} | {{const:^{2}s}} | {{pri:^{3}s}} | {{t:^{4}s}}".format(nl, sl, cl, pl, tl)
to_print = []
for n, d, c, t in itertools.izip(names, desc, constrs, ts):
to_print.append(format_spec.format(name=n, desc=d, const=c, t=t))
for n, d, c, t, p in itertools.izip(names, desc, constrs, ts, prirs):
to_print.append(format_spec.format(name=n, desc=d, const=c, t=t, pri=p))
# to_print = [format_spec.format(p=p, const=c, t=t) if isinstance(p, Param) else p.__str__(header=False) for p, c, t in itertools.izip(self._parameters_, constrs, ts)]
sep = '-' * (nl + sl + cl + tl + 8 * 2 + 3)
sep = '-' * (nl + sl + cl + + pl + tl + 8 * 2 + 3)
if header:
header = " {{0:<{0}s}} | {{1:^{1}s}} | {{2:^{2}s}} | {{3:^{3}s}}".format(nl, sl, cl, tl).format(name, "Value", "Constraint", "Tied to")
header = " {{0:<{0}s}} | {{1:^{1}s}} | {{2:^{2}s}} | {{3:^{3}s}} | {{4:^{4}s}}".format(nl, sl, cl, pl, tl).format(name, "Value", "Constraint", "Prior", "Tied to")
# header += '\n' + sep
to_print.insert(0, header)
return '\n'.format(sep).join(to_print)