mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-06-08 15:05:15 +02:00
added centering and fixed serious bug
This commit is contained in:
parent
baa06181e8
commit
7b1e1eb18b
1 changed files with 20 additions and 10 deletions
|
|
@ -18,7 +18,7 @@ class opt_SGD(Optimizer):
|
|||
|
||||
"""
|
||||
|
||||
def __init__(self, start, iterations = 10, learning_rate = 1e-4, momentum = 0.9, model = None, messages = False, batch_size = 1, self_paced = False, **kwargs):
|
||||
def __init__(self, start, iterations = 10, learning_rate = 1e-4, momentum = 0.9, model = None, messages = False, batch_size = 1, self_paced = False, center = True, **kwargs):
|
||||
self.opt_name = "Stochastic Gradient Descent"
|
||||
|
||||
self.model = model
|
||||
|
|
@ -30,6 +30,7 @@ class opt_SGD(Optimizer):
|
|||
self.messages = messages
|
||||
self.batch_size = batch_size
|
||||
self.self_paced = self_paced
|
||||
self.center = center
|
||||
|
||||
num_params = len(self.model._get_params())
|
||||
if isinstance(self.learning_rate, float):
|
||||
|
|
@ -48,7 +49,7 @@ class opt_SGD(Optimizer):
|
|||
return status
|
||||
|
||||
def non_null_samples(self, data):
|
||||
return np.isnan(data).sum(axis=1) == 0
|
||||
return (np.isnan(data).sum(axis=1) == 0)
|
||||
|
||||
def check_for_missing(self, data):
|
||||
return np.isnan(data).sum() > 0
|
||||
|
|
@ -138,8 +139,9 @@ class opt_SGD(Optimizer):
|
|||
|
||||
j = self.subset_parameter_vector(self.x_opt, samples, shapes)
|
||||
self.model.X = X[samples]
|
||||
# self.model.Y -= self.model.Y.mean() # <----------------- WARNING!!!!
|
||||
# self.model.Y /= self.model.Y.std()
|
||||
if self.center:
|
||||
self.model.Y -= self.model.Y.mean()
|
||||
self.model.Y /= self.model.Y.std()
|
||||
model_name = self.model.__class__.__name__
|
||||
|
||||
if model_name == 'Bayesian_GPLVM':
|
||||
|
|
@ -161,7 +163,7 @@ class opt_SGD(Optimizer):
|
|||
return f, step, self.model.N
|
||||
|
||||
def opt(self, f_fp=None, f=None, fp=None):
|
||||
self.x_opt = self.model._get_params()
|
||||
self.x_opt = self.model._get_params_transformed()
|
||||
X, Y = self.model.X.copy(), self.model.Y.copy()
|
||||
N, Q = self.model.X.shape
|
||||
D = self.model.Y.shape[1]
|
||||
|
|
@ -172,9 +174,14 @@ class opt_SGD(Optimizer):
|
|||
missing_data = self.check_for_missing(self.model.Y)
|
||||
self.model.Youter = None # this is probably not very efficient
|
||||
self.model.YYT = None
|
||||
step = np.zeros_like(self.model._get_params())
|
||||
num_params = self.model._get_params()
|
||||
|
||||
if self.center:
|
||||
print "WARNING: centering the data"
|
||||
|
||||
for it in range(self.iterations):
|
||||
step = np.zeros_like(num_params)
|
||||
|
||||
if it == 0 or self.self_paced is False:
|
||||
features = np.random.permutation(Y.shape[1])
|
||||
else:
|
||||
|
|
@ -189,12 +196,9 @@ class opt_SGD(Optimizer):
|
|||
for j in features:
|
||||
count += 1
|
||||
self.model.D = len(j)
|
||||
self.model.Y = Y[:, j:j+1]
|
||||
|
||||
self.model.Y = Y[:, j]
|
||||
# self.model.trYYT = np.sum(np.square(self.model.Y))
|
||||
if missing_data or sparse_matrix:
|
||||
# if self.model.Y.std() == 0.0 or self.model.Y.shape[0] == 0: <--- not sure about this
|
||||
# continue
|
||||
shapes = self.get_param_shapes(N, Q)
|
||||
f, step, Nj = self.step_with_missing_data(f_fp, X, step, shapes, sparse_matrix)
|
||||
else:
|
||||
|
|
@ -206,6 +210,12 @@ class opt_SGD(Optimizer):
|
|||
|
||||
if self.messages == 2:
|
||||
status = "evaluating {feature: 5d}/{tot: 5d} \t f: {f: 2.3f} \t non-missing: {nm: 4d}\r".format(feature = count, tot = len(features), f = f, nm = Nj)
|
||||
|
||||
# TODO: remove this, it's only for debugging
|
||||
if self.model.__class__.__name__ == 'Bayesian_GPLVM':
|
||||
beta = np.exp(self.x_opt)[-7]
|
||||
status = "evaluating {feature: 5d}/{tot: 5d} \t f: {f: 2.3f} \t non-missing: {nm: 4d} \t inv_bbeta: {beta: 1.5f}\r".format(feature = count, tot = len(features), f = f, nm = Nj, beta = 1./beta)
|
||||
|
||||
sys.stdout.write(status)
|
||||
sys.stdout.flush()
|
||||
last_printed_count = count
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue