gradient check

This commit is contained in:
Max Zwiessele 2014-03-11 16:23:29 +00:00
parent 85a471e0f6
commit 74999a89ad
2 changed files with 22 additions and 21 deletions

View file

@ -446,8 +446,8 @@ class ParamConcatenation(object):
def untie(self, *ties): def untie(self, *ties):
[param.untie(*ties) for param in self.params] [param.untie(*ties) for param in self.params]
def checkgrad(self, verbose=0, step=1e-6, tolerance=1e-3): def checkgrad(self, verbose=0, step=1e-6, tolerance=1e-3, _debug=False):
return self.params[0]._highest_parent_._checkgrad(self, verbose, step, tolerance) return self.params[0]._highest_parent_._checkgrad(self, verbose, step, tolerance, _debug=_debug)
#checkgrad.__doc__ = Gradcheckable.checkgrad.__doc__ #checkgrad.__doc__ = Gradcheckable.checkgrad.__doc__
__lt__ = lambda self, val: self._vals() < val __lt__ = lambda self, val: self._vals() < val

View file

@ -16,7 +16,7 @@ Observable Pattern for patameterization
from transformations import Transformation, Logexp, NegativeLogexp, Logistic, __fixed__, FIXED, UNFIXED from transformations import Transformation, Logexp, NegativeLogexp, Logistic, __fixed__, FIXED, UNFIXED
import numpy as np import numpy as np
__updated__ = '2013-12-16' __updated__ = '2014-03-11'
class HierarchyError(Exception): class HierarchyError(Exception):
""" """
@ -34,7 +34,7 @@ def adjust_name_for_printing(name):
class Observable(object): class Observable(object):
""" """
Observable pattern for parameterization. Observable pattern for parameterization.
This Object allows for observers to register with self and a (bound!) function This Object allows for observers to register with self and a (bound!) function
as an observer. Every time the observable changes, it sends a notification with as an observer. Every time the observable changes, it sends a notification with
self as only argument to all its observers. self as only argument to all its observers.
@ -43,10 +43,10 @@ class Observable(object):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super(Observable, self).__init__(*args, **kwargs) super(Observable, self).__init__(*args, **kwargs)
self._observer_callables_ = [] self._observer_callables_ = []
def add_observer(self, observer, callble, priority=0): def add_observer(self, observer, callble, priority=0):
self._insert_sorted(priority, observer, callble) self._insert_sorted(priority, observer, callble)
def remove_observer(self, observer, callble=None): def remove_observer(self, observer, callble=None):
to_remove = [] to_remove = []
for p, obs, clble in self._observer_callables_: for p, obs, clble in self._observer_callables_:
@ -58,15 +58,15 @@ class Observable(object):
to_remove.append((p, obs, clble)) to_remove.append((p, obs, clble))
for r in to_remove: for r in to_remove:
self._observer_callables_.remove(r) self._observer_callables_.remove(r)
def notify_observers(self, which=None, min_priority=None): def notify_observers(self, which=None, min_priority=None):
""" """
Notifies all observers. Which is the element, which kicked off this Notifies all observers. Which is the element, which kicked off this
notification loop. notification loop.
NOTE: notifies only observers with priority p > min_priority! NOTE: notifies only observers with priority p > min_priority!
^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^
:param which: object, which started this notification loop :param which: object, which started this notification loop
:param min_priority: only notify observers with priority > min_priority :param min_priority: only notify observers with priority > min_priority
if min_priority is None, notify all observers in order if min_priority is None, notify all observers in order
@ -88,11 +88,11 @@ class Observable(object):
break break
ins += 1 ins += 1
self._observer_callables_.insert(ins, (p, o, c)) self._observer_callables_.insert(ins, (p, o, c))
class Pickleable(object): class Pickleable(object):
""" """
Make an object pickleable (See python doc 'pickling'). Make an object pickleable (See python doc 'pickling').
This class allows for pickling support by Memento pattern. This class allows for pickling support by Memento pattern.
_getstate returns a memento of the class, which gets pickled. _getstate returns a memento of the class, which gets pickled.
_setstate(<memento>) (re-)sets the state of the class to the memento _setstate(<memento>) (re-)sets the state of the class to the memento
@ -153,7 +153,7 @@ class Pickleable(object):
class Parentable(object): class Parentable(object):
""" """
Enable an Object to have a parent. Enable an Object to have a parent.
Additionally this adds the parent_index, which is the index for the parent Additionally this adds the parent_index, which is the index for the parent
to look for in its parameter list. to look for in its parameter list.
""" """
@ -161,7 +161,7 @@ class Parentable(object):
_parent_index_ = None _parent_index_ = None
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super(Parentable, self).__init__(*args, **kwargs) super(Parentable, self).__init__(*args, **kwargs)
def has_parent(self): def has_parent(self):
""" """
Return whether this parentable object currently has a parent. Return whether this parentable object currently has a parent.
@ -205,8 +205,8 @@ class Gradcheckable(Parentable):
""" """
def __init__(self, *a, **kw): def __init__(self, *a, **kw):
super(Gradcheckable, self).__init__(*a, **kw) super(Gradcheckable, self).__init__(*a, **kw)
def checkgrad(self, verbose=0, step=1e-6, tolerance=1e-3): def checkgrad(self, verbose=0, step=1e-6, tolerance=1e-3, _debug=False):
""" """
Check the gradient of this parameter with respect to the highest parent's Check the gradient of this parameter with respect to the highest parent's
objective function. objective function.
@ -214,20 +214,21 @@ class Gradcheckable(Parentable):
with a stepsize step. with a stepsize step.
The check passes if either the ratio or the difference between numerical and The check passes if either the ratio or the difference between numerical and
analytical gradient is smaller then tolerance. analytical gradient is smaller then tolerance.
:param bool verbose: whether each parameter shall be checked individually. :param bool verbose: whether each parameter shall be checked individually.
:param float step: the stepsize for the numerical three point gradient estimate. :param float step: the stepsize for the numerical three point gradient estimate.
:param flaot tolerance: the tolerance for the gradient ratio or difference. :param flaot tolerance: the tolerance for the gradient ratio or difference.
""" """
if self.has_parent(): if self.has_parent():
return self._highest_parent_._checkgrad(self, verbose=verbose, step=step, tolerance=tolerance) return self._highest_parent_._checkgrad(self, verbose=verbose, step=step, tolerance=tolerance, _debug=_debug)
return self._checkgrad(self[''], verbose=verbose, step=step, tolerance=tolerance) return self._checkgrad(self[''], verbose=verbose, step=step, tolerance=tolerance, _debug=_debug)
def _checkgrad(self, param):
def _checkgrad(self, param, verbose=0, step=1e-6, tolerance=1e-3, _debug=False):
""" """
Perform the checkgrad on the model. Perform the checkgrad on the model.
TODO: this can be done more efficiently, when doing it inside here TODO: this can be done more efficiently, when doing it inside here
""" """
raise NotImplementedError, "Need log likelihood to check gradient against" raise HierarchyError, "This parameter is not in a model with a likelihood, and, therefore, cannot be gradient checked!"
class Nameable(Gradcheckable): class Nameable(Gradcheckable):
@ -258,7 +259,7 @@ class Nameable(Gradcheckable):
def hierarchy_name(self, adjust_for_printing=True): def hierarchy_name(self, adjust_for_printing=True):
""" """
return the name for this object with the parents names attached by dots. return the name for this object with the parents names attached by dots.
:param bool adjust_for_printing: whether to call :func:`~adjust_for_printing()` :param bool adjust_for_printing: whether to call :func:`~adjust_for_printing()`
on the names, recursively on the names, recursively
""" """