[mrd] plotting, init, inference etc.

This commit is contained in:
Max Zwiessele 2014-05-16 11:21:08 +01:00
parent efc1f4413c
commit 94c84a23a3
8 changed files with 236 additions and 72 deletions

View file

@ -180,40 +180,80 @@ class GP(Model):
return Ysim
def plot_f(self, *args, **kwargs):
def plot_f(self, plot_limits=None, which_data_rows='all',
which_data_ycols='all', fixed_inputs=[],
levels=20, samples=0, fignum=None, ax=None, resolution=None,
plot_raw=True,
linecol=None,fillcol=None, Y_metadata=None, data_symbol='kx'):
"""
Plot the GP's view of the world, where the data is normalized and
before applying a likelihood.
This is a convenience function: arguments are passed to
GPy.plotting.matplot_dep.models_plots.plot_f_fit
Plot the GP's view of the world, where the data is normalized and before applying a likelihood.
This is a call to plot with plot_raw=True.
Data will not be plotted in this, as the GP's view of the world
may live in another space, or units then the data.
"""
assert "matplotlib" in sys.modules, "matplotlib package has not been imported."
from ..plotting.matplot_dep import models_plots
return models_plots.plot_fit_f(self,*args,**kwargs)
kw = {}
if linecol is not None:
kw['linecol'] = linecol
if fillcol is not None:
kw['fillcol'] = fillcol
return models_plots.plot_fit(self, plot_limits, which_data_rows,
which_data_ycols, fixed_inputs,
levels, samples, fignum, ax, resolution,
plot_raw=plot_raw, Y_metadata=Y_metadata,
data_symbol=data_symbol, **kw)
def plot(self, *args, **kwargs):
def plot(self, plot_limits=None, which_data_rows='all',
which_data_ycols='all', fixed_inputs=[],
levels=20, samples=0, fignum=None, ax=None, resolution=None,
plot_raw=False,
linecol=None,fillcol=None, Y_metadata=None, data_symbol='kx'):
"""
Plot the posterior of the GP.
- In one dimension, the function is plotted with a shaded region
identifying two standard deviations.
- In two dimsensions, a contour-plot shows the mean predicted
function
- In higher dimensions, use fixed_inputs to plot the GP with some of
the inputs fixed.
- In one dimension, the function is plotted with a shaded region identifying two standard deviations.
- In two dimsensions, a contour-plot shows the mean predicted function
- In higher dimensions, use fixed_inputs to plot the GP with some of the inputs fixed.
Can plot only part of the data and part of the posterior functions
using which_data_rows which_data_ycols and which_parts
This is a convenience function: arguments are passed to
GPy.plotting.matplot_dep.models_plots.plot_fit
using which_data_rowsm which_data_ycols.
:param plot_limits: The limits of the plot. If 1D [xmin,xmax], if 2D [[xmin,ymin],[xmax,ymax]]. Defaluts to data limits
:type plot_limits: np.array
:param which_data_rows: which of the training data to plot (default all)
:type which_data_rows: 'all' or a slice object to slice model.X, model.Y
:param which_data_ycols: when the data has several columns (independant outputs), only plot these
:type which_data_rows: 'all' or a list of integers
:param fixed_inputs: a list of tuple [(i,v), (i,v)...], specifying that input index i should be set to value v.
:type fixed_inputs: a list of tuples
:param resolution: the number of intervals to sample the GP on. Defaults to 200 in 1D and 50 (a 50x50 grid) in 2D
:type resolution: int
:param levels: number of levels to plot in a contour plot.
:type levels: int
:param samples: the number of a posteriori samples to plot
:type samples: int
:param fignum: figure to plot on.
:type fignum: figure number
:param ax: axes to plot on.
:type ax: axes handle
:type output: integer (first output is 0)
:param linecol: color of line to plot [Tango.colorsHex['darkBlue']]
:type linecol:
:param fillcol: color of fill [Tango.colorsHex['lightBlue']]
:param levels: for 2D plotting, the number of contour levels to use is ax is None, create a new figure
"""
assert "matplotlib" in sys.modules, "matplotlib package has not been imported."
from ..plotting.matplot_dep import models_plots
return models_plots.plot_fit(self,*args,**kwargs)
kw = {}
if linecol is not None:
kw['linecol'] = linecol
if fillcol is not None:
kw['fillcol'] = fillcol
return models_plots.plot_fit(self, plot_limits, which_data_rows,
which_data_ycols, fixed_inputs,
levels, samples, fignum, ax, resolution,
plot_raw=plot_raw, Y_metadata=Y_metadata,
data_symbol=data_symbol, **kw)
def input_sensitivity(self):
"""

View file

@ -272,8 +272,11 @@ class Parameterized(Parameterizable):
def __setattr__(self, name, val):
# override the default behaviour, if setting a param, so broadcasting can by used
if hasattr(self, "parameters"):
pnames = self.parameter_names(False, adjust_for_printing=True, recursive=False)
if name in pnames: self.parameters[pnames.index(name)][:] = val; return
try:
pnames = self.parameter_names(False, adjust_for_printing=True, recursive=False)
if name in pnames: self.parameters[pnames.index(name)][:] = val; return
except AttributeError:
pass
object.__setattr__(self, name, val);
#===========================================================================
@ -281,11 +284,15 @@ class Parameterized(Parameterizable):
#===========================================================================
def __setstate__(self, state):
super(Parameterized, self).__setstate__(state)
self._connect_parameters()
self._connect_fixes()
self._notify_parent_change()
try:
self._connect_parameters()
self._connect_fixes()
self._notify_parent_change()
self.parameters_changed()
except Exception as e:
print "WARNING: caught exception {!s}, trying to continue".format(e)
self.parameters_changed()
def copy(self):
c = super(Parameterized, self).copy()
c._connect_parameters()

View file

@ -66,7 +66,11 @@ class SparseGP(GP):
#gradients wrt Z
self.Z.gradient = self.kern.gradients_X(dL_dKmm, self.Z)
self.Z.gradient += self.kern.gradients_Z_expectations(
self.grad_dict['dL_dpsi0'], self.grad_dict['dL_dpsi1'], self.grad_dict['dL_dpsi2'], Z=self.Z, variational_posterior=self.X)
self.grad_dict['dL_dpsi0'],
self.grad_dict['dL_dpsi1'],
self.grad_dict['dL_dpsi2'],
Z=self.Z,
variational_posterior=self.X)
else:
#gradients wrt kernel
self.kern.update_gradients_diag(self.grad_dict['dL_dKdiag'], self.X)