mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-09 03:52:39 +02:00
checkgrad now global and callable from any parameter
This commit is contained in:
parent
6a068775f5
commit
826d2f04ff
4 changed files with 27 additions and 14 deletions
|
|
@ -380,7 +380,7 @@ class Model(Parameterized):
|
||||||
sgd.run()
|
sgd.run()
|
||||||
self.optimization_runs.append(sgd)
|
self.optimization_runs.append(sgd)
|
||||||
|
|
||||||
def checkgrad(self, target_param=None, verbose=False, step=1e-6, tolerance=1e-3):
|
def _checkgrad(self, target_param=None, verbose=False, step=1e-6, tolerance=1e-3):
|
||||||
"""
|
"""
|
||||||
Check the gradient of the ,odel by comparing to a numerical
|
Check the gradient of the ,odel by comparing to a numerical
|
||||||
estimate. If the verbose flag is passed, invividual
|
estimate. If the verbose flag is passed, invividual
|
||||||
|
|
@ -434,7 +434,7 @@ class Model(Parameterized):
|
||||||
if target_param is None:
|
if target_param is None:
|
||||||
param_list = range(len(x))
|
param_list = range(len(x))
|
||||||
else:
|
else:
|
||||||
param_list = self.grep_param_names(target_param, transformed=True, search=True)
|
param_list = self._raveled_index_for(target_param)
|
||||||
if not np.any(param_list):
|
if not np.any(param_list):
|
||||||
print "No free parameters to check"
|
print "No free parameters to check"
|
||||||
return
|
return
|
||||||
|
|
|
||||||
|
|
@ -3,8 +3,8 @@
|
||||||
|
|
||||||
import itertools
|
import itertools
|
||||||
import numpy
|
import numpy
|
||||||
from parameter_core import Constrainable, adjust_name_for_printing
|
from parameter_core import Constrainable, Gradcheckable, adjust_name_for_printing
|
||||||
from array_core import ObservableArray
|
from array_core import ObservableArray, ParamList
|
||||||
|
|
||||||
###### printing
|
###### printing
|
||||||
__constraints_name__ = "Constraint"
|
__constraints_name__ = "Constraint"
|
||||||
|
|
@ -20,7 +20,7 @@ class Float(numpy.float64, Constrainable):
|
||||||
self._base = base
|
self._base = base
|
||||||
|
|
||||||
|
|
||||||
class Param(ObservableArray, Constrainable):
|
class Param(ObservableArray, Constrainable, Gradcheckable):
|
||||||
"""
|
"""
|
||||||
Parameter object for GPy models.
|
Parameter object for GPy models.
|
||||||
|
|
||||||
|
|
@ -547,6 +547,11 @@ class ParamConcatenation(object):
|
||||||
|
|
||||||
def untie(self, *ties):
|
def untie(self, *ties):
|
||||||
[param.untie(*ties) for param in self.params]
|
[param.untie(*ties) for param in self.params]
|
||||||
|
|
||||||
|
def checkgrad(self, verbose=0, step=1e-6, tolerance=1e-3):
|
||||||
|
return self.params[0]._highest_parent_._checkgrad(self, verbose, step, tolerance)
|
||||||
|
#checkgrad.__doc__ = Gradcheckable.checkgrad.__doc__
|
||||||
|
|
||||||
__lt__ = lambda self, val: self._vals() < val
|
__lt__ = lambda self, val: self._vals() < val
|
||||||
__le__ = lambda self, val: self._vals() <= val
|
__le__ = lambda self, val: self._vals() <= val
|
||||||
__eq__ = lambda self, val: self._vals() == val
|
__eq__ = lambda self, val: self._vals() == val
|
||||||
|
|
|
||||||
|
|
@ -52,7 +52,6 @@ class Parentable(object):
|
||||||
super(Parentable,self).__init__()
|
super(Parentable,self).__init__()
|
||||||
self._direct_parent_ = direct_parent
|
self._direct_parent_ = direct_parent
|
||||||
self._parent_index_ = parent_index
|
self._parent_index_ = parent_index
|
||||||
self._highest_parent_ = highest_parent
|
|
||||||
|
|
||||||
def has_parent(self):
|
def has_parent(self):
|
||||||
return self._direct_parent_ is not None
|
return self._direct_parent_ is not None
|
||||||
|
|
@ -79,6 +78,18 @@ class Nameable(Parentable):
|
||||||
if self.has_parent():
|
if self.has_parent():
|
||||||
self._direct_parent_._name_changed(self, from_name)
|
self._direct_parent_._name_changed(self, from_name)
|
||||||
|
|
||||||
|
class Gradcheckable(Parentable):
|
||||||
|
#===========================================================================
|
||||||
|
# Gradchecking
|
||||||
|
#===========================================================================
|
||||||
|
def checkgrad(self, verbose=0, step=1e-6, tolerance=1e-3):
|
||||||
|
if self.has_parent():
|
||||||
|
return self._highest_parent_._checkgrad(self, verbose=verbose, step=step, tolerance=tolerance)
|
||||||
|
return self._checkgrad(self[''], verbose=verbose, step=step, tolerance=tolerance)
|
||||||
|
def _checkgrad(self, param):
|
||||||
|
raise NotImplementedError, "Need log likelihood to check gradient against"
|
||||||
|
|
||||||
|
|
||||||
class Constrainable(Nameable):
|
class Constrainable(Nameable):
|
||||||
def __init__(self, name):
|
def __init__(self, name):
|
||||||
super(Constrainable,self).__init__(name)
|
super(Constrainable,self).__init__(name)
|
||||||
|
|
|
||||||
|
|
@ -8,7 +8,7 @@ import cPickle
|
||||||
import itertools
|
import itertools
|
||||||
from re import compile, _pattern_type
|
from re import compile, _pattern_type
|
||||||
from param import ParamConcatenation, Param
|
from param import ParamConcatenation, Param
|
||||||
from parameter_core import Constrainable, Pickleable, Observable, adjust_name_for_printing
|
from parameter_core import Constrainable, Pickleable, Observable, adjust_name_for_printing, Gradcheckable
|
||||||
from index_operations import ParameterIndexOperations,\
|
from index_operations import ParameterIndexOperations,\
|
||||||
index_empty
|
index_empty
|
||||||
from array_core import ParamList
|
from array_core import ParamList
|
||||||
|
|
@ -24,7 +24,7 @@ FIXED = False
|
||||||
UNFIXED = True
|
UNFIXED = True
|
||||||
#===============================================================================
|
#===============================================================================
|
||||||
|
|
||||||
class Parameterized(Constrainable, Pickleable, Observable):
|
class Parameterized(Constrainable, Pickleable, Observable, Gradcheckable):
|
||||||
"""
|
"""
|
||||||
Parameterized class
|
Parameterized class
|
||||||
|
|
||||||
|
|
@ -385,7 +385,7 @@ class Parameterized(Constrainable, Pickleable, Observable):
|
||||||
"""
|
"""
|
||||||
return numpy.r_[:self.size]
|
return numpy.r_[:self.size]
|
||||||
#===========================================================================
|
#===========================================================================
|
||||||
# Handle ties:
|
# Fixing parameters:
|
||||||
#===========================================================================
|
#===========================================================================
|
||||||
def _set_fixed(self, param_or_index):
|
def _set_fixed(self, param_or_index):
|
||||||
if not self._has_fixes(): self._fixes_ = numpy.ones(self.size, dtype=bool)
|
if not self._has_fixes(): self._fixes_ = numpy.ones(self.size, dtype=bool)
|
||||||
|
|
@ -410,9 +410,6 @@ class Parameterized(Constrainable, Pickleable, Observable):
|
||||||
if self._has_fixes():
|
if self._has_fixes():
|
||||||
return self._fixes_[self._raveled_index_for(param)]
|
return self._fixes_[self._raveled_index_for(param)]
|
||||||
return numpy.ones(self.size, dtype=bool)[self._raveled_index_for(param)]
|
return numpy.ones(self.size, dtype=bool)[self._raveled_index_for(param)]
|
||||||
#===========================================================================
|
|
||||||
# Fixing parameters:
|
|
||||||
#===========================================================================
|
|
||||||
def _fix(self, param, warning=True):
|
def _fix(self, param, warning=True):
|
||||||
f = self._add_constrain(param, __fixed__, warning)
|
f = self._add_constrain(param, __fixed__, warning)
|
||||||
self._set_fixed(f)
|
self._set_fixed(f)
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue