[stochastics] updated some stuff on the stochastics

This commit is contained in:
Max Zwiessele 2014-10-29 08:23:57 +00:00
parent af40ef8cfb
commit f916717aad

View file

@ -21,6 +21,11 @@ class StochasticStorage(object):
"""
pass
def reset(self):
"""
Reset the state of this stochastics generator.
"""
class SparseGPMissing(StochasticStorage):
def __init__(self, model, batchsize=1):
"""
@ -36,18 +41,19 @@ class SparseGPStochastics(StochasticStorage):
and the indices corresponding to those
"""
def __init__(self, model, batchsize=1):
import itertools
self.batchsize = batchsize
if self.batchsize == 1:
self.dimensions = itertools.cycle(range(model.Y_normalized.shape[1]))
else:
import numpy as np
self.dimensions = lambda: np.random.choice(model.Y_normalized.shape[1], size=batchsize, replace=False)
self.d = None
self.output_dim = model.Y.shape[1]
self.reset()
self.do_stochastics()
def do_stochastics(self):
if self.batchsize == 1:
self.d = [self.dimensions.next()]
self.current_dim = (self.current_dim+1)%self.output_dim
self.d = [self.current_dim]
else:
self.d = self.dimensions()
import numpy as np
self.d = np.random.choice(self.output_dim, size=self.batchsize, replace=False)
def reset(self):
self.current_dim = -1
self.d = None