mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-10 04:22:38 +02:00
[stochastics] updated some stuff on the stochastics
This commit is contained in:
parent
af40ef8cfb
commit
f916717aad
1 changed files with 15 additions and 9 deletions
|
|
@ -21,6 +21,11 @@ class StochasticStorage(object):
|
|||
"""
|
||||
pass
|
||||
|
||||
def reset(self):
|
||||
"""
|
||||
Reset the state of this stochastics generator.
|
||||
"""
|
||||
|
||||
class SparseGPMissing(StochasticStorage):
|
||||
def __init__(self, model, batchsize=1):
|
||||
"""
|
||||
|
|
@ -36,18 +41,19 @@ class SparseGPStochastics(StochasticStorage):
|
|||
and the indices corresponding to those
|
||||
"""
|
||||
def __init__(self, model, batchsize=1):
|
||||
import itertools
|
||||
self.batchsize = batchsize
|
||||
if self.batchsize == 1:
|
||||
self.dimensions = itertools.cycle(range(model.Y_normalized.shape[1]))
|
||||
else:
|
||||
import numpy as np
|
||||
self.dimensions = lambda: np.random.choice(model.Y_normalized.shape[1], size=batchsize, replace=False)
|
||||
self.d = None
|
||||
self.output_dim = model.Y.shape[1]
|
||||
self.reset()
|
||||
self.do_stochastics()
|
||||
|
||||
def do_stochastics(self):
|
||||
if self.batchsize == 1:
|
||||
self.d = [self.dimensions.next()]
|
||||
self.current_dim = (self.current_dim+1)%self.output_dim
|
||||
self.d = [self.current_dim]
|
||||
else:
|
||||
self.d = self.dimensions()
|
||||
import numpy as np
|
||||
self.d = np.random.choice(self.output_dim, size=self.batchsize, replace=False)
|
||||
|
||||
def reset(self):
|
||||
self.current_dim = -1
|
||||
self.d = None
|
||||
Loading…
Add table
Add a link
Reference in a new issue