mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-06-11 15:15:15 +02:00
derivatives working in DTC
This commit is contained in:
parent
32deb2085d
commit
34dec0ade3
1 changed files with 7 additions and 14 deletions
|
|
@ -40,9 +40,9 @@ class DTC(object):
|
|||
Kmmi, L, Li, _ = pdinv(Kmm)
|
||||
|
||||
# Compute A
|
||||
LiUT, _ = dtrtrs(L, U.T*np.sqrt(beta), lower=1)
|
||||
A_I = tdot(LiUT)
|
||||
A = A_I + np.eye(num_inducing)
|
||||
#LiUT, _ = dtrtrs(L, U.T*np.sqrt(beta), lower=1)
|
||||
LiUT = np.dot(Li, U.T)*np.sqrt(beta)
|
||||
A = tdot(LiUT) + np.eye(num_inducing)
|
||||
|
||||
# factor A
|
||||
LA = jitchol(A)
|
||||
|
|
@ -63,24 +63,17 @@ class DTC(object):
|
|||
0.5*np.sum(np.square(b))
|
||||
|
||||
# Compute dL_dKmm
|
||||
tmp, _ = dtrtrs(L, A_I, lower=1, trans=1)
|
||||
dL_dK, _ = dtrtrs(L, tmp.T, lower=1, trans=0)
|
||||
tmp, _ = dtrtrs(LA, tmp.T, lower=1, trans=1)
|
||||
dL_dK -= tdot(tmp.T)
|
||||
dL_dK *= output_dim
|
||||
dL_dK -= tdot(v)
|
||||
dL_dK /=2.
|
||||
vvT_P = tdot(v.reshape(-1,1)) + P
|
||||
dL_dK = 0.5*(Kmmi - vvT_P)
|
||||
|
||||
# Compute dL_dU
|
||||
vvT_P = tdot(v.reshape(-1,1)) + P
|
||||
vY = np.dot(v.reshape(-1,1),Y.T)
|
||||
dL_dU = vY + np.dot(vvT_P, U.T)
|
||||
dL_dU = vY - np.dot(vvT_P, U.T)
|
||||
dL_dU *= beta
|
||||
|
||||
#compute dL_dR
|
||||
Uv = np.dot(U, v)
|
||||
dL_dR = 0.5*(np.sum(U*np.dot(U,P), 1) - beta * np.sum(np.square(Y), 1) - 2.*np.sum(Uv*Y, 1) + np.sum(np.square(Uv), 1)
|
||||
)*beta**2
|
||||
dL_dR = 0.5*(np.sum(U*np.dot(U,P), 1) - 1./beta + np.sum(np.square(Y), 1) - 2.*np.sum(Uv*Y, 1) + np.sum(np.square(Uv), 1))*beta**2
|
||||
|
||||
grad_dict = {'dL_dKmm': dL_dK, 'dL_dKdiag':np.zeros_like(Knn), 'dL_dKnm':dL_dU.T}
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue