This commit is contained in:
beiwang 2017-01-16 18:12:53 +00:00
parent b9d0ae06b3
commit bc30ae968b
2 changed files with 144 additions and 80 deletions

View file

@ -33,104 +33,164 @@ class NormalPrior(VariationalPrior):
def update_gradients_KL(self, variational_posterior):
# dL:
# print (1. - (1. / (variational_posterior.variance))) * 0.5
variational_posterior.mean.gradient -= variational_posterior.mean
variational_posterior.variance.gradient -= (1. - (1. / (variational_posterior.variance))) * 0.5
# print variational_posterior.mean
# print variational_posterior.variance.gradient
class GmmNormalPrior(VariationalPrior):
def __init__(self, px_mu, px_var, pi, n_component, variational_pi, name="GMMNormalPrior", **kw):
def __init__(self, px_mu, px_var, pi, wi, n_component, variational_wi, name="GMMNormalPrior", **kw):
super(GmmNormalPrior, self).__init__(name=name, **kw)
self.n_component = n_component
self.px_mu = Param('mu_k', px_mu)
self.px_var = Param('var_k', px_var)
# Make sure they sum to one
variational_pi = variational_pi / np.sum(variational_pi)
pi = pi / np.sum(pi)
self.pi = pi # p(x) mixing coeffients
self.variational_pi = Param('variational_pi', variational_pi) # variational mixing coefficients
# variational_pi = variational_pi / np.sum(variational_pi)
# variational_wi = variational_wi /variational_wi.sum(axis=0)
self.pi = pi
self.wi = wi # p(x) mixing coeffients
self.variational_wi = Param('variational_wi', variational_wi) # variational mixing coefficients
self.check_all_weights()
self.link_parameter(self.px_mu)
self.link_parameter(self.px_var)
self.link_parameter(self.variational_pi)
self.variational_pi.constrain_bounded(0.0, 1.0)
self.link_parameter(self.px_var)
self.link_parameter(self.variational_wi)
# self.variational_wi = self.variational_wi/ self.variational_wi.sum(axis=0)
# self.variational_pi.constrain_bounded(0.0, 1.0)
#self.variational_wi.constrain_positive()
self.stop = 5
# self.stop = 5
def KL_divergence(self, variational_posterior):
# Lagrange multiplier maybe also needed here
# var_mean = np.square(variational_posterior.mean).sum()
# var_S = (variational_posterior.variance - np.log(variational_posterior.variance)).sum()
# return 0.5 * (var_mean + var_S) - 0.5 * variational_posterior.input_dim * variational_posterior.num_data
self.pi = np.exp(self.variational_wi)/np.exp(self.variational_wi).sum(axis = 0)
# self.variational_wi -= self.variational_wi.max(axis=0)
mu = variational_posterior.mean
S = variational_posterior.variance
pi = self.variational_pi
# variational_wets = self.variational_wi
# wets = self.wi
variational_wets = np.exp(self.variational_wi)/ np.exp(self.variational_wi).sum(axis = 0)
wets = np.exp(self.wi)/ np.exp(self.wi).sum(axis = 0)
total_n = variational_posterior.input_dim * variational_posterior.num_data
cov_diag = np.diagonal(np.linalg.inv(self.px_var).T)
mu_minus = self.px_mu[:, np.newaxis, :] - mu[np.newaxis, :, :]
cita = np.zeros(4)
for i in range(self.n_component):
cita[0] += (pi[i] * np.log(self.px_var[i])).sum()
cita[1] += (pi[i] * S / self.px_var[i]).sum()
cita[2] += (pi[i] * np.square(mu - self.px_mu[i]) / self.px_var[i]).sum()
cita[3] += (pi[i] * np.log(self.pi / pi[i])).sum()
return 0.5 * (cita[0] - (np.log(S)).sum() + cita[1]) + 0.5 * (cita[2] - total_n) + cita[3]
term_1 = (variational_wets * np.log(np.linalg.det(self.px_var))[:, np.newaxis]).sum()- np.log(S).sum()
term_2 = (variational_wets * np.log(wets/variational_wets)).sum()
term_3 = (variational_wets[:,:,np.newaxis] * cov_diag[:, np.newaxis, :]*S[np.newaxis, :, :]).sum()
term_4 = np.zeros((mu_minus.shape[0], mu_minus.shape[1], mu_minus.shape[2], mu_minus.shape[2]))
for k in range(mu_minus.shape[0]):
for i in range(mu_minus.shape[1]):
term_4[k,i,:,:] = np.dot(np.linalg.inv(self.px_var)[k, :,:], np.dot(mu_minus[k,i,:][:,None], mu_minus[k,i,:][None,:]))
return 0.5 *(term_1-total_n + term_3 + (np.trace((variational_wets[:,:,None,None] *term_4).T)).sum())- term_2
def update_gradients_KL(self, variational_posterior):
import pdb; pdb.set_trace() # breakpoint 1
print("Updating Gradients")
if self.stop<1:
return
self.stop-=1
#dL:
#variational_posterior.mean.gradient -= variational_posterior.mean
#variational_posterior.variance.gradient -= (1. - (1. / (variational_posterior.variance))) * 0.5
# import pdb; pdb.set_trace() # breakpoint 1
# print("Updating Gradients")
# print (self.variational_wi)
# if self.stop<1:
# return
# self.stop-=1
# dL:
# variational_posterior.mean.gradient -= variational_posterior.mean
# variational_posterior.variance.gradient -= (1. - (1. / (variational_posterior.variance))) * 0.5
self.px_mu.gradient = 0
self.px_var.gradient = 0
self.variational_wi.gradient = 0
# print self.variational_wi
#self.variational_wi -= self.variational_wi.max(axis = 0)[None,:]
# self.variational_wi = self.variational_wi/(self.variational_wi).sum(axis=0)
mu = variational_posterior.mean
S = variational_posterior.variance
pi = self.variational_pi
cita_0 = np.zeros_like(mu)
cita_1 = np.zeros_like(mu)
cita_2 = np.zeros_like(mu)
cita_3 = np.zeros_like(pi)
for i in range(self.n_component):
# variational_wets = self.variational_wi
# wets = self.wi
print("About to change the gradient")
print pi.values[i]
print mu
print self.px_mu.values[i]
print self.px_var.values[i]
# wi_max = self.variational_wi - self.variational_wi.max(axis = 0)#
# variational_wets = np.exp(wi_max)/ np.exp(wi_max).sum(axis = 0)
variational_wets = np.exp(self.variational_wi)/ np.exp(self.variational_wi).sum(axis = 0)
wets = np.exp(self.wi)/ np.exp(self.wi).sum(axis = 0)
cita_0 += pi.values[i] * (mu - self.px_mu.values[i]) / self.px_var.values[i]
print "Has this helped?"
self.px_mu[i].gradient += pi[i] * (mu - self.px_mu[i]) / self.px_var[i]
cita_1 += (pi[i] / self.px_var[i])
cita_2 += pi[i] * (S + np.square(mu - self.px_mu[i])) / np.square(self.px_var[i])
self.px_var[i].gradient += (pi[i] * (S + np.square(mu - self.px_mu[i])) / np.square(self.px_var[i]) - (pi[i] / self.px_var[i])) * 0.5
cita_3[i] = (np.log(self.px_var[i]).sum()
+ (S / self.px_var[i]).sum()
+ (np.square(mu - self.px_mu[i]) / self.px_var[i]).sum() )* (-0.5) + np.log(self.pi[i] / pi[i]) - pi[i] * np.log(self.pi[i] / np.square(pi[i]))
self.variational_pi[i].gradient += cita_3[i]
mu_minus = self.px_mu[:, np.newaxis, :] - mu[np.newaxis, :, :]
sigma_mu = np.zeros((mu_minus.shape))
sigma2_S = np.zeros((mu_minus.shape[0],mu_minus.shape[1],mu_minus.shape[2],mu_minus.shape[2]))
sigma_S = np.zeros((sigma2_S.shape))
sigma_S_sigma = np.zeros((sigma2_S.shape))
mu_sigma_mu = np.zeros((mu_minus.shape[0],mu_minus.shape[1]))
sigma_diag = np.diagonal(np.linalg.inv(self.px_var).T)
sigma_inv1 = np.linalg.inv(self.px_var)
for k in range(mu_minus.shape[0]):
for i in range(mu_minus.shape[1]):
sigma_mu[k,i,:] = np.dot(np.linalg.inv(self.px_var)[k,:,:], mu_minus[k,i,:])
sigma_S[k,i,:,:] = np.dot(np.linalg.inv(self.px_var)[k,:,:], np.diag(S[i,:]))
sigma2_S[k,i,:,:] = np.dot(np.diag(S[i,:]), np.matrix(sigma_inv1[k,:,:])**2)
sigma_S_sigma[k,i,:,:] = np.dot(sigma_mu[k,i,:][:,None],sigma_mu[k,i,:][None,:] )
mu_sigma_mu[k,i] = np.dot(mu_minus[k,i,:][None,:], sigma_mu[k,i,:][:,None])
variational_posterior.mean.gradient += (variational_wets[:,:,np.newaxis] * sigma_mu).sum(axis = 0)
variational_posterior.variance.gradient += 0.5 * (1. /S - (variational_wets[:, :, np.newaxis] * sigma_diag[:,np.newaxis,:]).sum(axis=0))
self.px_mu.gradient -= (variational_wets[:,:,np.newaxis] * sigma_mu).sum(axis=1)
self.px_var.gradient -= 0.5 * (variational_wets[:,:,np.newaxis, np.newaxis] * ((np.linalg.inv(self.px_var))[:,np.newaxis, :,:] - sigma2_S
- sigma_S_sigma) ).sum(axis=1)
dL_dw = np.zeros((self.variational_wi.shape))
ew = np.exp(self.variational_wi)
# ew = np.exp(wi_max)
sumew = ew.sum(axis=0)
dL_dq = ((0.5*(np.log(np.linalg.det(self.px_var))[:,np.newaxis] + (sigma_S).sum(axis=2).sum(axis=2) + mu_sigma_mu) - (np.log(wets/variational_wets) - 1)))
dq_dwi = ((sumew - ew) * ew ) / (sumew**2)
for i in range(mu_minus.shape[1]):
dq_dw = np.diag(dq_dwi[:,i])
for j in range(mu_minus.shape[0]):
for k in range(mu_minus.shape[0]):
if j != k:
dq_dw[j, k] = -ew[j, i] * ew[k,i] / (sumew[i]**2)
dL_dw[:,i] = np.dot(dq_dw, dL_dq[:,i])
self.variational_wi.gradient -= dL_dw
# print dL_dw
# for k in range(mu_minus.shape[1]):
# dq_dw_ij = np.zeros((mu_minus.shape[0],mu_minus.shape[0]))
# # print k
# for i in range(mu_minus.shape[0]):
# for j in range(mu_minus.shape[0]):
# if i == j:
# dq_dw_ij[i,j] = ew[i,j]/sumew[k] - (ew[i,j]/sumew[k])**2
# else :
# dq_dw_ij[i,j] = - ew[i,k] * ew[j,k] / sumew[k]**2
# # print dq_dw_ij
# dL_dw[:, k] = np.dot(dq_dw_ij, dL_dq[:,k])
# print (dL_dw)
# self.variational_wi.gradient -= dq_dw
# self.variational_wi.gradient -= (0.5*(np.log(np.linalg.det(self.px_var))[:,np.newaxis] + (sigma_S).sum(axis=2).sum(axis=2) + mu_sigma_mu) - (np.log)(wets/variational_wets) - 1)*(
# (np.exp(self.variational_wi).sum(axis = 0) - np.exp(self.variational_wi)) * np.exp(self.variational_wi)/ (self.variational_wi.sum(axis=0))**2)
# print (np.exp(self.variational_wi).sum(axis = 0) - np.exp(self.variational_wi)) * np.exp(self.variational_wi)/ (np.exp(self.variational_wi).sum(axis=0))**2
# print self.variational_wi.gradient
variational_posterior.mean.gradient -= cita_0
variational_posterior.variance.gradient += (1. / (S) - cita_1) * 0.5
def check_weights(self, weights):
assert weights.min() >= 0.0
assert weights.max() <= 1.0
assert weights.sum() == 1.0
assert weights.min() >= -64.0
assert weights.max() <= 64.0
# assert weights.sum() == 1.0
def check_all_weights(self):
self.check_weights(self.variational_pi)
self.check_weights(self.pi)
self.check_weights(self.variational_wi)
# self.check_weights(self.pi)
class SpikeAndSlabPrior(VariationalPrior):

