This commit is contained in:
mu 2014-01-14 13:19:22 +00:00
parent 4fe0583c7d
commit de39a22a4b

View file

@ -47,8 +47,8 @@ class StateSpace_1(Model):
# Default kernel
if kernel is None:
self.kern = kern.Matern32(1,lengthscale=0.3)
self.spacekern = kern.rbf(1,lengthscale=0.3)
self.kern = kern.Matern32(1,lengthscale=0.5)
self.spacekern = kern.rbf(1,lengthscale=0.5)
else:
self.kern = kernel
@ -350,11 +350,16 @@ class StateSpace_1(Model):
if Y.shape[0]==1:
K = PF[:,:,k].dot(H.T)/(H.dot(PF[:,:,k]).dot(H.T) + R)
else:
LL = linalg.cho_factor(H.dot(PF[:,:,k]).dot(H.T) + R*np.eye(Y.shape[0]))
# LL = linalg.cho_factor(H.dot(PF[:,:,k]).dot(H.T) + R*np.eye(Y.shape[0]))
# K = linalg.cho_solve(LL, H.dot(PF[:,:,k].T)).T
S = H.dot(PF[:,:,k]).dot(H.T) + R*np.eye(Y.shape[0])
LL = linalg.cho_factor(S)
K = linalg.cho_solve(LL, H.dot(PF[:,:,k].T)).T
MF[:,k] += K.dot(Y[:,k]-H.dot(MF[:,k]))
PF[:,:,k] -= K.dot(H).dot(PF[:,:,k])
# PF[:,:,k] -= K.dot(H).dot(PF[:,:,k])
PF[:,:,k] -= K.dot(S).dot(K.T)
PF[:,:,k] = 0.5 * (PF[:,:,k] + PF[:,:,k].T)
# LL = linalg.cho_factor(H.dot(PF[:,:,k]).dot(H.T) + R*np.eye(Y.shape[1]))
# K = linalg.cho_solve(LL, H.dot(PF[:,:,k].T)).T
@ -393,7 +398,7 @@ class StateSpace_1(Model):
PS[:,:,-k] += G.dot(PS[:,:,1-k]-A.dot(PS[:,:,-k]).dot(A.T)-Q).dot(G.T)
except linalg.LinAlgError:
"""numerical"""
stop
# Return
return (MS, PS)
@ -437,13 +442,18 @@ class StateSpace_1(Model):
lik -= 0.5*(v*v/S)[0,0] # !!!
else:
v = Y[:,k][None].T-H.dot(m)
LL, isupper = linalg.cho_factor(H.dot(P).dot(H.T) + R*np.eye(Y.shape[1]))
K = linalg.cho_solve((LL, isupper), H.dot(P)).T
S = H.dot(P).dot(H.T) + R*np.eye(Y.shape[0])
LL = linalg.cho_factor(S)
K = linalg.cho_solve(LL, H.dot(P)).T
# LL, isupper = linalg.cho_factor(H.dot(P).dot(H.T) + R*np.eye(Y.shape[1]))
# K = linalg.cho_solve((LL, isupper), H.dot(P)).T
lik -= np.sum(np.log(np.diag(LL)))
lik -= 0.5*v.shape[0]*np.log(2*np.pi)
lik -= 0.5*linalg.cho_solve((LL, isupper),v).T.dot(v)[0,0]
m += K.dot(v)
P -= K.dot(H).dot(P)
# P -= K.dot(H).dot(P)
P -= K.dot(S).dot(K.T)
P = 0.5 * (P + P.T)
#stop
# v = Y[:,k][None].T-H.dot(m)