mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-10 20:42:39 +02:00
[stochastics] updated some stuff on the stochastics
This commit is contained in:
parent
af40ef8cfb
commit
f916717aad
1 changed files with 15 additions and 9 deletions
|
|
@ -21,6 +21,11 @@ class StochasticStorage(object):
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
"""
|
||||||
|
Reset the state of this stochastics generator.
|
||||||
|
"""
|
||||||
|
|
||||||
class SparseGPMissing(StochasticStorage):
|
class SparseGPMissing(StochasticStorage):
|
||||||
def __init__(self, model, batchsize=1):
|
def __init__(self, model, batchsize=1):
|
||||||
"""
|
"""
|
||||||
|
|
@ -36,18 +41,19 @@ class SparseGPStochastics(StochasticStorage):
|
||||||
and the indices corresponding to those
|
and the indices corresponding to those
|
||||||
"""
|
"""
|
||||||
def __init__(self, model, batchsize=1):
|
def __init__(self, model, batchsize=1):
|
||||||
import itertools
|
|
||||||
self.batchsize = batchsize
|
self.batchsize = batchsize
|
||||||
if self.batchsize == 1:
|
self.output_dim = model.Y.shape[1]
|
||||||
self.dimensions = itertools.cycle(range(model.Y_normalized.shape[1]))
|
self.reset()
|
||||||
else:
|
|
||||||
import numpy as np
|
|
||||||
self.dimensions = lambda: np.random.choice(model.Y_normalized.shape[1], size=batchsize, replace=False)
|
|
||||||
self.d = None
|
|
||||||
self.do_stochastics()
|
self.do_stochastics()
|
||||||
|
|
||||||
def do_stochastics(self):
|
def do_stochastics(self):
|
||||||
if self.batchsize == 1:
|
if self.batchsize == 1:
|
||||||
self.d = [self.dimensions.next()]
|
self.current_dim = (self.current_dim+1)%self.output_dim
|
||||||
|
self.d = [self.current_dim]
|
||||||
else:
|
else:
|
||||||
self.d = self.dimensions()
|
import numpy as np
|
||||||
|
self.d = np.random.choice(self.output_dim, size=self.batchsize, replace=False)
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self.current_dim = -1
|
||||||
|
self.d = None
|
||||||
Loading…
Add table
Add a link
Reference in a new issue