Merge branch 'devel' of github.com:SheffieldML/GPy into devel

This commit is contained in:
James Hensman 2013-07-08 13:06:11 +01:00
commit 9c85911b87
37 changed files with 767 additions and 730 deletions

View file

@ -3,12 +3,12 @@
import numpy as np
import pylab as pb
from ..core.parameterised import Parameterised
from ..core.parameterized import Parameterized
from parts.kernpart import Kernpart
import itertools
from parts.prod import Prod as prod
class kern(Parameterised):
class kern(Parameterized):
def __init__(self, input_dim, parts=[], input_slices=None):
"""
This is the main kernel class for GPy. It handles multiple (additive) kernel functions, and keeps track of variaous things like which parameters live where.
@ -41,26 +41,51 @@ class kern(Parameterised):
self.compute_param_slices()
Parameterised.__init__(self)
Parameterized.__init__(self)
def getstate(self):
"""
Get the current state of the class,
here just all the indices, rest can get recomputed
"""
return Parameterized.getstate(self) + [self.parts,
self.Nparts,
self.num_params,
self.input_dim,
self.input_slices,
self.param_slices
]
def setstate(self, state):
self.param_slices = state.pop()
self.input_slices = state.pop()
self.input_dim = state.pop()
self.num_params = state.pop()
self.Nparts = state.pop()
self.parts = state.pop()
Parameterized.setstate(self, state)
def plot_ARD(self, fignum=None, ax=None):
def plot_ARD(self, fignum=None, ax=None, title=None):
"""If an ARD kernel is present, it bar-plots the ARD parameters"""
if ax is None:
fig = pb.figure(fignum)
ax = fig.add_subplot(111)
for p in self.parts:
if hasattr(p, 'ARD') and p.ARD:
ax.set_title('ARD parameters, %s kernel' % p.name)
if title is None:
ax.set_title('ARD parameters, %s kernel' % p.name)
else:
ax.set_title(title)
if p.name == 'linear':
ard_params = p.variances
else:
ard_params = 1. / p.lengthscale
ax.bar(np.arange(len(ard_params)) - 0.4, ard_params)
ax.set_xticks(np.arange(len(ard_params)))
ax.set_xticklabels([r"${}$".format(i) for i in range(len(ard_params))])
x = np.arange(len(ard_params))
ax.bar(x - 0.4, ard_params)
ax.set_xticks(x)
ax.set_xticklabels([r"${}$".format(i) for i in x])
return ax
def _transform_gradients(self, g):