From 9867330861c6b515636c9c732771ed34c3ade677 Mon Sep 17 00:00:00 2001 From: James Hensman Date: Wed, 26 Feb 2014 14:30:28 +0000 Subject: [PATCH] moved plot functionality from add to kern --- GPy/kern/_src/add.py | 10 +--------- GPy/kern/_src/coregionalize.py | 11 ----------- GPy/kern/_src/kern.py | 8 ++++++++ GPy/kern/_src/rbf.py | 12 ++++++++---- 4 files changed, 17 insertions(+), 24 deletions(-) diff --git a/GPy/kern/_src/add.py b/GPy/kern/_src/add.py index 1466800c..77fe057d 100644 --- a/GPy/kern/_src/add.py +++ b/GPy/kern/_src/add.py @@ -184,20 +184,12 @@ class Add(Kern): target_S += b return target_mu, target_S - def plot(self, *args, **kwargs): - """ - See GPy.plotting.matplot_dep.plot - """ - assert "matplotlib" in sys.modules, "matplotlib package has not been imported." - from ..plotting.matplot_dep import kernel_plots - kernel_plots.plot(self,*args) - def input_sensitivity(self): in_sen = np.zeros((self.num_params, self.input_dim)) for i, [p, i_s] in enumerate(zip(self._parameters_, self.input_slices)): in_sen[i, i_s] = p.input_sensitivity() return in_sen - + def _getstate(self): """ Get the current state of the class, diff --git a/GPy/kern/_src/coregionalize.py b/GPy/kern/_src/coregionalize.py index cafdd5ee..6679eba4 100644 --- a/GPy/kern/_src/coregionalize.py +++ b/GPy/kern/_src/coregionalize.py @@ -60,17 +60,6 @@ class Coregionalize(Kern): def K(self, X, X2=None): index = np.asarray(X, dtype=np.int) - #here's the old code (numpy) - #if index2 is None: - #index2 = index - #else: - #index2 = np.asarray(index2, dtype=np.int) - #false_target = target.copy() - #ii, jj = np.meshgrid(index, index2) - #ii, jj = ii.T, jj.T - #false_target += self.B[ii, jj] - - if X2 is None: target = np.empty((X.shape[0], X.shape[0]), dtype=np.float64) code=""" diff --git a/GPy/kern/_src/kern.py b/GPy/kern/_src/kern.py index d7d8f9ca..eb3291e0 100644 --- a/GPy/kern/_src/kern.py +++ b/GPy/kern/_src/kern.py @@ -68,6 +68,14 @@ class Kern(Parameterized): """ raise NotImplementedError + def plot(self, *args, **kwargs): + """ + See GPy.plotting.matplot_dep.plot + """ + assert "matplotlib" in sys.modules, "matplotlib package has not been imported." + from ..plotting.matplot_dep import kernel_plots + kernel_plots.plot(self,*args) + def plot_ARD(self, *args, **kw): """ See :class:`~GPy.plotting.matplot_dep.kernel_plots` diff --git a/GPy/kern/_src/rbf.py b/GPy/kern/_src/rbf.py index 666a79f4..f5bafb48 100644 --- a/GPy/kern/_src/rbf.py +++ b/GPy/kern/_src/rbf.py @@ -178,6 +178,11 @@ class RBF(Stationary): return denom, dist, dist_sq, psi1 + #@cache_this(ignore_args=(1,)) + def _Z_distances(self, Z): + Zhat = 0.5 * (Z[:, None, :] + Z[None, :, :]) # M,M,Q + Zdist = 0.5 * (Z[:, None, :] - Z[None, :, :]) # M,M,Q + return Zhat, Zdist #@cache_this TODO def _psi2computations(self, Z, vp): @@ -187,8 +192,7 @@ class RBF(Stationary): M = Z.shape[0] #compute required distances - Zhat = 0.5 * (Z[:, None, :] + Z[None, :, :]) # M,M,Q - Zdist = 0.5 * (Z[:, None, :] - Z[None, :, :]) # M,M,Q + Zhat, Zdist = self._Z_distances(Z) Zdist_sq = np.square(Zdist / self.lengthscale) # M,M,Q #allocate memory for the things we want to compute @@ -201,7 +205,7 @@ class RBF(Stationary): denom = (2.*S[:,None,None,:] / l2) + 1. # N,Q half_log_denom = 0.5 * np.log(denom[:,0,0,:]) denom_l2 = denom[:,0,0,:]*l2 - + variance_sq = float(np.square(self.variance)) code = """ double tmp, exponent_tmp; @@ -237,7 +241,7 @@ class RBF(Stationary): } } """ - + support_code = """ #include #include