[inference] minibatch inference needed polishing

This commit is contained in:
mzwiessele 2015-06-29 10:19:36 +02:00
parent 61ef9912cb
commit f46b23d16f
2 changed files with 38 additions and 23 deletions

View file

@ -5,6 +5,10 @@ class StochasticStorage(object):
'''
This is a container for holding the stochastic parameters,
such as subset indices or step length and so on.
self.d has to be a list of lists:
[dimension indices, nan indices for those dimensions]
so that the minibatches can be used as efficiently as possible.10
'''
def __init__(self, model):
"""
@ -28,9 +32,23 @@ class SparseGPMissing(StochasticStorage):
"""
Here we want to loop over all dimensions everytime.
Thus, we can just make sure the loop goes over self.d every
time.
time. We will try to get batches which look the same together
which speeds up calculations significantly.
"""
self.d = range(model.Y_normalized.shape[1])
import numpy as np
self.Y = model.Y_normalized
bdict = {}
for d in range(self.Y.shape[1]):
inan = np.isnan(self.Y[:, d])
arr_str = np.array2string(inan,
np.inf, 0,
True, '',
formatter={'bool':lambda x: '1' if x else '0'})
try:
bdict[arr_str][0].append(d)
except:
bdict[arr_str] = [[d], ~inan]
self.d = bdict.values()
class SparseGPStochastics(StochasticStorage):
"""
@ -40,16 +58,29 @@ class SparseGPStochastics(StochasticStorage):
def __init__(self, model, batchsize=1):
self.batchsize = batchsize
self.output_dim = model.Y.shape[1]
self.Y = model.Y_normalized
self.reset()
self.do_stochastics()
def do_stochastics(self):
if self.batchsize == 1:
self.current_dim = (self.current_dim+1)%self.output_dim
self.d = [self.current_dim]
self.d = [[[self.current_dim], np.isnan(self.Y[:, self.d])]]
else:
import numpy as np
self.d = np.random.choice(self.output_dim, size=self.batchsize, replace=False)
bdict = {}
for d in self.d:
inan = np.isnan(self.Y[:, d])
arr_str = np.array2string(inan,
np.inf, 0,
True, '',
formatter={'bool':lambda x: '1' if x else '0'})
try:
bdict[arr_str][0].append(d)
except:
bdict[arr_str] = [[d], ~inan]
self.d = bdict.values()
def reset(self):
self.current_dim = -1