mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-10 12:32:40 +02:00
very weird merge conflict, including in files that I did not change
This commit is contained in:
commit
601175de2d
73 changed files with 2234 additions and 1567 deletions
|
|
@ -5,6 +5,7 @@ import numpy as np
|
|||
from ..util.univariate_Gaussian import std_norm_pdf, std_norm_cdf
|
||||
import link_functions
|
||||
from likelihood import Likelihood
|
||||
from scipy import stats
|
||||
|
||||
class Bernoulli(Likelihood):
|
||||
"""
|
||||
|
|
@ -43,7 +44,7 @@ class Bernoulli(Likelihood):
|
|||
Y_prep[Y.flatten() == 0] = -1
|
||||
return Y_prep
|
||||
|
||||
def moments_match_ep(self, data_i, tau_i, v_i):
|
||||
def moments_match_ep(self, Y_i, tau_i, v_i):
|
||||
"""
|
||||
Moments match of the marginal approximation in EP algorithm
|
||||
|
||||
|
|
@ -51,9 +52,9 @@ class Bernoulli(Likelihood):
|
|||
:param tau_i: precision of the cavity distribution (float)
|
||||
:param v_i: mean/variance of the cavity distribution (float)
|
||||
"""
|
||||
if data_i == 1:
|
||||
if Y_i == 1:
|
||||
sign = 1.
|
||||
elif data_i == 0:
|
||||
elif Y_i == 0:
|
||||
sign = -1
|
||||
else:
|
||||
raise ValueError("bad value for Bernouilli observation (0, 1)")
|
||||
|
|
@ -76,7 +77,7 @@ class Bernoulli(Likelihood):
|
|||
|
||||
return Z_hat, mu_hat, sigma2_hat
|
||||
|
||||
def predictive_mean(self, mu, variance):
|
||||
def predictive_mean(self, mu, variance, Y_metadata=None):
|
||||
|
||||
if isinstance(self.gp_link, link_functions.Probit):
|
||||
return stats.norm.cdf(mu/np.sqrt(1+variance))
|
||||
|
|
@ -87,13 +88,12 @@ class Bernoulli(Likelihood):
|
|||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def predictive_variance(self, mu, variance, pred_mean):
|
||||
def predictive_variance(self, mu, variance, pred_mean, Y_metadata=None):
|
||||
|
||||
if isinstance(self.gp_link, link_functions.Heaviside):
|
||||
return 0.
|
||||
else:
|
||||
return np.nan
|
||||
#raise NotImplementedError
|
||||
|
||||
def pdf_link(self, link_f, y, Y_metadata=None):
|
||||
"""
|
||||
|
|
@ -212,7 +212,7 @@ class Bernoulli(Likelihood):
|
|||
np.seterr(**state)
|
||||
return d3logpdf_dlink3
|
||||
|
||||
def samples(self, gp):
|
||||
def samples(self, gp, Y_metadata=None):
|
||||
"""
|
||||
Returns a set of samples of observations based on a given value of the latent variable.
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue