[bgplvm&mrd] missing data greatly improved, still not there yet

This commit is contained in:
Max Zwiessele 2014-05-21 16:32:06 +01:00
parent b520eb212c
commit 156ba00719
6 changed files with 111 additions and 55 deletions

View file

@ -10,6 +10,7 @@ from ..util import linalg
from ..core.parameterization.variational import NormalPosterior, NormalPrior, VariationalPosterior
from ..inference.latent_function_inference.var_dtc_parallel import update_gradients
from ..inference.latent_function_inference.var_dtc_gpu import VarDTC_GPU
import logging
class BayesianGPLVM(SparseGP):
"""
@ -25,8 +26,10 @@ class BayesianGPLVM(SparseGP):
"""
def __init__(self, Y, input_dim, X=None, X_variance=None, init='PCA', num_inducing=10,
Z=None, kernel=None, inference_method=None, likelihood=None, name='bayesian gplvm', **kwargs):
self.logger = logging.getLogger("Bayesian GPLVM <{}>".format(hex(id(self))))
if X == None:
from ..util.initialization import initialize_latent
self.logger.info("initializing latent space X with method {}".format(init))
X, fracs = initialize_latent(init, input_dim, Y)
else:
fracs = np.ones(input_dim)
@ -36,7 +39,6 @@ class BayesianGPLVM(SparseGP):
if X_variance is None:
X_variance = np.random.uniform(0,.1,X.shape)
if Z is None:
Z = np.random.permutation(X.copy())[:num_inducing]
assert Z.shape[1] == X.shape[1]
@ -52,11 +54,14 @@ class BayesianGPLVM(SparseGP):
X = NormalPosterior(X, X_variance)
if inference_method is None:
if np.any(np.isnan(Y)):
inan = np.isnan(Y)
if np.any(inan):
from ..inference.latent_function_inference.var_dtc import VarDTCMissingData
inference_method = VarDTCMissingData()
self.logger.debug("creating inference_method with var_dtc missing data")
inference_method = VarDTCMissingData(inan=inan)
else:
from ..inference.latent_function_inference.var_dtc import VarDTC
self.logger.debug("creating inference_method var_dtc")
inference_method = VarDTC()
SparseGP.__init__(self, X, Y, Z, kernel, likelihood, inference_method, name, **kwargs)

View file

