From f916717aad0e0b97126327537e0eb397579cfb54 Mon Sep 17 00:00:00 2001 From: Max Zwiessele Date: Wed, 29 Oct 2014 08:23:57 +0000 Subject: [PATCH] [stochastics] updated some stuff on the stochastics --- GPy/inference/optimization/stochastics.py | 24 ++++++++++++++--------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/GPy/inference/optimization/stochastics.py b/GPy/inference/optimization/stochastics.py index c9804e98..f19c3c2e 100644 --- a/GPy/inference/optimization/stochastics.py +++ b/GPy/inference/optimization/stochastics.py @@ -21,6 +21,11 @@ class StochasticStorage(object): """ pass + def reset(self): + """ + Reset the state of this stochastics generator. + """ + class SparseGPMissing(StochasticStorage): def __init__(self, model, batchsize=1): """ @@ -36,18 +41,19 @@ class SparseGPStochastics(StochasticStorage): and the indices corresponding to those """ def __init__(self, model, batchsize=1): - import itertools self.batchsize = batchsize - if self.batchsize == 1: - self.dimensions = itertools.cycle(range(model.Y_normalized.shape[1])) - else: - import numpy as np - self.dimensions = lambda: np.random.choice(model.Y_normalized.shape[1], size=batchsize, replace=False) - self.d = None + self.output_dim = model.Y.shape[1] + self.reset() self.do_stochastics() def do_stochastics(self): if self.batchsize == 1: - self.d = [self.dimensions.next()] + self.current_dim = (self.current_dim+1)%self.output_dim + self.d = [self.current_dim] else: - self.d = self.dimensions() \ No newline at end of file + import numpy as np + self.d = np.random.choice(self.output_dim, size=self.batchsize, replace=False) + + def reset(self): + self.current_dim = -1 + self.d = None \ No newline at end of file