prior domain check

This commit is contained in:
Max Zwiessele 2014-03-14 12:32:08 +00:00
parent a800f4b7ed
commit 3e93579e3d
2 changed files with 13 additions and 4 deletions

View file

@ -226,7 +226,7 @@ class Param(OptimizationHandlable, ObservableArray):
# Constrainable # Constrainable
#=========================================================================== #===========================================================================
def _ensure_fixes(self): def _ensure_fixes(self):
if not self._has_fixes(): self._fixes_ = numpy.ones(self._realsize_, dtype=bool) self._fixes_ = numpy.ones(self._realsize_, dtype=bool)
#=========================================================================== #===========================================================================
# Convenience # Convenience

View file

@ -411,6 +411,15 @@ class Constrainable(Nameable, Indexable):
repriorized = self.unset_priors() repriorized = self.unset_priors()
self._add_to_index_operations(self.priors, repriorized, prior, warning) self._add_to_index_operations(self.priors, repriorized, prior, warning)
from domains import _REAL, _POSITIVE, _NEGATIVE
if prior.domain is _POSITIVE:
self.constrain_positive(warning)
elif prior.domain is _NEGATIVE:
self.constrain_negative(warning)
elif prior.domain is _REAL:
rav_i = self._raveled_index()
assert all(all(c.domain is _REAL for c in con) for con in self.constraints.properties_for(rav_i))
def unset_priors(self, *priors): def unset_priors(self, *priors):
""" """
Un-set all priors given from this parameter handle. Un-set all priors given from this parameter handle.
@ -421,14 +430,14 @@ class Constrainable(Nameable, Indexable):
def log_prior(self): def log_prior(self):
"""evaluate the prior""" """evaluate the prior"""
if self.priors.size > 0: if self.priors.size > 0:
x = self._get_params() x = self._param_array_
return reduce(lambda a, b: a + b, [p.lnpdf(x[ind]).sum() for p, ind in self.priors.iteritems()], 0) return reduce(lambda a, b: a + b, (p.lnpdf(x[ind]).sum() for p, ind in self.priors.iteritems()), 0)
return 0. return 0.
def _log_prior_gradients(self): def _log_prior_gradients(self):
"""evaluate the gradients of the priors""" """evaluate the gradients of the priors"""
if self.priors.size > 0: if self.priors.size > 0:
x = self._get_params() x = self._param_array_
ret = np.zeros(x.size) ret = np.zeros(x.size)
[np.put(ret, ind, p.lnpdf_grad(x[ind])) for p, ind in self.priors.iteritems()] [np.put(ret, ind, p.lnpdf_grad(x[ind])) for p, ind in self.priors.iteritems()]
return ret return ret