[classification] sparse gp classification and dtc update

This commit is contained in:
Max Zwiessele 2015-09-11 15:08:30 +01:00
parent 4ea5ebaa68
commit 1d354f5cce
14 changed files with 208 additions and 369 deletions

View file

@ -42,8 +42,12 @@ def plot_data(model, which_data_rows='all',
fig = plt.figure(num=fignum)
ax = fig.add_subplot(111)
#data
X = model.X
if hasattr(model, 'has_uncertain_inputs') and model.has_uncertain_inputs():
X = model.X.mean
X_variance = model.X.variance
else:
X = model.X
X_variance = None
Y = model.Y
#work out what the inputs are for plotting (1D or 2D)
@ -54,9 +58,14 @@ def plot_data(model, which_data_rows='all',
plots = {}
#one dimensional plotting
if len(free_dims) == 1:
plots['dataplot'] = []
if X_variance is not None: plots['xerrorbar'] = []
for d in which_data_ycols:
plots['dataplot'] = ax.plot(X[which_data_rows,free_dims], Y[which_data_rows, d], data_symbol, mew=mew)
plots['dataplot'].append(ax.plot(X[which_data_rows, free_dims], Y[which_data_rows, d], data_symbol, mew=mew))
if X_variance is not None:
plots['xerrorbar'] = ax.errorbar(X[which_data_rows, free_dims].flatten(), Y[which_data_rows, which_data_ycols].flatten(),
xerr=2 * np.sqrt(X_variance[which_data_rows, free_dims].flatten()),
ecolor='k', fmt='none', elinewidth=.5, alpha=.5)
#2D plotting
elif len(free_dims) == 2:
@ -219,10 +228,6 @@ def plot_fit(model, plot_limits=None, which_data_rows='all',
plots['xerrorbar'] = ax.errorbar(X[which_data_rows, free_dims].flatten(), m_X[which_data_rows, which_data_ycols].flatten(),
xerr=2 * np.sqrt(X_variance[which_data_rows, free_dims].flatten()),
ecolor='k', fmt=None, elinewidth=.5, alpha=.5)
else:
plots['xerrorbar'] = ax.errorbar(X[which_data_rows, free_dims].flatten(), Y[which_data_rows, which_data_ycols].flatten(),
xerr=2 * np.sqrt(X_variance[which_data_rows, free_dims].flatten()),
ecolor='k', fmt=None, elinewidth=.5, alpha=.5)
#set the limits of the plot to some sensible values
ymin, ymax = min(np.append(Y[which_data_rows, which_data_ycols].flatten(), lower)), max(np.append(Y[which_data_rows, which_data_ycols].flatten(), upper))