Added initial version of the refactored EP

This commit is contained in:
esiivola 2017-06-01 02:19:58 +03:00 committed by Akash Kumar Dhaka
parent e849a4c62d
commit dfc5bd42dc

View file

@ -1,11 +1,13 @@
# Copyright (c) 2012-2014, GPy authors (see AUTHORS.txt). # Copyright (c) 2012-2014, GPy authors (see AUTHORS.txt).
# Licensed under the BSD 3-clause license (see LICENSE.txt) # Licensed under the BSD 3-clause license (see LICENSE.txt)
import numpy as np import numpy as np
import itertools
from ...util.linalg import jitchol, DSYR, dtrtrs, dtrtri, pdinv, dpotrs, tdot, symmetrify from ...util.linalg import jitchol, DSYR, dtrtrs, dtrtri, pdinv, dpotrs, tdot, symmetrify
from paramz import ObsAr from paramz import ObsAr
from . import ExactGaussianInference, VarDTC from . import ExactGaussianInference, VarDTC
from ...util import diag from ...util import diag
from .posterior import PosteriorEP as Posterior from .posterior import PosteriorEP as Posterior
from .posterior import MultioutputPosteriorEP as MultioutputPosterior
log_2_pi = np.log(2*np.pi) log_2_pi = np.log(2*np.pi)
@ -36,6 +38,7 @@ class EPBase(object):
def reset(self): def reset(self):
self.old_mutilde, self.old_vtilde = None, None self.old_mutilde, self.old_vtilde = None, None
self.ga_approx_old = None
self._ep_approximation = None self._ep_approximation = None
def on_optimization_start(self): def on_optimization_start(self):
@ -90,41 +93,49 @@ class EP(EPBase, ExactGaussianInference):
# than ObsArrays # than ObsArrays
Y = Y.values.copy() Y = Y.values.copy()
#Initial values - Marginal moments #Initial values - Marginal moments, cavity params, gaussian approximation params and posterior params
Z_hat = np.empty(num_data,dtype=np.float64) marg_moments = marginalMoments(num_data)
mu_hat = np.empty(num_data,dtype=np.float64) cav_params = cavityParams(num_data)
sigma2_hat = np.empty(num_data,dtype=np.float64) ga_approx, post_params = self._init_approximations(K, num_data)
tau_cav = np.empty(num_data,dtype=np.float64)
v_cav = np.empty(num_data,dtype=np.float64)
#initial values - Gaussian factors
#Initial values - Posterior distribution parameters: q(f|X,Y) = N(f|mu,Sigma)
if self.old_mutilde is None:
tau_tilde, mu_tilde, v_tilde = np.zeros((3, num_data))
Sigma = K.copy()
diag.add(Sigma, 1e-7)
mu = np.zeros(num_data)
else:
assert self.old_mutilde.size == num_data, "data size mis-match: did you change the data? try resetting!"
mu_tilde, v_tilde = self.old_mutilde, self.old_vtilde
tau_tilde = v_tilde/mu_tilde
mu, Sigma, _ = self._ep_compute_posterior(K, tau_tilde, v_tilde)
diag.add(Sigma, 1e-7)
# TODO: Check the log-marginal under both conditions and choose the best one
#Approximation #Approximation
tau_diff = self.epsilon + 1. tau_diff = self.epsilon + 1.
v_diff = self.epsilon + 1. v_diff = self.epsilon + 1.
tau_tilde_old = np.nan
v_tilde_old = np.nan
iterations = 0 iterations = 0
while ((tau_diff > self.epsilon) or (v_diff > self.epsilon)) and (iterations < self.max_iters): while ((tau_diff > self.epsilon) or (v_diff > self.epsilon)) and (iterations < self.max_iters):
update_order = np.random.permutation(num_data) self._update_cavity_params(num_data, cav_params, post_params, marg_moments, ga_approx, likelihood, Y, Y_metadata)
#(re) compute Sigma and mu using full Cholesky decompy
post_params = self._ep_compute_posterior(K, ga_approx.tau, ga_approx.v)
#monitor convergence
if iterations > 0:
tau_diff = np.mean(np.square(ga_approx.tau-self.ga_approx_old.tau))
v_diff = np.mean(np.square(ga_approx.v-self.ga_approx_old.v))
self.ga_approx_old = gaussianApproximation(ga_approx.mu.copy(), ga_approx.v.copy(), ga_approx.tau.copy())
iterations += 1
ga_approx.mu = ga_approx.v/ga_approx.tau
# Z_tilde after removing the terms that can lead to infinite terms due to tau_tilde close to zero.
# This terms cancel with the coreresponding terms in the marginal loglikelihood
log_Z_tilde = self._log_Z_tilde(marg_moments, ga_approx, cav_params)
# - 0.5*np.log(tau_tilde) + 0.5*(v_tilde*v_tilde*1./tau_tilde)
return post_params.mu, post_params.Sigma, ga_approx.mu, ga_approx.tau, log_Z_tilde
def _log_Z_tilde(self, marg_moments, ga_approx, cav_params):
return (np.log(marg_moments.Z_hat) + 0.5*np.log(2*np.pi) + 0.5*np.log(1+ga_approx.tau/cav_params.tau) - 0.5 * ((ga_approx.v)**2 * 1./(cav_params.tau + ga_approx.tau))
+ 0.5*(cav_params.v * ( ( (ga_approx.tau/cav_params.tau) * cav_params.v - 2.0 * ga_approx.v ) * 1./(cav_params.tau + ga_approx.tau))))
def _update_cavity_params(self, num_data, cav_params, post_params, marg_moments, ga_approx, likelihood, Y, Y_metadata, update_order=None):
if update_order is None:
update_order = np.random.permutation(num_data)
for i in update_order: for i in update_order:
#Cavity distribution parameters #Cavity distribution parameters
tau_cav[i] = 1./Sigma[i,i] - self.eta*tau_tilde[i] cav_params.tau[i] = 1./post_params.Sigma[i,i] - self.eta*ga_approx.tau[i]
v_cav[i] = mu[i]/Sigma[i,i] - self.eta*v_tilde[i] cav_params.v[i] = post_params.mu[i]/post_params.Sigma[i,i] - self.eta*ga_approx.v[i]
if Y_metadata is not None: if Y_metadata is not None:
# Pick out the relavent metadata for Yi # Pick out the relavent metadata for Yi
Y_metadata_i = {} Y_metadata_i = {}
@ -133,54 +144,46 @@ class EP(EPBase, ExactGaussianInference):
else: else:
Y_metadata_i = None Y_metadata_i = None
#Marginal moments #Marginal moments
Z_hat[i], mu_hat[i], sigma2_hat[i] = likelihood.moments_match_ep(Y[i], tau_cav[i], v_cav[i], Y_metadata_i=Y_metadata_i) marg_moments.Z_hat[i], marg_moments.mu_hat[i], marg_moments.sigma2_hat[i] = likelihood.moments_match_ep(Y[i], cav_params.tau[i], cav_params.v[i], Y_metadata_i=Y_metadata_i)
#Site parameters update #Site parameters update
delta_tau = self.delta/self.eta*(1./sigma2_hat[i] - 1./Sigma[i,i]) delta_tau = self.delta/self.eta*(1./marg_moments.sigma2_hat[i] - 1./post_params.Sigma[i,i])
delta_v = self.delta/self.eta*(mu_hat[i]/sigma2_hat[i] - mu[i]/Sigma[i,i]) delta_v = self.delta/self.eta*(marg_moments.mu_hat[i]/marg_moments.sigma2_hat[i] - post_params.mu[i]/post_params.Sigma[i,i])
tau_tilde_prev = tau_tilde[i] tau_tilde_prev = ga_approx.tau[i]
tau_tilde[i] += delta_tau ga_approx.tau[i] += delta_tau
# Enforce positivity of tau_tilde. Even though this is guaranteed for logconcave sites, it is still possible # Enforce positivity of tau_tilde. Even though this is guaranteed for logconcave sites, it is still possible
# to get negative values due to numerical errors. Moreover, the value of tau_tilde should be positive in order to # to get negative values due to numerical errors. Moreover, the value of tau_tilde should be positive in order to
# update the marginal likelihood without inestability issues. # update the marginal likelihood without inestability issues.
if tau_tilde[i] < np.finfo(float).eps: if ga_approx.tau[i] < np.finfo(float).eps:
tau_tilde[i] = np.finfo(float).eps ga_approx.tau[i] = np.finfo(float).eps
delta_tau = tau_tilde[i] - tau_tilde_prev delta_tau = ga_approx.tau[i] - tau_tilde_prev
v_tilde[i] += delta_v ga_approx.v[i] += delta_v
if self.parallel_updates == False: if self.parallel_updates == False:
#Posterior distribution parameters update #Posterior distribution parameters update
ci = delta_tau/(1.+ delta_tau*Sigma[i,i]) ci = delta_tau/(1.+ delta_tau*post_params.Sigma[i,i])
DSYR(Sigma, Sigma[:,i].copy(), -ci) DSYR(post_params.Sigma, post_params.Sigma[:,i].copy(), -ci)
mu = np.dot(Sigma, v_tilde) post_params.mu = np.dot(post_params.Sigma, ga_approx.v)
#(re) compute Sigma and mu using full Cholesky decompy def _init_approximations(self, K, num_data):
mu, Sigma, _ = self._ep_compute_posterior(K, tau_tilde, v_tilde) #initial values - Gaussian factors
#Initial values - Posterior distribution parameters: q(f|X,Y) = N(f|mu,Sigma)
#monitor convergence if self.ga_approx_old is None:
if iterations > 0: mu_tilde, v_tilde, tau_tilde = np.zeros((3, num_data))
tau_diff = np.mean(np.square(tau_tilde-tau_tilde_old)) ga_approx = gaussianApproximation(mu_tilde, v_tilde, tau_tilde)
v_diff = np.mean(np.square(v_tilde-v_tilde_old)) Sigma = K.copy()
tau_tilde_old = tau_tilde.copy() diag.add(Sigma, 1e-7)
v_tilde_old = v_tilde.copy() mu = np.zeros(num_data)
post_params = posteriorParams(mu, Sigma)
iterations += 1 else:
assert self.ga_approx_old.mu.size == num_data, "data size mis-match: did you change the data? try resetting!"
mu_tilde = v_tilde/tau_tilde ga_approx = gaussianApproximation(self.ga_approx_old.mu, self.ga_approx_old.v)
mu_cav = v_cav/tau_cav post_params = self._ep_compute_posterior(K, ga_approx.tau, ga_approx.v)
sigma2_sigma2tilde = 1./tau_cav + 1./tau_tilde diag.add(post_params.Sigma, 1e-7)
# TODO: Check the log-marginal under both conditions and choose the best one
# Z_tilde after removing the terms that can lead to infinite terms due to tau_tilde close to zero. return (ga_approx, post_params)
# This terms cancel with the coreresponding terms in the marginal loglikelihood
log_Z_tilde = (np.log(Z_hat) + 0.5*np.log(2*np.pi) + 0.5*np.log(1+tau_tilde/tau_cav)
- 0.5 * ((v_tilde)**2 * 1./(tau_cav + tau_tilde)) + 0.5*(v_cav * ( ( (tau_tilde/tau_cav) * v_cav - 2.0 * v_tilde ) * 1./(tau_cav + tau_tilde))))
# - 0.5*np.log(tau_tilde) + 0.5*(v_tilde*v_tilde*1./tau_tilde)
self.old_mutilde = mu_tilde
self.old_vtilde = v_tilde
return mu, Sigma, mu_tilde, tau_tilde, log_Z_tilde
def _ep_compute_posterior(self, K, tau_tilde, v_tilde): def _ep_compute_posterior(self, K, tau_tilde, v_tilde):
num_data = len(tau_tilde) num_data = len(tau_tilde)
tau_tilde_root = np.sqrt(tau_tilde) tau_tilde_root = np.sqrt(tau_tilde)
@ -190,18 +193,18 @@ class EP(EPBase, ExactGaussianInference):
V, _ = dtrtrs(L, Sroot_tilde_K, lower=1) V, _ = dtrtrs(L, Sroot_tilde_K, lower=1)
Sigma = K - np.dot(V.T,V) #K - KS^(1/2)BS^(1/2)K = (K^(-1) + \Sigma^(-1))^(-1) Sigma = K - np.dot(V.T,V) #K - KS^(1/2)BS^(1/2)K = (K^(-1) + \Sigma^(-1))^(-1)
mu = np.dot(Sigma,v_tilde) mu = np.dot(Sigma,v_tilde)
return (mu, Sigma, L) return posteriorParams(mu, Sigma, L)
def _ep_marginal(self, K, tau_tilde, v_tilde, Z_tilde): def _ep_marginal(self, K, tau_tilde, v_tilde, Z_tilde):
mu, Sigma, L = self._ep_compute_posterior(K, tau_tilde, v_tilde) post_params = self._ep_compute_posterior(K, tau_tilde, v_tilde)
# Gaussian log marginal excluding terms that can go to infinity due to arbitrarily small tau_tilde. # Gaussian log marginal excluding terms that can go to infinity due to arbitrarily small tau_tilde.
# These terms cancel out with the terms excluded from Z_tilde # These terms cancel out with the terms excluded from Z_tilde
B_logdet = np.sum(2.0*np.log(np.diag(L))) B_logdet = np.sum(2.0*np.log(np.diag(post_params.L)))
log_marginal = 0.5*(-len(tau_tilde) * log_2_pi - B_logdet + np.sum(v_tilde * np.dot(Sigma,v_tilde))) log_marginal = 0.5*(-len(tau_tilde) * log_2_pi - B_logdet + np.sum(v_tilde * np.dot(post_params.Sigma,v_tilde)))
log_marginal += Z_tilde log_marginal += Z_tilde
return log_marginal, mu, Sigma, L return log_marginal, post_params.mu, post_params.Sigma, post_params.L
def _inference(self, K, tau_tilde, v_tilde, likelihood, Z_tilde, Y_metadata=None): def _inference(self, K, tau_tilde, v_tilde, likelihood, Z_tilde, Y_metadata=None):
log_marginal, mu, Sigma, L = self._ep_marginal(K, tau_tilde, v_tilde, Z_tilde) log_marginal, mu, Sigma, L = self._ep_marginal(K, tau_tilde, v_tilde, Z_tilde)
@ -277,8 +280,8 @@ class EPDTC(EPBase, VarDTC):
mu_hat = np.zeros(num_data,dtype=np.float64) mu_hat = np.zeros(num_data,dtype=np.float64)
sigma2_hat = np.zeros(num_data,dtype=np.float64) sigma2_hat = np.zeros(num_data,dtype=np.float64)
tau_cav = np.empty(num_data,dtype=np.float64) tau = np.empty(num_data,dtype=np.float64)
v_cav = np.empty(num_data,dtype=np.float64) v = np.empty(num_data,dtype=np.float64)
#initial values - Gaussian factors #initial values - Gaussian factors
#Initial values - Posterior distribution parameters: q(f|X,Y) = N(f|mu,Sigma) #Initial values - Posterior distribution parameters: q(f|X,Y) = N(f|mu,Sigma)
@ -315,8 +318,8 @@ class EPDTC(EPBase, VarDTC):
update_order = np.random.permutation(num_data) update_order = np.random.permutation(num_data)
for i in update_order: for i in update_order:
#Cavity distribution parameters #Cavity distribution parameters
tau_cav[i] = 1./Sigma_diag[i] - self.eta*tau_tilde[i] tau[i] = 1./Sigma_diag[i] - self.eta*tau_tilde[i]
v_cav[i] = mu[i]/Sigma_diag[i] - self.eta*v_tilde[i] v[i] = mu[i]/Sigma_diag[i] - self.eta*v_tilde[i]
if Y_metadata is not None: if Y_metadata is not None:
# Pick out the relavent metadata for Yi # Pick out the relavent metadata for Yi
Y_metadata_i = {} Y_metadata_i = {}
@ -326,7 +329,7 @@ class EPDTC(EPBase, VarDTC):
Y_metadata_i = None Y_metadata_i = None
#Marginal moments #Marginal moments
Z_hat[i], mu_hat[i], sigma2_hat[i] = likelihood.moments_match_ep(Y[i], tau_cav[i], v_cav[i], Y_metadata_i=Y_metadata_i) Z_hat[i], mu_hat[i], sigma2_hat[i] = likelihood.moments_match_ep(Y[i], tau[i], v[i], Y_metadata_i=Y_metadata_i)
#Site parameters update #Site parameters update
delta_tau = self.delta/self.eta*(1./sigma2_hat[i] - 1./Sigma_diag[i]) delta_tau = self.delta/self.eta*(1./sigma2_hat[i] - 1./Sigma_diag[i])
delta_v = self.delta/self.eta*(mu_hat[i]/sigma2_hat[i] - mu[i]/Sigma_diag[i]) delta_v = self.delta/self.eta*(mu_hat[i]/sigma2_hat[i] - mu[i]/Sigma_diag[i])
@ -365,8 +368,8 @@ class EPDTC(EPBase, VarDTC):
iterations += 1 iterations += 1
mu_tilde = v_tilde/tau_tilde mu_tilde = v_tilde/tau_tilde
mu_cav = v_cav/tau_cav mu_cav = v/tau
sigma2_sigma2tilde = 1./tau_cav + 1./tau_tilde sigma2_sigma2tilde = 1./tau + 1./tau_tilde
log_Z_tilde = (np.log(Z_hat) + 0.5*np.log(2*np.pi) + 0.5*np.log(sigma2_sigma2tilde) log_Z_tilde = (np.log(Z_hat) + 0.5*np.log(2*np.pi) + 0.5*np.log(sigma2_sigma2tilde)
+ 0.5*((mu_cav - mu_tilde)**2) / (sigma2_sigma2tilde)) + 0.5*((mu_cav - mu_tilde)**2) / (sigma2_sigma2tilde))
@ -388,3 +391,28 @@ class EPDTC(EPBase, VarDTC):
Sigma_diag = np.diag(Sigma).copy() Sigma_diag = np.diag(Sigma).copy()
return (mu, Sigma_diag, LLT) return (mu, Sigma_diag, LLT)
#Four wrapper classes to help modularisation of different EP versions
class marginalMoments(object):
def __init__(self, num_data):
#Initial values - Marginal moments
self.Z_hat = np.empty(num_data,dtype=np.float64)
self.mu_hat = np.empty(num_data,dtype=np.float64)
self.sigma2_hat = np.empty(num_data,dtype=np.float64)
class cavityParams(object):
def __init__(self, num_data):
self.tau = np.empty(num_data,dtype=np.float64)
self.v = np.empty(num_data,dtype=np.float64)
class gaussianApproximation(object):
def __init__(self, mu, v, tau=None):
self.mu = mu
self.v = v
self.tau = mu / v if tau is None else tau
class posteriorParams(object):
def __init__(self, mu=None, Sigma=None, L=None):
self.mu = mu
self.Sigma = Sigma
self.L = L