checkgrad now global and callable from any parameter

This commit is contained in:
Max Zwiessele 2014-02-10 16:01:55 +00:00
parent 6a068775f5
commit 826d2f04ff
4 changed files with 27 additions and 14 deletions

View file

@ -380,7 +380,7 @@ class Model(Parameterized):
sgd.run()
self.optimization_runs.append(sgd)
def checkgrad(self, target_param=None, verbose=False, step=1e-6, tolerance=1e-3):
def _checkgrad(self, target_param=None, verbose=False, step=1e-6, tolerance=1e-3):
"""
Check the gradient of the ,odel by comparing to a numerical
estimate. If the verbose flag is passed, invividual
@ -434,7 +434,7 @@ class Model(Parameterized):
if target_param is None:
param_list = range(len(x))
else:
param_list = self.grep_param_names(target_param, transformed=True, search=True)
param_list = self._raveled_index_for(target_param)
if not np.any(param_list):
print "No free parameters to check"
return

View file

@ -3,8 +3,8 @@
import itertools
import numpy
from parameter_core import Constrainable, adjust_name_for_printing
from array_core import ObservableArray
from parameter_core import Constrainable, Gradcheckable, adjust_name_for_printing
from array_core import ObservableArray, ParamList
###### printing
__constraints_name__ = "Constraint"
@ -20,7 +20,7 @@ class Float(numpy.float64, Constrainable):
self._base = base
class Param(ObservableArray, Constrainable):
class Param(ObservableArray, Constrainable, Gradcheckable):
"""
Parameter object for GPy models.
@ -547,6 +547,11 @@ class ParamConcatenation(object):
def untie(self, *ties):
[param.untie(*ties) for param in self.params]
def checkgrad(self, verbose=0, step=1e-6, tolerance=1e-3):
return self.params[0]._highest_parent_._checkgrad(self, verbose, step, tolerance)
#checkgrad.__doc__ = Gradcheckable.checkgrad.__doc__
__lt__ = lambda self, val: self._vals() < val
__le__ = lambda self, val: self._vals() <= val
__eq__ = lambda self, val: self._vals() == val

View file

@ -52,7 +52,6 @@ class Parentable(object):
super(Parentable,self).__init__()
self._direct_parent_ = direct_parent
self._parent_index_ = parent_index
self._highest_parent_ = highest_parent
def has_parent(self):
return self._direct_parent_ is not None
@ -77,7 +76,19 @@ class Nameable(Parentable):
from_name = self.name
self._name = name
if self.has_parent():
self._direct_parent_._name_changed(self, from_name)
self._direct_parent_._name_changed(self, from_name)
class Gradcheckable(Parentable):
#===========================================================================
# Gradchecking
#===========================================================================
def checkgrad(self, verbose=0, step=1e-6, tolerance=1e-3):
if self.has_parent():
return self._highest_parent_._checkgrad(self, verbose=verbose, step=step, tolerance=tolerance)
return self._checkgrad(self[''], verbose=verbose, step=step, tolerance=tolerance)
def _checkgrad(self, param):
raise NotImplementedError, "Need log likelihood to check gradient against"
class Constrainable(Nameable):
def __init__(self, name):

View file

@ -8,7 +8,7 @@ import cPickle
import itertools
from re import compile, _pattern_type
from param import ParamConcatenation, Param
from parameter_core import Constrainable, Pickleable, Observable, adjust_name_for_printing
from parameter_core import Constrainable, Pickleable, Observable, adjust_name_for_printing, Gradcheckable
from index_operations import ParameterIndexOperations,\
index_empty
from array_core import ParamList
@ -24,7 +24,7 @@ FIXED = False
UNFIXED = True
#===============================================================================
class Parameterized(Constrainable, Pickleable, Observable):
class Parameterized(Constrainable, Pickleable, Observable, Gradcheckable):
"""
Parameterized class
@ -230,7 +230,7 @@ class Parameterized(Constrainable, Pickleable, Observable):
elif not (pname in not_unique):
self.__dict__[pname] = p
self._added_names_.add(pname)
#===========================================================================
# Pickling operations
#===========================================================================
@ -385,7 +385,7 @@ class Parameterized(Constrainable, Pickleable, Observable):
"""
return numpy.r_[:self.size]
#===========================================================================
# Handle ties:
# Fixing parameters:
#===========================================================================
def _set_fixed(self, param_or_index):
if not self._has_fixes(): self._fixes_ = numpy.ones(self.size, dtype=bool)
@ -410,9 +410,6 @@ class Parameterized(Constrainable, Pickleable, Observable):
if self._has_fixes():
return self._fixes_[self._raveled_index_for(param)]
return numpy.ones(self.size, dtype=bool)[self._raveled_index_for(param)]
#===========================================================================
# Fixing parameters:
#===========================================================================
def _fix(self, param, warning=True):
f = self._add_constrain(param, __fixed__, warning)
self._set_fixed(f)