added slicing to kern.py

This commit is contained in:
Max Zwiessele 2013-09-03 10:04:33 +01:00
parent 917a4ebed3
commit a3d43553df

View file

@ -416,19 +416,19 @@ class kern(Parameterized):
# TODO: input_slices needed # TODO: input_slices needed
crossterms = 0 crossterms = 0
for p1, p2 in itertools.combinations(self.parts, 2): for [p1, i_s1], [p2, i_s2] in itertools.combinations(zip(self.parts, self.input_slices), 2):
if i_s1 == i_s2:
# TODO psi1 this must be faster/better/precached/more nice
tmp1 = np.zeros((mu.shape[0], Z.shape[0]))
p1.psi1(Z[:, i_s1], mu[:, i_s1], S[:, i_s1], tmp1)
tmp2 = np.zeros((mu.shape[0], Z.shape[0]))
p2.psi1(Z[:, i_s2], mu[:, i_s2], S[:, i_s2], tmp2)
prod = np.multiply(tmp1, tmp2)
crossterms += prod[:, :, None] + prod[:, None, :]
# TODO psi1 this must be faster/better/precached/more nice # target += crossterms
tmp1 = np.zeros((mu.shape[0], Z.shape[0])) return target + crossterms
p1.psi1(Z, mu, S, tmp1)
tmp2 = np.zeros((mu.shape[0], Z.shape[0]))
p2.psi1(Z, mu, S, tmp2)
prod = np.multiply(tmp1, tmp2)
crossterms += prod[:, :, None] + prod[:, None, :]
target += crossterms
return target
def dpsi2_dtheta(self, dL_dpsi2, Z, mu, S): def dpsi2_dtheta(self, dL_dpsi2, Z, mu, S):
"""Gradient of the psi2 statistics with respect to the parameters.""" """Gradient of the psi2 statistics with respect to the parameters."""