[magnification] plot_magnification expanded

This commit is contained in:
Max Zwiessele 2015-09-03 09:33:07 +01:00
parent ca60ad3195
commit 839e3dc6f0
8 changed files with 64 additions and 60 deletions

View file

@ -371,7 +371,7 @@ class GP(Model):
var_jac = compute_cov_inner(self.posterior.woodbury_inv)
return mean_jac, var_jac
def predict_wishard_embedding(self, Xnew, kern=None):
def predict_wishard_embedding(self, Xnew, kern=None, mean=True, covariance=True):
"""
Predict the wishard embedding G of the GP. This is the density of the
input of the GP defined by the probabilistic function mapping f.
@ -391,13 +391,16 @@ class GP(Model):
mumuT = np.einsum('iqd,ipd->iqp', mu_jac, mu_jac)
if var_jac.ndim == 3:
Sigma = np.einsum('iqd,ipd->iqp', var_jac, var_jac)
G = mumuT + Sigma
else:
Sigma = np.einsum('iq,ip->iqp', var_jac, var_jac)
G = mumuT + self.output_dim*Sigma
Sigma = self.output_dim*np.einsum('iq,ip->iqp', var_jac, var_jac)
G = 0.
if mean:
G += mumuT
if covariance:
G += Sigma
return G
def predict_magnification(self, Xnew, kern=None):
def predict_magnification(self, Xnew, kern=None, mean=True, covariance=True):
"""
Predict the magnification factor as
@ -405,7 +408,7 @@ class GP(Model):
for each point N in Xnew
"""
G = self.predict_wishard_embedding(Xnew, kern)
G = self.predict_wishard_embedding(Xnew, kern, mean, covariance)
from ..util.linalg import jitchol
return np.array([np.sqrt(np.exp(2*np.sum(np.log(np.diag(jitchol(G[n, :, :])))))) for n in range(Xnew.shape[0])])
#return np.array([np.sqrt(np.linalg.det(G[n, :, :])) for n in range(Xnew.shape[0])])
@ -569,7 +572,7 @@ class GP(Model):
resolution=50, ax=None, marker='o', s=40,
fignum=None, legend=True,
plot_limits=None,
aspect='auto', updates=False, **kwargs):
aspect='auto', updates=False, plot_inducing=True, kern=None, **kwargs):
import sys
assert "matplotlib" in sys.modules, "matplotlib package has not been imported."
@ -577,7 +580,7 @@ class GP(Model):
return dim_reduction_plots.plot_magnification(self, labels, which_indices,
resolution, ax, marker, s,
fignum, False, legend,
fignum, plot_inducing, legend,
plot_limits, aspect, updates, **kwargs)