gradient check

This commit is contained in:
Max Zwiessele 2014-03-11 16:23:29 +00:00
parent 85a471e0f6
commit 74999a89ad
2 changed files with 22 additions and 21 deletions

View file

@ -446,8 +446,8 @@ class ParamConcatenation(object):
def untie(self, *ties): def untie(self, *ties):
[param.untie(*ties) for param in self.params] [param.untie(*ties) for param in self.params]
def checkgrad(self, verbose=0, step=1e-6, tolerance=1e-3): def checkgrad(self, verbose=0, step=1e-6, tolerance=1e-3, _debug=False):
return self.params[0]._highest_parent_._checkgrad(self, verbose, step, tolerance) return self.params[0]._highest_parent_._checkgrad(self, verbose, step, tolerance, _debug=_debug)
#checkgrad.__doc__ = Gradcheckable.checkgrad.__doc__ #checkgrad.__doc__ = Gradcheckable.checkgrad.__doc__
__lt__ = lambda self, val: self._vals() < val __lt__ = lambda self, val: self._vals() < val

View file

@ -16,7 +16,7 @@ Observable Pattern for patameterization
from transformations import Transformation, Logexp, NegativeLogexp, Logistic, __fixed__, FIXED, UNFIXED from transformations import Transformation, Logexp, NegativeLogexp, Logistic, __fixed__, FIXED, UNFIXED
import numpy as np import numpy as np
__updated__ = '2013-12-16' __updated__ = '2014-03-11'
class HierarchyError(Exception): class HierarchyError(Exception):
""" """
@ -206,7 +206,7 @@ class Gradcheckable(Parentable):
def __init__(self, *a, **kw): def __init__(self, *a, **kw):
super(Gradcheckable, self).__init__(*a, **kw) super(Gradcheckable, self).__init__(*a, **kw)
def checkgrad(self, verbose=0, step=1e-6, tolerance=1e-3): def checkgrad(self, verbose=0, step=1e-6, tolerance=1e-3, _debug=False):
""" """
Check the gradient of this parameter with respect to the highest parent's Check the gradient of this parameter with respect to the highest parent's
objective function. objective function.
@ -220,14 +220,15 @@ class Gradcheckable(Parentable):
:param flaot tolerance: the tolerance for the gradient ratio or difference. :param flaot tolerance: the tolerance for the gradient ratio or difference.
""" """
if self.has_parent(): if self.has_parent():
return self._highest_parent_._checkgrad(self, verbose=verbose, step=step, tolerance=tolerance) return self._highest_parent_._checkgrad(self, verbose=verbose, step=step, tolerance=tolerance, _debug=_debug)
return self._checkgrad(self[''], verbose=verbose, step=step, tolerance=tolerance) return self._checkgrad(self[''], verbose=verbose, step=step, tolerance=tolerance, _debug=_debug)
def _checkgrad(self, param):
def _checkgrad(self, param, verbose=0, step=1e-6, tolerance=1e-3, _debug=False):
""" """
Perform the checkgrad on the model. Perform the checkgrad on the model.
TODO: this can be done more efficiently, when doing it inside here TODO: this can be done more efficiently, when doing it inside here
""" """
raise NotImplementedError, "Need log likelihood to check gradient against" raise HierarchyError, "This parameter is not in a model with a likelihood, and, therefore, cannot be gradient checked!"
class Nameable(Gradcheckable): class Nameable(Gradcheckable):