[ard] enhanced ard handling and plotting

Conflicts:
	GPy/kern/_src/linear.py
	GPy/models/ss_gplvm.py
This commit is contained in:
mzwiessele 2014-08-25 09:46:20 -07:00
parent 3972b4bd9a
commit d000893878
8 changed files with 323 additions and 118 deletions

View file

@ -14,6 +14,13 @@ class Add(CombinationKernel):
This kernel will take over the active dims of it's subkernels passed in.
"""
def __init__(self, subkerns, name='add'):
for i, kern in enumerate(subkerns[:]):
if isinstance(kern, Add):
del subkerns[i]
for part in kern.parts[::-1]:
kern.remove_parameter(part)
subkerns.insert(i, part)
super(Add, self).__init__(subkerns, name)
@Cache_this(limit=2, force_kwargs=['which_parts'])
@ -160,7 +167,7 @@ class Add(CombinationKernel):
[np.add(target_grads[i],grads[i],target_grads[i]) for i in xrange(len(grads))]
return target_grads
def add(self, other, name='sum'):
def add(self, other):
if isinstance(other, Add):
other_params = other.parameters[:]
for p in other_params:
@ -171,5 +178,11 @@ class Add(CombinationKernel):
self.input_dim, self.active_dims = self.get_input_dim_active_dims(self.parts)
return self
def input_sensitivity(self):
return reduce(np.add, [k.input_sensitivity() for k in self.parts])
def input_sensitivity(self, summarize=True):
if summarize:
return reduce(np.add, [k.input_sensitivity(summarize) for k in self.parts])
else:
i_s = np.zeros((len(self.parts), self.input_dim))
from operator import setitem
[setitem(i_s, (i, Ellipsis), k.input_sensitivity(summarize)) for i, k in enumerate(self.parts)]
return i_s