[inference] minibatch inference needed polishing

This commit is contained in:
mzwiessele 2015-06-29 10:19:36 +02:00
parent 61ef9912cb
commit f46b23d16f
2 changed files with 38 additions and 23 deletions

View file

@ -63,33 +63,18 @@ class SparseGPMiniBatch(SparseGP):
if stochastic and missing_data:
self.missing_data = True
self.ninan = ~np.isnan(Y)
self.stochastics = SparseGPStochastics(self, batchsize)
elif stochastic and not missing_data:
self.missing_data = False
self.stochastics = SparseGPStochastics(self, batchsize)
elif missing_data:
self.missing_data = True
self.ninan = ~np.isnan(Y)
self.stochastics = SparseGPMissing(self)
else:
self.stochastics = False
logger.info("Adding Z as parameter")
self.link_parameter(self.Z, index=0)
if self.missing_data:
self.Ylist = []
overall = self.Y_normalized.shape[1]
m_f = lambda i: "Precomputing Y for missing data: {: >7.2%}".format(float(i+1)/overall)
message = m_f(-1)
print(message, end=' ')
for d in range(overall):
self.Ylist.append(self.Y_normalized[self.ninan[:, d], d][:, None])
print(' '*(len(message)+1) + '\r', end=' ')
message = m_f(d)
print(message, end=' ')
print('')
self.posterior = None
def has_uncertain_inputs(self):
@ -245,8 +230,7 @@ class SparseGPMiniBatch(SparseGP):
message = m_f(-1)
print(message, end=' ')
for d in self.stochastics.d:
ninan = self.ninan[:, d]
for d, ninan in self.stochastics.d:
if not self.stochastics:
print(' '*(len(message)) + '\r', end=' ')
@ -257,7 +241,7 @@ class SparseGPMiniBatch(SparseGP):
grad_dict, current_values, value_indices = self._inner_parameters_changed(
self.kern, self.X[ninan],
self.Z, self.likelihood,
self.Ylist[d], self.Y_metadata,
self.Y_normalized[ninan][:, d], self.Y_metadata,
Lm, dL_dKmm,
subset_indices=dict(outputs=d, samples=ninan))
@ -266,8 +250,8 @@ class SparseGPMiniBatch(SparseGP):
Lm = posterior.K_chol
dL_dKmm = grad_dict['dL_dKmm']
woodbury_inv[:, :, d] = posterior.woodbury_inv
woodbury_vector[:, d:d+1] = posterior.woodbury_vector
woodbury_inv[:, :, d] = posterior.woodbury_inv[:,:,None]
woodbury_vector[:, d] = posterior.woodbury_vector
self._log_marginal_likelihood += log_marginal_likelihood
if not self.stochastics:
print('')