mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-15 06:52:39 +02:00
[bgplvm&mrd] missing data greatly improved, still not there yet
This commit is contained in:
parent
b520eb212c
commit
156ba00719
6 changed files with 111 additions and 55 deletions
|
|
@ -10,6 +10,7 @@ from ..util import linalg
|
|||
from ..core.parameterization.variational import NormalPosterior, NormalPrior, VariationalPosterior
|
||||
from ..inference.latent_function_inference.var_dtc_parallel import update_gradients
|
||||
from ..inference.latent_function_inference.var_dtc_gpu import VarDTC_GPU
|
||||
import logging
|
||||
|
||||
class BayesianGPLVM(SparseGP):
|
||||
"""
|
||||
|
|
@ -25,8 +26,10 @@ class BayesianGPLVM(SparseGP):
|
|||
"""
|
||||
def __init__(self, Y, input_dim, X=None, X_variance=None, init='PCA', num_inducing=10,
|
||||
Z=None, kernel=None, inference_method=None, likelihood=None, name='bayesian gplvm', **kwargs):
|
||||
self.logger = logging.getLogger("Bayesian GPLVM <{}>".format(hex(id(self))))
|
||||
if X == None:
|
||||
from ..util.initialization import initialize_latent
|
||||
self.logger.info("initializing latent space X with method {}".format(init))
|
||||
X, fracs = initialize_latent(init, input_dim, Y)
|
||||
else:
|
||||
fracs = np.ones(input_dim)
|
||||
|
|
@ -36,7 +39,6 @@ class BayesianGPLVM(SparseGP):
|
|||
if X_variance is None:
|
||||
X_variance = np.random.uniform(0,.1,X.shape)
|
||||
|
||||
|
||||
if Z is None:
|
||||
Z = np.random.permutation(X.copy())[:num_inducing]
|
||||
assert Z.shape[1] == X.shape[1]
|
||||
|
|
@ -52,11 +54,14 @@ class BayesianGPLVM(SparseGP):
|
|||
X = NormalPosterior(X, X_variance)
|
||||
|
||||
if inference_method is None:
|
||||
if np.any(np.isnan(Y)):
|
||||
inan = np.isnan(Y)
|
||||
if np.any(inan):
|
||||
from ..inference.latent_function_inference.var_dtc import VarDTCMissingData
|
||||
inference_method = VarDTCMissingData()
|
||||
self.logger.debug("creating inference_method with var_dtc missing data")
|
||||
inference_method = VarDTCMissingData(inan=inan)
|
||||
else:
|
||||
from ..inference.latent_function_inference.var_dtc import VarDTC
|
||||
self.logger.debug("creating inference_method var_dtc")
|
||||
inference_method = VarDTC()
|
||||
|
||||
SparseGP.__init__(self, X, Y, Z, kernel, likelihood, inference_method, name, **kwargs)
|
||||
|
|
|
|||
|
|
@ -2,10 +2,8 @@
|
|||
# Licensed under the BSD 3-clause license (see LICENSE.txt)
|
||||
|
||||
import numpy as np
|
||||
import itertools
|
||||
import pylab
|
||||
import itertools, logging
|
||||
|
||||
from ..core import Model
|
||||
from ..kern import Kern
|
||||
from ..core.parameterization.variational import NormalPosterior, NormalPrior
|
||||
from ..core.parameterization import Param, Parameterized
|
||||
|
|
@ -61,15 +59,18 @@ class MRD(SparseGP):
|
|||
inference_method=None, likelihoods=None, name='mrd', Ynames=None):
|
||||
super(GP, self).__init__(name)
|
||||
|
||||
self.logger = logging.getLogger("MRD <{}>".format(hex(id(self))))
|
||||
self.input_dim = input_dim
|
||||
self.num_inducing = num_inducing
|
||||
|
||||
if isinstance(Ylist, dict):
|
||||
Ynames, Ylist = zip(*Ylist.items())
|
||||
|
||||
self.logger.debug("creating observable arrays")
|
||||
self.Ylist = [ObsAr(Y) for Y in Ylist]
|
||||
|
||||
if Ynames is None:
|
||||
self.logger.debug("creating Ynames")
|
||||
Ynames = ['Y{}'.format(i) for i in range(len(Ylist))]
|
||||
self.names = Ynames
|
||||
assert len(self.names) == len(self.Ylist), "one name per dataset, or None if Ylist is a dict"
|
||||
|
|
@ -81,13 +82,15 @@ class MRD(SparseGP):
|
|||
inan = np.isnan(y)
|
||||
if np.any(inan):
|
||||
if not warned:
|
||||
print "WARING: NaN values detected, make sure initx method can cope with NaN values or provide starting latent space X"
|
||||
self.logger.warn("WARNING: NaN values detected, make sure initx method can cope with NaN values or provide starting latent space X")
|
||||
warned = True
|
||||
self.inference_method.append(VarDTCMissingData(limit=1, inan=inan))
|
||||
else:
|
||||
self.inference_method.append(VarDTC(limit=1))
|
||||
self.logger.debug("created inference method <{}>".format(hex(id(self.inference_method[-1]))))
|
||||
else:
|
||||
if not isinstance(inference_method, InferenceMethodList):
|
||||
self.logger.debug("making inference_method an InferenceMethodList")
|
||||
inference_method = InferenceMethodList(inference_method)
|
||||
self.inference_method = inference_method
|
||||
|
||||
|
|
@ -101,6 +104,7 @@ class MRD(SparseGP):
|
|||
self.num_inducing = self.Z.shape[0] # ensure M==N if M>N
|
||||
|
||||
# sort out the kernels
|
||||
self.logger.info("building kernels")
|
||||
if kernel is None:
|
||||
from ..kern import RBF
|
||||
self.kernels = [RBF(input_dim, ARD=1, lengthscale=fracs[i]) for i in range(len(Ylist))]
|
||||
|
|
@ -124,6 +128,7 @@ class MRD(SparseGP):
|
|||
self.likelihoods = [Gaussian(name='Gaussian_noise'.format(i)) for i in range(len(Ylist))]
|
||||
else: self.likelihoods = likelihoods
|
||||
|
||||
self.logger.info("adding X and Z")
|
||||
self.add_parameters(self.X, self.Z)
|
||||
|
||||
self.bgplvms = []
|
||||
|
|
@ -141,6 +146,7 @@ class MRD(SparseGP):
|
|||
self.bgplvms.append(p)
|
||||
|
||||
self.posterior = None
|
||||
self.logger.info("init done")
|
||||
self._in_init_ = False
|
||||
|
||||
def parameters_changed(self):
|
||||
|
|
@ -148,17 +154,19 @@ class MRD(SparseGP):
|
|||
self.posteriors = []
|
||||
self.Z.gradient[:] = 0.
|
||||
self.X.gradient[:] = 0.
|
||||
|
||||
for y, k, l, i in itertools.izip(self.Ylist, self.kernels, self.likelihoods, self.inference_method):
|
||||
self.logger.info('working on im <{}>'.format(hex(id(i))))
|
||||
posterior, lml, grad_dict = i.inference(k, self.X, self.Z, l, y)
|
||||
|
||||
self.posteriors.append(posterior)
|
||||
self._log_marginal_likelihood += lml
|
||||
|
||||
# likelihoods gradients
|
||||
self.logger.info("likelihood gradients")
|
||||
l.update_gradients(grad_dict.pop('dL_dthetaL'))
|
||||
|
||||
#gradients wrt kernel
|
||||
self.logger.info("kernel gradients")
|
||||
dL_dKmm = grad_dict.pop('dL_dKmm')
|
||||
k.update_gradients_full(dL_dKmm, self.Z, None)
|
||||
target = k.gradient.copy()
|
||||
|
|
@ -166,6 +174,7 @@ class MRD(SparseGP):
|
|||
k.gradient += target
|
||||
|
||||
#gradients wrt Z
|
||||
self.logger.info("Z gradients")
|
||||
self.Z.gradient += k.gradients_X(dL_dKmm, self.Z)
|
||||
self.Z.gradient += k.gradients_Z_expectations(
|
||||
grad_dict['dL_dpsi0'],
|
||||
|
|
@ -173,6 +182,7 @@ class MRD(SparseGP):
|
|||
grad_dict['dL_dpsi2'],
|
||||
Z=self.Z, variational_posterior=self.X)
|
||||
|
||||
self.logger.info("X gradients")
|
||||
dL_dmean, dL_dS = k.gradients_qX_expectations(variational_posterior=self.X, Z=self.Z, **grad_dict)
|
||||
self.X.mean.gradient += dL_dmean
|
||||
self.X.variance.gradient += dL_dS
|
||||
|
|
@ -219,8 +229,9 @@ class MRD(SparseGP):
|
|||
return Z
|
||||
|
||||
def _handle_plotting(self, fignum, axes, plotf, sharex=False, sharey=False):
|
||||
import matplotlib.pyplot as plt
|
||||
if axes is None:
|
||||
fig = pylab.figure(num=fignum)
|
||||
fig = plt.figure(num=fignum)
|
||||
sharex_ax = None
|
||||
sharey_ax = None
|
||||
plots = []
|
||||
|
|
@ -242,8 +253,8 @@ class MRD(SparseGP):
|
|||
raise ValueError("Need one axes per latent dimension input_dim")
|
||||
plots.append(plotf(i, g, ax))
|
||||
if sharey_ax is not None:
|
||||
pylab.setp(ax.get_yticklabels(), visible=False)
|
||||
pylab.draw()
|
||||
plt.setp(ax.get_yticklabels(), visible=False)
|
||||
plt.draw()
|
||||
if axes is None:
|
||||
try:
|
||||
fig.tight_layout()
|
||||
|
|
@ -300,11 +311,12 @@ class MRD(SparseGP):
|
|||
"""
|
||||
import sys
|
||||
assert "matplotlib" in sys.modules, "matplotlib package has not been imported."
|
||||
import matplotlib.pyplot as plt
|
||||
from ..plotting.matplot_dep import dim_reduction_plots
|
||||
if "Yindex" not in predict_kwargs:
|
||||
predict_kwargs['Yindex'] = 0
|
||||
if ax is None:
|
||||
fig = pylab.figure(num=fignum)
|
||||
fig = plt.figure(num=fignum)
|
||||
ax = fig.add_subplot(111)
|
||||
else:
|
||||
fig = ax.figure
|
||||
|
|
|
|||
|
|
@ -94,22 +94,22 @@ class MiscTests(unittest.TestCase):
|
|||
np.testing.assert_equal(m.log_likelihood(), m2.log_likelihood())
|
||||
|
||||
m.kern.lengthscale.randomize()
|
||||
m._trigger_params_changed()
|
||||
m.update_model()
|
||||
m2.kern.lengthscale = m.kern.lengthscale
|
||||
np.testing.assert_equal(m.log_likelihood(), m2.log_likelihood())
|
||||
|
||||
m.kern.lengthscale.randomize()
|
||||
m._trigger_params_changed()
|
||||
m.update_model()
|
||||
m2['.*lengthscale'] = m.kern.lengthscale
|
||||
np.testing.assert_equal(m.log_likelihood(), m2.log_likelihood())
|
||||
|
||||
m.kern.lengthscale.randomize()
|
||||
m._trigger_params_changed()
|
||||
m.update_model()
|
||||
m2['.*lengthscale'] = m.kern['.*lengthscale']
|
||||
np.testing.assert_equal(m.log_likelihood(), m2.log_likelihood())
|
||||
|
||||
m.kern.lengthscale.randomize()
|
||||
m._trigger_params_changed()
|
||||
m.update_model()
|
||||
m2.kern.lengthscale = m.kern['.*lengthscale']
|
||||
np.testing.assert_equal(m.log_likelihood(), m2.log_likelihood())
|
||||
|
||||
|
|
|
|||
|
|
@ -94,12 +94,12 @@ class Test(unittest.TestCase):
|
|||
def test_set_params(self):
|
||||
self.assertEqual(self.par.params_changed_count, 0, 'no params changed yet')
|
||||
self.par.param_array[:] = 1
|
||||
self.par._trigger_params_changed()
|
||||
self.par.update_model()
|
||||
self.assertEqual(self.par.params_changed_count, 1, 'now params changed')
|
||||
self.assertEqual(self.parent.parent_changed_count, self.par.params_changed_count)
|
||||
|
||||
self.par.param_array[:] = 2
|
||||
self.par._trigger_params_changed()
|
||||
self.par.update_model()
|
||||
self.assertEqual(self.par.params_changed_count, 2, 'now params changed')
|
||||
self.assertEqual(self.parent.parent_changed_count, self.par.params_changed_count)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,9 +1,7 @@
|
|||
from ..core.parameterization.parameter_core import Observable
|
||||
import itertools, collections, weakref
|
||||
import collections, weakref, logging
|
||||
|
||||
class Cacher(object):
|
||||
|
||||
|
||||
def __init__(self, operation, limit=5, ignore_args=(), force_kwargs=()):
|
||||
"""
|
||||
Parameters:
|
||||
|
|
@ -12,6 +10,7 @@ class Cacher(object):
|
|||
:param int limit: depth of cacher
|
||||
:param [int] ignore_args: list of indices, pointing at arguments to ignore in *args of operation(*args). This includes self!
|
||||
:param [str] force_kwargs: list of kwarg names (strings). If a kwarg with that name is given, the cacher will force recompute and wont cache anything.
|
||||
:param int verbose: verbosity level. 0: no print outs, 1: casual print outs, 2: debug level print outs
|
||||
"""
|
||||
self.limit = int(limit)
|
||||
self.ignore_args = ignore_args
|
||||
|
|
@ -19,6 +18,7 @@ class Cacher(object):
|
|||
self.operation=operation
|
||||
self.order = collections.deque()
|
||||
self.cached_inputs = {} # point from cache_ids to a list of [ind_ids], which where used in cache cache_id
|
||||
self.logger = logging.getLogger("cache")
|
||||
|
||||
#=======================================================================
|
||||
# point from each ind_id to [ref(obj), cache_ids]
|
||||
|
|
@ -30,78 +30,104 @@ class Cacher(object):
|
|||
self.cached_outputs = {} # point from cache_ids to outputs
|
||||
self.inputs_changed = {} # point from cache_ids to bools
|
||||
|
||||
def id(self, obj):
|
||||
"""returns the self.id of an object, to be used in caching individual self.ids"""
|
||||
return hex(id(obj))
|
||||
|
||||
def combine_inputs(self, args, kw):
|
||||
"Combines the args and kw in a unique way, such that ordering of kwargs does not lead to recompute"
|
||||
self.logger.debug("combining args and kw")
|
||||
return args + tuple(c[1] for c in sorted(kw.items(), key=lambda x: x[0]))
|
||||
|
||||
def prepare_cache_id(self, combined_args_kw, ignore_args):
|
||||
"get the cacheid (conc. string of argument ids in order) ignoring ignore_args"
|
||||
return "".join(str(id(a)) for i,a in enumerate(combined_args_kw) if i not in ignore_args)
|
||||
"get the cacheid (conc. string of argument self.ids in order) ignoring ignore_args"
|
||||
cache_id = "".join(self.id(a) for i,a in enumerate(combined_args_kw) if i not in ignore_args)
|
||||
self.logger.debug("cache_id={} was created".format(cache_id))
|
||||
return cache_id
|
||||
|
||||
def ensure_cache_length(self, cache_id):
|
||||
"Ensures the cache is within its limits and has one place free"
|
||||
self.logger.debug("cache length gets ensured")
|
||||
if len(self.order) == self.limit:
|
||||
self.logger.debug("cache limit of l={} was reached".format(self.limit))
|
||||
# we have reached the limit, so lets release one element
|
||||
cache_id = self.order.popleft()
|
||||
self.logger.debug("cach_id '{}' gets removed".format(cache_id))
|
||||
combined_args_kw = self.cached_inputs[cache_id]
|
||||
for ind in combined_args_kw:
|
||||
ind_id = id(ind)
|
||||
ref, cache_ids = self.cached_input_ids[ind_id]
|
||||
if len(cache_ids) == 1 and ref() is not None:
|
||||
ref().remove_observer(self, self.on_cache_changed)
|
||||
del self.cached_input_ids[ind_id]
|
||||
else:
|
||||
cache_ids.remove(cache_id)
|
||||
self.cached_input_ids[ind_id] = [ref, cache_ids]
|
||||
if ind is not None:
|
||||
ind_id = self.id(ind)
|
||||
ref, cache_ids = self.cached_input_ids[ind_id]
|
||||
if len(cache_ids) == 1 and ref() is not None:
|
||||
ref().remove_observer(self, self.on_cache_changed)
|
||||
del self.cached_input_ids[ind_id]
|
||||
else:
|
||||
cache_ids.remove(cache_id)
|
||||
self.cached_input_ids[ind_id] = [ref, cache_ids]
|
||||
self.logger.debug("removing caches")
|
||||
del self.cached_outputs[cache_id]
|
||||
del self.inputs_changed[cache_id]
|
||||
del self.cached_inputs[cache_id]
|
||||
|
||||
def add_to_cache(self, cache_id, combined_args_kw, output):
|
||||
def add_to_cache(self, cache_id, inputs, output):
|
||||
"""This adds cache_id to the cache, with inputs and output"""
|
||||
self.inputs_changed[cache_id] = False
|
||||
self.cached_outputs[cache_id] = output
|
||||
self.order.append(cache_id)
|
||||
self.cached_inputs[cache_id] = combined_args_kw
|
||||
for a in combined_args_kw:
|
||||
ind_id = id(a)
|
||||
v = self.cached_input_ids.get(ind_id, [weakref.ref(a), []])
|
||||
v[1].append(cache_id)
|
||||
if len(v[1]) == 1:
|
||||
a.add_observer(self, self.on_cache_changed)
|
||||
self.cached_input_ids[ind_id] = v
|
||||
self.cached_inputs[cache_id] = inputs
|
||||
for a in inputs:
|
||||
if a is not None:
|
||||
ind_id = self.id(a)
|
||||
v = self.cached_input_ids.get(ind_id, [weakref.ref(a), []])
|
||||
self.logger.debug("cache_id '{}' gets stored".format(cache_id))
|
||||
v[1].append(cache_id)
|
||||
if len(v[1]) == 1:
|
||||
self.logger.debug("adding observer to object {}".format(repr(a)))
|
||||
a.add_observer(self, self.on_cache_changed)
|
||||
self.cached_input_ids[ind_id] = v
|
||||
|
||||
def __call__(self, *args, **kw):
|
||||
"""
|
||||
A wrapper function for self.operation,
|
||||
"""
|
||||
|
||||
#=======================================================================
|
||||
# !WARNING CACHE OFFSWITCH!
|
||||
# return self.operation(*args, **kw)
|
||||
#=======================================================================
|
||||
|
||||
# 1: Check whether we have forced recompute arguments:
|
||||
if len(self.force_kwargs) != 0:
|
||||
for k in self.force_kwargs:
|
||||
if k in kw and kw[k] is not None:
|
||||
return self.operation(*args, **kw)
|
||||
|
||||
# 2: prepare_cache_id and get the unique id string for this call
|
||||
|
||||
# 2: prepare_cache_id and get the unique self.id string for this call
|
||||
inputs = self.combine_inputs(args, kw)
|
||||
cache_id = self.prepare_cache_id(inputs, self.ignore_args)
|
||||
|
||||
# 2: if anything is not cachable, we will just return the operation, without caching
|
||||
if reduce(lambda a,b: a or (not isinstance(b, Observable)), inputs, False):
|
||||
if reduce(lambda a,b: a or (not (isinstance(b, Observable) or b is None)), inputs, False):
|
||||
self.logger.info("some inputs are not observable: returning without caching")
|
||||
self.logger.info(str(map(lambda x: isinstance(x, Observable) or x is None, inputs)))
|
||||
self.logger.info(str(map(repr, inputs)))
|
||||
return self.operation(*args, **kw)
|
||||
# 3&4: check whether this cache_id has been cached, then has it changed?
|
||||
try:
|
||||
if(self.inputs_changed[cache_id]):
|
||||
# 4: This happens, when elements have changed for this cache id
|
||||
self.logger.debug("{} already seen, but inputs changed. refreshing cacher".format(cache_id))
|
||||
# 4: This happens, when elements have changed for this cache self.id
|
||||
self.inputs_changed[cache_id] = False
|
||||
self.cached_outputs[cache_id] = self.operation(*args, **kw)
|
||||
except KeyError:
|
||||
self.logger.info("{} never seen, creating cache entry".format(cache_id))
|
||||
# 3: This is when we never saw this chache_id:
|
||||
self.ensure_cache_length(cache_id)
|
||||
self.add_to_cache(cache_id, inputs, self.operation(*args, **kw))
|
||||
except:
|
||||
self.logger.error("an error occurred while trying to run caching for {}, resetting".format(cache_id))
|
||||
self.reset()
|
||||
raise
|
||||
# 5: We have seen this cache_id and it is cached:
|
||||
self.logger.info("returning cache {}".format(cache_id))
|
||||
return self.cached_outputs[cache_id]
|
||||
|
||||
def on_cache_changed(self, direct, which=None):
|
||||
|
|
@ -110,10 +136,13 @@ class Cacher(object):
|
|||
|
||||
this function gets 'hooked up' to the inputs when we cache them, and upon their elements being changed we update here.
|
||||
"""
|
||||
for ind_id in [id(direct), id(which)]:
|
||||
_, cache_ids = self.cached_input_ids.get(ind_id, [None, []])
|
||||
for cache_id in cache_ids:
|
||||
self.inputs_changed[cache_id] = True
|
||||
for what in [direct, which]:
|
||||
if what is not None:
|
||||
ind_id = self.id(what)
|
||||
_, cache_ids = self.cached_input_ids.get(ind_id, [None, []])
|
||||
for cache_id in cache_ids:
|
||||
self.logger.info("callback from {} changed inputs from {}".format(ind_id, self.inputs_changed[cache_id]))
|
||||
self.inputs_changed[cache_id] = True
|
||||
|
||||
def reset(self):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -4,9 +4,9 @@
|
|||
.. moduleauthor:: Max Zwiessele <ibinbei@gmail.com>
|
||||
|
||||
'''
|
||||
__updated__ = '2014-05-20'
|
||||
__updated__ = '2014-05-21'
|
||||
|
||||
import numpy as np
|
||||
import numpy as np, logging
|
||||
|
||||
def common_subarrays(X, axis=0):
|
||||
"""
|
||||
|
|
@ -14,11 +14,11 @@ def common_subarrays(X, axis=0):
|
|||
Common subarrays are returned as a dictionary of <subarray, [index]> pairs, where
|
||||
the subarray is a tuple representing the subarray and the index is the index
|
||||
for the subarray in X, where index is the index to the remaining axis.
|
||||
|
||||
|
||||
:param :class:`np.ndarray` X: 2d array to check for common subarrays in
|
||||
:param int axis: axis to apply subarray detection over.
|
||||
When the index is 0, compare rows -- columns, otherwise.
|
||||
|
||||
|
||||
Examples:
|
||||
=========
|
||||
|
||||
|
|
@ -48,7 +48,17 @@ def common_subarrays(X, axis=0):
|
|||
assert X.ndim == 2 and axis in (0,1), "Only implemented for 2D arrays"
|
||||
subarrays = defaultdict(list)
|
||||
cnt = count()
|
||||
np.apply_along_axis(lambda x: iadd(subarrays[tuple(x)], [cnt.next()]), 1-axis, X)
|
||||
logger = logging.getLogger("common_subarrays")
|
||||
def accumulate(x, s, c):
|
||||
logger.debug("creating tuple")
|
||||
t = tuple(x)
|
||||
logger.debug("tuple done")
|
||||
col = c.next()
|
||||
iadd(s[t], [col])
|
||||
logger.debug("added col {}".format(col))
|
||||
return None
|
||||
if axis == 0: [accumulate(x, subarrays, cnt) for x in X]
|
||||
else: [accumulate(x, subarrays, cnt) for x in X.T]
|
||||
return subarrays
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue