new update

This commit is contained in:
beiwang 2017-03-02 20:27:21 +00:00
parent bc30ae968b
commit 408e83c967
2 changed files with 44 additions and 19 deletions

View file

@ -40,12 +40,12 @@ class NormalPrior(VariationalPrior):
# print variational_posterior.variance.gradient
class GmmNormalPrior(VariationalPrior):
def __init__(self, px_mu, px_var, pi, wi, n_component, variational_wi, name="GMMNormalPrior", **kw):
def __init__(self, px_mu, px_lmatrix, pi, wi, n_component, variational_wi, name="GMMNormalPrior", **kw):
super(GmmNormalPrior, self).__init__(name=name, **kw)
self.n_component = n_component
self.px_mu = Param('mu_k', px_mu)
self.px_var = Param('var_k', px_var)
self.px_lmatrix = Param('lmatrix_k', px_lmatrix)
# Make sure they sum to one
# variational_pi = variational_pi / np.sum(variational_pi)
@ -56,7 +56,7 @@ class GmmNormalPrior(VariationalPrior):
self.check_all_weights()
self.link_parameter(self.px_mu)
self.link_parameter(self.px_var)
self.link_parameter(self.px_lmatrix)
self.link_parameter(self.variational_wi)
# self.variational_wi = self.variational_wi/ self.variational_wi.sum(axis=0)
# self.variational_pi.constrain_bounded(0.0, 1.0)
@ -72,6 +72,13 @@ class GmmNormalPrior(VariationalPrior):
mu = variational_posterior.mean
S = variational_posterior.variance
cov_inv = np.zeros((self.px_lmatrix.shape))
cov_k = np.zeros((self.px_lmatrix.shape))
######################################################
for k in range(self.px_lmatrix.shape[0]):
cov_inv[k,:,:] = np.linalg.inv(self.px_lmatrix[k,:,:]).T.dot(np.linalg.inv(self.px_lmatrix[k,:,:]))
cov_k[k,:,:] = np.dot(self.px_lmatrix[k,:,:], self.px_lmatrix[k,:,:].T)
#######################################################
# variational_wets = self.variational_wi
# wets = self.wi
@ -79,18 +86,19 @@ class GmmNormalPrior(VariationalPrior):
wets = np.exp(self.wi)/ np.exp(self.wi).sum(axis = 0)
total_n = variational_posterior.input_dim * variational_posterior.num_data
cov_diag = np.diagonal(np.linalg.inv(self.px_var).T)
# cov_diag = np.diagonal(np.linalg.inv(self.px_var).T)
cov_diag = np.diagonal(cov_inv.T)
mu_minus = self.px_mu[:, np.newaxis, :] - mu[np.newaxis, :, :]
term_1 = (variational_wets * np.log(np.linalg.det(self.px_var))[:, np.newaxis]).sum()- np.log(S).sum()
term_1 = (variational_wets * np.log(np.linalg.det(cov_k))[:, np.newaxis]).sum()- np.log(S).sum()
term_2 = (variational_wets * np.log(wets/variational_wets)).sum()
term_3 = (variational_wets[:,:,np.newaxis] * cov_diag[:, np.newaxis, :]*S[np.newaxis, :, :]).sum()
term_4 = np.zeros((mu_minus.shape[0], mu_minus.shape[1], mu_minus.shape[2], mu_minus.shape[2]))
term_4 = np.zeros((mu_minus.shape[0], mu_minus.shape[1]))
for k in range(mu_minus.shape[0]):
for i in range(mu_minus.shape[1]):
term_4[k,i,:,:] = np.dot(np.linalg.inv(self.px_var)[k, :,:], np.dot(mu_minus[k,i,:][:,None], mu_minus[k,i,:][None,:]))
term_4[k,i] = variational_wets[k, i]*np.trace(np.dot(cov_inv[k, :,:], np.dot(mu_minus[k,i,:][:,None], mu_minus[k,i,:][None,:])).T)
return 0.5 *(term_1-total_n + term_3 + (np.trace((variational_wets[:,:,None,None] *term_4).T)).sum())- term_2
return 0.5 *(term_1-total_n + term_3 + term_4.sum())- term_2
def update_gradients_KL(self, variational_posterior):
# import pdb; pdb.set_trace() # breakpoint 1
@ -103,7 +111,7 @@ class GmmNormalPrior(VariationalPrior):
# variational_posterior.mean.gradient -= variational_posterior.mean
# variational_posterior.variance.gradient -= (1. - (1. / (variational_posterior.variance))) * 0.5
self.px_mu.gradient = 0
self.px_var.gradient = 0
self.px_lmatrix.gradient = 0
self.variational_wi.gradient = 0
# print self.variational_wi
#self.variational_wi -= self.variational_wi.max(axis = 0)[None,:]
@ -113,6 +121,14 @@ class GmmNormalPrior(VariationalPrior):
mu = variational_posterior.mean
S = variational_posterior.variance
cov_inv = np.zeros((self.px_lmatrix.shape))
cov_k = np.zeros((self.px_lmatrix.shape))
######################################################
for k in range(self.px_lmatrix.shape[0]):
cov_inv[k,:,:] = np.linalg.inv(self.px_lmatrix[k,:,:]).T.dot(np.linalg.inv(self.px_lmatrix[k,:,:]))
cov_k[k,:,:] = np.dot(self.px_lmatrix[k,:,:], self.px_lmatrix[k,:,:].T)
#######################################################
# variational_wets = self.variational_wi
# wets = self.wi
@ -127,15 +143,15 @@ class GmmNormalPrior(VariationalPrior):
sigma_S = np.zeros((sigma2_S.shape))
sigma_S_sigma = np.zeros((sigma2_S.shape))
mu_sigma_mu = np.zeros((mu_minus.shape[0],mu_minus.shape[1]))
sigma_diag = np.diagonal(np.linalg.inv(self.px_var).T)
sigma_inv1 = np.linalg.inv(self.px_var)
sigma_diag = np.diagonal(cov_inv.T)
# sigma_inv1 = np.linalg.inv(self.px_var) #equal to cov_inv
for k in range(mu_minus.shape[0]):
for i in range(mu_minus.shape[1]):
sigma_mu[k,i,:] = np.dot(np.linalg.inv(self.px_var)[k,:,:], mu_minus[k,i,:])
sigma_S[k,i,:,:] = np.dot(np.linalg.inv(self.px_var)[k,:,:], np.diag(S[i,:]))
sigma2_S[k,i,:,:] = np.dot(np.diag(S[i,:]), np.matrix(sigma_inv1[k,:,:])**2)
sigma_mu[k,i,:] = np.dot(cov_inv[k,:,:], mu_minus[k,i,:])
sigma_S[k,i,:,:] = np.dot(cov_inv[k,:,:], np.diag(S[i,:]))
sigma2_S[k,i,:,:] = np.dot(np.diag(S[i,:]), np.matrix(cov_inv[k,:,:])**2)
sigma_S_sigma[k,i,:,:] = np.dot(sigma_mu[k,i,:][:,None],sigma_mu[k,i,:][None,:] )
mu_sigma_mu[k,i] = np.dot(mu_minus[k,i,:][None,:], sigma_mu[k,i,:][:,None])
@ -143,14 +159,23 @@ class GmmNormalPrior(VariationalPrior):
variational_posterior.mean.gradient += (variational_wets[:,:,np.newaxis] * sigma_mu).sum(axis = 0)
variational_posterior.variance.gradient += 0.5 * (1. /S - (variational_wets[:, :, np.newaxis] * sigma_diag[:,np.newaxis,:]).sum(axis=0))
self.px_mu.gradient -= (variational_wets[:,:,np.newaxis] * sigma_mu).sum(axis=1)
self.px_var.gradient -= 0.5 * (variational_wets[:,:,np.newaxis, np.newaxis] * ((np.linalg.inv(self.px_var))[:,np.newaxis, :,:] - sigma2_S
# self.px_var.gradient -= 0.5 * (variational_wets[:,:,np.newaxis, np.newaxis] * ((np.linalg.inv(self.px_var))[:,np.newaxis, :,:] - sigma2_S
# - sigma_S_sigma) ).sum(axis=1)
dL_dcov = 0.5 * (variational_wets[:,:,np.newaxis, np.newaxis] * (cov_inv[:,np.newaxis, :,:] - sigma2_S
- sigma_S_sigma) ).sum(axis=1)
dL_dlmatrix = np.zeros((dL_dcov.shape))
for k in range(mu_minus.shape[0]):
dL_dlmatrix[k,:,:] = 2 * np.dot(dL_dcov[k,:,:], self.px_lmatrix[k,:,:])
self.px_lmatrix.gradient -= dL_dlmatrix
# print self.px_lmatrix
# print 'test'
# print dL_dlmatrix
dL_dw = np.zeros((self.variational_wi.shape))
ew = np.exp(self.variational_wi)
# ew = np.exp(wi_max)
sumew = ew.sum(axis=0)
dL_dq = ((0.5*(np.log(np.linalg.det(self.px_var))[:,np.newaxis] + (sigma_S).sum(axis=2).sum(axis=2) + mu_sigma_mu) - (np.log(wets/variational_wets) - 1)))
dL_dq = ((0.5*(np.log(np.linalg.det(cov_k))[:,np.newaxis] + (sigma_S).sum(axis=2).sum(axis=2) + mu_sigma_mu) - (np.log(wets/variational_wets) - 1)))
dq_dwi = ((sumew - ew) * ew ) / (sumew**2)
for i in range(mu_minus.shape[1]):
dq_dw = np.diag(dq_dwi[:,i])

View file

@ -72,9 +72,9 @@ class GmmBayesianGPLVM(SparseGP_MPI):
px_mu = (np.ones((X_variance.shape[1], n_component )) * (range(n_component))).T + np.random.randn(n_component, X_variance.shape[1]) # initialization can be changed
# print px_mu
# px_mu = np.zeros(( n_component, X_variance.shape[1]))
px_var = np.zeros(( n_component, X_variance.shape[1], X_variance.shape[1] ))+ np.eye(X_variance.shape[1])[np.newaxis, :,:]
px_lmatrix = np.zeros(( n_component, X_variance.shape[1], X_variance.shape[1] ))+ np.eye(X_variance.shape[1])[np.newaxis, :,:]
self.variational_prior = GmmNormalPrior(px_mu=px_mu, px_var=px_var, pi = pi, wi=wi,
self.variational_prior = GmmNormalPrior(px_mu=px_mu, px_lmatrix=px_lmatrix, pi = pi, wi=wi,
n_component=n_component, variational_wi=variational_wi)
X = NormalPosterior(X, X_variance)