mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-18 13:55:14 +02:00
Now checkgrads for gaussian, and ALMOST for student t
This commit is contained in:
parent
e36ffcba6e
commit
b663fff622
2 changed files with 119 additions and 71 deletions
|
|
@ -7,6 +7,7 @@ from likelihood import likelihood
|
|||
from ..util.linalg import pdinv, mdot, jitchol, chol_inv, pddet
|
||||
from scipy.linalg.lapack import dtrtrs
|
||||
import random
|
||||
from functools import partial
|
||||
#import pylab as plt
|
||||
|
||||
class Laplace(likelihood):
|
||||
|
|
@ -87,11 +88,15 @@ class Laplace(likelihood):
|
|||
|
||||
#Implicit
|
||||
impl = mdot(dlp, dL_dfhat, I_KW_i)
|
||||
expl_a = mdot(self.Ki_f, self.Ki_f.T)
|
||||
#expl_a = mdot(self.Ki_f, self.Ki_f.T)
|
||||
expl_a = np.dot(self.Ki_f, self.Ki_f.T)
|
||||
expl_b = self.Wi_K_i
|
||||
#print "expl_a: {}, expl_b: {}".format(expl_a, expl_b)
|
||||
expl = 0.5*expl_a + 0.5*expl_b # Might need to be -?
|
||||
dL_dthetaK_exp = dK_dthetaK(expl, X)
|
||||
#expl = 0.5*expl_a - 0.5*expl_b # Might need to be -?
|
||||
#dL_dthetaK_exp = dK_dthetaK(expl, X)
|
||||
dL_dthetaK_exp_a = dK_dthetaK(expl_a, X)
|
||||
dL_dthetaK_exp_b = dK_dthetaK(expl_b, X)
|
||||
dL_dthetaK_exp = 0.5*dL_dthetaK_exp_a - 0.5*dL_dthetaK_exp_b
|
||||
dL_dthetaK_imp = dK_dthetaK(impl, X)
|
||||
#print "dL_dthetaK_exp: {} dL_dthetaK_implicit: {}".format(dL_dthetaK_exp, dL_dthetaK_imp)
|
||||
#print "expl_a: {}, {} expl_b: {}, {}".format(np.mean(expl_a), np.std(expl_a), np.mean(expl_b), np.std(expl_b))
|
||||
|
|
@ -116,7 +121,13 @@ class Laplace(likelihood):
|
|||
#b = 0.5*np.dot(np.diag(e).T, d)
|
||||
#g = 0.5*(np.diag(self.K) - np.sum(cho_solve((self.B_chol, True), np.dot(np.diagflat(self.W_12),self.K))**2, 1))
|
||||
#dL_dthetaL_exp = np.sum(dlik_dthetaL[thetaL_i]) - np.dot(g.T, dlik_hess_dthetaL[thetaL_i])
|
||||
dL_dthetaL_exp = np.sum(dlik_dthetaL[thetaL_i]) - 0.5*np.dot(np.diag(self.Ki_W_i), dlik_hess_dthetaL[thetaL_i])
|
||||
|
||||
#dL_dthetaL_exp = np.sum(dlik_dthetaL[thetaL_i]) - 0.5*np.dot(np.diag(self.Ki_W_i), dlik_hess_dthetaL[thetaL_i])
|
||||
dL_dthetaL_exp = ( np.sum(dlik_dthetaL[thetaL_i])
|
||||
#- 0.5*np.trace(mdot(self.Ki_W_i, (self.K, np.diagflat(dlik_hess_dthetaL[thetaL_i]))))
|
||||
+ np.dot(0.5*np.diag(self.Ki_W_i)[:,None].T, dlik_hess_dthetaL[thetaL_i])
|
||||
)
|
||||
import ipdb; ipdb.set_trace() # XXX BREAKPOINT
|
||||
|
||||
#Implicit
|
||||
df_hat_dthetaL = mdot(I_KW_i, self.K, dlik_grad_dthetaL[thetaL_i])
|
||||
|
|
@ -168,22 +179,31 @@ class Laplace(likelihood):
|
|||
Y_tilde = Wi*self.Ki_f + self.f_hat
|
||||
|
||||
self.Wi_K_i = self.W_12*self.Bi*self.W_12.T #same as rasms R
|
||||
#self.Wi_K_i, _, _, self.ln_det_Wi_K = pdinv(self.Sigma_tilde + self.K) # TODO: Check if Wi_K_i == R above and same with det below
|
||||
self.ln_det_Wi_K = pddet(self.Sigma_tilde + self.K)
|
||||
|
||||
#self.Wi_K_i[self.Wi_K_i< 1e-6] = 1e-6
|
||||
|
||||
self.ln_det_K_Wi__Bi = self.ln_I_KW_det + pddet(self.Sigma_tilde + self.K)
|
||||
#self.ln_det_K_Wi__Bi = self.ln_I_KW_det + pddet(self.Sigma_tilde + self.K)
|
||||
self.lik = self.likelihood_function.link_function(self.data, self.f_hat, extra_data=self.extra_data)
|
||||
|
||||
self.y_Wi_Ki_i_y = mdot(Y_tilde.T, self.Wi_K_i, Y_tilde)
|
||||
self.aA = 0.5*self.ln_det_K_Wi__Bi
|
||||
self.bB = - 0.5*self.f_Ki_f
|
||||
self.cC = 0.5*self.y_Wi_Ki_i_y
|
||||
#self.aA = 0.5*self.ln_det_K_Wi__Bi
|
||||
#self.bB = - 0.5*self.f_Ki_f
|
||||
#self.cC = 0.5*self.y_Wi_Ki_i_y
|
||||
Z_tilde = (+ self.lik
|
||||
+ 0.5*self.ln_det_K_Wi__Bi
|
||||
#+ 0.5*self.ln_det_K_Wi__Bi
|
||||
- 0.5*self.ln_B_det
|
||||
+ 0.5*self.ln_det_Wi_K
|
||||
- 0.5*self.f_Ki_f
|
||||
+ 0.5*self.y_Wi_Ki_i_y
|
||||
)
|
||||
print "Ztilde: {} lik: {} a: {} b: {} c: {}".format(Z_tilde, self.lik, self.aA, self.bB, self.cC)
|
||||
print self.likelihood_function._get_params()
|
||||
#self.aA = 0.5*self.ln_det_Wi_K
|
||||
#self.bB = - 0.5*self.f_Ki_f
|
||||
#self.cC = 0.5*self.y_Wi_Ki_i_y
|
||||
#self.dD = -0.5*self.ln_B_det
|
||||
#print "Ztilde: {} lik: {} a: {} b: {} c: {} d:".format(Z_tilde, self.lik, self.aA, self.bB, self.cC, self.dD)
|
||||
print "param value: {}".format(self.likelihood_function._get_params())
|
||||
|
||||
#Convert to float as its (1, 1) and Z must be a scalar
|
||||
self.Z = np.float64(Z_tilde)
|
||||
|
|
@ -222,7 +242,7 @@ class Laplace(likelihood):
|
|||
|
||||
#TODO: Could save on computation when using rasm by returning these, means it isn't just a "mode finder" though
|
||||
self.B, self.B_chol, self.W_12 = self._compute_B_statistics(self.K, self.W)
|
||||
self.Bi, _, _, B_det = pdinv(self.B)
|
||||
self.Bi, _, _, self.ln_B_det = pdinv(self.B)
|
||||
|
||||
#Do the computation again at f to get Ki_f which is useful
|
||||
#b = self.W*self.f_hat + self.likelihood_function.dlik_df(self.data, self.f_hat, extra_data=self.extra_data)
|
||||
|
|
@ -234,7 +254,7 @@ class Laplace(likelihood):
|
|||
self.Ki_W_i = self.K - mdot(self.K, self.W_12*self.Bi*self.W_12.T, self.K)
|
||||
|
||||
#For det, |I + KW| == |I + W_12*K*W_12|
|
||||
self.ln_I_KW_det = pddet(np.eye(self.N) + self.W_12*self.K*self.W_12.T)
|
||||
#self.ln_I_KW_det = pddet(np.eye(self.N) + self.W_12*self.K*self.W_12.T)
|
||||
|
||||
#self.ln_I_KW_det = pddet(np.eye(self.N) + np.dot(self.K, self.W))
|
||||
#self.ln_z_hat = (- 0.5*self.f_Ki_f
|
||||
|
|
@ -299,7 +319,7 @@ class Laplace(likelihood):
|
|||
|
||||
def rasm_mode(self, K, MAX_ITER=100, MAX_RESTART=10):
|
||||
"""
|
||||
Rasmussens numerically stable mode finding
|
||||
Rasmussen's numerically stable mode finding
|
||||
For nomenclature see Rasmussen & Williams 2006
|
||||
|
||||
:K: Covariance matrix
|
||||
|
|
@ -308,7 +328,7 @@ class Laplace(likelihood):
|
|||
:returns: f_mode
|
||||
"""
|
||||
self.old_before_s = self.likelihood_function._get_params()
|
||||
print "before: ", self.old_before_s
|
||||
#print "before: ", self.old_before_s
|
||||
#if self.old_before_s < 1e-5:
|
||||
#import ipdb; ipdb.set_trace() ### XXX BREAKPOINT
|
||||
|
||||
|
|
@ -351,42 +371,42 @@ class Laplace(likelihood):
|
|||
full_step_a = b - W_12*solve_L
|
||||
da = full_step_a - old_a
|
||||
|
||||
#f_old = f.copy()
|
||||
#def inner_obj(step_size, old_a, da, K):
|
||||
#a = old_a + step_size*da
|
||||
#f = np.dot(K, a)
|
||||
#self.a = a.copy() # This is nasty, need to set something within an optimization though
|
||||
#self.f = f.copy()
|
||||
#return -obj(a, f)
|
||||
|
||||
#from functools import partial
|
||||
#i_o = partial(inner_obj, old_a=old_a, da=da, K=K)
|
||||
##new_obj = sp.optimize.brent(i_o, tol=1e-4, maxiter=20)
|
||||
#new_obj = sp.optimize.minimize_scalar(i_o, method='brent', tol=1e-4, options={'maxiter':20, 'disp':True}).fun
|
||||
#f = self.f.copy()
|
||||
#import ipdb; ipdb.set_trace() ### XXX BREAKPOINT
|
||||
|
||||
f_old = f.copy()
|
||||
update_passed = False
|
||||
while not update_passed:
|
||||
def inner_obj(step_size, old_a, da, K):
|
||||
a = old_a + step_size*da
|
||||
f = np.dot(K, a)
|
||||
self.a = a.copy() # This is nasty, need to set something within an optimization though
|
||||
self.f = f.copy()
|
||||
return -obj(a, f)
|
||||
|
||||
old_obj = new_obj
|
||||
new_obj = obj(a, f)
|
||||
difference = new_obj - old_obj
|
||||
print "difference: ",difference
|
||||
if difference < 0:
|
||||
#print "Objective function rose", np.float(difference)
|
||||
#If the objective function isn't rising, restart optimization
|
||||
step_size *= 0.8
|
||||
#print "Reducing step-size to {ss:.3} and restarting optimization".format(ss=step_size)
|
||||
#objective function isn't increasing, try reducing step size
|
||||
f = f_old.copy() #it's actually faster not to go back to old location and just zigzag across the mode
|
||||
old_obj = new_obj
|
||||
rs += 1
|
||||
else:
|
||||
update_passed = True
|
||||
i_o = partial(inner_obj, old_a=old_a, da=da, K=K)
|
||||
#new_obj = sp.optimize.brent(i_o, tol=1e-4, maxiter=20)
|
||||
new_obj = sp.optimize.minimize_scalar(i_o, method='brent', tol=1e-4, options={'maxiter':20, 'disp':True}).fun
|
||||
f = self.f.copy()
|
||||
a = self.a.copy()
|
||||
#import ipdb; ipdb.set_trace() ### XXX BREAKPOINT
|
||||
|
||||
#f_old = f.copy()
|
||||
#update_passed = False
|
||||
#while not update_passed:
|
||||
#a = old_a + step_size*da
|
||||
#f = np.dot(K, a)
|
||||
|
||||
#old_obj = new_obj
|
||||
#new_obj = obj(a, f)
|
||||
#difference = new_obj - old_obj
|
||||
##print "difference: ",difference
|
||||
#if difference < 0:
|
||||
##print "Objective function rose", np.float(difference)
|
||||
##If the objective function isn't rising, restart optimization
|
||||
#step_size *= 0.8
|
||||
##print "Reducing step-size to {ss:.3} and restarting optimization".format(ss=step_size)
|
||||
##objective function isn't increasing, try reducing step size
|
||||
#f = f_old.copy() #it's actually faster not to go back to old location and just zigzag across the mode
|
||||
#old_obj = new_obj
|
||||
#rs += 1
|
||||
#else:
|
||||
#update_passed = True
|
||||
|
||||
#difference = abs(new_obj - old_obj)
|
||||
#old_obj = new_obj.copy()
|
||||
|
|
@ -400,10 +420,11 @@ class Laplace(likelihood):
|
|||
self.old_a = old_a.copy()
|
||||
#print "Positive difference obj: ", np.float(difference)
|
||||
#print "Iterations: {}, Step size reductions: {}, Final_difference: {}, step_size: {}".format(i, rs, difference, step_size)
|
||||
print "Iterations: {}, Final_difference: {}".format(i, difference)
|
||||
#print "Iterations: {}, Final_difference: {}".format(i, difference)
|
||||
if difference > 1e-4:
|
||||
print "FAIL FAIL FAIL FAIL FAIL FAIL"
|
||||
if False:
|
||||
#if True:
|
||||
#print "Not perfect f_hat fit difference: {}".format(difference)
|
||||
if True:
|
||||
import ipdb; ipdb.set_trace() ### XXX BREAKPOINT
|
||||
if hasattr(self, 'X'):
|
||||
import pylab as pb
|
||||
|
|
@ -449,7 +470,7 @@ class Laplace(likelihood):
|
|||
self.old_ff = f.copy()
|
||||
self.old_K = self.K.copy()
|
||||
self.old_s = self.likelihood_function._get_params()
|
||||
print "after: ", self.old_s
|
||||
#print "after: ", self.old_s
|
||||
#print "FINAL a max: {} a min: {} a var: {}".format(np.max(self.a), np.min(self.a), np.var(self.a))
|
||||
self.a = a
|
||||
#self.B, self.B_chol, self.W_12 = B, L, W_12
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue