mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-05 17:52:39 +02:00
[plotting] latent plotting had dimension mix up in it
This commit is contained in:
parent
8878353fdb
commit
cb3b4ca08d
14 changed files with 49 additions and 25 deletions
|
|
@ -393,7 +393,7 @@ def plot_f(self, plot_limits=None, fixed_inputs=None,
|
|||
apply_link, which_data_ycols, which_data_rows,
|
||||
visible_dims, levels, samples, 0,
|
||||
lower, upper, plot_data, plot_inducing,
|
||||
plot_density, predict_kw, projection, legend)
|
||||
plot_density, predict_kw, projection, legend, **kwargs)
|
||||
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -123,13 +123,16 @@ def plot_latent_inducing(self,
|
|||
:param kwargs: the kwargs for the scatter plots
|
||||
"""
|
||||
input_1, input_2, input_3 = sig_dims = self.get_most_significant_input_dimensions(which_indices)
|
||||
if input_3 is None: zlabel=None
|
||||
else: zlabel = 'latent dimension %i' % input_3
|
||||
|
||||
|
||||
if 'color' not in kwargs:
|
||||
kwargs['color'] = 'white'
|
||||
canvas, kwargs = pl().new_canvas(projection=projection,
|
||||
xlabel='latent dimension %i' % input_1,
|
||||
ylabel='latent dimension %i' % input_2,
|
||||
zlabel='latent dimension %i' % input_3, **kwargs)
|
||||
zlabel=zlabel, **kwargs)
|
||||
Z = self.Z.values
|
||||
labels = np.array(['inducing'] * Z.shape[0])
|
||||
scatters = _plot_latent_scatter(canvas, Z, sig_dims, labels, marker, num_samples, projection=projection, **kwargs)
|
||||
|
|
@ -195,7 +198,7 @@ def plot_magnification(self, labels=None, which_indices=None,
|
|||
labels = np.ones(self.num_data)
|
||||
legend = False
|
||||
scatters = _plot_latent_scatter(canvas, X, which_indices, labels, marker, num_samples, projection='2d', **scatter_kwargs or {})
|
||||
view = _plot_magnification(self, canvas, which_indices[:2], Xgrid, xmin, xmax, resolution, updates, mean, covariance, kern, **imshow_kwargs)
|
||||
view = _plot_magnification(self, canvas, which_indices, Xgrid, xmin, xmax, resolution, updates, mean, covariance, kern, **imshow_kwargs)
|
||||
retval = pl().add_to_canvas(canvas, dict(scatter=scatters, imshow=view),
|
||||
legend=legend,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -131,6 +131,8 @@ def helper_for_plot_data(self, X, plot_limits, visible_dims, fixed_inputs, resol
|
|||
Xnew, x, y, xmin, xmax = x_frame2D(X[:,free_dims], plot_limits, resolution)
|
||||
Xgrid = np.zeros((Xnew.shape[0], self.input_dim))
|
||||
Xgrid[:,free_dims] = Xnew
|
||||
#xmin = Xgrid.min(0)[free_dims]
|
||||
#xmax = Xgrid.max(0)[free_dims]
|
||||
for i,v in fixed_inputs:
|
||||
Xgrid[:,i] = v
|
||||
else:
|
||||
|
|
@ -305,7 +307,7 @@ def get_free_dims(model, visible_dims, fixed_dims):
|
|||
visible_dims = np.arange(model.input_dim)
|
||||
dims = np.asanyarray(visible_dims)
|
||||
if fixed_dims is not None:
|
||||
dims = np.setdiff1d(dims, fixed_dims)
|
||||
dims = [dim for dim in dims if dim not in fixed_dims]
|
||||
return np.asanyarray([dim for dim in dims if dim is not None])
|
||||
|
||||
|
||||
|
|
@ -357,7 +359,7 @@ def x_frame2D(X,plot_limits=None,resolution=None):
|
|||
"""
|
||||
assert X.shape[1]==2, "x_frame2D is defined for two-dimensional inputs"
|
||||
if plot_limits is None:
|
||||
xmin, xmax = X.min(0),X.max(0)
|
||||
xmin, xmax = X.min(0), X.max(0)
|
||||
xmin, xmax = xmin-0.075*(xmax-xmin), xmax+0.075*(xmax-xmin)
|
||||
elif len(plot_limits) == 2:
|
||||
xmin, xmax = plot_limits
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue