mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-03 16:52:39 +02:00
warning messages optional for re-constraining
This commit is contained in:
parent
e0d7de9b3b
commit
3a97eaf684
2 changed files with 20 additions and 18 deletions
|
|
@ -250,7 +250,7 @@ class Model(Parameterized):
|
||||||
else:
|
else:
|
||||||
self._set_params_transformed(initial_parameters)
|
self._set_params_transformed(initial_parameters)
|
||||||
|
|
||||||
def ensure_default_constraints(self):
|
def ensure_default_constraints(self, warning=True):
|
||||||
"""
|
"""
|
||||||
Ensure that any variables which should clearly be positive
|
Ensure that any variables which should clearly be positive
|
||||||
have been constrained somehow. The method performs a regular
|
have been constrained somehow. The method performs a regular
|
||||||
|
|
@ -268,7 +268,7 @@ class Model(Parameterized):
|
||||||
if not (i in currently_constrained):
|
if not (i in currently_constrained):
|
||||||
to_make_positive.append(i)
|
to_make_positive.append(i)
|
||||||
if len(to_make_positive):
|
if len(to_make_positive):
|
||||||
self.constrain_positive(np.asarray(to_make_positive))
|
self.constrain_positive(np.asarray(to_make_positive), warning=warning)
|
||||||
|
|
||||||
def objective_function(self, x):
|
def objective_function(self, x):
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -193,17 +193,17 @@ class Parameterized(object):
|
||||||
else:
|
else:
|
||||||
self.fixed_indices, self.fixed_values = [], []
|
self.fixed_indices, self.fixed_values = [], []
|
||||||
|
|
||||||
def constrain_negative(self, regexp):
|
def constrain_negative(self, regexp, warning=True):
|
||||||
""" Set negative constraints. """
|
""" Set negative constraints. """
|
||||||
self.constrain(regexp, transformations.NegativeLogexp())
|
self.constrain(regexp, transformations.NegativeLogexp(), warning)
|
||||||
|
|
||||||
def constrain_positive(self, regexp):
|
def constrain_positive(self, regexp, warning=True):
|
||||||
""" Set positive constraints. """
|
""" Set positive constraints. """
|
||||||
self.constrain(regexp, transformations.Logexp())
|
self.constrain(regexp, transformations.Logexp(), warning)
|
||||||
|
|
||||||
def constrain_bounded(self, regexp, lower, upper):
|
def constrain_bounded(self, regexp, lower, upper, warning=True):
|
||||||
""" Set bounded constraints. """
|
""" Set bounded constraints. """
|
||||||
self.constrain(regexp, transformations.Logistic(lower, upper))
|
self.constrain(regexp, transformations.Logistic(lower, upper), warning)
|
||||||
|
|
||||||
def all_constrained_indices(self):
|
def all_constrained_indices(self):
|
||||||
if len(self.constrained_indices) or len(self.fixed_indices):
|
if len(self.constrained_indices) or len(self.fixed_indices):
|
||||||
|
|
@ -211,17 +211,18 @@ class Parameterized(object):
|
||||||
else:
|
else:
|
||||||
return np.empty(shape=(0,))
|
return np.empty(shape=(0,))
|
||||||
|
|
||||||
def constrain(self, regexp, transform):
|
def constrain(self, regexp, transform, warning=True):
|
||||||
assert isinstance(transform, transformations.Transformation)
|
assert isinstance(transform, transformations.Transformation)
|
||||||
|
|
||||||
matches = self.grep_param_names(regexp)
|
matches = self.grep_param_names(regexp)
|
||||||
overlap = set(matches).intersection(set(self.all_constrained_indices()))
|
overlap = set(matches).intersection(set(self.all_constrained_indices()))
|
||||||
if overlap:
|
if overlap:
|
||||||
self.unconstrain(np.asarray(list(overlap)))
|
self.unconstrain(np.asarray(list(overlap)))
|
||||||
print 'Warning: re-constraining these parameters'
|
if warning:
|
||||||
pn = self._get_param_names()
|
print 'Warning: re-constraining these parameters'
|
||||||
for i in overlap:
|
pn = self._get_param_names()
|
||||||
print pn[i]
|
for i in overlap:
|
||||||
|
print pn[i]
|
||||||
|
|
||||||
self.constrained_indices.append(matches)
|
self.constrained_indices.append(matches)
|
||||||
self.constraints.append(transform)
|
self.constraints.append(transform)
|
||||||
|
|
@ -229,7 +230,7 @@ class Parameterized(object):
|
||||||
x[matches] = transform.initialize(x[matches])
|
x[matches] = transform.initialize(x[matches])
|
||||||
self._set_params(x)
|
self._set_params(x)
|
||||||
|
|
||||||
def constrain_fixed(self, regexp, value=None):
|
def constrain_fixed(self, regexp, value=None, warning=True):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
:param regexp: which parameters need to be fixed.
|
:param regexp: which parameters need to be fixed.
|
||||||
|
|
@ -249,10 +250,11 @@ class Parameterized(object):
|
||||||
overlap = set(matches).intersection(set(self.all_constrained_indices()))
|
overlap = set(matches).intersection(set(self.all_constrained_indices()))
|
||||||
if overlap:
|
if overlap:
|
||||||
self.unconstrain(np.asarray(list(overlap)))
|
self.unconstrain(np.asarray(list(overlap)))
|
||||||
print 'Warning: re-constraining these parameters'
|
if warning:
|
||||||
pn = self._get_param_names()
|
print 'Warning: re-constraining these parameters'
|
||||||
for i in overlap:
|
pn = self._get_param_names()
|
||||||
print pn[i]
|
for i in overlap:
|
||||||
|
print pn[i]
|
||||||
|
|
||||||
self.fixed_indices.append(matches)
|
self.fixed_indices.append(matches)
|
||||||
if value != None:
|
if value != None:
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue