mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-06-11 15:15:15 +02:00
added (optional) iter param dump
This commit is contained in:
parent
c9709cf4da
commit
db895209ca
1 changed files with 13 additions and 3 deletions
|
|
@ -4,8 +4,7 @@ import scipy.sparse
|
|||
from optimization import Optimizer
|
||||
from scipy import linalg, optimize
|
||||
import pylab as plt
|
||||
import copy
|
||||
import sys
|
||||
import copy, sys, pickle
|
||||
|
||||
class opt_SGD(Optimizer):
|
||||
"""
|
||||
|
|
@ -19,7 +18,7 @@ class opt_SGD(Optimizer):
|
|||
|
||||
"""
|
||||
|
||||
def __init__(self, start, iterations = 10, learning_rate = 1e-4, momentum = 0.9, model = None, messages = False, batch_size = 1, self_paced = False, center = True, **kwargs):
|
||||
def __init__(self, start, iterations = 10, learning_rate = 1e-4, momentum = 0.9, model = None, messages = False, batch_size = 1, self_paced = False, center = True, iteration_file = None, **kwargs):
|
||||
self.opt_name = "Stochastic Gradient Descent"
|
||||
|
||||
self.model = model
|
||||
|
|
@ -33,6 +32,7 @@ class opt_SGD(Optimizer):
|
|||
self.self_paced = self_paced
|
||||
self.center = center
|
||||
self.param_traces = [('noise',[])]
|
||||
self.iteration_file = iteration_file
|
||||
# if len([p for p in self.model.kern.parts if p.name == 'bias']) == 1:
|
||||
# self.param_traces.append(('bias',[]))
|
||||
# if len([p for p in self.model.kern.parts if p.name == 'linear']) == 1:
|
||||
|
|
@ -271,8 +271,18 @@ class opt_SGD(Optimizer):
|
|||
|
||||
# self.model.Youter = np.dot(Y, Y.T)
|
||||
self.trace.append(self.f_opt)
|
||||
if self.iteration_file is not None:
|
||||
f = open(self.iteration_file + "iteration%d.pickle" % it, 'w')
|
||||
data = [self.x_opt, self.fopt_trace, self.param_traces]
|
||||
pickle.dump(data, f)
|
||||
f.close()
|
||||
|
||||
if self.messages != 0:
|
||||
sys.stdout.write('\r' + ' '*len(status)*2 + ' \r')
|
||||
status = "SGD Iteration: {0: 3d}/{1: 3d} f: {2: 2.3f}\n".format(it+1, self.iterations, self.f_opt)
|
||||
sys.stdout.write(status)
|
||||
sys.stdout.flush()
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue