From e2770b01bb98f42346fbaa9b9414e30998a500b4 Mon Sep 17 00:00:00 2001 From: Nicolo Fusi Date: Mon, 10 Dec 2012 17:20:45 +0000 Subject: [PATCH] models are now pickleable --- GPy/core/model.py | 4 ++-- GPy/inference/optimization.py | 40 +++++++++++++++++------------------ 2 files changed, 21 insertions(+), 23 deletions(-) diff --git a/GPy/core/model.py b/GPy/core/model.py index 2b150d37..0db18061 100644 --- a/GPy/core/model.py +++ b/GPy/core/model.py @@ -189,8 +189,8 @@ class model(parameterised): start = self.extract_param() optimizer = optimization.get_optimizer(optimizer) - opt = optimizer(start, f_fp=f_fp,f=f,fp=fp,model = self,**kwargs) - opt.run() + opt = optimizer(start, model = self, **kwargs) + opt.run(f_fp=f_fp, f=f, fp=fp) self.optimization_runs.append(opt) self.expand_param(opt.x_opt) diff --git a/GPy/inference/optimization.py b/GPy/inference/optimization.py index c96d2a3b..4cf56b69 100644 --- a/GPy/inference/optimization.py +++ b/GPy/inference/optimization.py @@ -19,15 +19,12 @@ class Optimizer(): :param messages: print messages from the optimizer? :type messages: (True | False) :param max_f_eval: maximum number of function evaluations - + :rtype: optimizer object. - - """ - def __init__(self, x_init, f_fp, f, fp , messages=False, max_f_eval=1e4, ftol=None, gtol=None, xtol=None): + + """ + def __init__(self, x_init, messages=False, model = None, max_f_eval=1e4, ftol=None, gtol=None, xtol=None): self.opt_name = None - self.f_fp = f_fp - self.f = f - self.fp = fp self.x_init = x_init self.messages = messages self.f_opt = None @@ -40,14 +37,15 @@ class Optimizer(): self.xtol = xtol self.gtol = gtol self.ftol = ftol - - def run(self): + self.model = model + + def run(self, **kwargs): start = dt.datetime.now() - self.opt() + self.opt(**kwargs) end = dt.datetime.now() self.time = str(end-start) - def opt(self): + def opt(self, f_fp = None, f = None, fp = None): raise NotImplementedError, "this needs to be implemented to use the optimizer class" def plot(self): @@ -71,7 +69,7 @@ class opt_tnc(Optimizer): Optimizer.__init__(self, *args, **kwargs) self.opt_name = "TNC (Scipy implementation)" - def opt(self): + def opt(self, f_fp = None, f = None, fp = None): """ Run the TNC optimizer @@ -79,7 +77,7 @@ class opt_tnc(Optimizer): tnc_rcstrings = ['Local minimum', 'Converged', 'XConverged', 'Maximum number of f evaluations reached', 'Line search failed', 'Function is constant'] - assert self.f_fp != None, "TNC requires f_fp" + assert f_fp != None, "TNC requires f_fp" opt_dict = {} if self.xtol is not None: @@ -89,10 +87,10 @@ class opt_tnc(Optimizer): if self.gtol is not None: opt_dict['pgtol'] = self.gtol - opt_result = optimize.fmin_tnc(self.f_fp, self.x_init, messages = self.messages, + opt_result = optimize.fmin_tnc(f_fp, self.x_init, messages = self.messages, maxfun = self.max_f_eval, **opt_dict) self.x_opt = opt_result[0] - self.f_opt = self.f_fp(self.x_opt)[0] + self.f_opt = f_fp(self.x_opt)[0] self.funct_eval = opt_result[1] self.status = tnc_rcstrings[opt_result[2]] @@ -101,14 +99,14 @@ class opt_lbfgsb(Optimizer): Optimizer.__init__(self, *args, **kwargs) self.opt_name = "L-BFGS-B (Scipy implementation)" - def opt(self): + def opt(self, f_fp = None, f = None, fp = None): """ Run the optimizer """ rcstrings = ['Converged', 'Maximum number of f evaluations reached', 'Error'] - assert self.f_fp != None, "BFGS requires f_fp" + assert f_fp != None, "BFGS requires f_fp" if self.messages: iprint = 1 @@ -123,10 +121,10 @@ class opt_lbfgsb(Optimizer): if self.gtol is not None: opt_dict['pgtol'] = self.gtol - opt_result = optimize.fmin_l_bfgs_b(self.f_fp, self.x_init, iprint = iprint, + opt_result = optimize.fmin_l_bfgs_b(f_fp, self.x_init, iprint = iprint, maxfun = self.max_f_eval, **opt_dict) self.x_opt = opt_result[0] - self.f_opt = self.f_fp(self.x_opt)[0] + self.f_opt = f_fp(self.x_opt)[0] self.funct_eval = opt_result[2]['funcalls'] self.status = rcstrings[opt_result[2]['warnflag']] @@ -135,7 +133,7 @@ class opt_simplex(Optimizer): Optimizer.__init__(self, *args, **kwargs) self.opt_name = "Nelder-Mead simplex routine (via Scipy)" - def opt(self): + def opt(self, f_fp = None, f = None, fp = None): """ The simplex optimizer does not require gradients. """ @@ -150,7 +148,7 @@ class opt_simplex(Optimizer): if self.gtol is not None: print "WARNING: simplex doesn't have an gtol arg, so I'm going to ignore it" - opt_result = optimize.fmin(self.f, self.x_init, (), disp = self.messages, + opt_result = optimize.fmin(f, self.x_init, (), disp = self.messages, maxfun = self.max_f_eval, full_output=True, **opt_dict) self.x_opt = opt_result[0]