working on cross terms

This commit is contained in:
Nicolo Fusi 2013-01-30 11:18:22 +00:00
parent 6959a905dc
commit ece4e2442c
2 changed files with 6 additions and 5 deletions

View file

@ -273,7 +273,7 @@ class kern(parameterised):
return target
def dpsi2_dtheta(self,partial,Z,mu,S,slices1=None,slices2=None):
def dpsi2_dtheta(self,partial,partial1,Z,mu,S,slices1=None,slices2=None):
"""Returns shape (N,M,M,Ntheta)"""
slices1, slices2 = self._process_slices(slices1,slices2)
target = np.zeros(self.Nparam)
@ -286,12 +286,13 @@ class kern(parameterised):
[p.psi1(Z[s2],mu[s1],S[s1],psi1_target[s1,s2]) for p,s1,s2,psi1_target in zip(self.parts,slices1,slices2, psi1_matrices)]
# 2. get all the dpsi1/dtheta gradients
psi1_gradients = [np.zeros(self.Nparam) for p in self.parts]
[p.dpsi1_dtheta(partial[s2,s1],Z[s2,i_s],mu[s1,i_s],S[s1,i_s],target[ps]) for p,ps,s1,s2,i_s in zip(self.parts, self.param_slices,slices1,slices2,self.input_slices)]
[p.dpsi1_dtheta(partial1[s2,s1],Z[s2,i_s],mu[s1,i_s],S[s1,i_s],psi1g_target[ps]) for p,ps,s1,s2,i_s,psi1g_target in zip(self.parts, self.param_slices,slices1,slices2,self.input_slices,psi1_gradients)]
# 3. multiply them somehow
for a,b in itertools.combinations(range(len(psi1_matrices)), 2):
psi2_cross = np.multiply(psi1, psi1_grad) # some newaxis of this
target += psi2_cross[:,None,:] + psi2_cross[:, :,None]
gne = (psi1_gradients[a][None]*psi1_matrices[b].sum(0)[:,None]).sum(0)
target += 0#(gne[None] + gne[:, None]).sum(0)
return target
def dpsi2_dZ(self,partial,Z,mu,S,slices1=None,slices2=None):