mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-04 01:02:39 +02:00
[kernel] plotting ard for prod and covariance plots added
This commit is contained in:
parent
59ba858aba
commit
12dba962f1
13 changed files with 130 additions and 116 deletions
|
|
@ -102,19 +102,17 @@ def helper_predict_with_model(self, Xgrid, plot_raw, apply_link, percentiles, wh
|
|||
fsamples[:, s] = self.likelihood.gp_link.transf(fsamples[:, s])
|
||||
return retmu, percs, fsamples
|
||||
|
||||
def helper_for_plot_data(self, plot_limits, visible_dims, fixed_inputs, resolution):
|
||||
def helper_for_plot_data(self, X, plot_limits, visible_dims, fixed_inputs, resolution):
|
||||
"""
|
||||
Figure out the data, free_dims and create an Xgrid for
|
||||
the prediction.
|
||||
|
||||
This is only implemented for two dimensions for now!
|
||||
"""
|
||||
X, Xvar, Y = get_x_y_var(self)
|
||||
|
||||
#work out what the inputs are for plotting (1D or 2D)
|
||||
if fixed_inputs is None:
|
||||
fixed_inputs = []
|
||||
fixed_dims = get_fixed_dims(self, fixed_inputs)
|
||||
fixed_dims = get_fixed_dims(fixed_inputs)
|
||||
free_dims = get_free_dims(self, visible_dims, fixed_dims)
|
||||
|
||||
if len(free_dims) == 1:
|
||||
|
|
@ -129,7 +127,7 @@ def helper_for_plot_data(self, plot_limits, visible_dims, fixed_inputs, resoluti
|
|||
y = None
|
||||
elif len(free_dims) == 2:
|
||||
#define the frame for plotting on
|
||||
resolution = resolution or 50
|
||||
resolution = resolution or 35
|
||||
Xnew, x, y, xmin, xmax = x_frame2D(X[:,free_dims], plot_limits, resolution)
|
||||
Xgrid = np.zeros((Xnew.shape[0], self.input_dim))
|
||||
Xgrid[:,free_dims] = Xnew
|
||||
|
|
@ -137,7 +135,7 @@ def helper_for_plot_data(self, plot_limits, visible_dims, fixed_inputs, resoluti
|
|||
Xgrid[:,i] = v
|
||||
else:
|
||||
raise TypeError("calculated free_dims {} from visible_dims {} and fixed_dims {} is neither 1D nor 2D".format(free_dims, visible_dims, fixed_dims))
|
||||
return X, Xvar, Y, fixed_dims, free_dims, Xgrid, x, y, xmin, xmax, resolution
|
||||
return fixed_dims, free_dims, Xgrid, x, y, xmin, xmax, resolution
|
||||
|
||||
def scatter_label_generator(labels, X, visible_dims, marker=None):
|
||||
ulabels = []
|
||||
|
|
@ -271,13 +269,16 @@ def update_not_existing_kwargs(to_update, update_from):
|
|||
|
||||
def get_x_y_var(model):
|
||||
"""
|
||||
The the data from a model as
|
||||
Either the the data from a model as
|
||||
X the inputs,
|
||||
X_variance the variance of the inputs ([default: None])
|
||||
and Y the outputs
|
||||
|
||||
If (X, X_variance, Y) is given, this just returns.
|
||||
|
||||
:returns: (X, X_variance, Y)
|
||||
"""
|
||||
# model given
|
||||
if hasattr(model, 'has_uncertain_inputs') and model.has_uncertain_inputs():
|
||||
X = model.X.mean.values
|
||||
X_variance = model.X.variance.values
|
||||
|
|
@ -305,7 +306,7 @@ def get_free_dims(model, visible_dims, fixed_dims):
|
|||
return np.asanyarray([dim for dim in dims if dim is not None])
|
||||
|
||||
|
||||
def get_fixed_dims(model, fixed_inputs):
|
||||
def get_fixed_dims(fixed_inputs):
|
||||
"""
|
||||
Work out the fixed dimensions from the fixed_inputs list of tuples.
|
||||
"""
|
||||
|
|
@ -339,7 +340,7 @@ def x_frame1D(X,plot_limits=None,resolution=None):
|
|||
else:
|
||||
xmin,xmax = X.min(0),X.max(0)
|
||||
xmin, xmax = xmin-0.25*(xmax-xmin), xmax+0.25*(xmax-xmin)
|
||||
elif len(plot_limits)==2:
|
||||
elif len(plot_limits) == 2:
|
||||
xmin, xmax = plot_limits
|
||||
else:
|
||||
raise ValueError("Bad limits for plotting")
|
||||
|
|
@ -355,9 +356,15 @@ def x_frame2D(X,plot_limits=None,resolution=None):
|
|||
if plot_limits is None:
|
||||
xmin, xmax = X.min(0),X.max(0)
|
||||
xmin, xmax = xmin-0.075*(xmax-xmin), xmax+0.075*(xmax-xmin)
|
||||
elif len(plot_limits)==2:
|
||||
elif len(plot_limits) == 2:
|
||||
xmin, xmax = plot_limits
|
||||
elif len(plot_limits)==4:
|
||||
try:
|
||||
xmin = xmin[0], xmin[1]
|
||||
except:
|
||||
# only one limit given, copy over to other lim
|
||||
xmin = [plot_limits[0], plot_limits[0]]
|
||||
xmax = [plot_limits[1], plot_limits[1]]
|
||||
elif len(plot_limits) == 4:
|
||||
xmin, xmax = (plot_limits[0], plot_limits[2]), (plot_limits[1], plot_limits[3])
|
||||
else:
|
||||
raise ValueError("Bad limits for plotting")
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue