mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-11 13:02:38 +02:00
Plots for multiple outputs
This commit is contained in:
parent
70c44b2cdd
commit
7e1e8de5e4
2 changed files with 54 additions and 3 deletions
|
|
@ -184,3 +184,29 @@ class GP(GPBase):
|
||||||
# now push through likelihood
|
# now push through likelihood
|
||||||
mean, var, _025pm, _975pm = self.likelihood.predictive_values(mu, var, full_cov, noise_model = output)
|
mean, var, _025pm, _975pm = self.likelihood.predictive_values(mu, var, full_cov, noise_model = output)
|
||||||
return mean, var, _025pm, _975pm
|
return mean, var, _025pm, _975pm
|
||||||
|
|
||||||
|
def _raw_predict_single_output(self, _Xnew, output=0, which_parts='all', full_cov=False,stop=False):
|
||||||
|
"""
|
||||||
|
Internal helper function for making predictions, does not account
|
||||||
|
for normalization or likelihood
|
||||||
|
"""
|
||||||
|
assert isinstance(self.likelihood,EP_Mixed_Noise)
|
||||||
|
index = np.ones_like(_Xnew)*output
|
||||||
|
_Xnew = np.hstack((_Xnew,index))
|
||||||
|
|
||||||
|
Kx = self.kern.K(_Xnew,self.X,which_parts=which_parts).T
|
||||||
|
#KiKx = np.dot(self.Ki, Kx)
|
||||||
|
KiKx, _ = dpotrs(self.L, np.asfortranarray(Kx), lower=1)
|
||||||
|
mu = np.dot(KiKx.T, self.likelihood.Y)
|
||||||
|
if full_cov:
|
||||||
|
Kxx = self.kern.K(_Xnew, which_parts=which_parts)
|
||||||
|
var = Kxx - np.dot(KiKx.T, Kx)
|
||||||
|
else:
|
||||||
|
Kxx = self.kern.Kdiag(_Xnew, which_parts=which_parts)
|
||||||
|
var = Kxx - np.sum(np.multiply(KiKx, Kx), 0)
|
||||||
|
var = var[:, None]
|
||||||
|
if stop:
|
||||||
|
debug_this # @UndefinedVariable
|
||||||
|
return mu, var
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -34,7 +34,7 @@ class GPBase(Model):
|
||||||
# All leaf nodes should call self._set_params(self._get_params()) at
|
# All leaf nodes should call self._set_params(self._get_params()) at
|
||||||
# the end
|
# the end
|
||||||
|
|
||||||
def plot_f(self, samples=0, plot_limits=None, which_data='all', which_parts='all', resolution=None, full_cov=False, fignum=None, ax=None):
|
def plot_f(self, samples=0, plot_limits=None, which_data='all', which_parts='all', resolution=None, full_cov=False, fignum=None, ax=None,output=None):
|
||||||
"""
|
"""
|
||||||
Plot the GP's view of the world, where the data is normalized and the
|
Plot the GP's view of the world, where the data is normalized and the
|
||||||
likelihood is Gaussian.
|
likelihood is Gaussian.
|
||||||
|
|
@ -62,7 +62,7 @@ class GPBase(Model):
|
||||||
fig = pb.figure(num=fignum)
|
fig = pb.figure(num=fignum)
|
||||||
ax = fig.add_subplot(111)
|
ax = fig.add_subplot(111)
|
||||||
|
|
||||||
if self.X.shape[1] == 1:
|
if self.X.shape[1] == 1 and not isinstance(self.likelihood,EP_Mixed_Noise):
|
||||||
Xnew, xmin, xmax = x_frame1D(self.X, plot_limits=plot_limits)
|
Xnew, xmin, xmax = x_frame1D(self.X, plot_limits=plot_limits)
|
||||||
if samples == 0:
|
if samples == 0:
|
||||||
m, v = self._raw_predict(Xnew, which_parts=which_parts)
|
m, v = self._raw_predict(Xnew, which_parts=which_parts)
|
||||||
|
|
@ -80,7 +80,7 @@ class GPBase(Model):
|
||||||
ymin, ymax = ymin - 0.1 * (ymax - ymin), ymax + 0.1 * (ymax - ymin)
|
ymin, ymax = ymin - 0.1 * (ymax - ymin), ymax + 0.1 * (ymax - ymin)
|
||||||
ax.set_ylim(ymin, ymax)
|
ax.set_ylim(ymin, ymax)
|
||||||
|
|
||||||
elif self.X.shape[1] == 2:
|
elif self.X.shape[1] == 2 and not isinstance(self.likelihood,EP_Mixed_Noise):
|
||||||
resolution = resolution or 50
|
resolution = resolution or 50
|
||||||
Xnew, xmin, xmax, xx, yy = x_frame2D(self.X, plot_limits, resolution)
|
Xnew, xmin, xmax, xx, yy = x_frame2D(self.X, plot_limits, resolution)
|
||||||
m, v = self._raw_predict(Xnew, which_parts=which_parts)
|
m, v = self._raw_predict(Xnew, which_parts=which_parts)
|
||||||
|
|
@ -89,6 +89,31 @@ class GPBase(Model):
|
||||||
ax.scatter(self.X[:, 0], self.X[:, 1], 40, self.likelihood.Y, linewidth=0, cmap=pb.cm.jet, vmin=m.min(), vmax=m.max()) # @UndefinedVariable
|
ax.scatter(self.X[:, 0], self.X[:, 1], 40, self.likelihood.Y, linewidth=0, cmap=pb.cm.jet, vmin=m.min(), vmax=m.max()) # @UndefinedVariable
|
||||||
ax.set_xlim(xmin[0], xmax[0])
|
ax.set_xlim(xmin[0], xmax[0])
|
||||||
ax.set_ylim(xmin[1], xmax[1])
|
ax.set_ylim(xmin[1], xmax[1])
|
||||||
|
|
||||||
|
|
||||||
|
elif self.X.shape[1] == 2 and isinstance(self.likelihood,EP_Mixed_Noise):
|
||||||
|
Xu = self.X[self.X[:,-1]==output ,0:1]
|
||||||
|
Xnew, xmin, xmax = x_frame1D(Xu, plot_limits=plot_limits)
|
||||||
|
|
||||||
|
if samples == 0:
|
||||||
|
m, v = self._raw_predict_single_output(Xnew, output=output, which_parts=which_parts)
|
||||||
|
gpplot(Xnew, m, m - 2 * np.sqrt(v), m + 2 * np.sqrt(v), axes=ax)
|
||||||
|
#ax.plot(self.X[which_data], self.likelihood.Y[which_data], 'kx', mew=1.5)
|
||||||
|
ax.plot(Xu[which_data], self.likelihood.Y[self.likelihood.index==output][:,None], 'kx', mew=1.5)
|
||||||
|
else:
|
||||||
|
m, v = self._raw_predict_single_output(Xnew, output=output, which_parts=which_parts, full_cov=True)
|
||||||
|
Ysim = np.random.multivariate_normal(m.flatten(), v, samples)
|
||||||
|
gpplot(Xnew, m, m - 2 * np.sqrt(np.diag(v)[:, None]), m + 2 * np.sqrt(np.diag(v))[:, None, ], axes=ax)
|
||||||
|
for i in range(samples):
|
||||||
|
ax.plot(Xnew, Ysim[i, :], Tango.colorsHex['darkBlue'], linewidth=0.25)
|
||||||
|
#ax.plot(self.X[which_data], self.likelihood.Y[which_data], 'kx', mew=1.5)
|
||||||
|
#ax.plot(Xu[which_data], self.likelihood.Y[self.likelihood.index==output][:,None], 'kx', mew=1.5)
|
||||||
|
ax.set_xlim(xmin, xmax)
|
||||||
|
ymin, ymax = min(np.append(self.likelihood.Y, m - 2 * np.sqrt(np.diag(v)[:, None]))), max(np.append(self.likelihood.Y, m + 2 * np.sqrt(np.diag(v)[:, None])))
|
||||||
|
ymin, ymax = ymin - 0.1 * (ymax - ymin), ymax + 0.1 * (ymax - ymin)
|
||||||
|
ax.set_ylim(ymin, ymax)
|
||||||
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError, "Cannot define a frame with more than two input dimensions"
|
raise NotImplementedError, "Cannot define a frame with more than two input dimensions"
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue