mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-06 02:24:17 +02:00
New functions for EP-matching moments
This commit is contained in:
parent
c521c243e1
commit
68f493b86c
2 changed files with 39 additions and 3 deletions
|
|
@ -7,6 +7,7 @@ from scipy import stats
|
|||
import scipy as sp
|
||||
import pylab as pb
|
||||
from ..util.plot import gpplot
|
||||
from ..util.univariate_Gaussian import std_norm_pdf,std_norm_cdf
|
||||
|
||||
class likelihood_function:
|
||||
"""
|
||||
|
|
@ -37,11 +38,11 @@ class probit(likelihood_function):
|
|||
:param tau_i: precision of the cavity distribution (float)
|
||||
:param v_i: mean/variance of the cavity distribution (float)
|
||||
"""
|
||||
if data_i == 0: data_i = -1 #NOTE Binary classification algorithm works better with classes {-1,1}, 1D-plotting works better with classes {0,1}.
|
||||
#if data_i == 0: data_i = -1 #NOTE Binary classification algorithm works better with classes {-1,1}, 1D-plotting works better with classes {0,1}.
|
||||
# TODO: some version of assert
|
||||
z = data_i*v_i/np.sqrt(tau_i**2 + tau_i)
|
||||
Z_hat = stats.norm.cdf(z)
|
||||
phi = stats.norm.pdf(z)
|
||||
Z_hat = std_norm_cdf(z)
|
||||
phi = std_norm_pdf(z)
|
||||
mu_hat = v_i/tau_i + data_i*phi/(Z_hat*np.sqrt(tau_i**2 + tau_i))
|
||||
sigma2_hat = 1./tau_i - (phi/((tau_i**2+tau_i)*Z_hat))*(z+phi/Z_hat)
|
||||
return Z_hat, mu_hat, sigma2_hat
|
||||
|
|
|
|||
35
GPy/util/univariate_Gaussian.py
Normal file
35
GPy/util/univariate_Gaussian.py
Normal file
|
|
@ -0,0 +1,35 @@
|
|||
# Copyright (c) 2012, 2013 Ricardo Andrade
|
||||
# Licensed under the BSD 3-clause license (see LICENSE.txt)
|
||||
|
||||
import numpy as np
|
||||
from scipy import weave
|
||||
|
||||
def std_norm_pdf(x):
|
||||
"""Standard Gaussian density function"""
|
||||
return 1./np.sqrt(2.*np.pi)*np.exp(-.5*x**2)
|
||||
|
||||
def std_norm_cdf(x):
|
||||
"""
|
||||
Cumulative standard Gaussian distribution
|
||||
Based on Abramowitz, M. and Stegun, I. (1970)
|
||||
"""
|
||||
support_code = "#include <math.h>"
|
||||
code = """
|
||||
|
||||
double sign = 1.0;
|
||||
if (x < 0.0){
|
||||
sign = -1.0;
|
||||
x = -x;
|
||||
}
|
||||
x = x/sqrt(2.0);
|
||||
|
||||
double t = 1.0/(1.0 + 0.3275911*x);
|
||||
|
||||
double erf = 1. - exp(-x*x)*t*(0.254829592 + t*(-0.284496736 + t*(1.421413741 + t*(-1.453152027 + t*(1.061405429)))));
|
||||
|
||||
return_val = 0.5*(1.0 + sign*erf);
|
||||
"""
|
||||
x = float(x)
|
||||
return weave.inline(code,arg_names=['x'],support_code=support_code)
|
||||
|
||||
|
||||
Loading…
Add table
Add a link
Reference in a new issue