mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-21 14:05:14 +02:00
made BGPLVM oil flow demo work, added ARD weights plot
This commit is contained in:
parent
eeb965d136
commit
55ad96f38b
3 changed files with 44 additions and 27 deletions
|
|
@ -33,12 +33,12 @@ class opt_SGD(Optimizer):
|
|||
self.self_paced = self_paced
|
||||
self.center = center
|
||||
self.param_traces = [('noise',[])]
|
||||
if len([p for p in self.model.kern.parts if p.name == 'bias']) == 1:
|
||||
self.param_traces.append(('bias',[]))
|
||||
if len([p for p in self.model.kern.parts if p.name == 'linear']) == 1:
|
||||
self.param_traces.append(('linear',[]))
|
||||
if len([p for p in self.model.kern.parts if p.name == 'rbf']) == 1:
|
||||
self.param_traces.append(('rbf_var',[]))
|
||||
# if len([p for p in self.model.kern.parts if p.name == 'bias']) == 1:
|
||||
# self.param_traces.append(('bias',[]))
|
||||
# if len([p for p in self.model.kern.parts if p.name == 'linear']) == 1:
|
||||
# self.param_traces.append(('linear',[]))
|
||||
# if len([p for p in self.model.kern.parts if p.name == 'rbf']) == 1:
|
||||
# self.param_traces.append(('rbf_var',[]))
|
||||
|
||||
self.param_traces = dict(self.param_traces)
|
||||
self.fopt_trace = []
|
||||
|
|
@ -156,29 +156,23 @@ class opt_SGD(Optimizer):
|
|||
Y = self.model.likelihood.Y
|
||||
samples = self.non_null_samples(self.model.likelihood.Y)
|
||||
self.model.N = samples.sum()
|
||||
|
||||
if self.center:
|
||||
self.model.likelihood._mean = Y[samples].mean()
|
||||
self.model.likelihood._std = Y[samples].std()
|
||||
|
||||
self.model.likelihood.set_data(Y[samples])
|
||||
Y = Y[samples]
|
||||
else:
|
||||
samples = self.model.likelihood.Y.nonzero()[0]
|
||||
self.model.N = len(samples)
|
||||
Y = np.asarray(self.model.likelihood.Y[samples].todense(), dtype = np.float64)
|
||||
if self.center:
|
||||
self.model.likelihood._mean = Y.mean()
|
||||
self.model.likelihood._std = Y.std()
|
||||
|
||||
self.model.likelihood.set_data(Y)
|
||||
if self.model.N == 0 or Y.std() == 0.0:
|
||||
return 0, step, self.model.N
|
||||
|
||||
# FIXME: get rid of self.center, everything should be centered by default
|
||||
self.model.likelihood._mean = Y.mean()
|
||||
self.model.likelihood._std = Y.std()
|
||||
self.model.likelihood.set_data(Y)
|
||||
|
||||
# self.model.likelihood.N = self.model.N
|
||||
j = self.subset_parameter_vector(self.x_opt, samples, shapes)
|
||||
self.model.X = X[samples]
|
||||
|
||||
if self.model.N == 0 or self.model.likelihood.Y.std() == 0.0:
|
||||
return 0, step, self.model.N
|
||||
|
||||
# if self.center:
|
||||
# self.model.likelihood.Y -= self.model.likelihood.Y.mean()
|
||||
# self.model.likelihood.Y /= self.model.likelihood.Y.std()
|
||||
|
|
@ -186,7 +180,8 @@ class opt_SGD(Optimizer):
|
|||
model_name = self.model.__class__.__name__
|
||||
|
||||
if model_name == 'Bayesian_GPLVM':
|
||||
self.model.likelihood.trYYT = np.sum(np.square(self.model.likelihood.Y))
|
||||
self.model.likelihood.YYT = np.dot(self.model.likelihood.Y, self.model.likelihood.Y.T)
|
||||
self.model.likelihood.trYYT = np.trace(self.model.likelihood.YYT)
|
||||
|
||||
b, p = self.shift_constraints(j)
|
||||
f, fp = f_fp(self.x_opt[j])
|
||||
|
|
@ -196,6 +191,7 @@ class opt_SGD(Optimizer):
|
|||
|
||||
step[j] = self.momentum * step[j] + self.learning_rate[j] * fp
|
||||
self.x_opt[j] -= step[j]
|
||||
|
||||
self.restore_constraints(b, p)
|
||||
|
||||
return f, step, self.model.N
|
||||
|
|
@ -256,14 +252,13 @@ class opt_SGD(Optimizer):
|
|||
sys.stdout.write(status)
|
||||
sys.stdout.flush()
|
||||
last_printed_count = count
|
||||
|
||||
self.param_traces['noise'].append(noise)
|
||||
NLL.append(f)
|
||||
|
||||
self.fopt_trace.append(f)
|
||||
for k in self.param_traces.keys():
|
||||
self.param_traces[k].append(self.model.get(k)[0])
|
||||
|
||||
|
||||
# for k in self.param_traces.keys():
|
||||
# self.param_traces[k].append(self.model.get(k)[0])
|
||||
|
||||
# should really be a sum(), but earlier samples in the iteration will have a very crappy ll
|
||||
self.f_opt = np.mean(NLL)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue