[inference] less constant jitter, and jitter adjustements

Conflicts:
	GPy/util/linalg.py
This commit is contained in:
mzwiessele 2014-07-22 09:21:01 -07:00
parent 51e48f7508
commit 8c80fb9c52
4 changed files with 30 additions and 19 deletions

View file

@ -60,10 +60,14 @@ class SparseGP(GP):
self.likelihood.update_gradients(self.grad_dict['dL_dthetaL'])
if isinstance(self.X, VariationalPosterior):
#gradients wrt kernel
dL_dKmm = self.grad_dict.pop('dL_dKmm')
dL_dKmm = self.grad_dict['dL_dKmm']
self.kern.update_gradients_full(dL_dKmm, self.Z, None)
target = self.kern.gradient.copy()
self.kern.update_gradients_expectations(variational_posterior=self.X, Z=self.Z, dL_dpsi0=self.grad_dict['dL_dpsi0'], dL_dpsi1=self.grad_dict['dL_dpsi1'], dL_dpsi2=self.grad_dict['dL_dpsi2'])
self.kern.update_gradients_expectations(variational_posterior=self.X,
Z=self.Z,
dL_dpsi0=self.grad_dict['dL_dpsi0'],
dL_dpsi1=self.grad_dict['dL_dpsi1'],
dL_dpsi2=self.grad_dict['dL_dpsi2'])
self.kern.gradient += target
#gradients wrt Z

View file

@ -194,7 +194,7 @@ class VarDTC(LatentFunctionInference):
return post, log_marginal, grad_dict
class VarDTCMissingData(LatentFunctionInference):
const_jitter = 1e-6
const_jitter = 1e-10
def __init__(self, limit=1, inan=None):
from ...util.caching import Cacher
self._Y = Cacher(self._subarray_computations, limit)
@ -289,13 +289,6 @@ class VarDTCMissingData(LatentFunctionInference):
Lm = jitchol(Kmm)
if uncertain_inputs: LmInv = dtrtri(Lm)
#VVT_factor_all = np.empty(Y.shape)
#full_VVT_factor = VVT_factor_all.shape[1] == Y.shape[1]
#if not full_VVT_factor:
# psi1V = np.dot(Y.T*beta_all, psi1_all).T
#logger.info('computing dimension-wise likelihood and derivatives')
#size = len(Ys)
size = Y.shape[1]
next_ten = 0
for i, [y, v, trYYT] in enumerate(itertools.izip(Ys, self._inan.T, traces)):
@ -348,7 +341,6 @@ class VarDTCMissingData(LatentFunctionInference):
VVT_factor, Cpsi1Vf, DBi_plus_BiPBi,
psi1, het_noise, uncertain_inputs)
#import ipdb;ipdb.set_trace()
dL_dpsi0_all[v] += dL_dpsi0
dL_dpsi1_all[v, :] += dL_dpsi1
if uncertain_inputs:

View file

@ -30,7 +30,7 @@ def most_significant_input_dimensions(model, which_indices):
def plot_latent(model, labels=None, which_indices=None,
resolution=50, ax=None, marker='o', s=40,
fignum=None, plot_inducing=False, legend=True,
plot_limits=None,
plot_limits=None,
aspect='auto', updates=False, predict_kwargs={}, imshow_kwargs={}):
"""
:param labels: a np.array of size model.num_data containing labels for the points (can be number, strings, etc)
@ -84,6 +84,7 @@ def plot_latent(model, labels=None, which_indices=None,
cmap=pb.cm.binary, **imshow_kwargs)
# make sure labels are in order of input:
labels = np.asarray(labels)
ulabels = []
for lab in labels:
if not lab in ulabels:

View file

@ -15,6 +15,7 @@ import scipy
import warnings
import os
from config import *
import logging
_scipyversion = np.float64((scipy.__version__).split('.')[:2])
_fix_dpotri_scipy_bug = True
@ -93,14 +94,20 @@ def jitchol(A, maxtries=5):
raise linalg.LinAlgError, "not pd: non-positive diagonal elements"
jitter = diagA.mean() * 1e-6
while maxtries > 0 and np.isfinite(jitter):
print 'Warning: adding jitter of {:.10e}'.format(jitter)
try:
return linalg.cholesky(A + np.eye(A.shape[0]).T * jitter, lower=True)
L = linalg.cholesky(A + np.eye(A.shape[0]) * jitter, lower=True)
except:
jitter *= 10
finally:
maxtries -= 1
raise linalg.LinAlgError, "not positive definite, even with jitter."
import traceback
try: raise
except:
logging.warning('\n'.join(['Added jitter of {:.10e}'.format(jitter),
' in '+traceback.format_list(traceback.extract_stack(limit=2)[-2:-1])[0][2:]]))
import ipdb;ipdb.set_trace()
return L
@ -110,7 +117,7 @@ def jitchol(A, maxtries=5):
# """
# Wrapper for lapack dtrtri function
# Inverse of L
#
#
# :param L: Triangular Matrix L
# :param lower: is matrix lower (true) or upper (false)
# :returns: Li, info
@ -122,10 +129,17 @@ def dtrtrs(A, B, lower=1, trans=0, unitdiag=0):
"""
Wrapper for lapack dtrtrs function
DTRTRS solves a triangular system of the form
A * X = B or A**T * X = B,
where A is a triangular matrix of order N, and B is an N-by-NRHS
matrix. A check is made to verify that A is nonsingular.
:param A: Matrix A(triangular)
:param B: Matrix B
:param lower: is matrix lower (true) or upper (false)
:returns:
:returns: Solution to A * X = B or A**T * X = B
"""
A = np.asfortranarray(A)
@ -146,11 +160,11 @@ def dpotrs(A, B, lower=1):
def dpotri(A, lower=1):
"""
Wrapper for lapack dpotri function
DPOTRI - compute the inverse of a real symmetric positive
definite matrix A using the Cholesky factorization A =
U**T*U or A = L*L**T computed by DPOTRF
:param A: Matrix A
:param lower: is matrix lower (true) or upper (false)
:returns: A inverse
@ -159,7 +173,7 @@ def dpotri(A, lower=1):
if _fix_dpotri_scipy_bug:
assert lower==1, "scipy linalg behaviour is very weird. please use lower, fortran ordered arrays"
lower = 0
A = force_F_ordered(A)
R, info = lapack.dpotri(A, lower=lower) #needs to be zero here, seems to be a scipy bug