mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-05 01:32:40 +02:00
Merge branch 'devel' of github.com:SheffieldML/GPy into devel
This commit is contained in:
commit
4dfdce9b80
4 changed files with 266 additions and 85 deletions
|
|
@ -516,8 +516,8 @@ class GP(Model):
|
||||||
def plot(self, plot_limits=None, which_data_rows='all',
|
def plot(self, plot_limits=None, which_data_rows='all',
|
||||||
which_data_ycols='all', fixed_inputs=[],
|
which_data_ycols='all', fixed_inputs=[],
|
||||||
levels=20, samples=0, fignum=None, ax=None, resolution=None,
|
levels=20, samples=0, fignum=None, ax=None, resolution=None,
|
||||||
plot_raw=False,
|
plot_raw=False, linecol=None,fillcol=None, Y_metadata=None,
|
||||||
linecol=None,fillcol=None, Y_metadata=None, data_symbol='kx', predict_kw=None):
|
data_symbol='kx', predict_kw=None, plot_training_data=True):
|
||||||
"""
|
"""
|
||||||
Plot the posterior of the GP.
|
Plot the posterior of the GP.
|
||||||
- In one dimension, the function is plotted with a shaded region identifying two standard deviations.
|
- In one dimension, the function is plotted with a shaded region identifying two standard deviations.
|
||||||
|
|
@ -554,6 +554,8 @@ class GP(Model):
|
||||||
:type Y_metadata: dict
|
:type Y_metadata: dict
|
||||||
:param data_symbol: symbol as used matplotlib, by default this is a black cross ('kx')
|
:param data_symbol: symbol as used matplotlib, by default this is a black cross ('kx')
|
||||||
:type data_symbol: color either as Tango.colorsHex object or character ('r' is red, 'g' is green) alongside marker type, as is standard in matplotlib.
|
:type data_symbol: color either as Tango.colorsHex object or character ('r' is red, 'g' is green) alongside marker type, as is standard in matplotlib.
|
||||||
|
:param plot_training_data: whether or not to plot the training points
|
||||||
|
:type plot_training_data: boolean
|
||||||
"""
|
"""
|
||||||
assert "matplotlib" in sys.modules, "matplotlib package has not been imported."
|
assert "matplotlib" in sys.modules, "matplotlib package has not been imported."
|
||||||
from ..plotting.matplot_dep import models_plots
|
from ..plotting.matplot_dep import models_plots
|
||||||
|
|
@ -566,7 +568,85 @@ class GP(Model):
|
||||||
which_data_ycols, fixed_inputs,
|
which_data_ycols, fixed_inputs,
|
||||||
levels, samples, fignum, ax, resolution,
|
levels, samples, fignum, ax, resolution,
|
||||||
plot_raw=plot_raw, Y_metadata=Y_metadata,
|
plot_raw=plot_raw, Y_metadata=Y_metadata,
|
||||||
data_symbol=data_symbol, predict_kw=predict_kw, **kw)
|
data_symbol=data_symbol, predict_kw=predict_kw,
|
||||||
|
plot_training_data=plot_training_data, **kw)
|
||||||
|
|
||||||
|
|
||||||
|
def plot_data(self, which_data_rows='all',
|
||||||
|
which_data_ycols='all', visible_dims=None,
|
||||||
|
fignum=None, ax=None, data_symbol='kx'):
|
||||||
|
"""
|
||||||
|
Plot the training data
|
||||||
|
- For higher dimensions than two, use fixed_inputs to plot the data points with some of the inputs fixed.
|
||||||
|
|
||||||
|
Can plot only part of the data
|
||||||
|
using which_data_rows and which_data_ycols.
|
||||||
|
|
||||||
|
:param plot_limits: The limits of the plot. If 1D [xmin,xmax], if 2D [[xmin,ymin],[xmax,ymax]]. Defaluts to data limits
|
||||||
|
:type plot_limits: np.array
|
||||||
|
:param which_data_rows: which of the training data to plot (default all)
|
||||||
|
:type which_data_rows: 'all' or a slice object to slice model.X, model.Y
|
||||||
|
:param which_data_ycols: when the data has several columns (independant outputs), only plot these
|
||||||
|
:type which_data_ycols: 'all' or a list of integers
|
||||||
|
:param visible_dims: an array specifying the input dimensions to plot (maximum two)
|
||||||
|
:type visible_dims: a numpy array
|
||||||
|
:param resolution: the number of intervals to sample the GP on. Defaults to 200 in 1D and 50 (a 50x50 grid) in 2D
|
||||||
|
:type resolution: int
|
||||||
|
:param levels: number of levels to plot in a contour plot.
|
||||||
|
:param levels: for 2D plotting, the number of contour levels to use is ax is None, create a new figure
|
||||||
|
:type levels: int
|
||||||
|
:param samples: the number of a posteriori samples to plot
|
||||||
|
:type samples: int
|
||||||
|
:param fignum: figure to plot on.
|
||||||
|
:type fignum: figure number
|
||||||
|
:param ax: axes to plot on.
|
||||||
|
:type ax: axes handle
|
||||||
|
:param linecol: color of line to plot [Tango.colorsHex['darkBlue']]
|
||||||
|
:type linecol: color either as Tango.colorsHex object or character ('r' is red, 'g' is green) as is standard in matplotlib
|
||||||
|
:param fillcol: color of fill [Tango.colorsHex['lightBlue']]
|
||||||
|
:type fillcol: color either as Tango.colorsHex object or character ('r' is red, 'g' is green) as is standard in matplotlib
|
||||||
|
:param data_symbol: symbol as used matplotlib, by default this is a black cross ('kx')
|
||||||
|
:type data_symbol: color either as Tango.colorsHex object or character ('r' is red, 'g' is green) alongside marker type, as is standard in matplotlib.
|
||||||
|
"""
|
||||||
|
assert "matplotlib" in sys.modules, "matplotlib package has not been imported."
|
||||||
|
from ..plotting.matplot_dep import models_plots
|
||||||
|
kw = {}
|
||||||
|
return models_plots.plot_data(self, which_data_rows,
|
||||||
|
which_data_ycols, visible_dims,
|
||||||
|
fignum, ax, data_symbol, **kw)
|
||||||
|
|
||||||
|
|
||||||
|
def plot_fit_errorbars(self, which_data_rows='all',
|
||||||
|
which_data_ycols='all', fixed_inputs=[], fignum=None, ax=None,
|
||||||
|
linecol=None, data_symbol='kx', predict_kw=None, plot_training_data=True):
|
||||||
|
|
||||||
|
"""
|
||||||
|
Plot the posterior error bars corresponding to the training data
|
||||||
|
- For higher dimensions than two, use fixed_inputs to plot the data points with some of the inputs fixed.
|
||||||
|
|
||||||
|
Can plot only part of the data
|
||||||
|
using which_data_rows and which_data_ycols.
|
||||||
|
|
||||||
|
:param which_data_rows: which of the training data to plot (default all)
|
||||||
|
:type which_data_rows: 'all' or a slice object to slice model.X, model.Y
|
||||||
|
:param which_data_ycols: when the data has several columns (independant outputs), only plot these
|
||||||
|
:type which_data_rows: 'all' or a list of integers
|
||||||
|
:param fixed_inputs: a list of tuple [(i,v), (i,v)...], specifying that input index i should be set to value v.
|
||||||
|
:type fixed_inputs: a list of tuples
|
||||||
|
:param fignum: figure to plot on.
|
||||||
|
:type fignum: figure number
|
||||||
|
:param ax: axes to plot on.
|
||||||
|
:type ax: axes handle
|
||||||
|
:param plot_training_data: whether or not to plot the training points
|
||||||
|
:type plot_training_data: boolean
|
||||||
|
"""
|
||||||
|
assert "matplotlib" in sys.modules, "matplotlib package has not been imported."
|
||||||
|
from ..plotting.matplot_dep import models_plots
|
||||||
|
kw = {}
|
||||||
|
return models_plots.plot_fit_errorbars(self, which_data_rows, which_data_ycols, fixed_inputs,
|
||||||
|
fignum, ax, linecol, data_symbol,
|
||||||
|
predict_kw, plot_training_data, **kw)
|
||||||
|
|
||||||
|
|
||||||
def plot_magnification(self, labels=None, which_indices=None,
|
def plot_magnification(self, labels=None, which_indices=None,
|
||||||
resolution=50, ax=None, marker='o', s=40,
|
resolution=50, ax=None, marker='o', s=40,
|
||||||
|
|
|
||||||
|
|
@ -47,6 +47,27 @@ def gpplot(x, mu, lower, upper, edgecol='#3300FF', fillcol='#33CCFF', ax=None, f
|
||||||
return plots
|
return plots
|
||||||
|
|
||||||
|
|
||||||
|
def gperrors(x, mu, lower, upper, edgecol=None, ax=None, fignum=None, **kwargs):
|
||||||
|
_, axes = ax_default(fignum, ax)
|
||||||
|
|
||||||
|
mu = mu.flatten()
|
||||||
|
x = x.flatten()
|
||||||
|
lower = lower.flatten()
|
||||||
|
upper = upper.flatten()
|
||||||
|
|
||||||
|
plots = []
|
||||||
|
|
||||||
|
if edgecol is None:
|
||||||
|
edgecol='#3300FF'
|
||||||
|
|
||||||
|
if not 'alpha' in kwargs.keys():
|
||||||
|
kwargs['alpha'] = 0.3
|
||||||
|
|
||||||
|
plots.append(axes.errorbar(x,mu,yerr=np.vstack([mu-lower,upper-mu]),color=edgecol,**kwargs))
|
||||||
|
plots[-1][0].remove()
|
||||||
|
return plots
|
||||||
|
|
||||||
|
|
||||||
def removeRightTicks(ax=None):
|
def removeRightTicks(ax=None):
|
||||||
ax = ax or pb.gca()
|
ax = ax or pb.gca()
|
||||||
for i, line in enumerate(ax.get_yticklines()):
|
for i, line in enumerate(ax.get_yticklines()):
|
||||||
|
|
|
||||||
|
|
@ -3,19 +3,79 @@
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from . import Tango
|
from . import Tango
|
||||||
from base_plots import gpplot, x_frame1D, x_frame2D
|
from base_plots import gpplot, x_frame1D, x_frame2D,gperrors
|
||||||
from ...models.gp_coregionalized_regression import GPCoregionalizedRegression
|
from ...models.gp_coregionalized_regression import GPCoregionalizedRegression
|
||||||
from ...models.sparse_gp_coregionalized_regression import SparseGPCoregionalizedRegression
|
from ...models.sparse_gp_coregionalized_regression import SparseGPCoregionalizedRegression
|
||||||
from scipy import sparse
|
from scipy import sparse
|
||||||
from ...core.parameterization.variational import VariationalPosterior
|
from ...core.parameterization.variational import VariationalPosterior
|
||||||
from matplotlib import pyplot as plt
|
from matplotlib import pyplot as plt
|
||||||
|
|
||||||
|
|
||||||
|
def plot_data(model, which_data_rows='all',
|
||||||
|
which_data_ycols='all', visible_dims=None,
|
||||||
|
fignum=None, ax=None, data_symbol='kx',mew=1.5):
|
||||||
|
"""
|
||||||
|
Plot the training data
|
||||||
|
- For higher dimensions than two, use fixed_inputs to plot the data points with some of the inputs fixed.
|
||||||
|
|
||||||
|
Can plot only part of the data
|
||||||
|
using which_data_rows and which_data_ycols.
|
||||||
|
|
||||||
|
:param which_data_rows: which of the training data to plot (default all)
|
||||||
|
:type which_data_rows: 'all' or a slice object to slice model.X, model.Y
|
||||||
|
:param which_data_ycols: when the data has several columns (independant outputs), only plot these
|
||||||
|
:type which_data_rows: 'all' or a list of integers
|
||||||
|
:param visible_dims: an array specifying the input dimensions to plot (maximum two)
|
||||||
|
:type visible_dims: a numpy array
|
||||||
|
:param fignum: figure to plot on.
|
||||||
|
:type fignum: figure number
|
||||||
|
:param ax: axes to plot on.
|
||||||
|
:type ax: axes handle
|
||||||
|
"""
|
||||||
|
#deal with optional arguments
|
||||||
|
if which_data_rows == 'all':
|
||||||
|
which_data_rows = slice(None)
|
||||||
|
if which_data_ycols == 'all':
|
||||||
|
which_data_ycols = np.arange(model.output_dim)
|
||||||
|
|
||||||
|
if ax is None:
|
||||||
|
fig = plt.figure(num=fignum)
|
||||||
|
ax = fig.add_subplot(111)
|
||||||
|
|
||||||
|
#data
|
||||||
|
X = model.X
|
||||||
|
Y = model.Y
|
||||||
|
|
||||||
|
#work out what the inputs are for plotting (1D or 2D)
|
||||||
|
if visible_dims is None:
|
||||||
|
visible_dims = np.arange(model.input_dim)
|
||||||
|
assert visible_dims.size <= 2, "Visible inputs cannot be larger than two"
|
||||||
|
free_dims = visible_dims
|
||||||
|
plots = {}
|
||||||
|
#one dimensional plotting
|
||||||
|
if len(free_dims) == 1:
|
||||||
|
|
||||||
|
for d in which_data_ycols:
|
||||||
|
plots['dataplot'] = ax.plot(X[which_data_rows,free_dims], Y[which_data_rows, d], data_symbol, mew=mew)
|
||||||
|
|
||||||
|
#2D plotting
|
||||||
|
elif len(free_dims) == 2:
|
||||||
|
|
||||||
|
for d in which_data_ycols:
|
||||||
|
plots['dataplot'] = ax.scatter(X[which_data_rows, free_dims[0]], X[which_data_rows, free_dims[1]], 40,
|
||||||
|
Y[which_data_rows, d], cmap=plt.cm.jet, vmin=Y.min(), vmax=Y.max(), linewidth=0.)
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise NotImplementedError("Cannot define a frame with more than two input dimensions")
|
||||||
|
return plots
|
||||||
|
|
||||||
|
|
||||||
def plot_fit(model, plot_limits=None, which_data_rows='all',
|
def plot_fit(model, plot_limits=None, which_data_rows='all',
|
||||||
which_data_ycols='all', fixed_inputs=[],
|
which_data_ycols='all', fixed_inputs=[],
|
||||||
levels=20, samples=0, fignum=None, ax=None, resolution=None,
|
levels=20, samples=0, fignum=None, ax=None, resolution=None,
|
||||||
plot_raw=False,
|
plot_raw=False,
|
||||||
linecol=Tango.colorsHex['darkBlue'],fillcol=Tango.colorsHex['lightBlue'], Y_metadata=None, data_symbol='kx',
|
linecol=Tango.colorsHex['darkBlue'],fillcol=Tango.colorsHex['lightBlue'], Y_metadata=None, data_symbol='kx',
|
||||||
apply_link=False, samples_f=0, plot_uncertain_inputs=True, predict_kw=None):
|
apply_link=False, samples_f=0, plot_uncertain_inputs=True, predict_kw=None, plot_training_data=True):
|
||||||
"""
|
"""
|
||||||
Plot the posterior of the GP.
|
Plot the posterior of the GP.
|
||||||
- In one dimension, the function is plotted with a shaded region identifying two standard deviations.
|
- In one dimension, the function is plotted with a shaded region identifying two standard deviations.
|
||||||
|
|
@ -43,7 +103,6 @@ def plot_fit(model, plot_limits=None, which_data_rows='all',
|
||||||
:type fignum: figure number
|
:type fignum: figure number
|
||||||
:param ax: axes to plot on.
|
:param ax: axes to plot on.
|
||||||
:type ax: axes handle
|
:type ax: axes handle
|
||||||
:type output: integer (first output is 0)
|
|
||||||
:param linecol: color of line to plot.
|
:param linecol: color of line to plot.
|
||||||
:type linecol:
|
:type linecol:
|
||||||
:param fillcol: color of fill
|
:param fillcol: color of fill
|
||||||
|
|
@ -52,6 +111,8 @@ def plot_fit(model, plot_limits=None, which_data_rows='all',
|
||||||
:type apply_link: boolean
|
:type apply_link: boolean
|
||||||
:param samples_f: the number of posteriori f samples to plot p(f*|y)
|
:param samples_f: the number of posteriori f samples to plot p(f*|y)
|
||||||
:type samples_f: int
|
:type samples_f: int
|
||||||
|
:param plot_training_data: whether or not to plot the training points
|
||||||
|
:type plot_training_data: boolean
|
||||||
"""
|
"""
|
||||||
#deal with optional arguments
|
#deal with optional arguments
|
||||||
if which_data_rows == 'all':
|
if which_data_rows == 'all':
|
||||||
|
|
@ -116,7 +177,11 @@ def plot_fit(model, plot_limits=None, which_data_rows='all',
|
||||||
|
|
||||||
for d in which_data_ycols:
|
for d in which_data_ycols:
|
||||||
plots['gpplot'] = gpplot(Xnew, m[:, d], lower[:, d], upper[:, d], ax=ax, edgecol=linecol, fillcol=fillcol)
|
plots['gpplot'] = gpplot(Xnew, m[:, d], lower[:, d], upper[:, d], ax=ax, edgecol=linecol, fillcol=fillcol)
|
||||||
if not plot_raw: plots['dataplot'] = ax.plot(X[which_data_rows,free_dims], Y[which_data_rows, d], data_symbol, mew=1.5)
|
#if not plot_raw: plots['dataplot'] = ax.plot(X[which_data_rows,free_dims], Y[which_data_rows, d], data_symbol, mew=1.5)
|
||||||
|
if not plot_raw and plot_training_data:
|
||||||
|
plots['dataplot'] = plot_data(model=model, which_data_rows=which_data_rows,
|
||||||
|
visible_dims=free_dims, data_symbol=data_symbol, mew=1.5, ax=ax, fignum=fignum)
|
||||||
|
|
||||||
|
|
||||||
#optionally plot some samples
|
#optionally plot some samples
|
||||||
if samples: #NOTE not tested with fixed_inputs
|
if samples: #NOTE not tested with fixed_inputs
|
||||||
|
|
@ -196,7 +261,9 @@ def plot_fit(model, plot_limits=None, which_data_rows='all',
|
||||||
for d in which_data_ycols:
|
for d in which_data_ycols:
|
||||||
m_d = m[:,d].reshape(resolution, resolution).T
|
m_d = m[:,d].reshape(resolution, resolution).T
|
||||||
plots['contour'] = ax.contour(x, y, m_d, levels, vmin=m.min(), vmax=m.max(), cmap=plt.cm.jet)
|
plots['contour'] = ax.contour(x, y, m_d, levels, vmin=m.min(), vmax=m.max(), cmap=plt.cm.jet)
|
||||||
if not plot_raw: plots['dataplot'] = ax.scatter(X[which_data_rows, free_dims[0]], X[which_data_rows, free_dims[1]], 40, Y[which_data_rows, d], cmap=plt.cm.jet, vmin=m.min(), vmax=m.max(), linewidth=0.)
|
#if not plot_raw: plots['dataplot'] = ax.scatter(X[which_data_rows, free_dims[0]], X[which_data_rows, free_dims[1]], 40, Y[which_data_rows, d], cmap=plt.cm.jet, vmin=m.min(), vmax=m.max(), linewidth=0.)
|
||||||
|
if not plot_raw and plot_training_data:
|
||||||
|
plots['dataplot'] = ax.scatter(X[which_data_rows, free_dims[0]], X[which_data_rows, free_dims[1]], 40, Y[which_data_rows, d], cmap=plt.cm.jet, vmin=m.min(), vmax=m.max(), linewidth=0.)
|
||||||
|
|
||||||
#set the limits of the plot to some sensible values
|
#set the limits of the plot to some sensible values
|
||||||
ax.set_xlim(xmin[0], xmax[0])
|
ax.set_xlim(xmin[0], xmax[0])
|
||||||
|
|
@ -261,3 +328,83 @@ def fixed_inputs(model, non_fixed_inputs, fix_routine='median', as_list=True, X_
|
||||||
return f_inputs
|
return f_inputs
|
||||||
else:
|
else:
|
||||||
return X
|
return X
|
||||||
|
|
||||||
|
|
||||||
|
def plot_fit_errorbars(model, which_data_rows='all',
|
||||||
|
which_data_ycols='all', fixed_inputs=[],
|
||||||
|
fignum=None, ax=None,
|
||||||
|
linecol='red', data_symbol='kx',
|
||||||
|
predict_kw=None, plot_training_data=True):
|
||||||
|
|
||||||
|
"""
|
||||||
|
Plot the posterior error bars corresponding to the training data
|
||||||
|
- For higher dimensions than two, use fixed_inputs to plot the data points with some of the inputs fixed.
|
||||||
|
|
||||||
|
Can plot only part of the data
|
||||||
|
using which_data_rows and which_data_ycols.
|
||||||
|
|
||||||
|
:param which_data_rows: which of the training data to plot (default all)
|
||||||
|
:type which_data_rows: 'all' or a slice object to slice model.X, model.Y
|
||||||
|
:param which_data_ycols: when the data has several columns (independant outputs), only plot these
|
||||||
|
:type which_data_rows: 'all' or a list of integers
|
||||||
|
:param fixed_inputs: a list of tuple [(i,v), (i,v)...], specifying that input index i should be set to value v.
|
||||||
|
:type fixed_inputs: a list of tuples
|
||||||
|
:param fignum: figure to plot on.
|
||||||
|
:type fignum: figure number
|
||||||
|
:param ax: axes to plot on.
|
||||||
|
:type ax: axes handle
|
||||||
|
:param plot_training_data: whether or not to plot the training points
|
||||||
|
:type plot_training_data: boolean
|
||||||
|
"""
|
||||||
|
|
||||||
|
#deal with optional arguments
|
||||||
|
if which_data_rows == 'all':
|
||||||
|
which_data_rows = slice(None)
|
||||||
|
if which_data_ycols == 'all':
|
||||||
|
which_data_ycols = np.arange(model.output_dim)
|
||||||
|
|
||||||
|
if ax is None:
|
||||||
|
fig = plt.figure(num=fignum)
|
||||||
|
ax = fig.add_subplot(111)
|
||||||
|
|
||||||
|
X = model.X
|
||||||
|
Y = model.Y
|
||||||
|
|
||||||
|
if predict_kw is None:
|
||||||
|
predict_kw = {}
|
||||||
|
|
||||||
|
|
||||||
|
#work out what the inputs are for plotting (1D or 2D)
|
||||||
|
fixed_dims = np.array([i for i,v in fixed_inputs])
|
||||||
|
free_dims = np.setdiff1d(np.arange(model.input_dim),fixed_dims)
|
||||||
|
plots = {}
|
||||||
|
|
||||||
|
#one dimensional plotting
|
||||||
|
if len(free_dims) == 1:
|
||||||
|
|
||||||
|
m, v = model.predict(X, full_cov=False, Y_metadata=model.Y_metadata, **predict_kw)
|
||||||
|
fmu, fv = model._raw_predict(X, full_cov=False, **predict_kw)
|
||||||
|
lower, upper = model.likelihood.predictive_quantiles(fmu, fv, (2.5, 97.5), Y_metadata=model.Y_metadata)
|
||||||
|
|
||||||
|
|
||||||
|
for d in which_data_ycols:
|
||||||
|
plots['gperrors'] = gperrors(X, m[:, d], lower[:, d], upper[:, d], edgecol=linecol, ax=ax, fignum=fignum )
|
||||||
|
if plot_training_data:
|
||||||
|
plots['dataplot'] = plot_data(model=model, which_data_rows=which_data_rows,
|
||||||
|
visible_dims=free_dims, data_symbol=data_symbol, mew=1.5, ax=ax, fignum=fignum)
|
||||||
|
|
||||||
|
|
||||||
|
#set the limits of the plot to some sensible values
|
||||||
|
ymin, ymax = min(np.append(Y[which_data_rows, which_data_ycols].flatten(), lower)), max(np.append(Y[which_data_rows, which_data_ycols].flatten(), upper))
|
||||||
|
ymin, ymax = ymin - 0.1 * (ymax - ymin), ymax + 0.1 * (ymax - ymin)
|
||||||
|
ax.set_xlim(X[:,free_dims].min(), X[:,free_dims].max())
|
||||||
|
ax.set_ylim(ymin, ymax)
|
||||||
|
|
||||||
|
|
||||||
|
elif len(free_dims) == 2:
|
||||||
|
raise NotImplementedError("Not implemented yet")
|
||||||
|
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise NotImplementedError("Cannot define a frame with more than two input dimensions")
|
||||||
|
return plots
|
||||||
|
|
|
||||||
|
|
@ -7,48 +7,13 @@
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from scipy import linalg
|
from scipy import linalg
|
||||||
import types
|
from scipy.linalg import lapack, blas
|
||||||
import ctypes
|
|
||||||
from ctypes import byref, c_char, c_int, c_double # TODO
|
|
||||||
import scipy
|
|
||||||
import warnings
|
|
||||||
import os
|
|
||||||
from .config import config
|
from .config import config
|
||||||
import logging
|
import logging
|
||||||
from . import linalg_cython
|
from . import linalg_cython
|
||||||
|
|
||||||
|
|
||||||
_scipyversion = np.float64((scipy.__version__).split('.')[:2])
|
|
||||||
_fix_dpotri_scipy_bug = True
|
|
||||||
if np.all(_scipyversion >= np.array([0, 14])):
|
|
||||||
from scipy.linalg import lapack
|
|
||||||
_fix_dpotri_scipy_bug = False
|
|
||||||
elif np.all(_scipyversion >= np.array([0, 12])):
|
|
||||||
#import scipy.linalg.lapack.clapack as lapack
|
|
||||||
from scipy.linalg import lapack
|
|
||||||
else:
|
|
||||||
from scipy.linalg.lapack import flapack as lapack
|
|
||||||
|
|
||||||
if config.getboolean('anaconda', 'installed') and config.getboolean('anaconda', 'MKL'):
|
|
||||||
try:
|
|
||||||
anaconda_path = str(config.get('anaconda', 'location'))
|
|
||||||
mkl_rt = ctypes.cdll.LoadLibrary(os.path.join(anaconda_path, 'DLLs', 'mkl_rt.dll'))
|
|
||||||
dsyrk = mkl_rt.dsyrk
|
|
||||||
dsyr = mkl_rt.dsyr
|
|
||||||
_blas_available = True
|
|
||||||
print('anaconda installed and mkl is loaded')
|
|
||||||
except:
|
|
||||||
_blas_available = False
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
_blaslib = ctypes.cdll.LoadLibrary(np.core._dotblas.__file__) # @UndefinedVariable
|
|
||||||
dsyrk = _blaslib.dsyrk_
|
|
||||||
dsyr = _blaslib.dsyr_
|
|
||||||
_blas_available = True
|
|
||||||
except AttributeError as e:
|
|
||||||
_blas_available = False
|
|
||||||
warnings.warn("warning: caught this exception:" + str(e))
|
|
||||||
|
|
||||||
def force_F_ordered_symmetric(A):
|
def force_F_ordered_symmetric(A):
|
||||||
"""
|
"""
|
||||||
return a F ordered version of A, assuming A is symmetric
|
return a F ordered version of A, assuming A is symmetric
|
||||||
|
|
@ -169,9 +134,6 @@ def dpotri(A, lower=1):
|
||||||
:returns: A inverse
|
:returns: A inverse
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if _fix_dpotri_scipy_bug:
|
|
||||||
assert lower==1, "scipy linalg behaviour is very weird. please use lower, fortran ordered arrays"
|
|
||||||
lower = 0
|
|
||||||
|
|
||||||
A = force_F_ordered(A)
|
A = force_F_ordered(A)
|
||||||
R, info = lapack.dpotri(A, lower=lower) #needs to be zero here, seems to be a scipy bug
|
R, info = lapack.dpotri(A, lower=lower) #needs to be zero here, seems to be a scipy bug
|
||||||
|
|
@ -300,8 +262,8 @@ def pca(Y, input_dim):
|
||||||
Z = linalg.svd(Y - Y.mean(axis=0), full_matrices=False)
|
Z = linalg.svd(Y - Y.mean(axis=0), full_matrices=False)
|
||||||
[X, W] = [Z[0][:, 0:input_dim], np.dot(np.diag(Z[1]), Z[2]).T[:, 0:input_dim]]
|
[X, W] = [Z[0][:, 0:input_dim], np.dot(np.diag(Z[1]), Z[2]).T[:, 0:input_dim]]
|
||||||
v = X.std(axis=0)
|
v = X.std(axis=0)
|
||||||
X /= v;
|
X /= v
|
||||||
W *= v;
|
W *= v
|
||||||
return X, W.T
|
return X, W.T
|
||||||
|
|
||||||
def ppca(Y, Q, iterations=100):
|
def ppca(Y, Q, iterations=100):
|
||||||
|
|
@ -347,34 +309,15 @@ def tdot_blas(mat, out=None):
|
||||||
out[:] = 0.0
|
out[:] = 0.0
|
||||||
|
|
||||||
# # Call to DSYRK from BLAS
|
# # Call to DSYRK from BLAS
|
||||||
# If already in Fortran order (rare), and has the right sorts of strides I
|
|
||||||
# could avoid the copy. I also thought swapping to cblas API would allow use
|
|
||||||
# of C order. However, I tried that and had errors with large matrices:
|
|
||||||
# http://homepages.inf.ed.ac.uk/imurray2/code/tdot/tdot_broken.py
|
|
||||||
mat = np.asfortranarray(mat)
|
mat = np.asfortranarray(mat)
|
||||||
TRANS = c_char('n'.encode('ascii'))
|
out = blas.dsyrk(alpha=1.0, a=mat, beta=0.0, c=out, overwrite_c=1,
|
||||||
N = c_int(mat.shape[0])
|
trans=0, lower=0)
|
||||||
K = c_int(mat.shape[1])
|
|
||||||
LDA = c_int(mat.shape[0])
|
|
||||||
UPLO = c_char('l'.encode('ascii'))
|
|
||||||
ALPHA = c_double(1.0)
|
|
||||||
A = mat.ctypes.data_as(ctypes.c_void_p)
|
|
||||||
BETA = c_double(0.0)
|
|
||||||
C = out.ctypes.data_as(ctypes.c_void_p)
|
|
||||||
LDC = c_int(np.max(out.strides) // 8)
|
|
||||||
dsyrk(byref(UPLO), byref(TRANS), byref(N), byref(K),
|
|
||||||
byref(ALPHA), A, byref(LDA), byref(BETA), C, byref(LDC))
|
|
||||||
|
|
||||||
symmetrify(out, upper=True)
|
symmetrify(out, upper=True)
|
||||||
|
|
||||||
|
|
||||||
return np.ascontiguousarray(out)
|
return np.ascontiguousarray(out)
|
||||||
|
|
||||||
def tdot(*args, **kwargs):
|
def tdot(*args, **kwargs):
|
||||||
if _blas_available:
|
return tdot_blas(*args, **kwargs)
|
||||||
return tdot_blas(*args, **kwargs)
|
|
||||||
else:
|
|
||||||
return tdot_numpy(*args, **kwargs)
|
|
||||||
|
|
||||||
def DSYR_blas(A, x, alpha=1.):
|
def DSYR_blas(A, x, alpha=1.):
|
||||||
"""
|
"""
|
||||||
|
|
@ -386,15 +329,7 @@ def DSYR_blas(A, x, alpha=1.):
|
||||||
:param alpha: scalar
|
:param alpha: scalar
|
||||||
|
|
||||||
"""
|
"""
|
||||||
N = c_int(A.shape[0])
|
A = blas.dsyr(lower=0, x=x, a=A, alpha=alpha, overwrite_a=True)
|
||||||
LDA = c_int(A.shape[0])
|
|
||||||
UPLO = c_char('l'.encode('ascii'))
|
|
||||||
ALPHA = c_double(alpha)
|
|
||||||
A_ = A.ctypes.data_as(ctypes.c_void_p)
|
|
||||||
x_ = x.ctypes.data_as(ctypes.c_void_p)
|
|
||||||
INCX = c_int(1)
|
|
||||||
dsyr(byref(UPLO), byref(N), byref(ALPHA),
|
|
||||||
x_, byref(INCX), A_, byref(LDA))
|
|
||||||
symmetrify(A, upper=True)
|
symmetrify(A, upper=True)
|
||||||
|
|
||||||
def DSYR_numpy(A, x, alpha=1.):
|
def DSYR_numpy(A, x, alpha=1.):
|
||||||
|
|
@ -411,10 +346,8 @@ def DSYR_numpy(A, x, alpha=1.):
|
||||||
|
|
||||||
|
|
||||||
def DSYR(*args, **kwargs):
|
def DSYR(*args, **kwargs):
|
||||||
if _blas_available:
|
return DSYR_blas(*args, **kwargs)
|
||||||
return DSYR_blas(*args, **kwargs)
|
|
||||||
else:
|
|
||||||
return DSYR_numpy(*args, **kwargs)
|
|
||||||
|
|
||||||
def symmetrify(A, upper=False):
|
def symmetrify(A, upper=False):
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue