mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-09 20:12:38 +02:00
[updates] starting to extract out standalone modules
This commit is contained in:
parent
1e6ed9f873
commit
07a5e290de
6 changed files with 149 additions and 124 deletions
|
|
@ -96,7 +96,7 @@ class GP(Model):
|
||||||
def set_XY(self, X=None, Y=None):
|
def set_XY(self, X=None, Y=None):
|
||||||
"""
|
"""
|
||||||
Set the input / output of the model
|
Set the input / output of the model
|
||||||
|
|
||||||
:param X: input observations
|
:param X: input observations
|
||||||
:param Y: output observations
|
:param Y: output observations
|
||||||
"""
|
"""
|
||||||
|
|
@ -384,16 +384,16 @@ class GP(Model):
|
||||||
print "KeyboardInterrupt caught, calling on_optimization_end() to round things up"
|
print "KeyboardInterrupt caught, calling on_optimization_end() to round things up"
|
||||||
self.inference_method.on_optimization_end()
|
self.inference_method.on_optimization_end()
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def infer_newX(self, Y_new, optimize=True, ):
|
def infer_newX(self, Y_new, optimize=True, ):
|
||||||
"""
|
"""
|
||||||
Infer the distribution of X for the new observed data *Y_new*.
|
Infer the distribution of X for the new observed data *Y_new*.
|
||||||
|
|
||||||
:param Y_new: the new observed data for inference
|
:param Y_new: the new observed data for inference
|
||||||
:type Y_new: numpy.ndarray
|
:type Y_new: numpy.ndarray
|
||||||
:param optimize: whether to optimize the location of new X (True by default)
|
:param optimize: whether to optimize the location of new X (True by default)
|
||||||
:type optimize: boolean
|
:type optimize: boolean
|
||||||
:return: a tuple containing the posterior estimation of X and the model that optimize X
|
:return: a tuple containing the posterior estimation of X and the model that optimize X
|
||||||
:rtype: (GPy.core.parameterization.variational.VariationalPosterior or numpy.ndarray, GPy.core.Model)
|
:rtype: (GPy.core.parameterization.variational.VariationalPosterior or numpy.ndarray, GPy.core.Model)
|
||||||
"""
|
"""
|
||||||
from ..inference.latent_function_inference.inferenceX import infer_newX
|
from ..inference.latent_function_inference.inferenceX import infer_newX
|
||||||
|
|
|
||||||
69
GPy/core/parameterization/observable.py
Normal file
69
GPy/core/parameterization/observable.py
Normal file
|
|
@ -0,0 +1,69 @@
|
||||||
|
'''
|
||||||
|
Created on 30 Oct 2014
|
||||||
|
|
||||||
|
@author: maxz
|
||||||
|
'''
|
||||||
|
|
||||||
|
|
||||||
|
class Observable(object):
|
||||||
|
"""
|
||||||
|
Observable pattern for parameterization.
|
||||||
|
|
||||||
|
This Object allows for observers to register with self and a (bound!) function
|
||||||
|
as an observer. Every time the observable changes, it sends a notification with
|
||||||
|
self as only argument to all its observers.
|
||||||
|
"""
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super(Observable, self).__init__()
|
||||||
|
from lists_and_dicts import ObserverList
|
||||||
|
self.observers = ObserverList()
|
||||||
|
|
||||||
|
def add_observer(self, observer, callble, priority=0):
|
||||||
|
"""
|
||||||
|
Add an observer `observer` with the callback `callble`
|
||||||
|
and priority `priority` to this observers list.
|
||||||
|
"""
|
||||||
|
self.observers.add(priority, observer, callble)
|
||||||
|
|
||||||
|
def remove_observer(self, observer, callble=None):
|
||||||
|
"""
|
||||||
|
Either (if callble is None) remove all callables,
|
||||||
|
which were added alongside observer,
|
||||||
|
or remove callable `callble` which was added alongside
|
||||||
|
the observer `observer`.
|
||||||
|
"""
|
||||||
|
to_remove = []
|
||||||
|
for poc in self.observers:
|
||||||
|
_, obs, clble = poc
|
||||||
|
if callble is not None:
|
||||||
|
if (obs is observer) and (callble == clble):
|
||||||
|
to_remove.append(poc)
|
||||||
|
else:
|
||||||
|
if obs is observer:
|
||||||
|
to_remove.append(poc)
|
||||||
|
for r in to_remove:
|
||||||
|
self.observers.remove(*r)
|
||||||
|
|
||||||
|
def notify_observers(self, which=None, min_priority=None):
|
||||||
|
"""
|
||||||
|
Notifies all observers. Which is the element, which kicked off this
|
||||||
|
notification loop. The first argument will be self, the second `which`.
|
||||||
|
|
||||||
|
NOTE: notifies only observers with priority p > min_priority!
|
||||||
|
^^^^^^^^^^^^^^^^
|
||||||
|
:param min_priority: only notify observers with priority > min_priority
|
||||||
|
if min_priority is None, notify all observers in order
|
||||||
|
"""
|
||||||
|
if which is None:
|
||||||
|
which = self
|
||||||
|
if min_priority is None:
|
||||||
|
[callble(self, which=which) for _, _, callble in self.observers]
|
||||||
|
else:
|
||||||
|
for p, _, callble in self.observers:
|
||||||
|
if p <= min_priority:
|
||||||
|
break
|
||||||
|
callble(self, which=which)
|
||||||
|
|
||||||
|
def change_priority(self, observer, callble, priority):
|
||||||
|
self.remove_observer(observer, callble)
|
||||||
|
self.add_observer(observer, callble, priority)
|
||||||
|
|
@ -1,10 +1,11 @@
|
||||||
# Copyright (c) 2012, GPy authors (see AUTHORS.txt).
|
# Copyright (c) 2012, GPy authors (see AUTHORS.txt).
|
||||||
# Licensed under the BSD 3-clause license (see LICENSE.txt)
|
# Licensed under the BSD 3-clause license (see LICENSE.txt)
|
||||||
|
|
||||||
__updated__ = '2014-10-29'
|
__updated__ = '2014-11-11'
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from parameter_core import Observable, Pickleable
|
from parameter_core import Pickleable
|
||||||
|
from observable import Observable
|
||||||
|
|
||||||
class ObsAr(np.ndarray, Pickleable, Observable):
|
class ObsAr(np.ndarray, Pickleable, Observable):
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -17,6 +17,7 @@ from transformations import Transformation,Logexp, NegativeLogexp, Logistic, __f
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import re
|
import re
|
||||||
import logging
|
import logging
|
||||||
|
from updateable import Updateable
|
||||||
|
|
||||||
class HierarchyError(Exception):
|
class HierarchyError(Exception):
|
||||||
"""
|
"""
|
||||||
|
|
@ -40,115 +41,6 @@ def adjust_name_for_printing(name):
|
||||||
return ''
|
return ''
|
||||||
|
|
||||||
|
|
||||||
class Observable(object):
|
|
||||||
"""
|
|
||||||
Observable pattern for parameterization.
|
|
||||||
|
|
||||||
This Object allows for observers to register with self and a (bound!) function
|
|
||||||
as an observer. Every time the observable changes, it sends a notification with
|
|
||||||
self as only argument to all its observers.
|
|
||||||
"""
|
|
||||||
_updates = True
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
super(Observable, self).__init__()
|
|
||||||
from lists_and_dicts import ObserverList
|
|
||||||
self.observers = ObserverList()
|
|
||||||
|
|
||||||
@property
|
|
||||||
def updates(self):
|
|
||||||
raise DeprecationWarning("updates is now a function, see update(True|False|None)")
|
|
||||||
|
|
||||||
@updates.setter
|
|
||||||
def updates(self, ups):
|
|
||||||
raise DeprecationWarning("updates is now a function, see update(True|False|None)")
|
|
||||||
|
|
||||||
def update_model(self, updates=None):
|
|
||||||
"""
|
|
||||||
Get or set, whether automatic updates are performed. When updates are
|
|
||||||
off, the model might be in a non-working state. To make the model work
|
|
||||||
turn updates on again.
|
|
||||||
|
|
||||||
:param bool|None updates:
|
|
||||||
|
|
||||||
bool: whether to do updates
|
|
||||||
None: get the current update state
|
|
||||||
"""
|
|
||||||
if updates is None:
|
|
||||||
p = getattr(self, '_highest_parent_', None)
|
|
||||||
if p is not None:
|
|
||||||
self._updates = p._updates
|
|
||||||
return self._updates
|
|
||||||
assert isinstance(updates, bool), "updates are either on (True) or off (False)"
|
|
||||||
p = getattr(self, '_highest_parent_', None)
|
|
||||||
if p is not None:
|
|
||||||
p._updates = updates
|
|
||||||
else:
|
|
||||||
self._updates = updates
|
|
||||||
self.trigger_update()
|
|
||||||
|
|
||||||
def toggle_update(self):
|
|
||||||
self.update_model(not self.update())
|
|
||||||
|
|
||||||
def trigger_update(self, trigger_parent=True):
|
|
||||||
"""
|
|
||||||
Update the model from the current state.
|
|
||||||
Make sure that updates are on, otherwise this
|
|
||||||
method will do nothing
|
|
||||||
|
|
||||||
:param bool trigger_parent: Whether to trigger the parent, after self has updated
|
|
||||||
"""
|
|
||||||
if not self.update_model() or (hasattr(self, "_in_init_") and self._in_init_):
|
|
||||||
#print "Warning: updates are off, updating the model will do nothing"
|
|
||||||
return
|
|
||||||
self._trigger_params_changed(trigger_parent)
|
|
||||||
|
|
||||||
def add_observer(self, observer, callble, priority=0):
|
|
||||||
"""
|
|
||||||
Add an observer `observer` with the callback `callble`
|
|
||||||
and priority `priority` to this observers list.
|
|
||||||
"""
|
|
||||||
self.observers.add(priority, observer, callble)
|
|
||||||
|
|
||||||
def remove_observer(self, observer, callble=None):
|
|
||||||
"""
|
|
||||||
Either (if callble is None) remove all callables,
|
|
||||||
which were added alongside observer,
|
|
||||||
or remove callable `callble` which was added alongside
|
|
||||||
the observer `observer`.
|
|
||||||
"""
|
|
||||||
to_remove = []
|
|
||||||
for poc in self.observers:
|
|
||||||
_, obs, clble = poc
|
|
||||||
if callble is not None:
|
|
||||||
if (obs is observer) and (callble == clble):
|
|
||||||
to_remove.append(poc)
|
|
||||||
else:
|
|
||||||
if obs is observer:
|
|
||||||
to_remove.append(poc)
|
|
||||||
for r in to_remove:
|
|
||||||
self.observers.remove(*r)
|
|
||||||
|
|
||||||
def notify_observers(self, which=None, min_priority=None):
|
|
||||||
"""
|
|
||||||
Notifies all observers. Which is the element, which kicked off this
|
|
||||||
notification loop. The first argument will be self, the second `which`.
|
|
||||||
|
|
||||||
NOTE: notifies only observers with priority p > min_priority!
|
|
||||||
^^^^^^^^^^^^^^^^
|
|
||||||
:param min_priority: only notify observers with priority > min_priority
|
|
||||||
if min_priority is None, notify all observers in order
|
|
||||||
"""
|
|
||||||
if not self.update_model():
|
|
||||||
return
|
|
||||||
if which is None:
|
|
||||||
which = self
|
|
||||||
if min_priority is None:
|
|
||||||
[callble(self, which=which) for _, _, callble in self.observers]
|
|
||||||
else:
|
|
||||||
for p, _, callble in self.observers:
|
|
||||||
if p <= min_priority:
|
|
||||||
break
|
|
||||||
callble(self, which=which)
|
|
||||||
|
|
||||||
class Parentable(object):
|
class Parentable(object):
|
||||||
"""
|
"""
|
||||||
|
|
@ -307,11 +199,11 @@ class Gradcheckable(Pickleable, Parentable):
|
||||||
:param float step: the stepsize for the numerical three point gradient estimate.
|
:param float step: the stepsize for the numerical three point gradient estimate.
|
||||||
:param float tolerance: the tolerance for the gradient ratio or difference.
|
:param float tolerance: the tolerance for the gradient ratio or difference.
|
||||||
:param float df_tolerance: the tolerance for df_tolerance
|
:param float df_tolerance: the tolerance for df_tolerance
|
||||||
|
|
||||||
Note:-
|
Note:-
|
||||||
The *dF_ratio* indicates the limit of accuracy of numerical gradients.
|
The *dF_ratio* indicates the limit of accuracy of numerical gradients.
|
||||||
If it is too small, e.g., smaller than 1e-12, the numerical gradients
|
If it is too small, e.g., smaller than 1e-12, the numerical gradients
|
||||||
are usually not accurate enough for the tests (shown with blue).
|
are usually not accurate enough for the tests (shown with blue).
|
||||||
"""
|
"""
|
||||||
if self.has_parent():
|
if self.has_parent():
|
||||||
return self._highest_parent_._checkgrad(self, verbose=verbose, step=step, tolerance=tolerance, df_tolerance=df_tolerance)
|
return self._highest_parent_._checkgrad(self, verbose=verbose, step=step, tolerance=tolerance, df_tolerance=df_tolerance)
|
||||||
|
|
@ -363,7 +255,7 @@ class Nameable(Gradcheckable):
|
||||||
return adjust(self.name)
|
return adjust(self.name)
|
||||||
|
|
||||||
|
|
||||||
class Indexable(Nameable, Observable):
|
class Indexable(Nameable, Updateable):
|
||||||
"""
|
"""
|
||||||
Make an object constrainable with Priors and Transformations.
|
Make an object constrainable with Priors and Transformations.
|
||||||
TODO: Mappings!!
|
TODO: Mappings!!
|
||||||
|
|
@ -726,7 +618,7 @@ class OptimizationHandlable(Indexable):
|
||||||
fixes[self.constraints[__fixed__]] = FIXED
|
fixes[self.constraints[__fixed__]] = FIXED
|
||||||
return self._optimizer_copy_[np.logical_and(fixes, self._highest_parent_.tie.getTieFlag(self))]
|
return self._optimizer_copy_[np.logical_and(fixes, self._highest_parent_.tie.getTieFlag(self))]
|
||||||
elif self._has_fixes():
|
elif self._has_fixes():
|
||||||
return self._optimizer_copy_[self._fixes_]
|
return self._optimizer_copy_[self._fixes_]
|
||||||
|
|
||||||
self._optimizer_copy_transformed = True
|
self._optimizer_copy_transformed = True
|
||||||
|
|
||||||
|
|
@ -757,7 +649,7 @@ class OptimizationHandlable(Indexable):
|
||||||
#self._highest_parent_.tie.propagate_val()
|
#self._highest_parent_.tie.propagate_val()
|
||||||
|
|
||||||
self._optimizer_copy_transformed = False
|
self._optimizer_copy_transformed = False
|
||||||
self._trigger_params_changed()
|
self.trigger_update()
|
||||||
|
|
||||||
def _get_params_transformed(self):
|
def _get_params_transformed(self):
|
||||||
raise DeprecationWarning, "_get|set_params{_optimizer_copy_transformed} is deprecated, use self.optimizer array insetad!"
|
raise DeprecationWarning, "_get|set_params{_optimizer_copy_transformed} is deprecated, use self.optimizer array insetad!"
|
||||||
|
|
|
||||||
63
GPy/core/parameterization/updateable.py
Normal file
63
GPy/core/parameterization/updateable.py
Normal file
|
|
@ -0,0 +1,63 @@
|
||||||
|
'''
|
||||||
|
Created on 11 Nov 2014
|
||||||
|
|
||||||
|
@author: maxz
|
||||||
|
'''
|
||||||
|
from observable import Observable
|
||||||
|
|
||||||
|
|
||||||
|
class Updateable(Observable):
|
||||||
|
"""
|
||||||
|
A model can be updated or not.
|
||||||
|
Make sure updates can be switched on and off.
|
||||||
|
"""
|
||||||
|
_updates = True
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super(Updateable, self).__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def updates(self):
|
||||||
|
raise DeprecationWarning("updates is now a function, see update(True|False|None)")
|
||||||
|
|
||||||
|
@updates.setter
|
||||||
|
def updates(self, ups):
|
||||||
|
raise DeprecationWarning("updates is now a function, see update(True|False|None)")
|
||||||
|
|
||||||
|
def update_model(self, updates=None):
|
||||||
|
"""
|
||||||
|
Get or set, whether automatic updates are performed. When updates are
|
||||||
|
off, the model might be in a non-working state. To make the model work
|
||||||
|
turn updates on again.
|
||||||
|
|
||||||
|
:param bool|None updates:
|
||||||
|
|
||||||
|
bool: whether to do updates
|
||||||
|
None: get the current update state
|
||||||
|
"""
|
||||||
|
if updates is None:
|
||||||
|
p = getattr(self, '_highest_parent_', None)
|
||||||
|
if p is not None:
|
||||||
|
self._updates = p._updates
|
||||||
|
return self._updates
|
||||||
|
assert isinstance(updates, bool), "updates are either on (True) or off (False)"
|
||||||
|
p = getattr(self, '_highest_parent_', None)
|
||||||
|
if p is not None:
|
||||||
|
p._updates = updates
|
||||||
|
self._updates = updates
|
||||||
|
self.trigger_update()
|
||||||
|
|
||||||
|
def toggle_update(self):
|
||||||
|
self.update_model(not self.update_model())
|
||||||
|
|
||||||
|
def trigger_update(self, trigger_parent=True):
|
||||||
|
"""
|
||||||
|
Update the model from the current state.
|
||||||
|
Make sure that updates are on, otherwise this
|
||||||
|
method will do nothing
|
||||||
|
|
||||||
|
:param bool trigger_parent: Whether to trigger the parent, after self has updated
|
||||||
|
"""
|
||||||
|
if not self.update_model() or (hasattr(self, "_in_init_") and self._in_init_):
|
||||||
|
#print "Warning: updates are off, updating the model will do nothing"
|
||||||
|
return
|
||||||
|
self._trigger_params_changed(trigger_parent)
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
from ..core.parameterization.parameter_core import Observable
|
from ..core.parameterization.observable import Observable
|
||||||
import collections, weakref, logging
|
import collections, weakref
|
||||||
|
|
||||||
class Cacher(object):
|
class Cacher(object):
|
||||||
def __init__(self, operation, limit=5, ignore_args=(), force_kwargs=()):
|
def __init__(self, operation, limit=5, ignore_args=(), force_kwargs=()):
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue