mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-04-30 23:36:23 +02:00
changes Nparts for num_parts in kern
This commit is contained in:
parent
958e9f7c7a
commit
39eb0368d8
2 changed files with 16 additions and 8 deletions
|
|
@ -31,7 +31,7 @@ class kern(Parameterized):
|
|||
|
||||
"""
|
||||
self.parts = parts
|
||||
self.Nparts = len(parts)
|
||||
self.num_parts = len(parts)
|
||||
self.num_params = sum([p.num_params for p in self.parts])
|
||||
|
||||
self.input_dim = input_dim
|
||||
|
|
@ -61,7 +61,7 @@ class kern(Parameterized):
|
|||
here just all the indices, rest can get recomputed
|
||||
"""
|
||||
return Parameterized.getstate(self) + [self.parts,
|
||||
self.Nparts,
|
||||
self.num_parts,
|
||||
self.num_params,
|
||||
self.input_dim,
|
||||
self.input_slices,
|
||||
|
|
@ -73,7 +73,7 @@ class kern(Parameterized):
|
|||
self.input_slices = state.pop()
|
||||
self.input_dim = state.pop()
|
||||
self.num_params = state.pop()
|
||||
self.Nparts = state.pop()
|
||||
self.num_parts = state.pop()
|
||||
self.parts = state.pop()
|
||||
Parameterized.setstate(self, state)
|
||||
|
||||
|
|
@ -308,7 +308,7 @@ class kern(Parameterized):
|
|||
|
||||
def K(self, X, X2=None, which_parts='all'):
|
||||
if which_parts == 'all':
|
||||
which_parts = [True] * self.Nparts
|
||||
which_parts = [True] * self.num_parts
|
||||
assert X.shape[1] == self.input_dim
|
||||
if X2 is None:
|
||||
target = np.zeros((X.shape[0], X.shape[0]))
|
||||
|
|
@ -359,7 +359,7 @@ class kern(Parameterized):
|
|||
def Kdiag(self, X, which_parts='all'):
|
||||
"""Compute the diagonal of the covariance function for inputs X."""
|
||||
if which_parts == 'all':
|
||||
which_parts = [True] * self.Nparts
|
||||
which_parts = [True] * self.num_parts
|
||||
assert X.shape[1] == self.input_dim
|
||||
target = np.zeros(X.shape[0])
|
||||
[p.Kdiag(X[:, i_s], target=target) for p, i_s, part_on in zip(self.parts, self.input_slices, which_parts) if part_on]
|
||||
|
|
@ -497,7 +497,7 @@ class kern(Parameterized):
|
|||
|
||||
def plot(self, x=None, plot_limits=None, which_parts='all', resolution=None, *args, **kwargs):
|
||||
if which_parts == 'all':
|
||||
which_parts = [True] * self.Nparts
|
||||
which_parts = [True] * self.num_parts
|
||||
if self.input_dim == 1:
|
||||
if x is None:
|
||||
x = np.zeros((1, 1))
|
||||
|
|
|
|||
|
|
@ -7,6 +7,13 @@ import GPy
|
|||
|
||||
verbose = False
|
||||
|
||||
try:
|
||||
import sympy
|
||||
SYMPY_AVAILABLE=True
|
||||
except ImportError:
|
||||
SYMPY_AVAILABLE=False
|
||||
|
||||
|
||||
class KernelTests(unittest.TestCase):
|
||||
def test_kerneltie(self):
|
||||
K = GPy.kern.rbf(5, ARD=True)
|
||||
|
|
@ -22,8 +29,9 @@ class KernelTests(unittest.TestCase):
|
|||
self.assertTrue(GPy.kern.kern_test(kern, verbose=verbose))
|
||||
|
||||
def test_rbf_sympykernel(self):
|
||||
kern = GPy.kern.rbf_sympy(5)
|
||||
self.assertTrue(GPy.kern.kern_test(kern, verbose=verbose))
|
||||
if SYMPY_AVAILABLE:
|
||||
kern = GPy.kern.rbf_sympy(5)
|
||||
self.assertTrue(GPy.kern.kern_test(kern, verbose=verbose))
|
||||
|
||||
def test_rbf_invkernel(self):
|
||||
kern = GPy.kern.rbf_inv(5)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue