mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-04 01:02:39 +02:00
[matplotlib] plot updates and testing
This commit is contained in:
parent
116ad8762c
commit
486def6e0c
21 changed files with 217 additions and 204 deletions
|
|
@ -47,85 +47,19 @@ def _wait_for_updates(view, updates):
|
|||
pass
|
||||
|
||||
|
||||
def plot_prediction_fit(self, plot_limits=None,
|
||||
which_data_rows='all', which_data_ycols='all',
|
||||
fixed_inputs=None, resolution=None,
|
||||
plot_raw=False, apply_link=False, visible_dims=None,
|
||||
predict_kw=None, scatter_kwargs=None, **plot_kwargs):
|
||||
"""
|
||||
Plot the fit of the (Bayesian)GPLVM latent space prediction to the outputs.
|
||||
This scatters two output dimensions against each other and a line
|
||||
from the prediction in two dimensions between them.
|
||||
|
||||
Give the Y_metadata in the predict_kw if you need it.
|
||||
|
||||
:param which_data_rows: which of the training data to plot (default all)
|
||||
:type which_data_rows: 'all' or a slice object to slice self.X, self.Y
|
||||
:param array-like which_data_ycols: which columns of y to plot (array-like or list of ints)
|
||||
:param fixed_inputs: a list of tuple [(i,v), (i,v)...], specifying that input dimension i should be set to value v.
|
||||
:type fixed_inputs: a list of tuples
|
||||
:param int resolution: The resolution of the prediction [defaults are 1D:200, 2D:50]
|
||||
:param bool plot_raw: plot the latent function (usually denoted f) only?
|
||||
:param bool apply_link: whether to apply the link function of the GP to the raw prediction.
|
||||
:param array-like visible_dims: which columns of the input X (!) to plot (array-like or list of ints)
|
||||
:param dict predict_kw: the keyword arguments for the prediction. If you want to plot a specific kernel give dict(kern=<specific kernel>) in here
|
||||
:param dict sactter_kwargs: kwargs for the scatter plot, specific for the plotting library you are using
|
||||
:param kwargs plot_kwargs: kwargs for the data plot for the plotting library you are using
|
||||
"""
|
||||
canvas, kwargs = pl.get_new_canvas(**plot_kwargs)
|
||||
plots = _plot_prediction_fit(self, canvas, plot_limits, which_data_rows, which_data_ycols,
|
||||
fixed_inputs, resolution, plot_raw,
|
||||
apply_link, visible_dims,
|
||||
predict_kw, scatter_kwargs, **kwargs)
|
||||
return pl.show_canvas(canvas, plots)
|
||||
|
||||
def _plot_prediction_fit(self, canvas, plot_limits=None,
|
||||
which_data_rows='all', which_data_ycols='all',
|
||||
fixed_inputs=None, resolution=None,
|
||||
plot_raw=False, apply_link=False, visible_dims=False,
|
||||
predict_kw=None, scatter_kwargs=None, **plot_kwargs):
|
||||
|
||||
ycols = get_which_data_ycols(self, which_data_ycols)
|
||||
rows = get_which_data_rows(self, which_data_rows)
|
||||
|
||||
if visible_dims is None:
|
||||
visible_dims = self.get_most_significant_input_dimensions()[:1]
|
||||
|
||||
X, _, Y, _, free_dims, Xgrid, _, _, _, _, resolution = helper_for_plot_data(self, plot_limits, visible_dims, fixed_inputs, resolution)
|
||||
|
||||
plots = {}
|
||||
|
||||
if len(free_dims)<2:
|
||||
if len(free_dims)==1:
|
||||
if scatter_kwargs is None:
|
||||
scatter_kwargs = {}
|
||||
update_not_existing_kwargs(scatter_kwargs, pl.defaults.data_y_1d) # @UndefinedVariable
|
||||
plots['output'] = pl.scatter(canvas, Y[rows, ycols[0]], Y[rows, ycols[1]],
|
||||
color=X[rows, free_dims[0]],
|
||||
**scatter_kwargs)
|
||||
if predict_kw is None:
|
||||
predict_kw = {}
|
||||
mu, _, _ = helper_predict_with_model(self, Xgrid, plot_raw,
|
||||
apply_link, None,
|
||||
ycols, predict_kw)
|
||||
update_not_existing_kwargs(plot_kwargs, pl.defaults.data_y_1d_plot) # @UndefinedVariable
|
||||
plots['output_fit'] = pl.plot(canvas, mu[:, 0], mu[:, 1], **plot_kwargs)
|
||||
else:
|
||||
pass #Nothing to plot!
|
||||
else:
|
||||
raise NotImplementedError("Cannot plot in more then one dimension.")
|
||||
return plots
|
||||
|
||||
def _plot_latent_scatter(self, canvas, X, visible_dims, labels, marker, num_samples, projection='2d', **kwargs):
|
||||
def _plot_latent_scatter(canvas, X, visible_dims, labels, marker, num_samples, projection='2d', **kwargs):
|
||||
from .. import Tango
|
||||
Tango.reset()
|
||||
if labels is None:
|
||||
labels = np.ones(self.num_data)
|
||||
X, labels = subsample_X(X, labels, num_samples)
|
||||
scatters = []
|
||||
scatters = []
|
||||
generate_colors = 'color' not in kwargs
|
||||
for x, y, z, this_label, _, m in scatter_label_generator(labels, X, visible_dims, marker):
|
||||
update_not_existing_kwargs(kwargs, pl.defaults.latent_scatter)
|
||||
scatters.append(pl.scatter(canvas, x, y, Z=z, marker=m, color=Tango.nextMedium(), label=this_label, **kwargs))
|
||||
if generate_colors:
|
||||
kwargs['color'] = Tango.nextMedium()
|
||||
if projection == '3d':
|
||||
scatters.append(pl.scatter(canvas, x, y, Z=z, marker=m, label=this_label, **kwargs))
|
||||
else: scatters.append(pl.scatter(canvas, x, y, marker=m, label=this_label, **kwargs))
|
||||
return scatters
|
||||
|
||||
def plot_latent_scatter(self, labels=None,
|
||||
|
|
@ -147,43 +81,81 @@ def plot_latent_scatter(self, labels=None,
|
|||
:param str marker: markers to use - cycle if more labels then markers are given
|
||||
:param kwargs: the kwargs for the scatter plots
|
||||
"""
|
||||
sig_dims = self.get_most_significant_input_dimensions(which_indices)
|
||||
input_1, input_2, input_3 = [i for i in sig_dims if i is not None]
|
||||
|
||||
input_1, input_2, input_3 = sig_dims = self.get_most_significant_input_dimensions(which_indices)
|
||||
|
||||
canvas, kwargs = pl.get_new_canvas(projection=projection, **kwargs)
|
||||
X, _, _ = get_x_y_var(self)
|
||||
scatters = _plot_latent_scatter(self, canvas, X, sig_dims, labels, marker, num_samples, projection=projection, **kwargs)
|
||||
if labels is None:
|
||||
labels = np.ones(self.num_data)
|
||||
legend = False
|
||||
else:
|
||||
legend = find_best_layout_for_subplots(len(np.unique(labels)))
|
||||
scatters = _plot_latent_scatter(canvas, X, sig_dims, labels, marker, num_samples, projection=projection, **kwargs)
|
||||
if projection == '3d':
|
||||
return pl.show_canvas(canvas, dict(scatter=scatters), legend=legend and (labels is not None),
|
||||
return pl.show_canvas(canvas, dict(scatter=scatters), legend=legend,
|
||||
xlabel='latent dimension %i' % input_1,
|
||||
ylabel='latent dimension %i' % input_2,
|
||||
zlabel='latent dimension %i' % input_3)
|
||||
else:
|
||||
return pl.show_canvas(canvas, dict(scatter=scatters), legend=legend and (labels is not None),
|
||||
return pl.show_canvas(canvas, dict(scatter=scatters), legend=legend,
|
||||
xlabel='latent dimension %i' % input_1,
|
||||
ylabel='latent dimension %i' % input_2,
|
||||
#zlabel='latent dimension %i' % input_3
|
||||
)
|
||||
def plot_latent_inducing(self,
|
||||
which_indices=None,
|
||||
legend=False,
|
||||
plot_limits=None,
|
||||
marker='^',
|
||||
num_samples=1000,
|
||||
projection='2d',
|
||||
**kwargs):
|
||||
"""
|
||||
Plot a scatter plot of the inducing inputs.
|
||||
|
||||
:param array-like labels: a label for each data point (row) of the inputs
|
||||
:param (int, int) which_indices: which input dimensions to plot against each other
|
||||
:param bool legend: whether to plot the legend on the figure
|
||||
:param plot_limits: the plot limits for the plot
|
||||
:type plot_limits: (xmin, xmax, ymin, ymax) or ((xmin, xmax), (ymin, ymax))
|
||||
:param str marker: markers to use - cycle if more labels then markers are given
|
||||
:param kwargs: the kwargs for the scatter plots
|
||||
"""
|
||||
input_1, input_2, input_3 = sig_dims = self.get_most_significant_input_dimensions(which_indices)
|
||||
|
||||
if 'color' not in kwargs:
|
||||
kwargs['color'] = 'white'
|
||||
canvas, kwargs = pl.get_new_canvas(projection=projection, **kwargs)
|
||||
X, _, _ = get_x_y_var(self)
|
||||
labels = np.ones(self.num_data)
|
||||
scatters = _plot_latent_scatter(canvas, X, sig_dims, labels, marker, num_samples, projection=projection, **kwargs)
|
||||
if projection == '3d':
|
||||
return pl.show_canvas(canvas, dict(scatter=scatters), legend=legend,
|
||||
xlabel='latent dimension %i' % input_1,
|
||||
ylabel='latent dimension %i' % input_2,
|
||||
zlabel='latent dimension %i' % input_3)
|
||||
else:
|
||||
return pl.show_canvas(canvas, dict(scatter=scatters), legend=legend,
|
||||
xlabel='latent dimension %i' % input_1,
|
||||
ylabel='latent dimension %i' % input_2,
|
||||
#zlabel='latent dimension %i' % input_3
|
||||
)
|
||||
|
||||
|
||||
def _plot_magnification(self, canvas, input_1, input_2, Xgrid,
|
||||
def _plot_magnification(self, canvas, which_indices, Xgrid,
|
||||
xmin, xmax, resolution,
|
||||
mean=True, covariance=True,
|
||||
kern=None,
|
||||
**imshow_kwargs):
|
||||
def plot_function(x):
|
||||
Xtest_full = np.zeros((x.shape[0], Xgrid.shape[1]))
|
||||
Xtest_full[:, [input_1, input_2]] = x
|
||||
Xtest_full[:, which_indices] = x
|
||||
mf = self.predict_magnification(Xtest_full, kern=kern, mean=mean, covariance=covariance)
|
||||
return mf.reshape(resolution, resolution).T
|
||||
imshow_kwargs = update_not_existing_kwargs(imshow_kwargs, pl.defaults.magnification)
|
||||
Y = plot_function(Xgrid[:, [input_1, input_2]])
|
||||
view = pl.imshow(canvas, Y,
|
||||
(xmin[0], xmax[0], xmin[1], xmax[1]),
|
||||
None, plot_function, resolution,
|
||||
vmin=Y.min(), vmax=Y.max(),
|
||||
**imshow_kwargs)
|
||||
return view
|
||||
try:
|
||||
return pl.imshow_interact(canvas, plot_function, (xmin[0], xmax[0], xmin[1], xmax[1]), resolution=resolution, **imshow_kwargs)
|
||||
except NotImplementedError:
|
||||
return pl.imshow(canvas, plot_function(Xgrid), (xmin[0], xmax[0], xmin[1], xmax[1]), **imshow_kwargs)
|
||||
|
||||
def plot_magnification(self, labels=None, which_indices=None,
|
||||
resolution=60, marker='<>^vsd', legend=True,
|
||||
|
|
@ -211,13 +183,16 @@ def plot_magnification(self, labels=None, which_indices=None,
|
|||
:param imshow_kwargs: the kwargs for the imshow (magnification factor)
|
||||
:param kwargs: the kwargs for the scatter plots
|
||||
"""
|
||||
input_1, input_2 = self.get_most_significant_input_dimensions(which_indices)
|
||||
input_1, input_2 = which_indices = self.get_most_significant_input_dimensions(which_indices)[:2]
|
||||
canvas, imshow_kwargs = pl.get_new_canvas(**imshow_kwargs)
|
||||
X, _, _, _, _, Xgrid, _, _, xmin, xmax, resolution = helper_for_plot_data(self, plot_limits, (input_1, input_2), None, resolution)
|
||||
scatters = _plot_latent_scatter(self, canvas, X, input_1, input_2, labels, marker, num_samples, **scatter_kwargs or {})
|
||||
view = _plot_magnification(self, canvas, input_1, input_2, Xgrid, xmin, xmax, resolution, mean, covariance, kern, **imshow_kwargs)
|
||||
if (legend is True) and (labels is not None):
|
||||
X, _, _, _, _, Xgrid, _, _, xmin, xmax, resolution = helper_for_plot_data(self, plot_limits, which_indices, None, resolution)
|
||||
if (labels is not None):
|
||||
legend = find_best_layout_for_subplots(len(np.unique(labels)))[1]
|
||||
else:
|
||||
labels = np.ones(self.num_data)
|
||||
legend = False
|
||||
scatters = _plot_latent_scatter(canvas, X, which_indices, labels, marker, num_samples, projection='2d', **scatter_kwargs or {})
|
||||
view = _plot_magnification(self, canvas, which_indices[:2], Xgrid, xmin, xmax, resolution, mean, covariance, kern, **imshow_kwargs)
|
||||
plots = pl.show_canvas(canvas, dict(scatter=scatters, imshow=view),
|
||||
legend=legend,
|
||||
xlim=(xmin[0], xmax[0]), ylim=(xmin[1], xmax[1]),
|
||||
|
|
@ -228,24 +203,21 @@ def plot_magnification(self, labels=None, which_indices=None,
|
|||
|
||||
|
||||
|
||||
def _plot_latent(self, canvas, input_1, input_2, Xgrid,
|
||||
def _plot_latent(self, canvas, which_indices, Xgrid,
|
||||
xmin, xmax, resolution,
|
||||
kern=None,
|
||||
**imshow_kwargs):
|
||||
def plot_function(x):
|
||||
Xtest_full = np.zeros((x.shape[0], Xgrid.shape[1]))
|
||||
Xtest_full[:, [input_1, input_2]] = x
|
||||
Xtest_full[:, which_indices] = x
|
||||
mf = np.log(self.predict(Xtest_full, kern=kern)[1])
|
||||
return mf.reshape(resolution, resolution).T
|
||||
|
||||
imshow_kwargs = update_not_existing_kwargs(imshow_kwargs, pl.defaults.latent)
|
||||
Y = plot_function(Xgrid[:, [input_1, input_2]]).reshape(resolution, resolution).T
|
||||
view = pl.imshow(canvas, Y,
|
||||
(xmin[0], xmax[0], xmin[1], xmax[1]),
|
||||
None, plot_function, resolution,
|
||||
vmin=Y.min(), vmax=Y.max(),
|
||||
**imshow_kwargs)
|
||||
return view
|
||||
try:
|
||||
return pl.imshow_interact(canvas, plot_function, (xmin[0], xmax[0], xmin[1], xmax[1]), resolution=resolution, **imshow_kwargs)
|
||||
except NotImplementedError:
|
||||
return pl.imshow(canvas, plot_function(Xgrid), (xmin[0], xmax[0], xmin[1], xmax[1]), **imshow_kwargs)
|
||||
|
||||
def plot_latent(self, labels=None, which_indices=None,
|
||||
resolution=60, legend=True,
|
||||
|
|
@ -272,13 +244,16 @@ def plot_latent(self, labels=None, which_indices=None,
|
|||
:param imshow_kwargs: the kwargs for the imshow (magnification factor)
|
||||
:param scatter_kwargs: the kwargs for the scatter plots
|
||||
"""
|
||||
input_1, input_2 = self.get_most_significant_input_dimensions(which_indices)
|
||||
input_1, input_2 = which_indices = self.get_most_significant_input_dimensions(which_indices)[:2]
|
||||
canvas, imshow_kwargs = pl.get_new_canvas(**imshow_kwargs)
|
||||
X, _, _, _, _, Xgrid, _, _, xmin, xmax, resolution = helper_for_plot_data(self, plot_limits, (input_1, input_2), None, resolution)
|
||||
scatters = _plot_latent_scatter(self, canvas, X, input_1, input_2, labels, marker, num_samples, **scatter_kwargs or {})
|
||||
view = _plot_latent(self, canvas, input_1, input_2, Xgrid, xmin, xmax, resolution, kern, **imshow_kwargs)
|
||||
if (legend is True) and (labels is not None):
|
||||
X, _, _, _, _, Xgrid, _, _, xmin, xmax, resolution = helper_for_plot_data(self, plot_limits, which_indices, None, resolution)
|
||||
if (labels is not None):
|
||||
legend = find_best_layout_for_subplots(len(np.unique(labels)))[1]
|
||||
else:
|
||||
labels = np.ones(self.num_data)
|
||||
legend = False
|
||||
scatters = _plot_latent_scatter(canvas, X, which_indices, labels, marker, num_samples, projection='2d', **scatter_kwargs or {})
|
||||
view = _plot_latent(self, canvas, which_indices, Xgrid, xmin, xmax, resolution, kern, **imshow_kwargs)
|
||||
plots = pl.show_canvas(canvas, dict(scatter=scatters, imshow=view),
|
||||
legend=legend,
|
||||
xlim=(xmin[0], xmax[0]), ylim=(xmin[1], xmax[1]),
|
||||
|
|
@ -286,25 +261,25 @@ def plot_latent(self, labels=None, which_indices=None,
|
|||
_wait_for_updates(view, updates)
|
||||
return plots
|
||||
|
||||
def _plot_steepest_gradient_map(self, canvas, input_1, input_2, Xgrid,
|
||||
def _plot_steepest_gradient_map(self, canvas, which_indices, Xgrid,
|
||||
xmin, xmax, resolution, output_labels,
|
||||
kern=None, annotation_kwargs=None,
|
||||
**imshow_kwargs):
|
||||
if output_labels is None:
|
||||
output_labels = range(self.output_dim)
|
||||
def plot_function(x):
|
||||
Xgrid[:, [input_1, input_2]] = x
|
||||
dmu_dX = self.predictive_gradients(Xgrid, kern=kern)[0].sum(1)
|
||||
Xgrid[:, which_indices] = x
|
||||
dmu_dX = np.sqrt(((self.predictive_gradients(Xgrid, kern=kern)[0])**2).sum(1))
|
||||
#dmu_dX = self.predictive_gradients(Xgrid, kern=kern)[0].sum(1)
|
||||
argmax = np.argmax(dmu_dX, 1).astype(int)
|
||||
return dmu_dX.max(1).reshape(resolution, resolution).T, np.array(output_labels)[argmax].reshape(resolution, resolution)
|
||||
Y, annotation = plot_function(Xgrid[:, [input_1, input_2]])
|
||||
return dmu_dX.max(1).reshape(resolution, resolution).T, np.array(output_labels)[argmax].reshape(resolution, resolution).T
|
||||
annotation_kwargs = update_not_existing_kwargs(annotation_kwargs or {}, pl.defaults.annotation)
|
||||
imshow_kwargs = update_not_existing_kwargs(imshow_kwargs or {}, pl.defaults.gradient)
|
||||
imshow, annotation = pl.annotation_heatmap(canvas, Y, annotation,
|
||||
(xmin[0], xmax[0], xmin[1], xmax[1]),
|
||||
None, plot_function, resolution,
|
||||
imshow_kwargs=imshow_kwargs, **annotation_kwargs)
|
||||
return dict(heatmap=imshow, annotation=annotation)
|
||||
try:
|
||||
return dict(annotation=pl.annotation_heatmap_interact(canvas, plot_function, (xmin[0], xmax[0], xmin[1], xmax[1]), resolution=resolution, imshow_kwargs=imshow_kwargs, **annotation_kwargs))
|
||||
except NotImplementedError:
|
||||
imshow, annotation = pl.annotation_heatmap(canvas, *plot_function(Xgrid), extent=(xmin[0], xmax[0], xmin[1], xmax[1]), imshow_kwargs=imshow_kwargs, **annotation_kwargs)
|
||||
return dict(heatmap=imshow, annotation=annotation)
|
||||
|
||||
def plot_steepest_gradient_map(self, output_labels=None, data_labels=None, which_indices=None,
|
||||
resolution=15, legend=True,
|
||||
|
|
@ -333,13 +308,16 @@ def plot_steepest_gradient_map(self, output_labels=None, data_labels=None, which
|
|||
:param annotation_kwargs: the kwargs for the annotation plot
|
||||
:param scatter_kwargs: the kwargs for the scatter plots
|
||||
"""
|
||||
input_1, input_2 = self.get_most_significant_input_dimensions(which_indices)
|
||||
input_1, input_2 = which_indices = self.get_most_significant_input_dimensions(which_indices)[:2]
|
||||
canvas, imshow_kwargs = pl.get_new_canvas(**imshow_kwargs)
|
||||
X, _, _, _, _, Xgrid, _, _, xmin, xmax, resolution = helper_for_plot_data(self, plot_limits, (input_1, input_2), None, resolution)
|
||||
plots = dict(scatter=_plot_latent_scatter(self, canvas, X, input_1, input_2, data_labels, marker, num_samples, **scatter_kwargs or {}))
|
||||
plots.update(_plot_steepest_gradient_map(self, canvas, input_1, input_2, Xgrid, xmin, xmax, resolution, output_labels, kern, annotation_kwargs=annotation_kwargs, **imshow_kwargs))
|
||||
if (legend is True) and (data_labels is not None):
|
||||
X, _, _, _, _, Xgrid, _, _, xmin, xmax, resolution = helper_for_plot_data(self, plot_limits, which_indices, None, resolution)
|
||||
if (data_labels is not None):
|
||||
legend = find_best_layout_for_subplots(len(np.unique(data_labels)))[1]
|
||||
else:
|
||||
data_labels = np.ones(self.num_data)
|
||||
legend = False
|
||||
plots = dict(scatter=_plot_latent_scatter(canvas, X, which_indices, data_labels, marker, num_samples, **scatter_kwargs or {}))
|
||||
plots.update(_plot_steepest_gradient_map(self, canvas, which_indices, Xgrid, xmin, xmax, resolution, output_labels, kern, annotation_kwargs=annotation_kwargs, **imshow_kwargs))
|
||||
pl.show_canvas(canvas, plots, legend=legend,
|
||||
xlim=(xmin[0], xmax[0]), ylim=(xmin[1], xmax[1]),
|
||||
xlabel='latent dimension %i' % input_1, ylabel='latent dimension %i' % input_2)
|
||||
|
|
|
|||
|
|
@ -100,6 +100,8 @@ def helper_for_plot_data(self, plot_limits, visible_dims, fixed_inputs, resoluti
|
|||
"""
|
||||
Figure out the data, free_dims and create an Xgrid for
|
||||
the prediction.
|
||||
|
||||
This is only implemented for two dimensions for now!
|
||||
"""
|
||||
X, Xvar, Y = get_x_y_var(self)
|
||||
|
||||
|
|
@ -107,13 +109,13 @@ def helper_for_plot_data(self, plot_limits, visible_dims, fixed_inputs, resoluti
|
|||
if fixed_inputs is None:
|
||||
fixed_inputs = []
|
||||
fixed_dims = get_fixed_dims(self, fixed_inputs)
|
||||
free_dims = get_free_dims(self, visible_dims, fixed_dims)
|
||||
free_dims = get_free_dims(self, visible_dims, fixed_dims)[:2]
|
||||
|
||||
if len(free_dims) == 1:
|
||||
#define the frame on which to plot
|
||||
resolution = resolution or 200
|
||||
Xnew, xmin, xmax = x_frame1D(X[:,free_dims], plot_limits=plot_limits, resolution=resolution)
|
||||
Xgrid = np.empty((Xnew.shape[0],self.input_dim))
|
||||
Xgrid = np.zeros((Xnew.shape[0],self.input_dim))
|
||||
Xgrid[:,free_dims] = Xnew
|
||||
for i,v in fixed_inputs:
|
||||
Xgrid[:,i] = v
|
||||
|
|
@ -123,10 +125,12 @@ def helper_for_plot_data(self, plot_limits, visible_dims, fixed_inputs, resoluti
|
|||
#define the frame for plotting on
|
||||
resolution = resolution or 50
|
||||
Xnew, x, y, xmin, xmax = x_frame2D(X[:,free_dims], plot_limits, resolution)
|
||||
Xgrid = np.empty((Xnew.shape[0],self.input_dim))
|
||||
Xgrid = np.zeros((Xnew.shape[0], self.input_dim))
|
||||
Xgrid[:,free_dims] = Xnew
|
||||
for i,v in fixed_inputs:
|
||||
Xgrid[:,i] = v
|
||||
Xgrid[:,i] = v
|
||||
else:
|
||||
raise TypeError("calculated free_dims {} from visible_dims {} and fixed_dims {} is neither 1D nor 2D".format(free_dims, visible_dims, fixed_dims))
|
||||
return X, Xvar, Y, fixed_dims, free_dims, Xgrid, x, y, xmin, xmax, resolution
|
||||
|
||||
def scatter_label_generator(labels, X, visible_dims, marker=None):
|
||||
|
|
@ -140,7 +144,16 @@ def scatter_label_generator(labels, X, visible_dims, marker=None):
|
|||
else:
|
||||
m = None
|
||||
|
||||
input_1, input_2, input_3 = visible_dims
|
||||
try:
|
||||
input_1, input_2, input_3 = visible_dims
|
||||
except:
|
||||
try:
|
||||
# tuple or int?
|
||||
input_1, input_2 = visible_dims
|
||||
input_3 = None
|
||||
except:
|
||||
input_1 = visible_dims
|
||||
input_2 = input_3 = None
|
||||
|
||||
for ul in ulabels:
|
||||
if type(ul) is np.string_:
|
||||
|
|
@ -280,10 +293,11 @@ def get_free_dims(model, visible_dims, fixed_dims):
|
|||
"""
|
||||
if visible_dims is None:
|
||||
visible_dims = np.arange(model.input_dim)
|
||||
visible_dims = np.asanyarray(visible_dims)
|
||||
dims = np.asanyarray(visible_dims)
|
||||
if fixed_dims is not None:
|
||||
return np.setdiff1d(visible_dims, fixed_dims)
|
||||
return visible_dims
|
||||
dims = np.setdiff1d(dims, fixed_dims)
|
||||
return np.asanyarray([dim for dim in dims if dim is not None])
|
||||
|
||||
|
||||
def get_fixed_dims(model, fixed_inputs):
|
||||
"""
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue