mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-13 05:52:38 +02:00
[mrd] plotting, init, inference etc.
This commit is contained in:
parent
efc1f4413c
commit
94c84a23a3
8 changed files with 236 additions and 72 deletions
|
|
@ -180,40 +180,80 @@ class GP(Model):
|
|||
|
||||
return Ysim
|
||||
|
||||
def plot_f(self, *args, **kwargs):
|
||||
def plot_f(self, plot_limits=None, which_data_rows='all',
|
||||
which_data_ycols='all', fixed_inputs=[],
|
||||
levels=20, samples=0, fignum=None, ax=None, resolution=None,
|
||||
plot_raw=True,
|
||||
linecol=None,fillcol=None, Y_metadata=None, data_symbol='kx'):
|
||||
"""
|
||||
|
||||
Plot the GP's view of the world, where the data is normalized and
|
||||
before applying a likelihood.
|
||||
|
||||
This is a convenience function: arguments are passed to
|
||||
GPy.plotting.matplot_dep.models_plots.plot_f_fit
|
||||
|
||||
Plot the GP's view of the world, where the data is normalized and before applying a likelihood.
|
||||
This is a call to plot with plot_raw=True.
|
||||
Data will not be plotted in this, as the GP's view of the world
|
||||
may live in another space, or units then the data.
|
||||
"""
|
||||
assert "matplotlib" in sys.modules, "matplotlib package has not been imported."
|
||||
from ..plotting.matplot_dep import models_plots
|
||||
return models_plots.plot_fit_f(self,*args,**kwargs)
|
||||
kw = {}
|
||||
if linecol is not None:
|
||||
kw['linecol'] = linecol
|
||||
if fillcol is not None:
|
||||
kw['fillcol'] = fillcol
|
||||
return models_plots.plot_fit(self, plot_limits, which_data_rows,
|
||||
which_data_ycols, fixed_inputs,
|
||||
levels, samples, fignum, ax, resolution,
|
||||
plot_raw=plot_raw, Y_metadata=Y_metadata,
|
||||
data_symbol=data_symbol, **kw)
|
||||
|
||||
def plot(self, *args, **kwargs):
|
||||
def plot(self, plot_limits=None, which_data_rows='all',
|
||||
which_data_ycols='all', fixed_inputs=[],
|
||||
levels=20, samples=0, fignum=None, ax=None, resolution=None,
|
||||
plot_raw=False,
|
||||
linecol=None,fillcol=None, Y_metadata=None, data_symbol='kx'):
|
||||
"""
|
||||
Plot the posterior of the GP.
|
||||
- In one dimension, the function is plotted with a shaded region
|
||||
identifying two standard deviations.
|
||||
- In two dimsensions, a contour-plot shows the mean predicted
|
||||
function
|
||||
- In higher dimensions, use fixed_inputs to plot the GP with some of
|
||||
the inputs fixed.
|
||||
- In one dimension, the function is plotted with a shaded region identifying two standard deviations.
|
||||
- In two dimsensions, a contour-plot shows the mean predicted function
|
||||
- In higher dimensions, use fixed_inputs to plot the GP with some of the inputs fixed.
|
||||
|
||||
Can plot only part of the data and part of the posterior functions
|
||||
using which_data_rows which_data_ycols and which_parts
|
||||
|
||||
This is a convenience function: arguments are passed to
|
||||
GPy.plotting.matplot_dep.models_plots.plot_fit
|
||||
using which_data_rowsm which_data_ycols.
|
||||
|
||||
:param plot_limits: The limits of the plot. If 1D [xmin,xmax], if 2D [[xmin,ymin],[xmax,ymax]]. Defaluts to data limits
|
||||
:type plot_limits: np.array
|
||||
:param which_data_rows: which of the training data to plot (default all)
|
||||
:type which_data_rows: 'all' or a slice object to slice model.X, model.Y
|
||||
:param which_data_ycols: when the data has several columns (independant outputs), only plot these
|
||||
:type which_data_rows: 'all' or a list of integers
|
||||
:param fixed_inputs: a list of tuple [(i,v), (i,v)...], specifying that input index i should be set to value v.
|
||||
:type fixed_inputs: a list of tuples
|
||||
:param resolution: the number of intervals to sample the GP on. Defaults to 200 in 1D and 50 (a 50x50 grid) in 2D
|
||||
:type resolution: int
|
||||
:param levels: number of levels to plot in a contour plot.
|
||||
:type levels: int
|
||||
:param samples: the number of a posteriori samples to plot
|
||||
:type samples: int
|
||||
:param fignum: figure to plot on.
|
||||
:type fignum: figure number
|
||||
:param ax: axes to plot on.
|
||||
:type ax: axes handle
|
||||
:type output: integer (first output is 0)
|
||||
:param linecol: color of line to plot [Tango.colorsHex['darkBlue']]
|
||||
:type linecol:
|
||||
:param fillcol: color of fill [Tango.colorsHex['lightBlue']]
|
||||
:param levels: for 2D plotting, the number of contour levels to use is ax is None, create a new figure
|
||||
"""
|
||||
assert "matplotlib" in sys.modules, "matplotlib package has not been imported."
|
||||
from ..plotting.matplot_dep import models_plots
|
||||
return models_plots.plot_fit(self,*args,**kwargs)
|
||||
kw = {}
|
||||
if linecol is not None:
|
||||
kw['linecol'] = linecol
|
||||
if fillcol is not None:
|
||||
kw['fillcol'] = fillcol
|
||||
return models_plots.plot_fit(self, plot_limits, which_data_rows,
|
||||
which_data_ycols, fixed_inputs,
|
||||
levels, samples, fignum, ax, resolution,
|
||||
plot_raw=plot_raw, Y_metadata=Y_metadata,
|
||||
data_symbol=data_symbol, **kw)
|
||||
|
||||
def input_sensitivity(self):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -272,8 +272,11 @@ class Parameterized(Parameterizable):
|
|||
def __setattr__(self, name, val):
|
||||
# override the default behaviour, if setting a param, so broadcasting can by used
|
||||
if hasattr(self, "parameters"):
|
||||
pnames = self.parameter_names(False, adjust_for_printing=True, recursive=False)
|
||||
if name in pnames: self.parameters[pnames.index(name)][:] = val; return
|
||||
try:
|
||||
pnames = self.parameter_names(False, adjust_for_printing=True, recursive=False)
|
||||
if name in pnames: self.parameters[pnames.index(name)][:] = val; return
|
||||
except AttributeError:
|
||||
pass
|
||||
object.__setattr__(self, name, val);
|
||||
|
||||
#===========================================================================
|
||||
|
|
@ -281,11 +284,15 @@ class Parameterized(Parameterizable):
|
|||
#===========================================================================
|
||||
def __setstate__(self, state):
|
||||
super(Parameterized, self).__setstate__(state)
|
||||
self._connect_parameters()
|
||||
self._connect_fixes()
|
||||
self._notify_parent_change()
|
||||
try:
|
||||
self._connect_parameters()
|
||||
self._connect_fixes()
|
||||
self._notify_parent_change()
|
||||
|
||||
self.parameters_changed()
|
||||
except Exception as e:
|
||||
print "WARNING: caught exception {!s}, trying to continue".format(e)
|
||||
|
||||
self.parameters_changed()
|
||||
def copy(self):
|
||||
c = super(Parameterized, self).copy()
|
||||
c._connect_parameters()
|
||||
|
|
|
|||
|
|
@ -66,7 +66,11 @@ class SparseGP(GP):
|
|||
#gradients wrt Z
|
||||
self.Z.gradient = self.kern.gradients_X(dL_dKmm, self.Z)
|
||||
self.Z.gradient += self.kern.gradients_Z_expectations(
|
||||
self.grad_dict['dL_dpsi0'], self.grad_dict['dL_dpsi1'], self.grad_dict['dL_dpsi2'], Z=self.Z, variational_posterior=self.X)
|
||||
self.grad_dict['dL_dpsi0'],
|
||||
self.grad_dict['dL_dpsi1'],
|
||||
self.grad_dict['dL_dpsi2'],
|
||||
Z=self.Z,
|
||||
variational_posterior=self.X)
|
||||
else:
|
||||
#gradients wrt kernel
|
||||
self.kern.update_gradients_diag(self.grad_dict['dL_dKdiag'], self.X)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue