Fix the bug in the prediction of full covariance matrix (#702)

* fix the bug in the prediction of full covariance matrix

* Fix the test cases.

* update the version requirement for matplotlib
This commit is contained in:
Zhenwen Dai 2018-10-30 07:43:44 +00:00 committed by GitHub
parent 4f3047e035
commit a86723c8ab
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 6 additions and 5 deletions

View file

@ -245,6 +245,7 @@ class Posterior(object):
for i in range(var.shape[1]):
var[:, i] = (Kxx - (np.sum(np.dot(woodbury_inv[:, :, i].T, Kx) * Kx, 0)))
var = var
var = np.clip(var, 1e-15, np.inf)
else:
psi0_star = kern.psi0(pred_var, Xnew)
psi1_star = kern.psi1(pred_var, Xnew)
@ -265,7 +266,7 @@ class Posterior(object):
var += -psi2_star.reshape(N, -1).dot(woodbury_inv.flat)[:, None]
else:
var += -psi2_star.reshape(N, -1).dot(woodbury_inv.reshape(-1, D))
var = np.clip(var, 1e-15, np.inf)
var = np.clip(var, 1e-15, np.inf)
return mu, var

View file

@ -122,7 +122,7 @@ class MatplotlibPlots(AbstractPlottingLibrary):
def barplot(self, ax, x, height, width=0.8, bottom=0, color=Tango.colorsHex['mediumBlue'], label=None, **kwargs):
if 'align' not in kwargs:
kwargs['align'] = 'center'
return ax.bar(left=x, height=height, width=width,
return ax.bar(x=x, height=height, width=width,
bottom=bottom, label=label, color=color,
**kwargs)

View file

@ -238,7 +238,7 @@ class MiscTests(unittest.TestCase):
# Not easy to check if woodbury_inv is correct in itself as it requires a large derivation and expression
Kinv = m.posterior.woodbury_inv
K_hat = k.K(self.X_new) - k.K(self.X_new, Z).dot(Kinv).dot(k.K(Z, self.X_new))
K_hat = np.clip(K_hat, 1e-15, np.inf)
# K_hat = np.clip(K_hat, 1e-15, np.inf)
mu, covar = m.predict_noiseless(self.X_new, full_cov=True)
self.assertEquals(mu.shape, (self.N_new, self.D))

View file

@ -164,7 +164,7 @@ setup(name = 'GPy',
'optional':['mpi4py',
'ipython>=4.0.0',
],
'plotting':['matplotlib >= 1.3',
'plotting':['matplotlib >= 3.0',
'plotly >= 1.8.6'],
'notebook':['jupyter_client >= 4.0.6',
'ipywidgets >= 4.0.3',
@ -185,7 +185,7 @@ setup(name = 'GPy',
'Intended Audience :: Developers',
'Topic :: Software Development',
'Topic :: Software Development :: Libraries :: Python Modules',
]
)