@ -2,10 +2,8 @@
# Licensed under the BSD 3-clause license (see LICENSE.txt)
import numpy as np
import itertools
import pylab
import itertools, logging
from ..core import Model
from ..kern import Kern
from ..core.parameterization.variational import NormalPosterior, NormalPrior
from ..core.parameterization import Param, Parameterized
@ -61,15 +59,18 @@ class MRD(SparseGP):
inference_method=None, likelihoods=None, name='mrd', Ynames=None):
super(GP, self).__init__(name)
self.logger = logging.getLogger("MRD <{}>".format(hex(id(self))))
self.input_dim = input_dim
self.num_inducing = num_inducing
if isinstance(Ylist, dict):
Ynames, Ylist = zip(*Ylist.items())
self.logger.debug("creating observable arrays")
self.Ylist = [ObsAr(Y) for Y in Ylist]
if Ynames is None:
self.logger.debug("creating Ynames")
Ynames = ['Y{}'.format(i) for i in range(len(Ylist))]
self.names = Ynames
assert len(self.names) == len(self.Ylist), "one name per dataset, or None if Ylist is a dict"
@ -81,13 +82,15 @@ class MRD(SparseGP):
inan = np.isnan(y)
if np.any(inan):
if not warned:
print "WARING: NaN values detected, make sure initx method can cope with NaN values or provide starting latent space X"
self.logger.warn("WARNING: NaN values detected, make sure initx method can cope with NaN values or provide starting latent space X")
warned = True
self.inference_method.append(VarDTCMissingData(limit=1, inan=inan))
else:
self.inference_method.append(VarDTC(limit=1))
self.logger.debug("created inference method <{}>".format(hex(id(self.inference_method[-1]))))
else:
if not isinstance(inference_method, InferenceMethodList):
self.logger.debug("making inference_method an InferenceMethodList")
inference_method = InferenceMethodList(inference_method)
self.inference_method = inference_method
@ -101,6 +104,7 @@ class MRD(SparseGP):
self.num_inducing = self.Z.shape[0] # ensure M==N if M>N
# sort out the kernels
self.logger.info("building kernels")
if kernel is None:
from ..kern import RBF
self.kernels = [RBF(input_dim, ARD=1, lengthscale=fracs[i]) for i in range(len(Ylist))]
@ -124,6 +128,7 @@ class MRD(SparseGP):
self.likelihoods = [Gaussian(name='Gaussian_noise'.format(i)) for i in range(len(Ylist))]
else: self.likelihoods = likelihoods
self.logger.info("adding X and Z")
self.add_parameters(self.X, self.Z)
self.bgplvms = []
@ -141,6 +146,7 @@ class MRD(SparseGP):
self.bgplvms.append(p)
self.posterior = None
self.logger.info("init done")
self._in_init_ = False
def parameters_changed(self):
@ -148,17 +154,19 @@ class MRD(SparseGP):
self.posteriors = []
self.Z.gradient[:] = 0.
self.X.gradient[:] = 0.
for y, k, l, i in itertools.izip(self.Ylist, self.kernels, self.likelihoods, self.inference_method):
self.logger.info('working on im <{}>'.format(hex(id(i))))
posterior, lml, grad_dict = i.inference(k, self.X, self.Z, l, y)
self.posteriors.append(posterior)
self._log_marginal_likelihood += lml
# likelihoods gradients
self.logger.info("likelihood gradients")
l.update_gradients(grad_dict.pop('dL_dthetaL'))
#gradients wrt kernel
self.logger.info("kernel gradients")
dL_dKmm = grad_dict.pop('dL_dKmm')
k.update_gradients_full(dL_dKmm, self.Z, None)
target = k.gradient.copy()
@ -166,6 +174,7 @@ class MRD(SparseGP):
k.gradient += target
#gradients wrt Z
self.logger.info("Z gradients")
self.Z.gradient += k.gradients_X(dL_dKmm, self.Z)
self.Z.gradient += k.gradients_Z_expectations(
grad_dict['dL_dpsi0'],
@ -173,6 +182,7 @@ class MRD(SparseGP):
grad_dict['dL_dpsi2'],
Z=self.Z, variational_posterior=self.X)
self.logger.info("X gradients")
dL_dmean, dL_dS = k.gradients_qX_expectations(variational_posterior=self.X, Z=self.Z, **grad_dict)
self.X.mean.gradient += dL_dmean
self.X.variance.gradient += dL_dS
@ -219,8 +229,9 @@ class MRD(SparseGP):
return Z
def _handle_plotting(self, fignum, axes, plotf, sharex=False, sharey=False):
import matplotlib.pyplot as plt
if axes is None:
fig = pylab.figure(num=fignum)
fig = plt.figure(num=fignum)
sharex_ax = None
sharey_ax = None
plots = []
@ -242,8 +253,8 @@ class MRD(SparseGP):
raise ValueError("Need one axes per latent dimension input_dim")
plots.append(plotf(i, g, ax))
if sharey_ax is not None:
pylab.setp(ax.get_yticklabels(), visible=False)
pylab.draw()
plt.setp(ax.get_yticklabels(), visible=False)
plt.draw()
if axes is None:
try:
fig.tight_layout()
@ -300,11 +311,12 @@ class MRD(SparseGP):
"""
import sys
assert "matplotlib" in sys.modules, "matplotlib package has not been imported."
import matplotlib.pyplot as plt
from ..plotting.matplot_dep import dim_reduction_plots
if "Yindex" not in predict_kwargs:
predict_kwargs['Yindex'] = 0
if ax is None:
fig = pylab.figure(num=fignum)
fig = plt.figure(num=fignum)
ax = fig.add_subplot(111)
else:
fig = ax.figure

View file

@ -94,22 +94,22 @@ class MiscTests(unittest.TestCase):
np.testing.assert_equal(m.log_likelihood(), m2.log_likelihood())
m.kern.lengthscale.randomize()
m._trigger_params_changed()
m.update_model()
m2.kern.lengthscale = m.kern.lengthscale
np.testing.assert_equal(m.log_likelihood(), m2.log_likelihood())
m.kern.lengthscale.randomize()
m._trigger_params_changed()
m.update_model()
m2['.*lengthscale'] = m.kern.lengthscale
np.testing.assert_equal(m.log_likelihood(), m2.log_likelihood())
m.kern.lengthscale.randomize()
m._trigger_params_changed()
m.update_model()
m2['.*lengthscale'] = m.kern['.*lengthscale']
np.testing.assert_equal(m.log_likelihood(), m2.log_likelihood())
m.kern.lengthscale.randomize()
m._trigger_params_changed()
m.update_model()
m2.kern.lengthscale = m.kern['.*lengthscale']
np.testing.assert_equal(m.log_likelihood(), m2.log_likelihood())

View file

@ -94,12 +94,12 @@ class Test(unittest.TestCase):
def test_set_params(self):
self.assertEqual(self.par.params_changed_count, 0, 'no params changed yet')
self.par.param_array[:] = 1
self.par._trigger_params_changed()
self.par.update_model()
self.assertEqual(self.par.params_changed_count, 1, 'now params changed')
self.assertEqual(self.parent.parent_changed_count, self.par.params_changed_count)
self.par.param_array[:] = 2
self.par._trigger_params_changed()
self.par.update_model()
self.assertEqual(self.par.params_changed_count, 2, 'now params changed')
self.assertEqual(self.parent.parent_changed_count, self.par.params_changed_count)

View file

@ -1,9 +1,7 @@
from ..core.parameterization.parameter_core import Observable
import itertools, collections, weakref
import collections, weakref, logging
class Cacher(object):
def __init__(self, operation, limit=5, ignore_args=(), force_kwargs=()):
"""
Parameters:
@ -12,6 +10,7 @@ class Cacher(object):
:param int limit: depth of cacher
:param [int] ignore_args: list of indices, pointing at arguments to ignore in *args of operation(*args). This includes self!
:param [str] force_kwargs: list of kwarg names (strings). If a kwarg with that name is given, the cacher will force recompute and wont cache anything.
:param int verbose: verbosity level. 0: no print outs, 1: casual print outs, 2: debug level print outs
"""
self.limit = int(limit)
self.ignore_args = ignore_args
@ -19,6 +18,7 @@ class Cacher(object):
self.operation=operation
self.order = collections.deque()
self.cached_inputs = {} # point from cache_ids to a list of [ind_ids], which where used in cache cache_id
self.logger = logging.getLogger("cache")
#=======================================================================
# point from each ind_id to [ref(obj), cache_ids]
@ -30,78 +30,104 @@ class Cacher(object):
self.cached_outputs = {} # point from cache_ids to outputs
self.inputs_changed = {} # point from cache_ids to bools
def id(self, obj):
"""returns the self.id of an object, to be used in caching individual self.ids"""
return hex(id(obj))
def combine_inputs(self, args, kw):
"Combines the args and kw in a unique way, such that ordering of kwargs does not lead to recompute"
self.logger.debug("combining args and kw")
return args + tuple(c[1] for c in sorted(kw.items(), key=lambda x: x[0]))
def prepare_cache_id(self, combined_args_kw, ignore_args):
"get the cacheid (conc. string of argument ids in order) ignoring ignore_args"
return "".join(str(id(a)) for i,a in enumerate(combined_args_kw) if i not in ignore_args)
"get the cacheid (conc. string of argument self.ids in order) ignoring ignore_args"
cache_id = "".join(self.id(a) for i,a in enumerate(combined_args_kw) if i not in ignore_args)
self.logger.debug("cache_id={} was created".format(cache_id))
return cache_id
def ensure_cache_length(self, cache_id):
"Ensures the cache is within its limits and has one place free"
self.logger.debug("cache length gets ensured")
if len(self.order) == self.limit:
self.logger.debug("cache limit of l={} was reached".format(self.limit))
# we have reached the limit, so lets release one element
cache_id = self.order.popleft()
self.logger.debug("cach_id '{}' gets removed".format(cache_id))
combined_args_kw = self.cached_inputs[cache_id]
for ind in combined_args_kw:
ind_id = id(ind)
ref, cache_ids = self.cached_input_ids[ind_id]
if len(cache_ids) == 1 and ref() is not None:
ref().remove_observer(self, self.on_cache_changed)
del self.cached_input_ids[ind_id]
else:
cache_ids.remove(cache_id)
self.cached_input_ids[ind_id] = [ref, cache_ids]
if ind is not None:
ind_id = self.id(ind)
ref, cache_ids = self.cached_input_ids[ind_id]
if len(cache_ids) == 1 and ref() is not None:
ref().remove_observer(self, self.on_cache_changed)
del self.cached_input_ids[ind_id]
else:
cache_ids.remove(cache_id)
self.cached_input_ids[ind_id] = [ref, cache_ids]
self.logger.debug("removing caches")
del self.cached_outputs[cache_id]
del self.inputs_changed[cache_id]
del self.cached_inputs[cache_id]
def add_to_cache(self, cache_id, combined_args_kw, output):
def add_to_cache(self, cache_id, inputs, output):
"""This adds cache_id to the cache, with inputs and output"""
self.inputs_changed[cache_id] = False
self.cached_outputs[cache_id] = output
self.order.append(cache_id)
self.cached_inputs[cache_id] = combined_args_kw
for a in combined_args_kw:
ind_id = id(a)
v = self.cached_input_ids.get(ind_id, [weakref.ref(a), []])
v[1].append(cache_id)
if len(v[1]) == 1:
a.add_observer(self, self.on_cache_changed)
self.cached_input_ids[ind_id] = v
self.cached_inputs[cache_id] = inputs
for a in inputs:
if a is not None:
ind_id = self.id(a)
v = self.cached_input_ids.get(ind_id, [weakref.ref(a), []])
self.logger.debug("cache_id '{}' gets stored".format(cache_id))
v[1].append(cache_id)
if len(v[1]) == 1:
self.logger.debug("adding observer to object {}".format(repr(a)))
a.add_observer(self, self.on_cache_changed)
self.cached_input_ids[ind_id] = v
def __call__(self, *args, **kw):
"""
A wrapper function for self.operation,
"""
#=======================================================================
# !WARNING CACHE OFFSWITCH!
# return self.operation(*args, **kw)
#=======================================================================
# 1: Check whether we have forced recompute arguments:
if len(self.force_kwargs) != 0:
for k in self.force_kwargs:
if k in kw and kw[k] is not None:
return self.operation(*args, **kw)
# 2: prepare_cache_id and get the unique id string for this call
# 2: prepare_cache_id and get the unique self.id string for this call
inputs = self.combine_inputs(args, kw)
cache_id = self.prepare_cache_id(inputs, self.ignore_args)
# 2: if anything is not cachable, we will just return the operation, without caching
if reduce(lambda a,b: a or (not isinstance(b, Observable)), inputs, False):
if reduce(lambda a,b: a or (not (isinstance(b, Observable) or b is None)), inputs, False):
self.logger.info("some inputs are not observable: returning without caching")
self.logger.info(str(map(lambda x: isinstance(x, Observable) or x is None, inputs)))
self.logger.info(str(map(repr, inputs)))
return self.operation(*args, **kw)
# 3&4: check whether this cache_id has been cached, then has it changed?
try:
if(self.inputs_changed[cache_id]):
# 4: This happens, when elements have changed for this cache id
self.logger.debug("{} already seen, but inputs changed. refreshing cacher".format(cache_id))
# 4: This happens, when elements have changed for this cache self.id
self.inputs_changed[cache_id] = False
self.cached_outputs[cache_id] = self.operation(*args, **kw)
except KeyError:
self.logger.info("{} never seen, creating cache entry".format(cache_id))
# 3: This is when we never saw this chache_id:
self.ensure_cache_length(cache_id)
self.add_to_cache(cache_id, inputs, self.operation(*args, **kw))
except:
self.logger.error("an error occurred while trying to run caching for {}, resetting".format(cache_id))
self.reset()
raise
# 5: We have seen this cache_id and it is cached:
self.logger.info("returning cache {}".format(cache_id))
return self.cached_outputs[cache_id]
def on_cache_changed(self, direct, which=None):
@ -110,10 +136,13 @@ class Cacher(object):
this function gets 'hooked up' to the inputs when we cache them, and upon their elements being changed we update here.
"""
for ind_id in [id(direct), id(which)]:
_, cache_ids = self.cached_input_ids.get(ind_id, [None, []])
for cache_id in cache_ids:
self.inputs_changed[cache_id] = True
for what in [direct, which]:
if what is not None:
ind_id = self.id(what)
_, cache_ids = self.cached_input_ids.get(ind_id, [None, []])
for cache_id in cache_ids:
self.logger.info("callback from {} changed inputs from {}".format(ind_id, self.inputs_changed[cache_id]))
self.inputs_changed[cache_id] = True
def reset(self):
"""

View file

@ -4,9 +4,9 @@
.. moduleauthor:: Max Zwiessele <ibinbei@gmail.com>
'''
__updated__ = '2014-05-20'
__updated__ = '2014-05-21'
import numpy as np
import numpy as np, logging
def common_subarrays(X, axis=0):
"""
@ -14,11 +14,11 @@ def common_subarrays(X, axis=0):
Common subarrays are returned as a dictionary of <subarray, [index]> pairs, where
the subarray is a tuple representing the subarray and the index is the index
for the subarray in X, where index is the index to the remaining axis.
:param :class:`np.ndarray` X: 2d array to check for common subarrays in
:param int axis: axis to apply subarray detection over.
When the index is 0, compare rows -- columns, otherwise.
Examples:
=========
@ -48,7 +48,17 @@ def common_subarrays(X, axis=0):
assert X.ndim == 2 and axis in (0,1), "Only implemented for 2D arrays"
subarrays = defaultdict(list)
cnt = count()
np.apply_along_axis(lambda x: iadd(subarrays[tuple(x)], [cnt.next()]), 1-axis, X)
logger = logging.getLogger("common_subarrays")
def accumulate(x, s, c):
logger.debug("creating tuple")
t = tuple(x)
logger.debug("tuple done")
col = c.next()
iadd(s[t], [col])
logger.debug("added col {}".format(col))
return None
if axis == 0: [accumulate(x, subarrays, cnt) for x in X]
else: [accumulate(x, subarrays, cnt) for x in X.T]
return subarrays
if __name__ == '__main__':