mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-09 20:12:38 +02:00
gradient operations and cachong
This commit is contained in:
parent
0c92fca31a
commit
b19f9b9f33
8 changed files with 151 additions and 138 deletions
|
|
@ -30,12 +30,16 @@ class ObservableArray(np.ndarray, Observable):
|
||||||
def __new__(cls, input_array):
|
def __new__(cls, input_array):
|
||||||
obj = np.atleast_1d(input_array).view(cls)
|
obj = np.atleast_1d(input_array).view(cls)
|
||||||
cls.__name__ = "ObservableArray\n "
|
cls.__name__ = "ObservableArray\n "
|
||||||
obj._observer_callables_ = {}
|
|
||||||
return obj
|
return obj
|
||||||
|
|
||||||
|
def __init__(self, *a, **kw):
|
||||||
|
super(ObservableArray, self).__init__(*a, **kw)
|
||||||
|
|
||||||
def __array_finalize__(self, obj):
|
def __array_finalize__(self, obj):
|
||||||
# see InfoArray.__array_finalize__ for comments
|
# see InfoArray.__array_finalize__ for comments
|
||||||
if obj is None: return
|
if obj is None: return
|
||||||
self._observer_callables_ = getattr(obj, '_observer_callables_', None)
|
self._observer_callables_ = getattr(obj, '_observer_callables_', None)
|
||||||
|
|
||||||
def __array_wrap__(self, out_arr, context=None):
|
def __array_wrap__(self, out_arr, context=None):
|
||||||
return out_arr.view(np.ndarray)
|
return out_arr.view(np.ndarray)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -15,7 +15,7 @@ __precision__ = numpy.get_printoptions()['precision'] # numpy printing precision
|
||||||
__print_threshold__ = 5
|
__print_threshold__ = 5
|
||||||
######
|
######
|
||||||
|
|
||||||
class Param(ObservableArray, Constrainable, Gradcheckable, Indexable, Parentable):
|
class Param(Constrainable, ObservableArray, Gradcheckable, Indexable):
|
||||||
"""
|
"""
|
||||||
Parameter object for GPy models.
|
Parameter object for GPy models.
|
||||||
|
|
||||||
|
|
@ -57,8 +57,8 @@ class Param(ObservableArray, Constrainable, Gradcheckable, Indexable, Parentable
|
||||||
obj._gradient_ = None
|
obj._gradient_ = None
|
||||||
return obj
|
return obj
|
||||||
|
|
||||||
def __init__(self, name, input_array, default_constraint=None):
|
def __init__(self, name, input_array, default_constraint=None, *a, **kw):
|
||||||
super(Param, self).__init__(name=name, default_constraint=default_constraint)
|
super(Param, self).__init__(name=name, default_constraint=default_constraint, *a, **kw)
|
||||||
|
|
||||||
def __array_finalize__(self, obj):
|
def __array_finalize__(self, obj):
|
||||||
# see InfoArray.__array_finalize__ for comments
|
# see InfoArray.__array_finalize__ for comments
|
||||||
|
|
@ -144,7 +144,10 @@ class Param(ObservableArray, Constrainable, Gradcheckable, Indexable, Parentable
|
||||||
return self.flat
|
return self.flat
|
||||||
|
|
||||||
def _collect_gradient(self, target):
|
def _collect_gradient(self, target):
|
||||||
target[:] = self.gradient.flat
|
target += self.gradient.flat
|
||||||
|
|
||||||
|
def _set_gradient(self, g):
|
||||||
|
self.gradient = g
|
||||||
|
|
||||||
#===========================================================================
|
#===========================================================================
|
||||||
# Array operations -> done
|
# Array operations -> done
|
||||||
|
|
|
||||||
|
|
@ -7,18 +7,24 @@ __updated__ = '2013-12-16'
|
||||||
|
|
||||||
def adjust_name_for_printing(name):
|
def adjust_name_for_printing(name):
|
||||||
if name is not None:
|
if name is not None:
|
||||||
return name.replace(" ", "_").replace(".", "_").replace("-","").replace("+","").replace("!","").replace("*","").replace("/","")
|
return name.replace(" ", "_").replace(".", "_").replace("-", "").replace("+", "").replace("!", "").replace("*", "").replace("/", "")
|
||||||
return ''
|
return ''
|
||||||
|
|
||||||
class Observable(object):
|
class Observable(object):
|
||||||
_observer_callables_ = {}
|
def __init__(self, *args, **kwargs):
|
||||||
def add_observer(self, callble):
|
from collections import defaultdict
|
||||||
self._observer_callables_.append(callble)
|
self._observer_callables_ = defaultdict(list)
|
||||||
#callble(self)
|
|
||||||
def remove_observer(self, callble):
|
def add_observer(self, observer, callble):
|
||||||
del self._observer_callables_[callble]
|
self._observer_callables_[observer].append(callble)
|
||||||
|
# callble(self)
|
||||||
|
|
||||||
|
def remove_observer(self, observer, callble):
|
||||||
|
del self._observer_callables_[observer][callble]
|
||||||
|
|
||||||
def _notify_observers(self):
|
def _notify_observers(self):
|
||||||
[callble(self) for callble in self._observer_callables_]
|
[[callble(self) for callble in callables]
|
||||||
|
for callables in self._observer_callables_.itervalues()]
|
||||||
|
|
||||||
class Pickleable(object):
|
class Pickleable(object):
|
||||||
def _getstate(self):
|
def _getstate(self):
|
||||||
|
|
@ -47,10 +53,8 @@ class Pickleable(object):
|
||||||
#===============================================================================
|
#===============================================================================
|
||||||
|
|
||||||
class Parentable(object):
|
class Parentable(object):
|
||||||
def __init__(self, direct_parent=None, parent_index=None):
|
_direct_parent_ = None
|
||||||
super(Parentable,self).__init__()
|
_parent_index_ = None
|
||||||
self._direct_parent_ = direct_parent
|
|
||||||
self._parent_index_ = parent_index
|
|
||||||
|
|
||||||
def has_parent(self):
|
def has_parent(self):
|
||||||
return self._direct_parent_ is not None
|
return self._direct_parent_ is not None
|
||||||
|
|
@ -73,9 +77,8 @@ class Parentable(object):
|
||||||
self._direct_parent_._notify_parameters_changed()
|
self._direct_parent_._notify_parameters_changed()
|
||||||
|
|
||||||
class Nameable(Parentable):
|
class Nameable(Parentable):
|
||||||
_name = None
|
def __init__(self, name, *a, **kw):
|
||||||
def __init__(self, name, direct_parent=None, parent_index=None):
|
super(Nameable, self).__init__(*a, **kw)
|
||||||
super(Nameable,self).__init__(direct_parent, parent_index)
|
|
||||||
self._name = name or self.__class__.__name__
|
self._name = name or self.__class__.__name__
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|
@ -95,108 +98,10 @@ class Nameable(Parentable):
|
||||||
return self._direct_parent_.hirarchy_name() + "." + adjust(self.name)
|
return self._direct_parent_.hirarchy_name() + "." + adjust(self.name)
|
||||||
return adjust(self.name)
|
return adjust(self.name)
|
||||||
|
|
||||||
class Parameterizable(Parentable):
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
super(Parameterizable, self).__init__(*args, **kwargs)
|
|
||||||
from GPy.core.parameterization.array_core import ParamList
|
|
||||||
_parameters_ = ParamList()
|
|
||||||
self._added_names_ = set()
|
|
||||||
|
|
||||||
def parameter_names(self, add_self=False, adjust_for_printing=False, recursive=True):
|
|
||||||
if adjust_for_printing: adjust = lambda x: adjust_name_for_printing(x)
|
|
||||||
else: adjust = lambda x: x
|
|
||||||
if recursive: names = [xi for x in self._parameters_ for xi in x.parameter_names(add_self=True, adjust_for_printing=adjust_for_printing)]
|
|
||||||
else: names = [adjust(x.name) for x in self._parameters_]
|
|
||||||
if add_self: names = map(lambda x: adjust(self.name) + "." + x, names)
|
|
||||||
return names
|
|
||||||
|
|
||||||
def _add_parameter_name(self, param):
|
|
||||||
pname = adjust_name_for_printing(param.name)
|
|
||||||
# and makes sure to not delete programmatically added parameters
|
|
||||||
if pname in self.__dict__:
|
|
||||||
if not (param is self.__dict__[pname]):
|
|
||||||
if pname in self._added_names_:
|
|
||||||
del self.__dict__[pname]
|
|
||||||
self._add_parameter_name(param)
|
|
||||||
else:
|
|
||||||
self.__dict__[pname] = param
|
|
||||||
self._added_names_.add(pname)
|
|
||||||
|
|
||||||
def _remove_parameter_name(self, param=None, pname=None):
|
|
||||||
assert param is None or pname is None, "can only delete either param by name, or the name of a param"
|
|
||||||
pname = adjust_name_for_printing(pname) or adjust_name_for_printing(param.name)
|
|
||||||
if pname in self._added_names_:
|
|
||||||
del self.__dict__[pname]
|
|
||||||
self._added_names_.remove(pname)
|
|
||||||
self._connect_parameters()
|
|
||||||
|
|
||||||
def _name_changed(self, param, old_name):
|
|
||||||
self._remove_parameter_name(None, old_name)
|
|
||||||
self._add_parameter_name(param)
|
|
||||||
|
|
||||||
def _collect_gradient(self, target):
|
|
||||||
import itertools
|
|
||||||
[p._collect_gradient(target[s]) for p, s in itertools.izip(self._parameters_, self._param_slices_)]
|
|
||||||
|
|
||||||
def _get_params(self):
|
|
||||||
import numpy as np
|
|
||||||
# don't overwrite this anymore!
|
|
||||||
if not self.size:
|
|
||||||
return np.empty(shape=(0,), dtype=np.float64)
|
|
||||||
return np.hstack([x._get_params() for x in self._parameters_ if x.size > 0])
|
|
||||||
|
|
||||||
def _set_params(self, params, update=True):
|
|
||||||
# don't overwrite this anymore!
|
|
||||||
import itertools
|
|
||||||
[p._set_params(params[s], update=update) for p, s in itertools.izip(self._parameters_, self._param_slices_)]
|
|
||||||
self.parameters_changed()
|
|
||||||
|
|
||||||
def copy(self):
|
|
||||||
"""Returns a (deep) copy of the current model"""
|
|
||||||
import copy
|
|
||||||
from .index_operations import ParameterIndexOperations, ParameterIndexOperationsView
|
|
||||||
from .array_core import ParamList
|
|
||||||
dc = dict()
|
|
||||||
for k, v in self.__dict__.iteritems():
|
|
||||||
if k not in ['_direct_parent_', '_parameters_', '_parent_index_'] + self.parameter_names():
|
|
||||||
if isinstance(v, (Constrainable, ParameterIndexOperations, ParameterIndexOperationsView)):
|
|
||||||
dc[k] = v.copy()
|
|
||||||
else:
|
|
||||||
dc[k] = copy.deepcopy(v)
|
|
||||||
if k == '_parameters_':
|
|
||||||
params = [p.copy() for p in v]
|
|
||||||
#dc = copy.deepcopy(self.__dict__)
|
|
||||||
dc['_direct_parent_'] = None
|
|
||||||
dc['_parent_index_'] = None
|
|
||||||
dc['_parameters_'] = ParamList()
|
|
||||||
s = self.__new__(self.__class__)
|
|
||||||
s.__dict__ = dc
|
|
||||||
#import ipdb;ipdb.set_trace()
|
|
||||||
for p in params:
|
|
||||||
s.add_parameter(p)
|
|
||||||
#dc._notify_parent_change()
|
|
||||||
return s
|
|
||||||
#return copy.deepcopy(self)
|
|
||||||
|
|
||||||
def _notify_parameters_changed(self):
|
|
||||||
self.parameters_changed()
|
|
||||||
if self.has_parent():
|
|
||||||
self._direct_parent_._notify_parameters_changed()
|
|
||||||
|
|
||||||
def parameters_changed(self):
|
|
||||||
"""
|
|
||||||
This method gets called when parameters have changed.
|
|
||||||
Another way of listening to param changes is to
|
|
||||||
add self as a listener to the param, such that
|
|
||||||
updates get passed through. See :py:function:``GPy.core.param.Observable.add_observer``
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class Gradcheckable(Parentable):
|
class Gradcheckable(Parentable):
|
||||||
#===========================================================================
|
def __init__(self, *a, **kw):
|
||||||
# Gradchecking
|
super(Gradcheckable, self).__init__(*a, **kw)
|
||||||
#===========================================================================
|
|
||||||
def checkgrad(self, verbose=0, step=1e-6, tolerance=1e-3):
|
def checkgrad(self, verbose=0, step=1e-6, tolerance=1e-3):
|
||||||
if self.has_parent():
|
if self.has_parent():
|
||||||
return self._highest_parent_._checkgrad(self, verbose=verbose, step=step, tolerance=tolerance)
|
return self._highest_parent_._checkgrad(self, verbose=verbose, step=step, tolerance=tolerance)
|
||||||
|
|
@ -204,6 +109,7 @@ class Gradcheckable(Parentable):
|
||||||
def _checkgrad(self, param):
|
def _checkgrad(self, param):
|
||||||
raise NotImplementedError, "Need log likelihood to check gradient against"
|
raise NotImplementedError, "Need log likelihood to check gradient against"
|
||||||
|
|
||||||
|
|
||||||
class Indexable(object):
|
class Indexable(object):
|
||||||
def _raveled_index(self):
|
def _raveled_index(self):
|
||||||
raise NotImplementedError, "Need to be able to get the raveled Index"
|
raise NotImplementedError, "Need to be able to get the raveled Index"
|
||||||
|
|
@ -222,9 +128,10 @@ class Indexable(object):
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError, "shouldnt happen, raveld index transformation required from non parameterization object?"
|
raise NotImplementedError, "shouldnt happen, raveld index transformation required from non parameterization object?"
|
||||||
|
|
||||||
class Constrainable(Nameable, Indexable, Parentable):
|
|
||||||
def __init__(self, name, default_constraint=None):
|
class Constrainable(Nameable, Indexable):
|
||||||
super(Constrainable,self).__init__(name)
|
def __init__(self, name, default_constraint=None, *a, **kw):
|
||||||
|
super(Constrainable, self).__init__(name=name, *a, **kw)
|
||||||
self._default_constraint_ = default_constraint
|
self._default_constraint_ = default_constraint
|
||||||
from index_operations import ParameterIndexOperations
|
from index_operations import ParameterIndexOperations
|
||||||
self.constraints = ParameterIndexOperations()
|
self.constraints = ParameterIndexOperations()
|
||||||
|
|
@ -275,7 +182,7 @@ class Constrainable(Nameable, Indexable, Parentable):
|
||||||
def _set_unfixed(self, index):
|
def _set_unfixed(self, index):
|
||||||
import numpy as np
|
import numpy as np
|
||||||
if not self._has_fixes(): self._fixes_ = np.ones(self.size, dtype=bool)
|
if not self._has_fixes(): self._fixes_ = np.ones(self.size, dtype=bool)
|
||||||
#rav_i = self._raveled_index_for(param)[index]
|
# rav_i = self._raveled_index_for(param)[index]
|
||||||
self._fixes_[index] = UNFIXED
|
self._fixes_[index] = UNFIXED
|
||||||
if np.all(self._fixes_): self._fixes_ = None # ==UNFIXED
|
if np.all(self._fixes_): self._fixes_ = None # ==UNFIXED
|
||||||
|
|
||||||
|
|
@ -305,7 +212,7 @@ class Constrainable(Nameable, Indexable, Parentable):
|
||||||
"""evaluate the prior"""
|
"""evaluate the prior"""
|
||||||
if self.priors.size > 0:
|
if self.priors.size > 0:
|
||||||
x = self._get_params()
|
x = self._get_params()
|
||||||
return reduce(lambda a,b: a+b, [p.lnpdf(x[ind]).sum() for p, ind in self.priors.iteritems()], 0)
|
return reduce(lambda a, b: a + b, [p.lnpdf(x[ind]).sum() for p, ind in self.priors.iteritems()], 0)
|
||||||
return 0.
|
return 0.
|
||||||
|
|
||||||
def _log_prior_gradients(self):
|
def _log_prior_gradients(self):
|
||||||
|
|
@ -409,7 +316,7 @@ class Constrainable(Nameable, Indexable, Parentable):
|
||||||
if len(transforms) == 0:
|
if len(transforms) == 0:
|
||||||
transforms = which.properties()
|
transforms = which.properties()
|
||||||
import numpy as np
|
import numpy as np
|
||||||
removed = np.empty((0, ), dtype=int)
|
removed = np.empty((0,), dtype=int)
|
||||||
for t in transforms:
|
for t in transforms:
|
||||||
unconstrained = which.remove(t, self._raveled_index())
|
unconstrained = which.remove(t, self._raveled_index())
|
||||||
removed = np.union1d(removed, unconstrained)
|
removed = np.union1d(removed, unconstrained)
|
||||||
|
|
@ -419,5 +326,104 @@ class Constrainable(Nameable, Indexable, Parentable):
|
||||||
return removed
|
return removed
|
||||||
|
|
||||||
|
|
||||||
|
class Parameterizable(Constrainable):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super(Parameterizable, self).__init__(*args, **kwargs)
|
||||||
|
from GPy.core.parameterization.array_core import ParamList
|
||||||
|
_parameters_ = ParamList()
|
||||||
|
self._added_names_ = set()
|
||||||
|
|
||||||
|
def parameter_names(self, add_self=False, adjust_for_printing=False, recursive=True):
|
||||||
|
if adjust_for_printing: adjust = lambda x: adjust_name_for_printing(x)
|
||||||
|
else: adjust = lambda x: x
|
||||||
|
if recursive: names = [xi for x in self._parameters_ for xi in x.parameter_names(add_self=True, adjust_for_printing=adjust_for_printing)]
|
||||||
|
else: names = [adjust(x.name) for x in self._parameters_]
|
||||||
|
if add_self: names = map(lambda x: adjust(self.name) + "." + x, names)
|
||||||
|
return names
|
||||||
|
|
||||||
|
def _add_parameter_name(self, param):
|
||||||
|
pname = adjust_name_for_printing(param.name)
|
||||||
|
# and makes sure to not delete programmatically added parameters
|
||||||
|
if pname in self.__dict__:
|
||||||
|
if not (param is self.__dict__[pname]):
|
||||||
|
if pname in self._added_names_:
|
||||||
|
del self.__dict__[pname]
|
||||||
|
self._add_parameter_name(param)
|
||||||
|
else:
|
||||||
|
self.__dict__[pname] = param
|
||||||
|
self._added_names_.add(pname)
|
||||||
|
|
||||||
|
def _remove_parameter_name(self, param=None, pname=None):
|
||||||
|
assert param is None or pname is None, "can only delete either param by name, or the name of a param"
|
||||||
|
pname = adjust_name_for_printing(pname) or adjust_name_for_printing(param.name)
|
||||||
|
if pname in self._added_names_:
|
||||||
|
del self.__dict__[pname]
|
||||||
|
self._added_names_.remove(pname)
|
||||||
|
self._connect_parameters()
|
||||||
|
|
||||||
|
def _name_changed(self, param, old_name):
|
||||||
|
self._remove_parameter_name(None, old_name)
|
||||||
|
self._add_parameter_name(param)
|
||||||
|
|
||||||
|
def _collect_gradient(self, target):
|
||||||
|
import itertools
|
||||||
|
[p._collect_gradient(target[s]) for p, s in itertools.izip(self._parameters_, self._param_slices_)]
|
||||||
|
|
||||||
|
def _set_gradient(self, g):
|
||||||
|
import itertools
|
||||||
|
[p._set_gradient(g[s]) for p, s in itertools.izip(self._parameters_, self._param_slices_)]
|
||||||
|
|
||||||
|
def _get_params(self):
|
||||||
|
import numpy as np
|
||||||
|
# don't overwrite this anymore!
|
||||||
|
if not self.size:
|
||||||
|
return np.empty(shape=(0,), dtype=np.float64)
|
||||||
|
return np.hstack([x._get_params() for x in self._parameters_ if x.size > 0])
|
||||||
|
|
||||||
|
def _set_params(self, params, update=True):
|
||||||
|
# don't overwrite this anymore!
|
||||||
|
import itertools
|
||||||
|
[p._set_params(params[s], update=update) for p, s in itertools.izip(self._parameters_, self._param_slices_)]
|
||||||
|
self.parameters_changed()
|
||||||
|
|
||||||
|
def copy(self):
|
||||||
|
"""Returns a (deep) copy of the current model"""
|
||||||
|
import copy
|
||||||
|
from .index_operations import ParameterIndexOperations, ParameterIndexOperationsView
|
||||||
|
from .array_core import ParamList
|
||||||
|
dc = dict()
|
||||||
|
for k, v in self.__dict__.iteritems():
|
||||||
|
if k not in ['_direct_parent_', '_parameters_', '_parent_index_'] + self.parameter_names():
|
||||||
|
if isinstance(v, (Constrainable, ParameterIndexOperations, ParameterIndexOperationsView)):
|
||||||
|
dc[k] = v.copy()
|
||||||
|
else:
|
||||||
|
dc[k] = copy.deepcopy(v)
|
||||||
|
if k == '_parameters_':
|
||||||
|
params = [p.copy() for p in v]
|
||||||
|
# dc = copy.deepcopy(self.__dict__)
|
||||||
|
dc['_direct_parent_'] = None
|
||||||
|
dc['_parent_index_'] = None
|
||||||
|
dc['_parameters_'] = ParamList()
|
||||||
|
s = self.__new__(self.__class__)
|
||||||
|
s.__dict__ = dc
|
||||||
|
# import ipdb;ipdb.set_trace()
|
||||||
|
for p in params:
|
||||||
|
s.add_parameter(p)
|
||||||
|
# dc._notify_parent_change()
|
||||||
|
return s
|
||||||
|
# return copy.deepcopy(self)
|
||||||
|
|
||||||
|
def _notify_parameters_changed(self):
|
||||||
|
self.parameters_changed()
|
||||||
|
if self.has_parent():
|
||||||
|
self._direct_parent_._notify_parameters_changed()
|
||||||
|
|
||||||
|
def parameters_changed(self):
|
||||||
|
"""
|
||||||
|
This method gets called when parameters have changed.
|
||||||
|
Another way of listening to param changes is to
|
||||||
|
add self as a listener to the param, such that
|
||||||
|
updates get passed through. See :py:function:``GPy.core.param.Observable.add_observer``
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -7,11 +7,11 @@ import cPickle
|
||||||
import itertools
|
import itertools
|
||||||
from re import compile, _pattern_type
|
from re import compile, _pattern_type
|
||||||
from param import ParamConcatenation
|
from param import ParamConcatenation
|
||||||
from parameter_core import Constrainable, Pickleable, Observable, Parameterizable, Parentable, adjust_name_for_printing, Gradcheckable
|
from parameter_core import Constrainable, Pickleable, Parentable, Observable, Parameterizable, adjust_name_for_printing, Gradcheckable
|
||||||
from transformations import __fixed__
|
from transformations import __fixed__
|
||||||
from array_core import ParamList
|
from array_core import ParamList
|
||||||
|
|
||||||
class Parameterized(Constrainable, Pickleable, Observable, Gradcheckable, Parameterizable, Parentable):
|
class Parameterized(Parameterizable, Pickleable, Observable, Gradcheckable):
|
||||||
"""
|
"""
|
||||||
Parameterized class
|
Parameterized class
|
||||||
|
|
||||||
|
|
@ -53,8 +53,8 @@ class Parameterized(Constrainable, Pickleable, Observable, Gradcheckable, Parame
|
||||||
If you want to operate on all parameters use m[''] to wildcard select all paramters
|
If you want to operate on all parameters use m[''] to wildcard select all paramters
|
||||||
and concatenate them. Printing m[''] will result in printing of all parameters in detail.
|
and concatenate them. Printing m[''] will result in printing of all parameters in detail.
|
||||||
"""
|
"""
|
||||||
def __init__(self, name=None):
|
def __init__(self, name=None, *a, **kw):
|
||||||
super(Parameterized, self).__init__(name=name)
|
super(Parameterized, self).__init__(name=name, parent=None, parent_index=None, *a, **kw)
|
||||||
self._in_init_ = True
|
self._in_init_ = True
|
||||||
self._parameters_ = ParamList()
|
self._parameters_ = ParamList()
|
||||||
self.size = sum(p.size for p in self._parameters_)
|
self.size = sum(p.size for p in self._parameters_)
|
||||||
|
|
|
||||||
|
|
@ -9,7 +9,7 @@ from ...core.parameterization.param import Param
|
||||||
|
|
||||||
|
|
||||||
class Kern(Parameterized):
|
class Kern(Parameterized):
|
||||||
def __init__(self, input_dim, name):
|
def __init__(self, input_dim, name, *a, **kw):
|
||||||
"""
|
"""
|
||||||
The base class for a kernel: a positive definite function
|
The base class for a kernel: a positive definite function
|
||||||
which forms of a covariance function (kernel).
|
which forms of a covariance function (kernel).
|
||||||
|
|
@ -19,7 +19,7 @@ class Kern(Parameterized):
|
||||||
|
|
||||||
Do not instantiate.
|
Do not instantiate.
|
||||||
"""
|
"""
|
||||||
super(Kern, self).__init__(name)
|
super(Kern, self).__init__(name=name, *a, **kw)
|
||||||
self.input_dim = input_dim
|
self.input_dim = input_dim
|
||||||
|
|
||||||
def K(self, X, X2):
|
def K(self, X, X2):
|
||||||
|
|
|
||||||
|
|
@ -9,7 +9,7 @@ from ...util.linalg import tdot
|
||||||
from ...util.misc import fast_array_equal, param_to_array
|
from ...util.misc import fast_array_equal, param_to_array
|
||||||
from ...core.parameterization import Param
|
from ...core.parameterization import Param
|
||||||
from ...core.parameterization.transformations import Logexp
|
from ...core.parameterization.transformations import Logexp
|
||||||
from ...util.caching import Cacher, cache_this
|
from ...util.caching import cache_this
|
||||||
|
|
||||||
class Linear(Kern):
|
class Linear(Kern):
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -36,7 +36,7 @@ class BayesianGPLVM(SparseGP, GPLVM):
|
||||||
assert Z.shape[1] == X.shape[1]
|
assert Z.shape[1] == X.shape[1]
|
||||||
|
|
||||||
if kernel is None:
|
if kernel is None:
|
||||||
kernel = kern.rbf(input_dim) # + kern.white(input_dim)
|
kernel = kern.RBF(input_dim) # + kern.white(input_dim)
|
||||||
|
|
||||||
if likelihood is None:
|
if likelihood is None:
|
||||||
likelihood = Gaussian()
|
likelihood = Gaussian()
|
||||||
|
|
|
||||||
|
|
@ -12,7 +12,7 @@ class Cacher(object):
|
||||||
def __call__(self, *args):
|
def __call__(self, *args):
|
||||||
if self._reset_on_first:
|
if self._reset_on_first:
|
||||||
assert isinstance(args[0], Observable)
|
assert isinstance(args[0], Observable)
|
||||||
args[0].add_observer(self.reset)
|
args[0].add_observer(self, self.reset)
|
||||||
cached_args = args
|
cached_args = args
|
||||||
else:
|
else:
|
||||||
cached_args = args[1:]
|
cached_args = args[1:]
|
||||||
|
|
@ -29,21 +29,21 @@ class Cacher(object):
|
||||||
else:
|
else:
|
||||||
if len(self.cached_inputs) == self.limit:
|
if len(self.cached_inputs) == self.limit:
|
||||||
args_ = self.cached_inputs.pop(0)
|
args_ = self.cached_inputs.pop(0)
|
||||||
[a.remove_observer(self.on_cache_changed) for a in args_]
|
[a.remove_observer(self, self.on_cache_changed) for a in args_]
|
||||||
self.inputs_changed.pop(0)
|
self.inputs_changed.pop(0)
|
||||||
self.cached_outputs.pop(0)
|
self.cached_outputs.pop(0)
|
||||||
|
|
||||||
self.cached_inputs.append(cached_args)
|
self.cached_inputs.append(cached_args)
|
||||||
self.cached_outputs.append(self.operation(*args))
|
self.cached_outputs.append(self.operation(*args))
|
||||||
self.inputs_changed.append(False)
|
self.inputs_changed.append(False)
|
||||||
[a.add_observer(self.on_cache_changed) for a in args]
|
[a.add_observer(self, self.on_cache_changed) for a in args]
|
||||||
return self.cached_outputs[-1]
|
return self.cached_outputs[-1]
|
||||||
|
|
||||||
def on_cache_changed(self, arg):
|
def on_cache_changed(self, arg):
|
||||||
self.inputs_changed = [any([a is arg for a in args]) or old_ic for args, old_ic in zip(self.cached_inputs, self.inputs_changed)]
|
self.inputs_changed = [any([a is arg for a in args]) or old_ic for args, old_ic in zip(self.cached_inputs, self.inputs_changed)]
|
||||||
|
|
||||||
def reset(self, obj):
|
def reset(self, obj):
|
||||||
[[a.remove_observer(self.reset) for a in args] for args in self.cached_inputs]
|
[[a.remove_observer(self, self.reset) for a in args] for args in self.cached_inputs]
|
||||||
self.cached_inputs = []
|
self.cached_inputs = []
|
||||||
self.cached_outputs = []
|
self.cached_outputs = []
|
||||||
self.inputs_changed = []
|
self.inputs_changed = []
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue