mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-08 11:32:39 +02:00
fiddling with plotting
This commit is contained in:
parent
432d9668a6
commit
80629e00b6
5 changed files with 20 additions and 13 deletions
|
|
@ -58,6 +58,7 @@ class GP(Model):
|
|||
self.parameters_changed()
|
||||
|
||||
def parameters_changed(self):
|
||||
print self.kern
|
||||
self.posterior, self._log_marginal_likelihood, grad_dict = self.inference_method.inference(self.kern, self.X, self.likelihood, self.Y)
|
||||
self._dL_dK = grad_dict['dL_dK']
|
||||
|
||||
|
|
@ -75,8 +76,8 @@ class GP(Model):
|
|||
|
||||
"""
|
||||
Kx = self.kern.K(_Xnew, self.X, which_parts=which_parts).T
|
||||
LiKx, _ = dtrtrs(self.posterior._woodbury_chol, np.asfortranarray(Kx), lower=1)
|
||||
mu = np.dot(Kx.T, self.posterior._woodbury_vector)
|
||||
LiKx, _ = dtrtrs(self.posterior.woodbury_chol, np.asfortranarray(Kx), lower=1)
|
||||
mu = np.dot(Kx.T, self.posterior.woodbury_vector)
|
||||
if full_cov:
|
||||
Kxx = self.kern.K(_Xnew, which_parts=which_parts)
|
||||
var = Kxx - tdot(LiKx.T)
|
||||
|
|
|
|||
|
|
@ -214,6 +214,7 @@ class Parameterized(Constrainable, Pickleable, Observable):
|
|||
return
|
||||
i = 0
|
||||
sizes = [0]
|
||||
self._param_slices_ = []
|
||||
for p in self._parameters_:
|
||||
p._direct_parent_ = self
|
||||
p._highest_parent_ = self
|
||||
|
|
|
|||
|
|
@ -74,7 +74,7 @@ class Posterior(object):
|
|||
@property
|
||||
def mean(self):
|
||||
if self._mean is None:
|
||||
self._mean = np.dot(self._K, self._woodbury_vector)
|
||||
self._mean = np.dot(self._K, self.woodbury_vector)
|
||||
return self._mean
|
||||
|
||||
@property
|
||||
|
|
@ -93,10 +93,14 @@ class Posterior(object):
|
|||
@property
|
||||
def woodbury_chol(self):
|
||||
if self._woodbury_chol is None:
|
||||
B = self._K - self._covariance
|
||||
tmp, _ = dpotrs(self._K_chol, B)
|
||||
Wi, _ = dpotrs(self._K_chol, tmp.T)
|
||||
_, _, self._woodbury_chol, _ = pdinv(Wi)
|
||||
#try computing woodbury chol from cov
|
||||
if self._woodbury_inv is not None:
|
||||
_, _, self._woodbury_chol, _ = pdinv(self._woodbury_inv)
|
||||
elif self._covariance is not None:
|
||||
B = self._K - self._covariance
|
||||
tmp, _ = dpotrs(self.K_chol, B)
|
||||
self._woodbury_inv, _ = dpotrs(self.K_chol, tmp.T)
|
||||
_, _, self._woodbury_chol, _ = pdinv(self._woodbury_inv)
|
||||
return self._woodbury_chol
|
||||
|
||||
@property
|
||||
|
|
@ -109,7 +113,7 @@ class Posterior(object):
|
|||
@property
|
||||
def woodbury_vector(self):
|
||||
if self._woodbury_vector is None:
|
||||
self._woodbury_vector, _ = dpotrs(self._K_chol, self.mean)
|
||||
self._woodbury_vector, _ = dpotrs(self.K_chol, self.mean)
|
||||
return self._woodbury_vector
|
||||
|
||||
@property
|
||||
|
|
|
|||
|
|
@ -78,7 +78,7 @@ class Bernoulli(Likelihood):
|
|||
|
||||
return Z_hat, mu_hat, sigma2_hat
|
||||
|
||||
def _predictive_mean_analytical(self, mu, variance):
|
||||
def predictive_mean(self, mu, variance):
|
||||
|
||||
if isinstance(self.gp_link, link_functions.Probit):
|
||||
return stats.norm.cdf(mu/np.sqrt(1+variance))
|
||||
|
|
@ -89,12 +89,13 @@ class Bernoulli(Likelihood):
|
|||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def _predictive_variance_analytical(self, mu, variance, pred_mean):
|
||||
def predictive_variance(self, mu, variance, pred_mean):
|
||||
|
||||
if isinstance(self.gp_link, link_functions.Heaviside):
|
||||
return 0.
|
||||
else:
|
||||
raise NotImplementedError
|
||||
return np.nan
|
||||
#raise NotImplementedError
|
||||
|
||||
def pdf_link(self, link_f, y, extra_data=None):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -156,11 +156,11 @@ def plot_fit(model, plot_limits=None, which_data_rows='all',
|
|||
raise NotImplementedError, "Cannot define a frame with more than two input dimensions"
|
||||
|
||||
|
||||
def plot_f_fit(model, *args, **kwargs):
|
||||
def plot_fit_f(model, *args, **kwargs):
|
||||
"""
|
||||
Plot the GP's view of the world, where the data is normalized and before applying a likelihood.
|
||||
|
||||
All args and kwargs are passed on to models_plots.plot.
|
||||
"""
|
||||
kwargs['plot_raw'] = True
|
||||
plot(model,*args, **kwargs)
|
||||
plot_fit(model,*args, **kwargs)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue