mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-30 14:35:15 +02:00
[classification] sparse gp classification and dtc update
This commit is contained in:
parent
4ea5ebaa68
commit
1d354f5cce
14 changed files with 208 additions and 369 deletions
|
|
@ -42,8 +42,12 @@ def plot_data(model, which_data_rows='all',
|
|||
fig = plt.figure(num=fignum)
|
||||
ax = fig.add_subplot(111)
|
||||
|
||||
#data
|
||||
X = model.X
|
||||
if hasattr(model, 'has_uncertain_inputs') and model.has_uncertain_inputs():
|
||||
X = model.X.mean
|
||||
X_variance = model.X.variance
|
||||
else:
|
||||
X = model.X
|
||||
X_variance = None
|
||||
Y = model.Y
|
||||
|
||||
#work out what the inputs are for plotting (1D or 2D)
|
||||
|
|
@ -54,9 +58,14 @@ def plot_data(model, which_data_rows='all',
|
|||
plots = {}
|
||||
#one dimensional plotting
|
||||
if len(free_dims) == 1:
|
||||
|
||||
plots['dataplot'] = []
|
||||
if X_variance is not None: plots['xerrorbar'] = []
|
||||
for d in which_data_ycols:
|
||||
plots['dataplot'] = ax.plot(X[which_data_rows,free_dims], Y[which_data_rows, d], data_symbol, mew=mew)
|
||||
plots['dataplot'].append(ax.plot(X[which_data_rows, free_dims], Y[which_data_rows, d], data_symbol, mew=mew))
|
||||
if X_variance is not None:
|
||||
plots['xerrorbar'] = ax.errorbar(X[which_data_rows, free_dims].flatten(), Y[which_data_rows, which_data_ycols].flatten(),
|
||||
xerr=2 * np.sqrt(X_variance[which_data_rows, free_dims].flatten()),
|
||||
ecolor='k', fmt='none', elinewidth=.5, alpha=.5)
|
||||
|
||||
#2D plotting
|
||||
elif len(free_dims) == 2:
|
||||
|
|
@ -219,10 +228,6 @@ def plot_fit(model, plot_limits=None, which_data_rows='all',
|
|||
plots['xerrorbar'] = ax.errorbar(X[which_data_rows, free_dims].flatten(), m_X[which_data_rows, which_data_ycols].flatten(),
|
||||
xerr=2 * np.sqrt(X_variance[which_data_rows, free_dims].flatten()),
|
||||
ecolor='k', fmt=None, elinewidth=.5, alpha=.5)
|
||||
else:
|
||||
plots['xerrorbar'] = ax.errorbar(X[which_data_rows, free_dims].flatten(), Y[which_data_rows, which_data_ycols].flatten(),
|
||||
xerr=2 * np.sqrt(X_variance[which_data_rows, free_dims].flatten()),
|
||||
ecolor='k', fmt=None, elinewidth=.5, alpha=.5)
|
||||
|
||||
#set the limits of the plot to some sensible values
|
||||
ymin, ymax = min(np.append(Y[which_data_rows, which_data_ycols].flatten(), lower)), max(np.append(Y[which_data_rows, which_data_ycols].flatten(), upper))
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue