diff --git a/GPy/examples/poisson.py b/GPy/examples/poisson.py index 934637f1..ce68e921 100644 --- a/GPy/examples/poisson.py +++ b/GPy/examples/poisson.py @@ -3,46 +3,45 @@ """ -Simple Gaussian Processes classification +Gaussian Processes + Expectation Propagation - Poisson Likelihood """ import pylab as pb import numpy as np import GPy -pb.ion() -pb.close('all') default_seed=10000 -model_type='Full' -inducing=4 -seed=default_seed -"""Simple 1D classification example. -:param model_type: type of model to fit ['Full', 'FITC', 'DTC']. -:param seed : seed value for data generation (default is 4). -:type seed: int -:param inducing : number of inducing variables (only used for 'FITC' or 'DTC'). -:type inducing: int -""" +def toy_1d(seed=default_seed): + """ + Simple 1D classification example + :param seed : seed value for data generation (default is 4). + :type seed: int + """ -X = np.arange(0,100,5)[:,None] -F = np.round(np.sin(X/18.) + .1*X) + np.arange(5,25)[:,None] -E = np.random.randint(-5,5,20)[:,None] -Y = F + E -pb.figure() -likelihood = GPy.inference.likelihoods.poisson(Y,scale=1.) + X = np.arange(0,100,5)[:,None] + F = np.round(np.sin(X/18.) + .1*X) + np.arange(5,25)[:,None] + E = np.random.randint(-5,5,20)[:,None] + Y = F + E -m = GPy.models.GP(X,likelihood=likelihood) -#m = GPy.models.GP(X,Y=likelihood.Y) + kernel = GPy.kern.rbf(1) + distribution = GPy.likelihoods.likelihood_functions.Poisson() + likelihood = GPy.likelihoods.EP(Y,distribution) -m.constrain_positive('var') -m.constrain_positive('len') -m.tie_param('lengthscale') -if not isinstance(m.likelihood,GPy.inference.likelihoods.gaussian): - m.approximate_likelihood() -print m.checkgrad() -# Optimize and plot -m.optimize() -#m.em(plot_all=False) # EM algorithm -m.plot(samples=4) + m = GPy.models.GP(X,likelihood,kernel) + m.ensure_default_constraints() -print(m) + # Approximate likelihood + m.update_likelihood_approximation() + + # Optimize and plot + m.optimize() + #m.EPEM FIXME + print m + + # Plot + pb.subplot(211) + m.plot_f() #GP plot + pb.subplot(212) + m.plot() #Output plot + + return m