more massive and destructive changes everywhere

This commit is contained in:
James Hensman 2013-12-05 15:09:31 -05:00
parent 881800126f
commit 5ec64d2279
18 changed files with 202 additions and 166 deletions

View file

@ -2,10 +2,10 @@
# Licensed under the BSD 3-clause license (see LICENSE.txt)
import numpy as np
from scipy import stats,special
from scipy import stats, special
import scipy as sp
from GPy.util.univariate_Gaussian import std_norm_pdf,std_norm_cdf
import gp_transformations
from GPy.util.univariate_Gaussian import std_norm_pdf, std_norm_cdf
import link_functions
from likelihood import Likelihood
class Bernoulli(Likelihood):
@ -16,29 +16,29 @@ class Bernoulli(Likelihood):
p(y_{i}|\\lambda(f_{i})) = \\lambda(f_{i})^{y_{i}}(1-f_{i})^{1-y_{i}}
.. Note::
Y is expected to take values in {-1,1}
Y is expected to take values in {-1, 1}
Probit likelihood usually used
"""
def __init__(self,gp_link=None,analytical_mean=False,analytical_variance=False):
super(Bernoulli, self).__init__(gp_link,analytical_mean,analytical_variance)
if isinstance(gp_link , (gp_transformations.Heaviside, gp_transformations.Probit)):
def __init__(self, gp_link=None, analytical_mean=False, analytical_variance=False):
super(Bernoulli, self).__init__(gp_link, analytical_mean, analytical_variance)
if isinstance(gp_link , (link_functions.Heaviside, link_functions.Probit)):
self.log_concave = True
def _preprocess_values(self,Y):
def _preprocess_values(self, Y):
"""
Check if the values of the observations correspond to the values
assumed by the likelihood function.
..Note:: Binary classification algorithm works better with classes {-1,1}
..Note:: Binary classification algorithm works better with classes {-1, 1}
"""
Y_prep = Y.copy()
Y1 = Y[Y.flatten()==1].size
Y2 = Y[Y.flatten()==0].size
assert Y1 + Y2 == Y.size, 'Bernoulli likelihood is meant to be used only with outputs in {0,1}.'
assert Y1 + Y2 == Y.size, 'Bernoulli likelihood is meant to be used only with outputs in {0, 1}.'
Y_prep[Y.flatten() == 0] = -1
return Y_prep
def _moments_match_analytical(self,data_i,tau_i,v_i):
def _moments_match_analytical(self, data_i, tau_i, v_i):
"""
Moments match of the marginal approximation in EP algorithm
@ -51,15 +51,15 @@ class Bernoulli(Likelihood):
elif data_i == 0:
sign = -1
else:
raise ValueError("bad value for Bernouilli observation (0,1)")
if isinstance(self.gp_link,gp_transformations.Probit):
raise ValueError("bad value for Bernouilli observation (0, 1)")
if isinstance(self.gp_link, link_functions.Probit):
z = sign*v_i/np.sqrt(tau_i**2 + tau_i)
Z_hat = std_norm_cdf(z)
phi = std_norm_pdf(z)
mu_hat = v_i/tau_i + sign*phi/(Z_hat*np.sqrt(tau_i**2 + tau_i))
sigma2_hat = 1./tau_i - (phi/((tau_i**2+tau_i)*Z_hat))*(z+phi/Z_hat)
elif isinstance(self.gp_link,gp_transformations.Heaviside):
elif isinstance(self.gp_link, link_functions.Heaviside):
a = sign*v_i/np.sqrt(tau_i)
Z_hat = std_norm_cdf(a)
N = std_norm_pdf(a)
@ -68,24 +68,24 @@ class Bernoulli(Likelihood):
if np.any(np.isnan([Z_hat, mu_hat, sigma2_hat])):
stop
else:
raise ValueError("Exact moment matching not available for link {}".format(self.gp_link.gp_transformations.__name__))
raise ValueError("Exact moment matching not available for link {}".format(self.gp_link.__name__))
return Z_hat, mu_hat, sigma2_hat
def _predictive_mean_analytical(self,mu,variance):
def _predictive_mean_analytical(self, mu, variance):
if isinstance(self.gp_link,gp_transformations.Probit):
if isinstance(self.gp_link, link_functions.Probit):
return stats.norm.cdf(mu/np.sqrt(1+variance))
elif isinstance(self.gp_link,gp_transformations.Heaviside):
elif isinstance(self.gp_link, link_functions.Heaviside):
return stats.norm.cdf(mu/np.sqrt(variance))
else:
raise NotImplementedError
def _predictive_variance_analytical(self,mu,variance, pred_mean):
def _predictive_variance_analytical(self, mu, variance, pred_mean):
if isinstance(self.gp_link,gp_transformations.Heaviside):
if isinstance(self.gp_link, link_functions.Heaviside):
return 0.
else:
raise NotImplementedError
@ -106,7 +106,7 @@ class Bernoulli(Likelihood):
:rtype: float
.. Note:
Each y_i must be in {0,1}
Each y_i must be in {0, 1}
"""
assert np.atleast_1d(link_f).shape == np.atleast_1d(y).shape
objective = (link_f**y) * ((1.-link_f)**(1.-y))
@ -195,13 +195,13 @@ class Bernoulli(Likelihood):
d3logpdf_dlink3 = 2*(y/(link_f**3) - (1-y)/((1-link_f)**3))
return d3logpdf_dlink3
def _mean(self,gp):
def _mean(self, gp):
"""
Mass (or density) function
"""
return self.gp_link.transf(gp)
def _variance(self,gp):
def _variance(self, gp):
"""
Mass (or density) function
"""