From e7a9a6a2fa87f956577a5e2b60341fefad9ea274 Mon Sep 17 00:00:00 2001 From: Nicolo Fusi Date: Thu, 29 Nov 2012 16:28:49 +0000 Subject: [PATCH] inference --- GPy/inference/Expectation_Propagation.py | 224 +++++++++++++++++++++++ GPy/inference/__init__.py | 0 GPy/inference/likelihoods.py | 100 ++++++++++ GPy/inference/optimization.py | 201 ++++++++++++++++++++ GPy/inference/samplers.py | 81 ++++++++ 5 files changed, 606 insertions(+) create mode 100644 GPy/inference/Expectation_Propagation.py create mode 100644 GPy/inference/__init__.py create mode 100644 GPy/inference/likelihoods.py create mode 100644 GPy/inference/optimization.py create mode 100644 GPy/inference/samplers.py diff --git a/GPy/inference/Expectation_Propagation.py b/GPy/inference/Expectation_Propagation.py new file mode 100644 index 00000000..bb1bb959 --- /dev/null +++ b/GPy/inference/Expectation_Propagation.py @@ -0,0 +1,224 @@ +import numpy as np +import random +from scipy import stats, linalg +from .likelihoods import likelihood +from ..core import model +from ..util.linalg import pdinv,mdot,jitchol +from ..util.plot import gpplot +from .. import kern + +class EP: + def __init__(self,covariance,likelihood,Kmn=None,Knn_diag=None,epsilon=1e-3,powerep=[1.,1.]): + """ + Expectation Propagation + + Arguments + --------- + X : input observations + likelihood : Output's likelihood (likelihood class) + kernel : a GPy kernel (kern class) + inducing : Either an array specifying the inducing points location or a sacalar defining their number. None value for using a non-sparse model is used. + powerep : Power-EP parameters (eta,delta) - 2x1 numpy array (floats) + epsilon : Convergence criterion, maximum squared difference allowed between mean updates to stop iterations (float) + """ + self.likelihood = likelihood + assert covariance.shape[0] == covariance.shape[1] + if Kmn is not None: + self.Kmm = covariance + self.Kmn = Kmn + self.M = self.Kmn.shape[0] + self.N = self.Kmn.shape[1] + assert self.M < self.N, 'The number of inducing inputs must be smaller than the number of observations' + else: + self.K = covariance + self.N = self.K.shape[0] + if Knn_diag is not None: + self.Knn_diag = Knn_diag + assert len(Knn_diag) == self.N, 'Knn_diagonal has size different from N' + + self.epsilon = epsilon + self.eta, self.delta = powerep + self.jitter = 1e-12 + + """ + Initial values - Likelihood approximation parameters: + p(y|f) = t(f|tau_tilde,v_tilde) + """ + self.tau_tilde = np.zeros(self.N) + self.v_tilde = np.zeros(self.N) + + def restart_EP(self): + """ + Set the EP approximation to initial state + """ + self.tau_tilde = np.zeros(self.N) + self.v_tilde = np.zeros(self.N) + self.mu = np.zeros(self.N) + +class Full(EP): + def fit_EP(self): + """ + The expectation-propagation algorithm. + For nomenclature see Rasmussen & Williams 2006 (pag. 52-60) + """ + #Prior distribution parameters: p(f|X) = N(f|0,K) + #self.K = self.kernel.K(self.X,self.X) + + #Initial values - Posterior distribution parameters: q(f|X,Y) = N(f|mu,Sigma) + self.mu=np.zeros(self.N) + self.Sigma=self.K.copy() + + """ + Initial values - Cavity distribution parameters: + q_(f|mu_,sigma2_) = Product{q_i(f|mu_i,sigma2_i)} + sigma_ = 1./tau_ + mu_ = v_/tau_ + """ + self.tau_ = np.empty(self.N,dtype=float) + self.v_ = np.empty(self.N,dtype=float) + + #Initial values - Marginal moments + z = np.empty(self.N,dtype=float) + self.Z_hat = np.empty(self.N,dtype=float) + phi = np.empty(self.N,dtype=float) + mu_hat = np.empty(self.N,dtype=float) + sigma2_hat = np.empty(self.N,dtype=float) + + #Approximation + epsilon_np1 = self.epsilon + 1. + epsilon_np2 = self.epsilon + 1. + self.iterations = 0 + self.np1 = [self.tau_tilde.copy()] + self.np2 = [self.v_tilde.copy()] + while epsilon_np1 > self.epsilon or epsilon_np2 > self.epsilon: + update_order = np.arange(self.N) + random.shuffle(update_order) + for i in update_order: + #Cavity distribution parameters + self.tau_[i] = 1./self.Sigma[i,i] - self.eta*self.tau_tilde[i] + self.v_[i] = self.mu[i]/self.Sigma[i,i] - self.eta*self.v_tilde[i] + #Marginal moments + self.Z_hat[i], mu_hat[i], sigma2_hat[i] = self.likelihood.moments_match(i,self.tau_[i],self.v_[i]) + #Site parameters update + Delta_tau = self.delta/self.eta*(1./sigma2_hat[i] - 1./self.Sigma[i,i]) + Delta_v = self.delta/self.eta*(mu_hat[i]/sigma2_hat[i] - self.mu[i]/self.Sigma[i,i]) + self.tau_tilde[i] = self.tau_tilde[i] + Delta_tau + self.v_tilde[i] = self.v_tilde[i] + Delta_v + #Posterior distribution parameters update + si=self.Sigma[:,i].reshape(self.N,1) + self.Sigma = self.Sigma - Delta_tau/(1.+ Delta_tau*self.Sigma[i,i])*np.dot(si,si.T) + self.mu = np.dot(self.Sigma,self.v_tilde) + self.iterations += 1 + #Sigma recomptutation with Cholesky decompositon + Sroot_tilde_K = np.sqrt(self.tau_tilde)[:,None]*(self.K) + B = np.eye(self.N) + np.sqrt(self.tau_tilde)[None,:]*Sroot_tilde_K + L = jitchol(B) + V,info = linalg.flapack.dtrtrs(L,Sroot_tilde_K,lower=1) + self.Sigma = self.K - np.dot(V.T,V) + self.mu = np.dot(self.Sigma,self.v_tilde) + epsilon_np1 = sum((self.tau_tilde-self.np1[-1])**2)/self.N + epsilon_np2 = sum((self.v_tilde-self.np2[-1])**2)/self.N + self.np1.append(self.tau_tilde.copy()) + self.np2.append(self.v_tilde.copy()) + + self.np2.append(self.v_tilde.copy()) + +class FITC(EP): + def fit_EP(self): + """ + The expectation-propagation algorithm with sparse pseudo-input. + For nomenclature see Naish-Guzman and Holden, 2008. + """ + + """ + Prior approximation parameters: + q(f|X) = int_{df}{N(f|KfuKuu_invu,diag(Kff-Qff)*N(u|0,Kuu)} = N(f|0,Sigma0) + Sigma0 = diag(Knn-Qnn) + Qnn, Qnn = Knm*Kmmi*Kmn + """ + self.Kmmi, self.Kmm_hld = pdinv(self.Kmm) + self.P0 = self.Kmn.T + self.KmnKnm = np.dot(self.P0.T, self.P0) + self.KmmiKmn = np.dot(self.Kmmi,self.P0.T) + self.Qnn_diag = np.sum(self.P0.T*self.KmmiKmn,-2) + self.Diag0 = self.Knn_diag - self.Qnn_diag + self.R0 = jitchol(self.Kmmi).T + + """ + Posterior approximation: q(f|y) = N(f| mu, Sigma) + Sigma = Diag + P*R.T*R*P.T + K + mu = w + P*gamma + """ + self.w = np.zeros(self.N) + self.gamma = np.zeros(self.M) + self.mu = np.zeros(self.N) + self.P = self.P0.copy() + self.R = self.R0.copy() + self.Diag = self.Diag0.copy() + self.Sigma_diag = self.Knn_diag + + """ + Initial values - Cavity distribution parameters: + q_(g|mu_,sigma2_) = Product{q_i(g|mu_i,sigma2_i)} + sigma_ = 1./tau_ + mu_ = v_/tau_ + """ + self.tau_ = np.empty(self.N,dtype=float) + self.v_ = np.empty(self.N,dtype=float) + + #Initial values - Marginal moments + z = np.empty(self.N,dtype=float) + self.Z_hat = np.empty(self.N,dtype=float) + phi = np.empty(self.N,dtype=float) + mu_hat = np.empty(self.N,dtype=float) + sigma2_hat = np.empty(self.N,dtype=float) + + #Approximation + epsilon_np1 = 1 + epsilon_np2 = 1 + self.iterations = 0 + self.np1 = [self.tau_tilde.copy()] + self.np2 = [self.v_tilde.copy()] + while epsilon_np1 > self.epsilon or epsilon_np2 > self.epsilon: + update_order = np.arange(self.N) + random.shuffle(update_order) + for i in update_order: + #Cavity distribution parameters + self.tau_[i] = 1./self.Sigma_diag[i] - self.eta*self.tau_tilde[i] + self.v_[i] = self.mu[i]/self.Sigma_diag[i] - self.eta*self.v_tilde[i] + #Marginal moments + self.Z_hat[i], mu_hat[i], sigma2_hat[i] = self.likelihood.moments_match(i,self.tau_[i],self.v_[i]) + #Site parameters update + Delta_tau = self.delta/self.eta*(1./sigma2_hat[i] - 1./self.Sigma_diag[i]) + Delta_v = self.delta/self.eta*(mu_hat[i]/sigma2_hat[i] - self.mu[i]/self.Sigma_diag[i]) + self.tau_tilde[i] = self.tau_tilde[i] + Delta_tau + self.v_tilde[i] = self.v_tilde[i] + Delta_v + #Posterior distribution parameters update + dtd1 = Delta_tau*self.Diag[i] + 1. + dii = self.Diag[i] + self.Diag[i] = dii - (Delta_tau * dii**2.)/dtd1 + pi_ = self.P[i,:].reshape(1,self.M) + self.P[i,:] = pi_ - (Delta_tau*dii)/dtd1 * pi_ + Rp_i = np.dot(self.R,pi_.T) + RTR = np.dot(self.R.T,np.dot(np.eye(self.M) - Delta_tau/(1.+Delta_tau*self.Sigma_diag[i]) * np.dot(Rp_i,Rp_i.T),self.R)) + self.R = jitchol(RTR).T + self.w[i] = self.w[i] + (Delta_v - Delta_tau*self.w[i])*dii/dtd1 + self.gamma = self.gamma + (Delta_v - Delta_tau*self.mu[i])*np.dot(RTR,self.P[i,:].T) + self.RPT = np.dot(self.R,self.P.T) + self.Sigma_diag = self.Diag + np.sum(self.RPT.T*self.RPT.T,-1) + self.mu = self.w + np.dot(self.P,self.gamma) + self.iterations += 1 + #Sigma recomptutation with Cholesky decompositon + self.Diag = self.Diag0/(1.+ self.Diag0 * self.tau_tilde) + self.P = (self.Diag / self.Diag0)[:,None] * self.P0 + self.RPT0 = np.dot(self.R0,self.P0.T) + L = jitchol(np.eye(self.M) + np.dot(self.RPT0,(1./self.Diag0 - self.Diag/(self.Diag0**2))[:,None]*self.RPT0.T)) + self.R,info = linalg.flapack.dtrtrs(L,self.R0,lower=1) + self.RPT = np.dot(self.R,self.P.T) + self.Sigma_diag = self.Diag + np.sum(self.RPT.T*self.RPT.T,-1) + self.w = self.Diag * self.v_tilde + self.gamma = np.dot(self.R.T, np.dot(self.RPT,self.v_tilde)) + self.mu = self.w + np.dot(self.P,self.gamma) + epsilon_np1 = sum((self.tau_tilde-self.np1[-1])**2)/self.N + epsilon_np2 = sum((self.v_tilde-self.np2[-1])**2)/self.N + self.np1.append(self.tau_tilde.copy()) + self.np2.append(self.v_tilde.copy()) diff --git a/GPy/inference/__init__.py b/GPy/inference/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/GPy/inference/likelihoods.py b/GPy/inference/likelihoods.py new file mode 100644 index 00000000..f4e47c10 --- /dev/null +++ b/GPy/inference/likelihoods.py @@ -0,0 +1,100 @@ +import numpy as np +from scipy import stats +import scipy as sp +import pylab as pb +from ..util.plot import gpplot + +class likelihood: + def __init__(self,Y): + """ + Likelihood class for doing Expectation propagation + + :param Y: observed output (Nx1 numpy.darray) + ..Note:: Y values allowed depend on the likelihood used + """ + self.Y = Y + self.N = self.Y.shape[0] + + def plot1Da(self,X_new,Mean_new,Var_new,X_u,Mean_u,Var_u): + """ + Plot the predictive distribution of the GP model for 1-dimensional inputs + + :param X_new: The points at which to make a prediction + :param Mean_new: mean values at X_new + :param Var_new: variance values at X_new + :param X_u: input (inducing) points used to train the model + :param Mean_u: mean values at X_u + :param Var_new: variance values at X_u + """ + assert X_new.shape[1] == 1, 'Number of dimensions must be 1' + gpplot(X_new,Mean_new,Var_new) + pb.errorbar(X_u,Mean_u,2*np.sqrt(Var_u),fmt='r+') + pb.plot(X_u,Mean_u,'ro') + + def plot2D(self,X,X_new,F_new,U=None): + """ + Predictive distribution of the fitted GP model for 2-dimensional inputs + + :param X_new: The points at which to make a prediction + :param Mean_new: mean values at X_new + :param Var_new: variance values at X_new + :param X_u: input points used to train the model + :param Mean_u: mean values at X_u + :param Var_new: variance values at X_u + """ + N,D = X_new.shape + assert D == 2, 'Number of dimensions must be 2' + n = np.sqrt(N) + x1min = X_new[:,0].min() + x1max = X_new[:,0].max() + x2min = X_new[:,1].min() + x2max = X_new[:,1].max() + pb.imshow(F_new.reshape(n,n),extent=(x1min,x1max,x2max,x2min),vmin=0,vmax=1) + pb.colorbar() + C1 = np.arange(self.N)[self.Y.flatten()==1] + C2 = np.arange(self.N)[self.Y.flatten()==-1] + [pb.plot(X[i,0],X[i,1],'ro') for i in C1] + [pb.plot(X[i,0],X[i,1],'bo') for i in C2] + pb.xlim(x1min,x1max) + pb.ylim(x2min,x2max) + if U is not None: + [pb.plot(a,b,'wo') for a,b in U] + +class probit(likelihood): + """ + Probit likelihood + Y is expected to take values in {-1,1} + ----- + $$ + L(x) = \\Phi (Y_i*f_i) + $$ + """ + def moments_match(self,i,tau_i,v_i): + """ + Moments match of the marginal approximation in EP algorithm + + :param i: number of observation (int) + :param tau_i: precision of the cavity distribution (float) + :param v_i: mean/variance of the cavity distribution (float) + """ + z = self.Y[i]*v_i/np.sqrt(tau_i**2 + tau_i) + Z_hat = stats.norm.cdf(z) + phi = stats.norm.pdf(z) + mu_hat = v_i/tau_i + self.Y[i]*phi/(Z_hat*np.sqrt(tau_i**2 + tau_i)) + sigma2_hat = 1./tau_i - (phi/((tau_i**2+tau_i)*Z_hat))*(z+phi/Z_hat) + return Z_hat, mu_hat, sigma2_hat + + def plot1Db(self,X,X_new,F_new,U=None): + assert X.shape[1] == 1, 'Number of dimensions must be 1' + gpplot(X_new,F_new,np.zeros(X_new.shape[0])) + pb.plot(X,(self.Y+1)/2,'kx',mew=1.5) + pb.ylim(-0.2,1.2) + if U is not None: + pb.plot(U,U*0+.5,'r|',mew=1.5,markersize=12) + + def predictive_mean(self,mu,variance): + return stats.norm.cdf(mu/np.sqrt(1+variance)) + + def log_likelihood_gradients(): + raise NotImplementedError + diff --git a/GPy/inference/optimization.py b/GPy/inference/optimization.py new file mode 100644 index 00000000..6a28db44 --- /dev/null +++ b/GPy/inference/optimization.py @@ -0,0 +1,201 @@ +from scipy import optimize +# import rasmussens_minimize as rasm +import pdb +import pylab as pb +import datetime as dt + +class Optimizer(): + def __init__(self, x_init, f_fp, f, fp , messages = False, max_f_eval = 1e4, ftol = None, gtol = None, xtol = None): + """ + Superclass for all the optimizers. + + Arguments: + + x_init: initial set of parameters + f_fp: function that returns the function AND the gradients at the same time + f: function to optimize + fp: gradients + messages: print messages from the optimizer? (True | False) + max_f_eval: maximum number of function evaluations + + """ + self.opt_name = None + self.f_fp = f_fp + self.f = f + self.fp = fp + self.x_init = x_init + self.messages = messages + self.f_opt = None + self.x_opt = None + self.funct_eval = None + self.status = None + self.max_f_eval = int(max_f_eval) + self.trace = None + self.time = "Not available" + self.xtol = xtol + self.gtol = gtol + self.ftol = ftol + + def run(self): + start = dt.datetime.now() + self.opt() + end = dt.datetime.now() + self.time = str(end-start) + + def opt(self): + raise NotImplementedError, "this needs to be implemented to utilise the optimizer class" + + def plot(self): + if self.trace == None: + print "No trace present so I can't plot it. Please check that the optimizer actually supplies a trace." + else: + pb.figure() + pb.plot(self.trace) + pb.xlabel('Iteration') + pb.ylabel('f(x)') + + def diagnostics(self): + print "Optimizer: \t\t\t\t %s" % self.opt_name + print "f(x_opt): \t\t\t\t %.3f" % self.f_opt + print "Number of function evaluations: \t %d" % self.funct_eval + print "Optimization status: \t\t\t %s" % self.status + print "Time elapsed: \t\t\t\t %s" % self.time + +class opt_tnc(Optimizer): + def __init__(self, *args, **kwargs): + Optimizer.__init__(self, *args, **kwargs) + self.opt_name = "TNC (Scipy implementation)" + + def opt(self): + """ + Run the TNC optimizer + + """ + tnc_rcstrings = ['Local minimum', 'Converged', 'XConverged', 'Maximum number of f evaluations reached', + 'Line search failed', 'Function is constant'] + + assert self.f_fp != None, "TNC requires f_fp" + + opt_dict = {} + if self.xtol is not None: + opt_dict['xtol'] = self.xtol + if self.ftol is not None: + opt_dict['ftol'] = self.ftol + if self.gtol is not None: + opt_dict['pgtol'] = self.gtol + + opt_result = optimize.fmin_tnc(self.f_fp, self.x_init, messages = self.messages, + maxfun = self.max_f_eval, **opt_dict) + self.x_opt = opt_result[0] + self.f_opt = self.f_fp(self.x_opt)[0] + self.funct_eval = opt_result[1] + self.status = tnc_rcstrings[opt_result[2]] + +class opt_lbfgsb(Optimizer): + def __init__(self, *args, **kwargs): + Optimizer.__init__(self, *args, **kwargs) + self.opt_name = "L-BFGS-B (Scipy implementation)" + + def opt(self): + """ + Run the optimizer + + """ + rcstrings = ['Converged', 'Maximum number of f evaluations reached', 'Error'] + + assert self.f_fp != None, "BFGS requires f_fp" + + if self.messages: + iprint = 1 + else: + iprint = -1 + + opt_dict = {} + if self.xtol is not None: + print "WARNING: l-bfgs-b doesn't have an xtol arg, so I'm going to ignore it" + if self.ftol is not None: + print "WARNING: l-bfgs-b doesn't have an ftol arg, so I'm going to ignore it" + if self.gtol is not None: + opt_dict['pgtol'] = self.gtol + + opt_result = optimize.fmin_l_bfgs_b(self.f_fp, self.x_init, iprint = iprint, + maxfun = self.max_f_eval, **opt_dict) + self.x_opt = opt_result[0] + self.f_opt = self.f_fp(self.x_opt)[0] + self.funct_eval = opt_result[2]['funcalls'] + self.status = rcstrings[opt_result[2]['warnflag']] + +class opt_simplex(Optimizer): + def __init__(self, *args, **kwargs): + Optimizer.__init__(self, *args, **kwargs) + self.opt_name = "Nelder-Mead simplex routine (via Scipy)" + + def opt(self): + """ + The simplex optimizer does not require gradients, which + is great during development. Otherwise it's a bit slow. + """ + + statuses = ['Converged', 'Maximum number of function evaluations made','Maximum number of iterations reached'] + + opt_dict = {} + if self.xtol is not None: + opt_dict['xtol'] = self.xtol + if self.ftol is not None: + opt_dict['ftol'] = self.ftol + if self.gtol is not None: + print "WARNING: simplex doesn't have an gtol arg, so I'm going to ignore it" + + opt_result = optimize.fmin(self.f, self.x_init, (), disp = self.messages, + maxfun = self.max_f_eval, full_output=True, **opt_dict) + + self.x_opt = opt_result[0] + self.f_opt = opt_result[1] + self.funct_eval = opt_result[3] + self.status = statuses[opt_result[4]] + + self.trace = None + + +# class opt_rasm(Optimizer): +# def __init__(self, *args, **kwargs): +# Optimizer.__init__(self, *args, **kwargs) +# self.opt_name = "Rasmussen's SCG" + +# def opt(self): +# """ +# Run Rasmussen's SCG optimizer +# """ + +# assert self.f_fp != None, "Rasmussen's minimizer requires f_fp" +# statuses = ['Converged', 'Line search failed', 'Maximum number of f evaluations reached', +# 'NaNs in optimization'] + +# opt_dict = {} +# if self.xtol is not None: +# print "WARNING: minimize doesn't have an xtol arg, so I'm going to ignore it" +# if self.ftol is not None: +# print "WARNING: minimize doesn't have an ftol arg, so I'm going to ignore it" +# if self.gtol is not None: +# print "WARNING: minimize doesn't have an gtol arg, so I'm going to ignore it" + +# opt_result = rasm.minimize(self.x_init, self.f_fp, (), messages = self.messages, +# maxnumfuneval = self.max_f_eval) +# self.x_opt = opt_result[0] +# self.f_opt = opt_result[1][-1] +# self.funct_eval = opt_result[2] +# self.status = statuses[opt_result[3]] + +# self.trace = opt_result[1] + +def get_optimizer(f_min): + optimizers = {'fmin_tnc': opt_tnc, + # 'rasmussen': opt_rasm, + 'simplex': opt_simplex, + 'lbfgsb': opt_lbfgsb} + + for opt_name in optimizers.keys(): + if opt_name.lower().find(f_min.lower()) != -1: + return optimizers[opt_name] + + raise KeyError('No optimizer was found matching the name: %s' % f_min) diff --git a/GPy/inference/samplers.py b/GPy/inference/samplers.py new file mode 100644 index 00000000..866baa63 --- /dev/null +++ b/GPy/inference/samplers.py @@ -0,0 +1,81 @@ +import numpy as np +from scipy import linalg, optimize +import pylab as pb +import Tango +import sys +import re +import numdifftools as ndt +import pdb +import cPickle + + +class Metropolis_Hastings: + def __init__(self,model,cov=None): + """Metropolis Hastings, with tunings according to Gelman et al. """ + self.model = model + current = self.model.extract_param() + self.D = current.size + self.chains = [] + if cov is None: + self.cov = model.Laplace_covariance() + else: + self.cov = cov + self.scale = 2.4/np.sqrt(self.D) + self.new_chain(current) + + def new_chain(self, start=None): + self.chains.append([]) + if start is None: + self.model.randomize() + else: + self.model.expand_param(start) + + + + def sample(self, Ntotal, Nburn, Nthin, tune=True, tune_throughout=False, tune_interval=400): + current = self.model.extract_param() + fcurrent = self.model.log_likelihood() + self.model.log_prior() + accepted = np.zeros(Ntotal,dtype=np.bool) + for it in range(Ntotal): + print "sample %d of %d\r"%(it,Ntotal), + sys.stdout.flush() + prop = np.random.multivariate_normal(current, self.cov*self.scale*self.scale) + self.model.expand_param(prop) + fprop = self.model.log_likelihood() + self.model.log_prior() + + if fprop>fcurrent:#sample accepted, going 'uphill' + accepted[it] = True + current = prop + fcurrent = fprop + else: + u = np.random.rand() + if np.exp(fprop-fcurrent)>u:#sample accepted downhill + accepted[it] = True + current = prop + fcurrent = fprop + + #store current value + if (it > Nburn) & ((it%Nthin)==0): + self.chains[-1].append(current) + + #tuning! + if it & ((it%tune_interval)==0) & tune & ((it .25: + self.scale *= 1.1 + if pc < .15: + self.scale /= 1.1 + + def predict(self,function,args): + """Make a prediction for the function, to which we will pass the additional arguments""" + param = self.model.get_param() + fs = [] + for p in self.chain: + self.model.set_param(p) + fs.append(function(*args)) + self.model.set_param(param)# reset model to starting state + return fs + + +