mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-09 20:12:38 +02:00
checkgrad now global and callable from any parameter
This commit is contained in:
parent
6a068775f5
commit
826d2f04ff
4 changed files with 27 additions and 14 deletions
|
|
@ -380,7 +380,7 @@ class Model(Parameterized):
|
|||
sgd.run()
|
||||
self.optimization_runs.append(sgd)
|
||||
|
||||
def checkgrad(self, target_param=None, verbose=False, step=1e-6, tolerance=1e-3):
|
||||
def _checkgrad(self, target_param=None, verbose=False, step=1e-6, tolerance=1e-3):
|
||||
"""
|
||||
Check the gradient of the ,odel by comparing to a numerical
|
||||
estimate. If the verbose flag is passed, invividual
|
||||
|
|
@ -434,7 +434,7 @@ class Model(Parameterized):
|
|||
if target_param is None:
|
||||
param_list = range(len(x))
|
||||
else:
|
||||
param_list = self.grep_param_names(target_param, transformed=True, search=True)
|
||||
param_list = self._raveled_index_for(target_param)
|
||||
if not np.any(param_list):
|
||||
print "No free parameters to check"
|
||||
return
|
||||
|
|
|
|||
|
|
@ -3,8 +3,8 @@
|
|||
|
||||
import itertools
|
||||
import numpy
|
||||
from parameter_core import Constrainable, adjust_name_for_printing
|
||||
from array_core import ObservableArray
|
||||
from parameter_core import Constrainable, Gradcheckable, adjust_name_for_printing
|
||||
from array_core import ObservableArray, ParamList
|
||||
|
||||
###### printing
|
||||
__constraints_name__ = "Constraint"
|
||||
|
|
@ -20,7 +20,7 @@ class Float(numpy.float64, Constrainable):
|
|||
self._base = base
|
||||
|
||||
|
||||
class Param(ObservableArray, Constrainable):
|
||||
class Param(ObservableArray, Constrainable, Gradcheckable):
|
||||
"""
|
||||
Parameter object for GPy models.
|
||||
|
||||
|
|
@ -547,6 +547,11 @@ class ParamConcatenation(object):
|
|||
|
||||
def untie(self, *ties):
|
||||
[param.untie(*ties) for param in self.params]
|
||||
|
||||
def checkgrad(self, verbose=0, step=1e-6, tolerance=1e-3):
|
||||
return self.params[0]._highest_parent_._checkgrad(self, verbose, step, tolerance)
|
||||
#checkgrad.__doc__ = Gradcheckable.checkgrad.__doc__
|
||||
|
||||
__lt__ = lambda self, val: self._vals() < val
|
||||
__le__ = lambda self, val: self._vals() <= val
|
||||
__eq__ = lambda self, val: self._vals() == val
|
||||
|
|
|
|||
|
|
@ -52,7 +52,6 @@ class Parentable(object):
|
|||
super(Parentable,self).__init__()
|
||||
self._direct_parent_ = direct_parent
|
||||
self._parent_index_ = parent_index
|
||||
self._highest_parent_ = highest_parent
|
||||
|
||||
def has_parent(self):
|
||||
return self._direct_parent_ is not None
|
||||
|
|
@ -79,6 +78,18 @@ class Nameable(Parentable):
|
|||
if self.has_parent():
|
||||
self._direct_parent_._name_changed(self, from_name)
|
||||
|
||||
class Gradcheckable(Parentable):
|
||||
#===========================================================================
|
||||
# Gradchecking
|
||||
#===========================================================================
|
||||
def checkgrad(self, verbose=0, step=1e-6, tolerance=1e-3):
|
||||
if self.has_parent():
|
||||
return self._highest_parent_._checkgrad(self, verbose=verbose, step=step, tolerance=tolerance)
|
||||
return self._checkgrad(self[''], verbose=verbose, step=step, tolerance=tolerance)
|
||||
def _checkgrad(self, param):
|
||||
raise NotImplementedError, "Need log likelihood to check gradient against"
|
||||
|
||||
|
||||
class Constrainable(Nameable):
|
||||
def __init__(self, name):
|
||||
super(Constrainable,self).__init__(name)
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ import cPickle
|
|||
import itertools
|
||||
from re import compile, _pattern_type
|
||||
from param import ParamConcatenation, Param
|
||||
from parameter_core import Constrainable, Pickleable, Observable, adjust_name_for_printing
|
||||
from parameter_core import Constrainable, Pickleable, Observable, adjust_name_for_printing, Gradcheckable
|
||||
from index_operations import ParameterIndexOperations,\
|
||||
index_empty
|
||||
from array_core import ParamList
|
||||
|
|
@ -24,7 +24,7 @@ FIXED = False
|
|||
UNFIXED = True
|
||||
#===============================================================================
|
||||
|
||||
class Parameterized(Constrainable, Pickleable, Observable):
|
||||
class Parameterized(Constrainable, Pickleable, Observable, Gradcheckable):
|
||||
"""
|
||||
Parameterized class
|
||||
|
||||
|
|
@ -385,7 +385,7 @@ class Parameterized(Constrainable, Pickleable, Observable):
|
|||
"""
|
||||
return numpy.r_[:self.size]
|
||||
#===========================================================================
|
||||
# Handle ties:
|
||||
# Fixing parameters:
|
||||
#===========================================================================
|
||||
def _set_fixed(self, param_or_index):
|
||||
if not self._has_fixes(): self._fixes_ = numpy.ones(self.size, dtype=bool)
|
||||
|
|
@ -410,9 +410,6 @@ class Parameterized(Constrainable, Pickleable, Observable):
|
|||
if self._has_fixes():
|
||||
return self._fixes_[self._raveled_index_for(param)]
|
||||
return numpy.ones(self.size, dtype=bool)[self._raveled_index_for(param)]
|
||||
#===========================================================================
|
||||
# Fixing parameters:
|
||||
#===========================================================================
|
||||
def _fix(self, param, warning=True):
|
||||
f = self._add_constrain(param, __fixed__, warning)
|
||||
self._set_fixed(f)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue