mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-09 12:02:38 +02:00
gradient check
This commit is contained in:
parent
85a471e0f6
commit
74999a89ad
2 changed files with 22 additions and 21 deletions
|
|
@ -446,8 +446,8 @@ class ParamConcatenation(object):
|
||||||
def untie(self, *ties):
|
def untie(self, *ties):
|
||||||
[param.untie(*ties) for param in self.params]
|
[param.untie(*ties) for param in self.params]
|
||||||
|
|
||||||
def checkgrad(self, verbose=0, step=1e-6, tolerance=1e-3):
|
def checkgrad(self, verbose=0, step=1e-6, tolerance=1e-3, _debug=False):
|
||||||
return self.params[0]._highest_parent_._checkgrad(self, verbose, step, tolerance)
|
return self.params[0]._highest_parent_._checkgrad(self, verbose, step, tolerance, _debug=_debug)
|
||||||
#checkgrad.__doc__ = Gradcheckable.checkgrad.__doc__
|
#checkgrad.__doc__ = Gradcheckable.checkgrad.__doc__
|
||||||
|
|
||||||
__lt__ = lambda self, val: self._vals() < val
|
__lt__ = lambda self, val: self._vals() < val
|
||||||
|
|
|
||||||
|
|
@ -16,7 +16,7 @@ Observable Pattern for patameterization
|
||||||
from transformations import Transformation, Logexp, NegativeLogexp, Logistic, __fixed__, FIXED, UNFIXED
|
from transformations import Transformation, Logexp, NegativeLogexp, Logistic, __fixed__, FIXED, UNFIXED
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
__updated__ = '2013-12-16'
|
__updated__ = '2014-03-11'
|
||||||
|
|
||||||
class HierarchyError(Exception):
|
class HierarchyError(Exception):
|
||||||
"""
|
"""
|
||||||
|
|
@ -34,7 +34,7 @@ def adjust_name_for_printing(name):
|
||||||
class Observable(object):
|
class Observable(object):
|
||||||
"""
|
"""
|
||||||
Observable pattern for parameterization.
|
Observable pattern for parameterization.
|
||||||
|
|
||||||
This Object allows for observers to register with self and a (bound!) function
|
This Object allows for observers to register with self and a (bound!) function
|
||||||
as an observer. Every time the observable changes, it sends a notification with
|
as an observer. Every time the observable changes, it sends a notification with
|
||||||
self as only argument to all its observers.
|
self as only argument to all its observers.
|
||||||
|
|
@ -43,10 +43,10 @@ class Observable(object):
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super(Observable, self).__init__(*args, **kwargs)
|
super(Observable, self).__init__(*args, **kwargs)
|
||||||
self._observer_callables_ = []
|
self._observer_callables_ = []
|
||||||
|
|
||||||
def add_observer(self, observer, callble, priority=0):
|
def add_observer(self, observer, callble, priority=0):
|
||||||
self._insert_sorted(priority, observer, callble)
|
self._insert_sorted(priority, observer, callble)
|
||||||
|
|
||||||
def remove_observer(self, observer, callble=None):
|
def remove_observer(self, observer, callble=None):
|
||||||
to_remove = []
|
to_remove = []
|
||||||
for p, obs, clble in self._observer_callables_:
|
for p, obs, clble in self._observer_callables_:
|
||||||
|
|
@ -58,15 +58,15 @@ class Observable(object):
|
||||||
to_remove.append((p, obs, clble))
|
to_remove.append((p, obs, clble))
|
||||||
for r in to_remove:
|
for r in to_remove:
|
||||||
self._observer_callables_.remove(r)
|
self._observer_callables_.remove(r)
|
||||||
|
|
||||||
def notify_observers(self, which=None, min_priority=None):
|
def notify_observers(self, which=None, min_priority=None):
|
||||||
"""
|
"""
|
||||||
Notifies all observers. Which is the element, which kicked off this
|
Notifies all observers. Which is the element, which kicked off this
|
||||||
notification loop.
|
notification loop.
|
||||||
|
|
||||||
NOTE: notifies only observers with priority p > min_priority!
|
NOTE: notifies only observers with priority p > min_priority!
|
||||||
^^^^^^^^^^^^^^^^
|
^^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
:param which: object, which started this notification loop
|
:param which: object, which started this notification loop
|
||||||
:param min_priority: only notify observers with priority > min_priority
|
:param min_priority: only notify observers with priority > min_priority
|
||||||
if min_priority is None, notify all observers in order
|
if min_priority is None, notify all observers in order
|
||||||
|
|
@ -88,11 +88,11 @@ class Observable(object):
|
||||||
break
|
break
|
||||||
ins += 1
|
ins += 1
|
||||||
self._observer_callables_.insert(ins, (p, o, c))
|
self._observer_callables_.insert(ins, (p, o, c))
|
||||||
|
|
||||||
class Pickleable(object):
|
class Pickleable(object):
|
||||||
"""
|
"""
|
||||||
Make an object pickleable (See python doc 'pickling').
|
Make an object pickleable (See python doc 'pickling').
|
||||||
|
|
||||||
This class allows for pickling support by Memento pattern.
|
This class allows for pickling support by Memento pattern.
|
||||||
_getstate returns a memento of the class, which gets pickled.
|
_getstate returns a memento of the class, which gets pickled.
|
||||||
_setstate(<memento>) (re-)sets the state of the class to the memento
|
_setstate(<memento>) (re-)sets the state of the class to the memento
|
||||||
|
|
@ -153,7 +153,7 @@ class Pickleable(object):
|
||||||
class Parentable(object):
|
class Parentable(object):
|
||||||
"""
|
"""
|
||||||
Enable an Object to have a parent.
|
Enable an Object to have a parent.
|
||||||
|
|
||||||
Additionally this adds the parent_index, which is the index for the parent
|
Additionally this adds the parent_index, which is the index for the parent
|
||||||
to look for in its parameter list.
|
to look for in its parameter list.
|
||||||
"""
|
"""
|
||||||
|
|
@ -161,7 +161,7 @@ class Parentable(object):
|
||||||
_parent_index_ = None
|
_parent_index_ = None
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super(Parentable, self).__init__(*args, **kwargs)
|
super(Parentable, self).__init__(*args, **kwargs)
|
||||||
|
|
||||||
def has_parent(self):
|
def has_parent(self):
|
||||||
"""
|
"""
|
||||||
Return whether this parentable object currently has a parent.
|
Return whether this parentable object currently has a parent.
|
||||||
|
|
@ -205,8 +205,8 @@ class Gradcheckable(Parentable):
|
||||||
"""
|
"""
|
||||||
def __init__(self, *a, **kw):
|
def __init__(self, *a, **kw):
|
||||||
super(Gradcheckable, self).__init__(*a, **kw)
|
super(Gradcheckable, self).__init__(*a, **kw)
|
||||||
|
|
||||||
def checkgrad(self, verbose=0, step=1e-6, tolerance=1e-3):
|
def checkgrad(self, verbose=0, step=1e-6, tolerance=1e-3, _debug=False):
|
||||||
"""
|
"""
|
||||||
Check the gradient of this parameter with respect to the highest parent's
|
Check the gradient of this parameter with respect to the highest parent's
|
||||||
objective function.
|
objective function.
|
||||||
|
|
@ -214,20 +214,21 @@ class Gradcheckable(Parentable):
|
||||||
with a stepsize step.
|
with a stepsize step.
|
||||||
The check passes if either the ratio or the difference between numerical and
|
The check passes if either the ratio or the difference between numerical and
|
||||||
analytical gradient is smaller then tolerance.
|
analytical gradient is smaller then tolerance.
|
||||||
|
|
||||||
:param bool verbose: whether each parameter shall be checked individually.
|
:param bool verbose: whether each parameter shall be checked individually.
|
||||||
:param float step: the stepsize for the numerical three point gradient estimate.
|
:param float step: the stepsize for the numerical three point gradient estimate.
|
||||||
:param flaot tolerance: the tolerance for the gradient ratio or difference.
|
:param flaot tolerance: the tolerance for the gradient ratio or difference.
|
||||||
"""
|
"""
|
||||||
if self.has_parent():
|
if self.has_parent():
|
||||||
return self._highest_parent_._checkgrad(self, verbose=verbose, step=step, tolerance=tolerance)
|
return self._highest_parent_._checkgrad(self, verbose=verbose, step=step, tolerance=tolerance, _debug=_debug)
|
||||||
return self._checkgrad(self[''], verbose=verbose, step=step, tolerance=tolerance)
|
return self._checkgrad(self[''], verbose=verbose, step=step, tolerance=tolerance, _debug=_debug)
|
||||||
def _checkgrad(self, param):
|
|
||||||
|
def _checkgrad(self, param, verbose=0, step=1e-6, tolerance=1e-3, _debug=False):
|
||||||
"""
|
"""
|
||||||
Perform the checkgrad on the model.
|
Perform the checkgrad on the model.
|
||||||
TODO: this can be done more efficiently, when doing it inside here
|
TODO: this can be done more efficiently, when doing it inside here
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError, "Need log likelihood to check gradient against"
|
raise HierarchyError, "This parameter is not in a model with a likelihood, and, therefore, cannot be gradient checked!"
|
||||||
|
|
||||||
|
|
||||||
class Nameable(Gradcheckable):
|
class Nameable(Gradcheckable):
|
||||||
|
|
@ -258,7 +259,7 @@ class Nameable(Gradcheckable):
|
||||||
def hierarchy_name(self, adjust_for_printing=True):
|
def hierarchy_name(self, adjust_for_printing=True):
|
||||||
"""
|
"""
|
||||||
return the name for this object with the parents names attached by dots.
|
return the name for this object with the parents names attached by dots.
|
||||||
|
|
||||||
:param bool adjust_for_printing: whether to call :func:`~adjust_for_printing()`
|
:param bool adjust_for_printing: whether to call :func:`~adjust_for_printing()`
|
||||||
on the names, recursively
|
on the names, recursively
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue