mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-21 14:05:14 +02:00
[inference] less constant jitter, and jitter adjustements
Conflicts: GPy/util/linalg.py
This commit is contained in:
parent
51e48f7508
commit
8c80fb9c52
4 changed files with 30 additions and 19 deletions
|
|
@ -60,10 +60,14 @@ class SparseGP(GP):
|
|||
self.likelihood.update_gradients(self.grad_dict['dL_dthetaL'])
|
||||
if isinstance(self.X, VariationalPosterior):
|
||||
#gradients wrt kernel
|
||||
dL_dKmm = self.grad_dict.pop('dL_dKmm')
|
||||
dL_dKmm = self.grad_dict['dL_dKmm']
|
||||
self.kern.update_gradients_full(dL_dKmm, self.Z, None)
|
||||
target = self.kern.gradient.copy()
|
||||
self.kern.update_gradients_expectations(variational_posterior=self.X, Z=self.Z, dL_dpsi0=self.grad_dict['dL_dpsi0'], dL_dpsi1=self.grad_dict['dL_dpsi1'], dL_dpsi2=self.grad_dict['dL_dpsi2'])
|
||||
self.kern.update_gradients_expectations(variational_posterior=self.X,
|
||||
Z=self.Z,
|
||||
dL_dpsi0=self.grad_dict['dL_dpsi0'],
|
||||
dL_dpsi1=self.grad_dict['dL_dpsi1'],
|
||||
dL_dpsi2=self.grad_dict['dL_dpsi2'])
|
||||
self.kern.gradient += target
|
||||
|
||||
#gradients wrt Z
|
||||
|
|
|
|||
|
|
@ -194,7 +194,7 @@ class VarDTC(LatentFunctionInference):
|
|||
return post, log_marginal, grad_dict
|
||||
|
||||
class VarDTCMissingData(LatentFunctionInference):
|
||||
const_jitter = 1e-6
|
||||
const_jitter = 1e-10
|
||||
def __init__(self, limit=1, inan=None):
|
||||
from ...util.caching import Cacher
|
||||
self._Y = Cacher(self._subarray_computations, limit)
|
||||
|
|
@ -289,13 +289,6 @@ class VarDTCMissingData(LatentFunctionInference):
|
|||
Lm = jitchol(Kmm)
|
||||
if uncertain_inputs: LmInv = dtrtri(Lm)
|
||||
|
||||
#VVT_factor_all = np.empty(Y.shape)
|
||||
#full_VVT_factor = VVT_factor_all.shape[1] == Y.shape[1]
|
||||
#if not full_VVT_factor:
|
||||
# psi1V = np.dot(Y.T*beta_all, psi1_all).T
|
||||
|
||||
#logger.info('computing dimension-wise likelihood and derivatives')
|
||||
#size = len(Ys)
|
||||
size = Y.shape[1]
|
||||
next_ten = 0
|
||||
for i, [y, v, trYYT] in enumerate(itertools.izip(Ys, self._inan.T, traces)):
|
||||
|
|
@ -348,7 +341,6 @@ class VarDTCMissingData(LatentFunctionInference):
|
|||
VVT_factor, Cpsi1Vf, DBi_plus_BiPBi,
|
||||
psi1, het_noise, uncertain_inputs)
|
||||
|
||||
#import ipdb;ipdb.set_trace()
|
||||
dL_dpsi0_all[v] += dL_dpsi0
|
||||
dL_dpsi1_all[v, :] += dL_dpsi1
|
||||
if uncertain_inputs:
|
||||
|
|
|
|||
|
|
@ -30,7 +30,7 @@ def most_significant_input_dimensions(model, which_indices):
|
|||
def plot_latent(model, labels=None, which_indices=None,
|
||||
resolution=50, ax=None, marker='o', s=40,
|
||||
fignum=None, plot_inducing=False, legend=True,
|
||||
plot_limits=None,
|
||||
plot_limits=None,
|
||||
aspect='auto', updates=False, predict_kwargs={}, imshow_kwargs={}):
|
||||
"""
|
||||
:param labels: a np.array of size model.num_data containing labels for the points (can be number, strings, etc)
|
||||
|
|
@ -84,6 +84,7 @@ def plot_latent(model, labels=None, which_indices=None,
|
|||
cmap=pb.cm.binary, **imshow_kwargs)
|
||||
|
||||
# make sure labels are in order of input:
|
||||
labels = np.asarray(labels)
|
||||
ulabels = []
|
||||
for lab in labels:
|
||||
if not lab in ulabels:
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ import scipy
|
|||
import warnings
|
||||
import os
|
||||
from config import *
|
||||
import logging
|
||||
|
||||
_scipyversion = np.float64((scipy.__version__).split('.')[:2])
|
||||
_fix_dpotri_scipy_bug = True
|
||||
|
|
@ -93,14 +94,20 @@ def jitchol(A, maxtries=5):
|
|||
raise linalg.LinAlgError, "not pd: non-positive diagonal elements"
|
||||
jitter = diagA.mean() * 1e-6
|
||||
while maxtries > 0 and np.isfinite(jitter):
|
||||
print 'Warning: adding jitter of {:.10e}'.format(jitter)
|
||||
try:
|
||||
return linalg.cholesky(A + np.eye(A.shape[0]).T * jitter, lower=True)
|
||||
L = linalg.cholesky(A + np.eye(A.shape[0]) * jitter, lower=True)
|
||||
except:
|
||||
jitter *= 10
|
||||
finally:
|
||||
maxtries -= 1
|
||||
raise linalg.LinAlgError, "not positive definite, even with jitter."
|
||||
import traceback
|
||||
try: raise
|
||||
except:
|
||||
logging.warning('\n'.join(['Added jitter of {:.10e}'.format(jitter),
|
||||
' in '+traceback.format_list(traceback.extract_stack(limit=2)[-2:-1])[0][2:]]))
|
||||
import ipdb;ipdb.set_trace()
|
||||
return L
|
||||
|
||||
|
||||
|
||||
|
|
@ -110,7 +117,7 @@ def jitchol(A, maxtries=5):
|
|||
# """
|
||||
# Wrapper for lapack dtrtri function
|
||||
# Inverse of L
|
||||
#
|
||||
#
|
||||
# :param L: Triangular Matrix L
|
||||
# :param lower: is matrix lower (true) or upper (false)
|
||||
# :returns: Li, info
|
||||
|
|
@ -122,10 +129,17 @@ def dtrtrs(A, B, lower=1, trans=0, unitdiag=0):
|
|||
"""
|
||||
Wrapper for lapack dtrtrs function
|
||||
|
||||
DTRTRS solves a triangular system of the form
|
||||
|
||||
A * X = B or A**T * X = B,
|
||||
|
||||
where A is a triangular matrix of order N, and B is an N-by-NRHS
|
||||
matrix. A check is made to verify that A is nonsingular.
|
||||
|
||||
:param A: Matrix A(triangular)
|
||||
:param B: Matrix B
|
||||
:param lower: is matrix lower (true) or upper (false)
|
||||
:returns:
|
||||
:returns: Solution to A * X = B or A**T * X = B
|
||||
|
||||
"""
|
||||
A = np.asfortranarray(A)
|
||||
|
|
@ -146,11 +160,11 @@ def dpotrs(A, B, lower=1):
|
|||
def dpotri(A, lower=1):
|
||||
"""
|
||||
Wrapper for lapack dpotri function
|
||||
|
||||
|
||||
DPOTRI - compute the inverse of a real symmetric positive
|
||||
definite matrix A using the Cholesky factorization A =
|
||||
U**T*U or A = L*L**T computed by DPOTRF
|
||||
|
||||
|
||||
:param A: Matrix A
|
||||
:param lower: is matrix lower (true) or upper (false)
|
||||
:returns: A inverse
|
||||
|
|
@ -159,7 +173,7 @@ def dpotri(A, lower=1):
|
|||
if _fix_dpotri_scipy_bug:
|
||||
assert lower==1, "scipy linalg behaviour is very weird. please use lower, fortran ordered arrays"
|
||||
lower = 0
|
||||
|
||||
|
||||
A = force_F_ordered(A)
|
||||
R, info = lapack.dpotri(A, lower=lower) #needs to be zero here, seems to be a scipy bug
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue