diff --git a/GPy/core/model.py b/GPy/core/model.py index 2b150d37..0db18061 100644 --- a/GPy/core/model.py +++ b/GPy/core/model.py @@ -189,8 +189,8 @@ class model(parameterised): start = self.extract_param() optimizer = optimization.get_optimizer(optimizer) - opt = optimizer(start, f_fp=f_fp,f=f,fp=fp,model = self,**kwargs) - opt.run() + opt = optimizer(start, model = self, **kwargs) + opt.run(f_fp=f_fp, f=f, fp=fp) self.optimization_runs.append(opt) self.expand_param(opt.x_opt) diff --git a/GPy/inference/optimization.py b/GPy/inference/optimization.py index 19d627bc..557f8823 100644 --- a/GPy/inference/optimization.py +++ b/GPy/inference/optimization.py @@ -1,7 +1,9 @@ # Copyright (c) 2012, GPy authors (see AUTHORS.txt). # Licensed under the BSD 3-clause license (see LICENSE.txt) + from scipy import optimize +# import rasmussens_minimize as rasm import pdb import pylab as pb import datetime as dt @@ -21,11 +23,8 @@ class Optimizer(): :rtype: optimizer object. """ - def __init__(self, x_init, f_fp, f, fp , messages=False, max_f_eval=1e4, ftol=None, gtol=None, xtol=None, **kwargs): + def __init__(self, x_init, messages=False, model = None, max_f_eval=1e4, ftol=None, gtol=None, xtol=None): self.opt_name = None - self.f_fp = f_fp - self.f = f - self.fp = fp self.x_init = x_init self.messages = messages self.f_opt = None @@ -38,14 +37,15 @@ class Optimizer(): self.xtol = xtol self.gtol = gtol self.ftol = ftol - - def run(self): + self.model = model + + def run(self, **kwargs): start = dt.datetime.now() - self.opt() + self.opt(**kwargs) end = dt.datetime.now() self.time = str(end-start) - def opt(self): + def opt(self, f_fp = None, f = None, fp = None): raise NotImplementedError, "this needs to be implemented to use the optimizer class" def plot(self): @@ -69,7 +69,7 @@ class opt_tnc(Optimizer): Optimizer.__init__(self, *args, **kwargs) self.opt_name = "TNC (Scipy implementation)" - def opt(self): + def opt(self, f_fp = None, f = None, fp = None): """ Run the TNC optimizer @@ -77,7 +77,7 @@ class opt_tnc(Optimizer): tnc_rcstrings = ['Local minimum', 'Converged', 'XConverged', 'Maximum number of f evaluations reached', 'Line search failed', 'Function is constant'] - assert self.f_fp != None, "TNC requires f_fp" + assert f_fp != None, "TNC requires f_fp" opt_dict = {} if self.xtol is not None: @@ -87,10 +87,10 @@ class opt_tnc(Optimizer): if self.gtol is not None: opt_dict['pgtol'] = self.gtol - opt_result = optimize.fmin_tnc(self.f_fp, self.x_init, messages = self.messages, + opt_result = optimize.fmin_tnc(f_fp, self.x_init, messages = self.messages, maxfun = self.max_f_eval, **opt_dict) self.x_opt = opt_result[0] - self.f_opt = self.f_fp(self.x_opt)[0] + self.f_opt = f_fp(self.x_opt)[0] self.funct_eval = opt_result[1] self.status = tnc_rcstrings[opt_result[2]] @@ -99,14 +99,14 @@ class opt_lbfgsb(Optimizer): Optimizer.__init__(self, *args, **kwargs) self.opt_name = "L-BFGS-B (Scipy implementation)" - def opt(self): + def opt(self, f_fp = None, f = None, fp = None): """ Run the optimizer """ rcstrings = ['Converged', 'Maximum number of f evaluations reached', 'Error'] - assert self.f_fp != None, "BFGS requires f_fp" + assert f_fp != None, "BFGS requires f_fp" if self.messages: iprint = 1 @@ -121,10 +121,10 @@ class opt_lbfgsb(Optimizer): if self.gtol is not None: opt_dict['pgtol'] = self.gtol - opt_result = optimize.fmin_l_bfgs_b(self.f_fp, self.x_init, iprint = iprint, + opt_result = optimize.fmin_l_bfgs_b(f_fp, self.x_init, iprint = iprint, maxfun = self.max_f_eval, **opt_dict) self.x_opt = opt_result[0] - self.f_opt = self.f_fp(self.x_opt)[0] + self.f_opt = f_fp(self.x_opt)[0] self.funct_eval = opt_result[2]['funcalls'] self.status = rcstrings[opt_result[2]['warnflag']] @@ -133,7 +133,7 @@ class opt_simplex(Optimizer): Optimizer.__init__(self, *args, **kwargs) self.opt_name = "Nelder-Mead simplex routine (via Scipy)" - def opt(self): + def opt(self, f_fp = None, f = None, fp = None): """ The simplex optimizer does not require gradients. """ @@ -148,7 +148,7 @@ class opt_simplex(Optimizer): if self.gtol is not None: print "WARNING: simplex doesn't have an gtol arg, so I'm going to ignore it" - opt_result = optimize.fmin(self.f, self.x_init, (), disp = self.messages, + opt_result = optimize.fmin(f, self.x_init, (), disp = self.messages, maxfun = self.max_f_eval, full_output=True, **opt_dict) self.x_opt = opt_result[0]