From 826d2f04ff1fa6fb4dff0fc7ddf05c32d0bbc0ab Mon Sep 17 00:00:00 2001 From: Max Zwiessele Date: Mon, 10 Feb 2014 16:01:55 +0000 Subject: [PATCH] checkgrad now global and callable from any parameter --- GPy/core/model.py | 4 ++-- GPy/core/parameterization/param.py | 11 ++++++++--- GPy/core/parameterization/parameter_core.py | 15 +++++++++++++-- GPy/core/parameterization/parameterized.py | 11 ++++------- 4 files changed, 27 insertions(+), 14 deletions(-) diff --git a/GPy/core/model.py b/GPy/core/model.py index 35403ba7..db811801 100644 --- a/GPy/core/model.py +++ b/GPy/core/model.py @@ -380,7 +380,7 @@ class Model(Parameterized): sgd.run() self.optimization_runs.append(sgd) - def checkgrad(self, target_param=None, verbose=False, step=1e-6, tolerance=1e-3): + def _checkgrad(self, target_param=None, verbose=False, step=1e-6, tolerance=1e-3): """ Check the gradient of the ,odel by comparing to a numerical estimate. If the verbose flag is passed, invividual @@ -434,7 +434,7 @@ class Model(Parameterized): if target_param is None: param_list = range(len(x)) else: - param_list = self.grep_param_names(target_param, transformed=True, search=True) + param_list = self._raveled_index_for(target_param) if not np.any(param_list): print "No free parameters to check" return diff --git a/GPy/core/parameterization/param.py b/GPy/core/parameterization/param.py index 80661de0..4fc3aca0 100644 --- a/GPy/core/parameterization/param.py +++ b/GPy/core/parameterization/param.py @@ -3,8 +3,8 @@ import itertools import numpy -from parameter_core import Constrainable, adjust_name_for_printing -from array_core import ObservableArray +from parameter_core import Constrainable, Gradcheckable, adjust_name_for_printing +from array_core import ObservableArray, ParamList ###### printing __constraints_name__ = "Constraint" @@ -20,7 +20,7 @@ class Float(numpy.float64, Constrainable): self._base = base -class Param(ObservableArray, Constrainable): +class Param(ObservableArray, Constrainable, Gradcheckable): """ Parameter object for GPy models. @@ -547,6 +547,11 @@ class ParamConcatenation(object): def untie(self, *ties): [param.untie(*ties) for param in self.params] + + def checkgrad(self, verbose=0, step=1e-6, tolerance=1e-3): + return self.params[0]._highest_parent_._checkgrad(self, verbose, step, tolerance) + #checkgrad.__doc__ = Gradcheckable.checkgrad.__doc__ + __lt__ = lambda self, val: self._vals() < val __le__ = lambda self, val: self._vals() <= val __eq__ = lambda self, val: self._vals() == val diff --git a/GPy/core/parameterization/parameter_core.py b/GPy/core/parameterization/parameter_core.py index 51d9a110..b22e14f7 100644 --- a/GPy/core/parameterization/parameter_core.py +++ b/GPy/core/parameterization/parameter_core.py @@ -52,7 +52,6 @@ class Parentable(object): super(Parentable,self).__init__() self._direct_parent_ = direct_parent self._parent_index_ = parent_index - self._highest_parent_ = highest_parent def has_parent(self): return self._direct_parent_ is not None @@ -77,7 +76,19 @@ class Nameable(Parentable): from_name = self.name self._name = name if self.has_parent(): - self._direct_parent_._name_changed(self, from_name) + self._direct_parent_._name_changed(self, from_name) + +class Gradcheckable(Parentable): + #=========================================================================== + # Gradchecking + #=========================================================================== + def checkgrad(self, verbose=0, step=1e-6, tolerance=1e-3): + if self.has_parent(): + return self._highest_parent_._checkgrad(self, verbose=verbose, step=step, tolerance=tolerance) + return self._checkgrad(self[''], verbose=verbose, step=step, tolerance=tolerance) + def _checkgrad(self, param): + raise NotImplementedError, "Need log likelihood to check gradient against" + class Constrainable(Nameable): def __init__(self, name): diff --git a/GPy/core/parameterization/parameterized.py b/GPy/core/parameterization/parameterized.py index 61097951..746163dc 100644 --- a/GPy/core/parameterization/parameterized.py +++ b/GPy/core/parameterization/parameterized.py @@ -8,7 +8,7 @@ import cPickle import itertools from re import compile, _pattern_type from param import ParamConcatenation, Param -from parameter_core import Constrainable, Pickleable, Observable, adjust_name_for_printing +from parameter_core import Constrainable, Pickleable, Observable, adjust_name_for_printing, Gradcheckable from index_operations import ParameterIndexOperations,\ index_empty from array_core import ParamList @@ -24,7 +24,7 @@ FIXED = False UNFIXED = True #=============================================================================== -class Parameterized(Constrainable, Pickleable, Observable): +class Parameterized(Constrainable, Pickleable, Observable, Gradcheckable): """ Parameterized class @@ -230,7 +230,7 @@ class Parameterized(Constrainable, Pickleable, Observable): elif not (pname in not_unique): self.__dict__[pname] = p self._added_names_.add(pname) - + #=========================================================================== # Pickling operations #=========================================================================== @@ -385,7 +385,7 @@ class Parameterized(Constrainable, Pickleable, Observable): """ return numpy.r_[:self.size] #=========================================================================== - # Handle ties: + # Fixing parameters: #=========================================================================== def _set_fixed(self, param_or_index): if not self._has_fixes(): self._fixes_ = numpy.ones(self.size, dtype=bool) @@ -410,9 +410,6 @@ class Parameterized(Constrainable, Pickleable, Observable): if self._has_fixes(): return self._fixes_[self._raveled_index_for(param)] return numpy.ones(self.size, dtype=bool)[self._raveled_index_for(param)] - #=========================================================================== - # Fixing parameters: - #=========================================================================== def _fix(self, param, warning=True): f = self._add_constrain(param, __fixed__, warning) self._set_fixed(f)