mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-08 19:42:39 +02:00
[parameterization] Parameter adding more robust and better error handling
This commit is contained in:
parent
051a8115a2
commit
5f863e8ca2
3 changed files with 19 additions and 5 deletions
|
|
@ -954,19 +954,30 @@ class Parameterizable(OptimizationHandlable):
|
||||||
if ignore_added_names:
|
if ignore_added_names:
|
||||||
self.__dict__[pname] = param
|
self.__dict__[pname] = param
|
||||||
return
|
return
|
||||||
|
|
||||||
|
def warn_and_retry():
|
||||||
|
print """
|
||||||
|
WARNING: added a parameter with formatted name {},
|
||||||
|
which is already assigned to {}.
|
||||||
|
Trying to change the parameter name to
|
||||||
|
|
||||||
|
{}.{}
|
||||||
|
""".format(pname, self.hierarchy_name(), self.hierarchy_name(), param.name + "_")
|
||||||
|
param.name += "_"
|
||||||
|
self._add_parameter_name(param, ignore_added_names)
|
||||||
# and makes sure to not delete programmatically added parameters
|
# and makes sure to not delete programmatically added parameters
|
||||||
if pname in self.__dict__:
|
if pname in self.__dict__:
|
||||||
if not (param is self.__dict__[pname]):
|
if not (param is self.__dict__[pname]):
|
||||||
if pname in self._added_names_:
|
if pname in self._added_names_:
|
||||||
del self.__dict__[pname]
|
del self.__dict__[pname]
|
||||||
self._add_parameter_name(param)
|
self._add_parameter_name(param)
|
||||||
|
else:
|
||||||
|
warn_and_retry()
|
||||||
elif pname not in dir(self):
|
elif pname not in dir(self):
|
||||||
self.__dict__[pname] = param
|
self.__dict__[pname] = param
|
||||||
self._added_names_.add(pname)
|
self._added_names_.add(pname)
|
||||||
else:
|
else:
|
||||||
print "WARNING: added a parameter with formatted name {}, which is already a member of {} object. Trying to change the parameter name to\n {}".format(pname, self.__class__, param.name + "_")
|
warn_and_retry()
|
||||||
param.name += "_"
|
|
||||||
self._add_parameter_name(param, ignore_added_names)
|
|
||||||
|
|
||||||
def _remove_parameter_name(self, param=None, pname=None):
|
def _remove_parameter_name(self, param=None, pname=None):
|
||||||
assert param is None or pname is None, "can only delete either param by name, or the name of a param"
|
assert param is None or pname is None, "can only delete either param by name, or the name of a param"
|
||||||
|
|
|
||||||
|
|
@ -180,7 +180,10 @@ class Parameterized(Parameterizable):
|
||||||
:param param: param object to remove from being a parameter of this parameterized object.
|
:param param: param object to remove from being a parameter of this parameterized object.
|
||||||
"""
|
"""
|
||||||
if not param in self.parameters:
|
if not param in self.parameters:
|
||||||
raise RuntimeError, "Parameter {} does not belong to this object {}, remove parameters directly from their respective parents".format(param._short(), self.name)
|
try:
|
||||||
|
raise RuntimeError, "{} does not belong to this object {}, remove parameters directly from their respective parents".format(param._short(), self.name)
|
||||||
|
except AttributeError:
|
||||||
|
raise RuntimeError, "{} does not seem to be a parameter, remove parameters directly from their respective parents".format(str(param))
|
||||||
|
|
||||||
start = sum([p.size for p in self.parameters[:param._parent_index_]])
|
start = sum([p.size for p in self.parameters[:param._parent_index_]])
|
||||||
self._remove_parameter_name(param)
|
self._remove_parameter_name(param)
|
||||||
|
|
|
||||||
|
|
@ -111,7 +111,7 @@ class VariationalPosterior(Parameterized):
|
||||||
n.input_dim = n.mean.shape[1] if n.ndim != 1 else 1
|
n.input_dim = n.mean.shape[1] if n.ndim != 1 else 1
|
||||||
return n
|
return n
|
||||||
else:
|
else:
|
||||||
return super(VariationalPrior, self).__getitem__(s)
|
return super(VariationalPosterior, self).__getitem__(s)
|
||||||
|
|
||||||
class NormalPosterior(VariationalPosterior):
|
class NormalPosterior(VariationalPosterior):
|
||||||
'''
|
'''
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue