[matplotlib] plot updates and testing

This commit is contained in:
mzwiessele 2015-10-06 18:20:54 +01:00
parent 116ad8762c
commit 486def6e0c
21 changed files with 217 additions and 204 deletions

View file

@ -100,6 +100,8 @@ def helper_for_plot_data(self, plot_limits, visible_dims, fixed_inputs, resoluti
"""
Figure out the data, free_dims and create an Xgrid for
the prediction.
This is only implemented for two dimensions for now!
"""
X, Xvar, Y = get_x_y_var(self)
@ -107,13 +109,13 @@ def helper_for_plot_data(self, plot_limits, visible_dims, fixed_inputs, resoluti
if fixed_inputs is None:
fixed_inputs = []
fixed_dims = get_fixed_dims(self, fixed_inputs)
free_dims = get_free_dims(self, visible_dims, fixed_dims)
free_dims = get_free_dims(self, visible_dims, fixed_dims)[:2]
if len(free_dims) == 1:
#define the frame on which to plot
resolution = resolution or 200
Xnew, xmin, xmax = x_frame1D(X[:,free_dims], plot_limits=plot_limits, resolution=resolution)
Xgrid = np.empty((Xnew.shape[0],self.input_dim))
Xgrid = np.zeros((Xnew.shape[0],self.input_dim))
Xgrid[:,free_dims] = Xnew
for i,v in fixed_inputs:
Xgrid[:,i] = v
@ -123,10 +125,12 @@ def helper_for_plot_data(self, plot_limits, visible_dims, fixed_inputs, resoluti
#define the frame for plotting on
resolution = resolution or 50
Xnew, x, y, xmin, xmax = x_frame2D(X[:,free_dims], plot_limits, resolution)
Xgrid = np.empty((Xnew.shape[0],self.input_dim))
Xgrid = np.zeros((Xnew.shape[0], self.input_dim))
Xgrid[:,free_dims] = Xnew
for i,v in fixed_inputs:
Xgrid[:,i] = v
Xgrid[:,i] = v
else:
raise TypeError("calculated free_dims {} from visible_dims {} and fixed_dims {} is neither 1D nor 2D".format(free_dims, visible_dims, fixed_dims))
return X, Xvar, Y, fixed_dims, free_dims, Xgrid, x, y, xmin, xmax, resolution
def scatter_label_generator(labels, X, visible_dims, marker=None):
@ -140,7 +144,16 @@ def scatter_label_generator(labels, X, visible_dims, marker=None):
else:
m = None
input_1, input_2, input_3 = visible_dims
try:
input_1, input_2, input_3 = visible_dims
except:
try:
# tuple or int?
input_1, input_2 = visible_dims
input_3 = None
except:
input_1 = visible_dims
input_2 = input_3 = None
for ul in ulabels:
if type(ul) is np.string_:
@ -280,10 +293,11 @@ def get_free_dims(model, visible_dims, fixed_dims):
"""
if visible_dims is None:
visible_dims = np.arange(model.input_dim)
visible_dims = np.asanyarray(visible_dims)
dims = np.asanyarray(visible_dims)
if fixed_dims is not None:
return np.setdiff1d(visible_dims, fixed_dims)
return visible_dims
dims = np.setdiff1d(dims, fixed_dims)
return np.asanyarray([dim for dim in dims if dim is not None])
def get_fixed_dims(model, fixed_inputs):
"""