warning messages optional for re-constraining

This commit is contained in:
Max Zwiessele 2013-10-14 17:59:04 +01:00
parent e0d7de9b3b
commit 3a97eaf684
2 changed files with 20 additions and 18 deletions

View file

@ -250,7 +250,7 @@ class Model(Parameterized):
else: else:
self._set_params_transformed(initial_parameters) self._set_params_transformed(initial_parameters)
def ensure_default_constraints(self): def ensure_default_constraints(self, warning=True):
""" """
Ensure that any variables which should clearly be positive Ensure that any variables which should clearly be positive
have been constrained somehow. The method performs a regular have been constrained somehow. The method performs a regular
@ -268,7 +268,7 @@ class Model(Parameterized):
if not (i in currently_constrained): if not (i in currently_constrained):
to_make_positive.append(i) to_make_positive.append(i)
if len(to_make_positive): if len(to_make_positive):
self.constrain_positive(np.asarray(to_make_positive)) self.constrain_positive(np.asarray(to_make_positive), warning=warning)
def objective_function(self, x): def objective_function(self, x):
""" """

View file

@ -193,17 +193,17 @@ class Parameterized(object):
else: else:
self.fixed_indices, self.fixed_values = [], [] self.fixed_indices, self.fixed_values = [], []
def constrain_negative(self, regexp): def constrain_negative(self, regexp, warning=True):
""" Set negative constraints. """ """ Set negative constraints. """
self.constrain(regexp, transformations.NegativeLogexp()) self.constrain(regexp, transformations.NegativeLogexp(), warning)
def constrain_positive(self, regexp): def constrain_positive(self, regexp, warning=True):
""" Set positive constraints. """ """ Set positive constraints. """
self.constrain(regexp, transformations.Logexp()) self.constrain(regexp, transformations.Logexp(), warning)
def constrain_bounded(self, regexp, lower, upper): def constrain_bounded(self, regexp, lower, upper, warning=True):
""" Set bounded constraints. """ """ Set bounded constraints. """
self.constrain(regexp, transformations.Logistic(lower, upper)) self.constrain(regexp, transformations.Logistic(lower, upper), warning)
def all_constrained_indices(self): def all_constrained_indices(self):
if len(self.constrained_indices) or len(self.fixed_indices): if len(self.constrained_indices) or len(self.fixed_indices):
@ -211,13 +211,14 @@ class Parameterized(object):
else: else:
return np.empty(shape=(0,)) return np.empty(shape=(0,))
def constrain(self, regexp, transform): def constrain(self, regexp, transform, warning=True):
assert isinstance(transform, transformations.Transformation) assert isinstance(transform, transformations.Transformation)
matches = self.grep_param_names(regexp) matches = self.grep_param_names(regexp)
overlap = set(matches).intersection(set(self.all_constrained_indices())) overlap = set(matches).intersection(set(self.all_constrained_indices()))
if overlap: if overlap:
self.unconstrain(np.asarray(list(overlap))) self.unconstrain(np.asarray(list(overlap)))
if warning:
print 'Warning: re-constraining these parameters' print 'Warning: re-constraining these parameters'
pn = self._get_param_names() pn = self._get_param_names()
for i in overlap: for i in overlap:
@ -229,7 +230,7 @@ class Parameterized(object):
x[matches] = transform.initialize(x[matches]) x[matches] = transform.initialize(x[matches])
self._set_params(x) self._set_params(x)
def constrain_fixed(self, regexp, value=None): def constrain_fixed(self, regexp, value=None, warning=True):
""" """
:param regexp: which parameters need to be fixed. :param regexp: which parameters need to be fixed.
@ -249,6 +250,7 @@ class Parameterized(object):
overlap = set(matches).intersection(set(self.all_constrained_indices())) overlap = set(matches).intersection(set(self.all_constrained_indices()))
if overlap: if overlap:
self.unconstrain(np.asarray(list(overlap))) self.unconstrain(np.asarray(list(overlap)))
if warning:
print 'Warning: re-constraining these parameters' print 'Warning: re-constraining these parameters'
pn = self._get_param_names() pn = self._get_param_names()
for i in overlap: for i in overlap: