Changes in EP/EPDTC to fix numerical issues and increase the flexibility of the inference.

Changes to avoid numerical issues and improve the performance:
    - Keep value of the EP parameters between calls
    - Enforce positivity of tau_tilde
    - Stable computation of the EP moments for the Bernoulli likelihood
    - Compute marginal in the GP model without directly inverting tau_tilde

    Changes to improve the flexibility:
    - Add parameter for maximum number of iterations
    - Distinguish between alternated/nested mode
    - Distinguish between sequential/parallel updates in EP
This commit is contained in:
Moreno 2017-02-17 11:35:30 +00:00 committed by Akash Kumar Dhaka
parent 113f84d6a7
commit 63751de912
7 changed files with 522 additions and 117 deletions

View file

@ -1,15 +1,16 @@
# Copyright (c) 2012-2014, GPy authors (see AUTHORS.txt).
# Licensed under the BSD 3-clause license (see LICENSE.txt)
import numpy as np
from ...util.linalg import jitchol, DSYR, dtrtrs, dtrtri
from ...util.linalg import jitchol, DSYR, dtrtrs, dtrtri, pdinv, dpotrs, tdot, symmetrify
from paramz import ObsAr
from . import ExactGaussianInference, VarDTC
from ...util import diag
from .posterior import PosteriorEP as Posterior
log_2_pi = np.log(2*np.pi)
class EPBase(object):
def __init__(self, epsilon=1e-6, eta=1., delta=1., always_reset=False):
def __init__(self, epsilon=1e-6, eta=1., delta=1., always_reset=False, max_iters=np.inf, ep_mode="alternated", parallel_updates=False):
"""
The expectation-propagation algorithm.
For nomenclature see Rasmussen & Williams 2006.
@ -22,11 +23,15 @@ class EPBase(object):
:type delta: float64
:param always_reset: setting to always reset the approximation at the beginning of every inference call.
:type always_reest: boolean
:max_iters: int
:ep_mode: string. It can be "nested" (EP is run every time the Hyperparameters change) or "alternated" (It runs EP at the beginning and then optimize the Hyperparameters).
:parallel_updates: boolean. If true, updates of the parameters of the sites in parallel
"""
super(EPBase, self).__init__()
self.always_reset = always_reset
self.epsilon, self.eta, self.delta = epsilon, eta, delta
self.epsilon, self.eta, self.delta, self.max_iters = epsilon, eta, delta, max_iters
self.ep_mode = ep_mode
self.parallel_updates = parallel_updates
self.reset()
def reset(self):
@ -59,25 +64,28 @@ class EP(EPBase, ExactGaussianInference):
if K is None:
K = kern.K(X)
if getattr(self, '_ep_approximation', None) is None:
#if we don't yet have the results of runnign EP, run EP and store the computed factors in self._ep_approximation
mu, Sigma, mu_tilde, tau_tilde, Z_tilde = self._ep_approximation = self.expectation_propagation(K, Y, likelihood, Y_metadata)
if self.ep_mode=="nested":
#Force EP at each step of the optimization
self._ep_approximation = None
mu, Sigma, mu_tilde, tau_tilde, log_Z_tilde = self._ep_approximation = self.expectation_propagation(K, Y, likelihood, Y_metadata)
elif self.ep_mode=="alternated":
if getattr(self, '_ep_approximation', None) is None:
#if we don't yet have the results of runnign EP, run EP and store the computed factors in self._ep_approximation
mu, Sigma, mu_tilde, tau_tilde, log_Z_tilde = self._ep_approximation = self.expectation_propagation(K, Y, likelihood, Y_metadata)
else:
#if we've already run EP, just use the existing approximation stored in self._ep_approximation
mu, Sigma, mu_tilde, tau_tilde, log_Z_tilde = self._ep_approximation
else:
#if we've already run EP, just use the existing approximation stored in self._ep_approximation
mu, Sigma, mu_tilde, tau_tilde, Z_tilde = self._ep_approximation
raise ValueError("ep_mode value not valid")
return super(EP, self).inference(kern, X, likelihood, mu_tilde[:,None], mean_function=mean_function, Y_metadata=Y_metadata, variance=1./tau_tilde, K=K, Z_tilde=np.log(Z_tilde).sum())
v_tilde = mu_tilde * tau_tilde
return self._inference(K, tau_tilde, v_tilde, likelihood, Y_metadata=Y_metadata, Z_tilde=log_Z_tilde.sum())
def expectation_propagation(self, K, Y, likelihood, Y_metadata):
num_data, data_dim = Y.shape
assert data_dim == 1, "This EP methods only works for 1D outputs"
#Initial values - Posterior distribution parameters: q(f|X,Y) = N(f|mu,Sigma)
mu = np.zeros(num_data)
Sigma = K.copy()
diag.add(Sigma, 1e-7)
# Makes computing the sign quicker if we work with numpy arrays rather
# than ObsArrays
Y = Y.values.copy()
@ -91,12 +99,19 @@ class EP(EPBase, ExactGaussianInference):
v_cav = np.empty(num_data,dtype=np.float64)
#initial values - Gaussian factors
#Initial values - Posterior distribution parameters: q(f|X,Y) = N(f|mu,Sigma)
if self.old_mutilde is None:
tau_tilde, mu_tilde, v_tilde = np.zeros((3, num_data))
Sigma = K.copy()
diag.add(Sigma, 1e-7)
mu = np.zeros(num_data)
else:
assert self.old_mutilde.size == num_data, "data size mis-match: did you change the data? try resetting!"
mu_tilde, v_tilde = self.old_mutilde, self.old_vtilde
tau_tilde = v_tilde/mu_tilde
mu, Sigma, _ = self._ep_compute_posterior(K, tau_tilde, v_tilde)
diag.add(Sigma, 1e-7)
# TODO: Check the log-marginal under both conditions and choose the best one
#Approximation
tau_diff = self.epsilon + 1.
@ -104,7 +119,7 @@ class EP(EPBase, ExactGaussianInference):
tau_tilde_old = np.nan
v_tilde_old = np.nan
iterations = 0
while (tau_diff > self.epsilon) or (v_diff > self.epsilon):
while ((tau_diff > self.epsilon) or (v_diff > self.epsilon)) and (iterations < self.max_iters):
update_order = np.random.permutation(num_data)
for i in update_order:
#Cavity distribution parameters
@ -122,21 +137,25 @@ class EP(EPBase, ExactGaussianInference):
#Site parameters update
delta_tau = self.delta/self.eta*(1./sigma2_hat[i] - 1./Sigma[i,i])
delta_v = self.delta/self.eta*(mu_hat[i]/sigma2_hat[i] - mu[i]/Sigma[i,i])
tau_tilde_prev = tau_tilde[i]
tau_tilde[i] += delta_tau
# Enforce positivity of tau_tilde. Even though this is guaranteed for logconcave sites, it is still possible
# to get negative values due to numerical errors. Moreover, the value of tau_tilde should be positive in order to
# update the marginal likelihood without inestability issues.
if tau_tilde[i] < np.finfo(float).eps:
tau_tilde[i] = np.finfo(float).eps
delta_tau = tau_tilde[i] - tau_tilde_prev
v_tilde[i] += delta_v
#Posterior distribution parameters update
ci = delta_tau/(1.+ delta_tau*Sigma[i,i])
DSYR(Sigma, Sigma[:,i].copy(), -ci)
mu = np.dot(Sigma, v_tilde)
if self.parallel_updates == False:
#Posterior distribution parameters update
ci = delta_tau/(1.+ delta_tau*Sigma[i,i])
DSYR(Sigma, Sigma[:,i].copy(), -ci)
mu = np.dot(Sigma, v_tilde)
#(re) compute Sigma and mu using full Cholesky decompy
tau_tilde_root = np.sqrt(tau_tilde)
Sroot_tilde_K = tau_tilde_root[:,None] * K
B = np.eye(num_data) + Sroot_tilde_K * tau_tilde_root[None,:]
L = jitchol(B)
V, _ = dtrtrs(L, Sroot_tilde_K, lower=1)
Sigma = K - np.dot(V.T,V)
mu = np.dot(Sigma,v_tilde)
mu, Sigma, _ = self._ep_compute_posterior(K, tau_tilde, v_tilde)
#monitor convergence
if iterations > 0:
@ -150,15 +169,70 @@ class EP(EPBase, ExactGaussianInference):
mu_tilde = v_tilde/tau_tilde
mu_cav = v_cav/tau_cav
sigma2_sigma2tilde = 1./tau_cav + 1./tau_tilde
Z_tilde = np.exp(np.log(Z_hat) + 0.5*np.log(2*np.pi) + 0.5*np.log(sigma2_sigma2tilde)
+ 0.5*((mu_cav - mu_tilde)**2) / (sigma2_sigma2tilde))
return mu, Sigma, mu_tilde, tau_tilde, Z_tilde
# Z_tilde after removing the terms that can lead to infinite terms due to tau_tilde close to zero.
# This terms cancel with the coreresponding terms in the marginal loglikelihood
log_Z_tilde = (np.log(Z_hat) + 0.5*np.log(2*np.pi) + 0.5*np.log(1+tau_tilde/tau_cav)
- 0.5 * ((v_tilde)**2 * 1./(tau_cav + tau_tilde)) + 0.5*(v_cav * ( ( (tau_tilde/tau_cav) * v_cav - 2.0 * v_tilde ) * 1./(tau_cav + tau_tilde))))
# - 0.5*np.log(tau_tilde) + 0.5*(v_tilde*v_tilde*1./tau_tilde)
self.old_mutilde = mu_tilde
self.old_vtilde = v_tilde
return mu, Sigma, mu_tilde, tau_tilde, log_Z_tilde
def _ep_compute_posterior(self, K, tau_tilde, v_tilde):
num_data = len(tau_tilde)
tau_tilde_root = np.sqrt(tau_tilde)
Sroot_tilde_K = tau_tilde_root[:,None] * K
B = np.eye(num_data) + Sroot_tilde_K * tau_tilde_root[None,:]
L = jitchol(B)
V, _ = dtrtrs(L, Sroot_tilde_K, lower=1)
Sigma = K - np.dot(V.T,V) #K - KS^(1/2)BS^(1/2)K = (K^(-1) + \Sigma^(-1))^(-1)
mu = np.dot(Sigma,v_tilde)
return (mu, Sigma, L)
def _ep_marginal(self, K, tau_tilde, v_tilde, Z_tilde):
mu, Sigma, L = self._ep_compute_posterior(K, tau_tilde, v_tilde)
# Gaussian log marginal excluding terms that can go to infinity due to arbitrarily small tau_tilde.
# These terms cancel out with the terms excluded from Z_tilde
B_logdet = np.sum(2.0*np.log(np.diag(L)))
log_marginal = 0.5*(-len(tau_tilde) * log_2_pi - B_logdet + np.sum(v_tilde * np.dot(Sigma,v_tilde)))
log_marginal += Z_tilde
return log_marginal, mu, Sigma, L
def _inference(self, K, tau_tilde, v_tilde, likelihood, Z_tilde, Y_metadata=None):
log_marginal, mu, Sigma, L = self._ep_marginal(K, tau_tilde, v_tilde, Z_tilde)
tau_tilde_root = np.sqrt(tau_tilde)
Sroot_tilde_K = tau_tilde_root[:,None] * K
aux_alpha , _ = dpotrs(L, np.dot(Sroot_tilde_K, v_tilde), lower=1)
alpha = (v_tilde - tau_tilde_root * aux_alpha)[:,None] #(K + Sigma^(\tilde))^(-1) /mu^(/tilde)
LWi, _ = dtrtrs(L, np.diag(tau_tilde_root), lower=1)
Wi = np.dot(LWi.T,LWi)
symmetrify(Wi) #(K + Sigma^(\tilde))^(-1)
dL_dK = 0.5 * (tdot(alpha) - Wi)
dL_dthetaL = likelihood.exact_inference_gradients(np.diag(dL_dK), Y_metadata)
return Posterior(woodbury_inv=Wi, woodbury_vector=alpha, K=K), log_marginal, {'dL_dK':dL_dK, 'dL_dthetaL':dL_dthetaL, 'dL_dm':alpha}
class EPDTC(EPBase, VarDTC):
def inference(self, kern, X, Z, likelihood, Y, mean_function=None, Y_metadata=None, Lm=None, dL_dKmm=None, psi0=None, psi1=None, psi2=None):
assert Y.shape[1]==1, "ep in 1D only (for now!)"
if self.always_reset:
self.reset()
num_data, output_dim = Y.shape
assert output_dim == 1, "ep in 1D only (for now!)"
if Lm is None:
Kmm = kern.K(Z)
Lm = jitchol(Kmm)
Kmm = kern.K(Z)
if psi1 is None:
try:
Kmn = kern.K(Z, X)
@ -167,35 +241,36 @@ class EPDTC(EPBase, VarDTC):
else:
Kmn = psi1.T
if getattr(self, '_ep_approximation', None) is None:
mu, Sigma, mu_tilde, tau_tilde, Z_tilde = self._ep_approximation = self.expectation_propagation(Kmm, Kmn, Y, likelihood, Y_metadata)
if self.ep_mode=="nested":
#Force EP at each step of the optimization
self._ep_approximation = None
mu, Sigma_diag, mu_tilde, tau_tilde, log_Z_tilde = self._ep_approximation = self.expectation_propagation(Kmm, Kmn, Y, likelihood, Y_metadata)
elif self.ep_mode=="alternated":
if getattr(self, '_ep_approximation', None) is None:
#if we don't yet have the results of runnign EP, run EP and store the computed factors in self._ep_approximation
mu, Sigma_diag, mu_tilde, tau_tilde, log_Z_tilde = self._ep_approximation = self.expectation_propagation(Kmm, Kmn, Y, likelihood, Y_metadata)
else:
#if we've already run EP, just use the existing approximation stored in self._ep_approximation
mu, Sigma_diag, mu_tilde, tau_tilde, log_Z_tilde = self._ep_approximation
else:
mu, Sigma, mu_tilde, tau_tilde, Z_tilde = self._ep_approximation
raise ValueError("ep_mode value not valid")
return super(EPDTC, self).inference(kern, X, Z, likelihood, mu_tilde,
mean_function=mean_function,
Y_metadata=Y_metadata,
precision=tau_tilde,
Lm=Lm, dL_dKmm=dL_dKmm,
psi0=psi0, psi1=psi1, psi2=psi2, Z_tilde=np.log(Z_tilde).sum())
psi0=psi0, psi1=psi1, psi2=psi2, Z_tilde=log_Z_tilde.sum())
def expectation_propagation(self, Kmm, Kmn, Y, likelihood, Y_metadata):
num_data, output_dim = Y.shape
assert output_dim == 1, "This EP methods only works for 1D outputs"
LLT0 = Kmm.copy()
#diag.add(LLT0, 1e-8)
Lm = jitchol(LLT0)
Lmi = dtrtri(Lm)
Kmmi = np.dot(Lmi.T,Lmi)
KmmiKmn = np.dot(Kmmi,Kmn)
Qnn_diag = np.sum(Kmn*KmmiKmn,-2)
#Initial values - Posterior distribution parameters: q(f|X,Y) = N(f|mu,Sigma)
mu = np.zeros(num_data)
LLT = Kmm.copy() #Sigma = K.copy()
Sigma_diag = Qnn_diag.copy() + 1e-8
# Makes computing the sign quicker if we work with numpy arrays rather
# than ObsArrays
Y = Y.values.copy()
#Initial values - Marginal moments
Z_hat = np.zeros(num_data,dtype=np.float64)
@ -206,73 +281,110 @@ class EPDTC(EPBase, VarDTC):
v_cav = np.empty(num_data,dtype=np.float64)
#initial values - Gaussian factors
#Initial values - Posterior distribution parameters: q(f|X,Y) = N(f|mu,Sigma)
LLT0 = Kmm.copy()
Lm = jitchol(LLT0) #K_m = L_m L_m^\top
Vm,info = dtrtrs(Lm,Kmn,lower=1)
# Lmi = dtrtri(Lm)
# Kmmi = np.dot(Lmi.T,Lmi)
# KmmiKmn = np.dot(Kmmi,Kmn)
# Qnn_diag = np.sum(Kmn*KmmiKmn,-2)
Qnn_diag = np.sum(Vm*Vm,-2) #diag(Knm Kmm^(-1) Kmn)
#diag.add(LLT0, 1e-8)
if self.old_mutilde is None:
#Initial values - Posterior distribution parameters: q(f|X,Y) = N(f|mu,Sigma)
LLT = LLT0.copy() #Sigma = K.copy()
mu = np.zeros(num_data)
Sigma_diag = Qnn_diag.copy() + 1e-8
tau_tilde, mu_tilde, v_tilde = np.zeros((3, num_data))
else:
assert self.old_mutilde.size == num_data, "data size mis-match: did you change the data? try resetting!"
mu_tilde, v_tilde = self.old_mutilde, self.old_vtilde
tau_tilde = v_tilde/mu_tilde
mu, Sigma_diag, LLT = self._ep_compute_posterior(LLT0, Kmn, tau_tilde, v_tilde)
Sigma_diag += 1e-8
# TODO: Check the log-marginal under both conditions and choose the best one
#Approximation
tau_diff = self.epsilon + 1.
v_diff = self.epsilon + 1.
tau_tilde_old = np.nan
v_tilde_old = np.nan
iterations = 0
tau_tilde_old = 0.
v_tilde_old = 0.
update_order = np.random.permutation(num_data)
while (tau_diff > self.epsilon) or (v_diff > self.epsilon):
while ((tau_diff > self.epsilon) or (v_diff > self.epsilon)) and (iterations < self.max_iters):
update_order = np.random.permutation(num_data)
for i in update_order:
#Cavity distribution parameters
tau_cav[i] = 1./Sigma_diag[i] - self.eta*tau_tilde[i]
v_cav[i] = mu[i]/Sigma_diag[i] - self.eta*v_tilde[i]
if Y_metadata is not None:
# Pick out the relavent metadata for Yi
Y_metadata_i = {}
for key in Y_metadata.keys():
Y_metadata_i[key] = Y_metadata[key][i, :]
else:
Y_metadata_i = None
#Marginal moments
Z_hat[i], mu_hat[i], sigma2_hat[i] = likelihood.moments_match_ep(Y[i], tau_cav[i], v_cav[i])#, Y_metadata=None)#=(None if Y_metadata is None else Y_metadata[i]))
Z_hat[i], mu_hat[i], sigma2_hat[i] = likelihood.moments_match_ep(Y[i], tau_cav[i], v_cav[i], Y_metadata_i=Y_metadata_i)
#Site parameters update
delta_tau = self.delta/self.eta*(1./sigma2_hat[i] - 1./Sigma_diag[i])
delta_v = self.delta/self.eta*(mu_hat[i]/sigma2_hat[i] - mu[i]/Sigma_diag[i])
tau_tilde_prev = tau_tilde[i]
tau_tilde[i] += delta_tau
# Enforce positivity of tau_tilde. Even though this is guaranteed for logconcave sites, it is still possible
# to get negative values due to numerical errors. Moreover, the value of tau_tilde should be positive in order to
# update the marginal likelihood without inestability issues.
if tau_tilde[i] < np.finfo(float).eps:
tau_tilde[i] = np.finfo(float).eps
delta_tau = tau_tilde[i] - tau_tilde_prev
v_tilde[i] += delta_v
#Posterior distribution parameters update
if self.parallel_updates == False:
#DSYR(Sigma, Sigma[:,i].copy(), -delta_tau/(1.+ delta_tau*Sigma[i,i]))
DSYR(LLT,Kmn[:,i].copy(),delta_tau)
L = jitchol(LLT)
V,info = dtrtrs(L,Kmn,lower=1)
Sigma_diag = np.maximum(np.sum(V*V,-2), np.finfo(float).eps) #diag(K_nm (L L^\top)^(-1)) K_mn
si = np.sum(V.T*V[:,i],-1) #(V V^\top)[:,i]
mu += (delta_v-delta_tau*mu[i])*si
#mu = np.dot(Sigma, v_tilde)
#DSYR(Sigma, Sigma[:,i].copy(), -delta_tau/(1.+ delta_tau*Sigma[i,i]))
DSYR(LLT,Kmn[:,i].copy(),delta_tau)
L = jitchol(LLT+np.eye(LLT.shape[0])*1e-7)
V,info = dtrtrs(L,Kmn,lower=1)
Sigma_diag = np.sum(V*V,-2)
si = np.sum(V.T*V[:,i],-1)
mu += (delta_v-delta_tau*mu[i])*si
#mu = np.dot(Sigma, v_tilde)
#(re) compute Sigma and mu using full Cholesky decompy
LLT = LLT0 + np.dot(Kmn*tau_tilde[None,:],Kmn.T)
#diag.add(LLT, 1e-8)
L = jitchol(LLT)
V, _ = dtrtrs(L,Kmn,lower=1)
V2, _ = dtrtrs(L.T,V,lower=0)
#Sigma_diag = np.sum(V*V,-2)
#Knmv_tilde = np.dot(Kmn,v_tilde)
#mu = np.dot(V2.T,Knmv_tilde)
Sigma = np.dot(V2.T,V2)
mu = np.dot(Sigma,v_tilde)
#(re) compute Sigma, Sigma_diag and mu using full Cholesky decompy
mu, Sigma_diag, LLT = self._ep_compute_posterior(LLT0, Kmn, tau_tilde, v_tilde)
Sigma_diag = np.maximum(Sigma_diag, np.finfo(float).eps)
#monitor convergence
#if iterations>0:
tau_diff = np.mean(np.square(tau_tilde-tau_tilde_old))
v_diff = np.mean(np.square(v_tilde-v_tilde_old))
if iterations>0:
tau_diff = np.mean(np.square(tau_tilde-tau_tilde_old))
v_diff = np.mean(np.square(v_tilde-v_tilde_old))
tau_tilde_old = tau_tilde.copy()
v_tilde_old = v_tilde.copy()
# Only to while loop once:?
tau_diff = 0
v_diff = 0
iterations += 1
mu_tilde = v_tilde/tau_tilde
mu_cav = v_cav/tau_cav
sigma2_sigma2tilde = 1./tau_cav + 1./tau_tilde
Z_tilde = np.exp(np.log(Z_hat) + 0.5*np.log(2*np.pi) + 0.5*np.log(sigma2_sigma2tilde)
log_Z_tilde = (np.log(Z_hat) + 0.5*np.log(2*np.pi) + 0.5*np.log(sigma2_sigma2tilde)
+ 0.5*((mu_cav - mu_tilde)**2) / (sigma2_sigma2tilde))
return mu, Sigma, ObsAr(mu_tilde[:,None]), tau_tilde, Z_tilde
self.old_mutilde = mu_tilde
self.old_vtilde = v_tilde
return mu, Sigma_diag, ObsAr(mu_tilde[:,None]), tau_tilde, log_Z_tilde
def _ep_compute_posterior(self, LLT0, Kmn, tau_tilde, v_tilde):
LLT = LLT0 + np.dot(Kmn*tau_tilde[None,:],Kmn.T)
L = jitchol(LLT)
V, _ = dtrtrs(L,Kmn,lower=1)
#Sigma_diag = np.sum(V*V,-2)
#Knmv_tilde = np.dot(Kmn,v_tilde)
#mu = np.dot(V2.T,Knmv_tilde)
Sigma = np.dot(V.T,V)
mu = np.dot(Sigma,v_tilde)
Sigma_diag = np.diag(Sigma).copy()
return (mu, Sigma_diag, LLT)