View file

@ -53,25 +53,29 @@ class GmmBayesianGPLVM(SparseGP_MPI):
likelihood = Gaussian()
# Need to define what the model is initialised like
pi = np.ones(n_component) / float(n_component) # p(k)
variational_pi = pi.copy()
# px_mu = np.zeros(n_component)
# px_var = np.ones(n_component)
px_mu = [[]] * n_component
px_var = [[]] * n_component
for i in range(n_component):
px_mu[i] = np.zeros_like(X_variance)
px_var[i] = np.ones_like(X_variance)
# pi = np.ones(n_component) / float(n_component) # p(k)
# pi = (np.array(range(3),dtype = float)+1) / (np.array(range(3),dtype = float)+1).sum()
# wi = (np.array(range(3),dtype = float)+1)
wi = np.ones((n_component, X_variance.shape[0]))
# wi = (np.ones((X_variance.shape[0], n_component)) * (range(1, n_component+1))).T
variational_wi = wi.copy()
pi = np.exp(wi)/np.exp(wi).sum(axis = 0)
# wi = wi / wi.sum(axis=0)
# wi = np.zeros((n_component, X_variance.shape[0]))
# pi = np.log(1 + np.exp(wi)) / np.log(1 + np.exp(wi)).sum(axis = 0)
# px_mu = np.zeros((n_component, X_variance.shape[0], X_variance.shape[1]))
# px_var = np.ones((n_component, X_variance.shape[0], X_variance.shape[1]))
# print("Should print")
# print(pi)
# print(px_mu)
# print(px_var)
# print(variational_pi)
# print("Didnt print")
self.variational_prior = GmmNormalPrior(px_mu=px_mu, px_var=px_var, pi=pi,
n_component=n_component, variational_pi=variational_pi)
px_mu = (np.ones((X_variance.shape[1], n_component )) * (range(n_component))).T + np.random.randn(n_component, X_variance.shape[1]) # initialization can be changed
# print px_mu
# px_mu = np.zeros(( n_component, X_variance.shape[1]))
px_var = np.zeros(( n_component, X_variance.shape[1], X_variance.shape[1] ))+ np.eye(X_variance.shape[1])[np.newaxis, :,:]
self.variational_prior = GmmNormalPrior(px_mu=px_mu, px_var=px_var, pi = pi, wi=wi,
n_component=n_component, variational_wi=variational_wi)
X = NormalPosterior(X, X_variance)