mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-06-02 14:45:15 +02:00
minor bugfixes in plotting: quantiles are now computed using predict_kw
correctly
This commit is contained in:
parent
097b048100
commit
ea3bfbb597
4 changed files with 9 additions and 5 deletions
|
|
@ -249,7 +249,7 @@ class GP(Model):
|
|||
mean, var = self.likelihood.predictive_values(mu, var, full_cov, Y_metadata=Y_metadata)
|
||||
return mean, var
|
||||
|
||||
def predict_quantiles(self, X, quantiles=(2.5, 97.5), Y_metadata=None):
|
||||
def predict_quantiles(self, X, quantiles=(2.5, 97.5), Y_metadata=None, kern=None):
|
||||
"""
|
||||
Get the predictive quantiles around the prediction at X
|
||||
|
||||
|
|
@ -257,10 +257,12 @@ class GP(Model):
|
|||
:type X: np.ndarray (Xnew x self.input_dim)
|
||||
:param quantiles: tuple of quantiles, default is (2.5, 97.5) which is the 95% interval
|
||||
:type quantiles: tuple
|
||||
:param kern: optional kernel to use for prediction
|
||||
:type predict_kw: dict
|
||||
:returns: list of quantiles for each X and predictive quantiles for interval combination
|
||||
:rtype: [np.ndarray (Xnew x self.output_dim), np.ndarray (Xnew x self.output_dim)]
|
||||
"""
|
||||
m, v = self._raw_predict(X, full_cov=False)
|
||||
m, v = self._raw_predict(X, full_cov=False, kern=kern)
|
||||
if self.normalizer is not None:
|
||||
m, v = self.normalizer.inverse_mean(m), self.normalizer.inverse_variance(v)
|
||||
return self.likelihood.predictive_quantiles(m, v, quantiles, Y_metadata=Y_metadata)
|
||||
|
|
|
|||
|
|
@ -60,7 +60,8 @@ class GPVariationalGaussianApproximation(Model):
|
|||
var = np.diag(Sigma).reshape(-1,1)
|
||||
|
||||
F, dF_dm, dF_dv, dF_dthetaL = self.likelihood.variational_expectations(self.Y, m, var, Y_metadata=self.Y_metadata)
|
||||
self.likelihood.gradient = dF_dthetaL.sum(1).sum(1)
|
||||
if dF_dthetaL is not None:
|
||||
self.likelihood.gradient = dF_dthetaL.sum(1).sum(1)
|
||||
dF_da = np.dot(K, dF_dm)
|
||||
SigmaB = Sigma*self.beta
|
||||
dF_db = -np.diag(Sigma.dot(np.diag(dF_dv.flatten())).dot(SigmaB))*2
|
||||
|
|
|
|||
|
|
@ -110,7 +110,8 @@ def plot_fit(model, plot_limits=None, which_data_rows='all',
|
|||
else:
|
||||
Y_metadata['output_index'] = extra_data
|
||||
m, v = model.predict(Xgrid, full_cov=False, Y_metadata=Y_metadata, **predict_kw)
|
||||
lower, upper = model.predict_quantiles(Xgrid, Y_metadata=Y_metadata)
|
||||
fmu, fv = model._raw_predict(Xgrid, full_cov=False, **predict_kw)
|
||||
lower, upper = model.likelihood.predictive_quantiles(fmu, fv, (2.5, 97.5), Y_metadata=Y_metadata)
|
||||
|
||||
|
||||
for d in which_data_ycols:
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ The test cases for various inference algorithms
|
|||
import unittest, itertools
|
||||
import numpy as np
|
||||
import GPy
|
||||
|
||||
#np.seterr(invalid='raise')
|
||||
|
||||
class InferenceXTestCase(unittest.TestCase):
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue